make weights double

This commit is contained in:
Ricardo Montañana Gómez 2023-08-16 12:46:09 +02:00
parent 918a7b4180
commit a3e665eed6
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
2 changed files with 4 additions and 5 deletions

View File

@ -43,7 +43,7 @@ namespace bayesnet {
{
dataset = X;
buildDataset(y);
const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kFloat);
const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kDouble);
return build(features, className, states, weights);
}
// X is nxm where n is the number of features and m the number of samples
@ -55,13 +55,13 @@ namespace bayesnet {
}
auto ytmp = torch::tensor(y, kInt32);
buildDataset(ytmp);
const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kFloat);
const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kDouble);
return build(features, className, states, weights);
}
Classifier& Classifier::fit(torch::Tensor& dataset, vector<string>& features, string className, map<string, vector<int>>& states)
{
this->dataset = dataset;
const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kFloat);
const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kDouble);
return build(features, className, states, weights);
}
Classifier& Classifier::fit(torch::Tensor& dataset, vector<string>& features, string className, map<string, vector<int>>& states, const torch::Tensor& weights)

View File

@ -65,8 +65,7 @@ namespace bayesnet {
//Update new states of the feature/node
states[pFeatures[index]] = xStates;
}
// TODO weights can't be ones
const torch::Tensor weights = torch::ones({ pDataset.size(1) }, torch::kFloat);
const torch::Tensor weights = torch::full({ pDataset.size(1) }, 1.0 / pDataset.size(1), torch::kDouble);
model.fit(pDataset, weights, pFeatures, pClassName, states);
}
return states;