make weights double

This commit is contained in:
Ricardo Montañana Gómez 2023-08-16 12:46:09 +02:00
parent 918a7b4180
commit a3e665eed6
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
2 changed files with 4 additions and 5 deletions

View File

@ -43,7 +43,7 @@ namespace bayesnet {
{ {
dataset = X; dataset = X;
buildDataset(y); buildDataset(y);
const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kFloat); const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kDouble);
return build(features, className, states, weights); return build(features, className, states, weights);
} }
// X is nxm where n is the number of features and m the number of samples // X is nxm where n is the number of features and m the number of samples
@ -55,13 +55,13 @@ namespace bayesnet {
} }
auto ytmp = torch::tensor(y, kInt32); auto ytmp = torch::tensor(y, kInt32);
buildDataset(ytmp); buildDataset(ytmp);
const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kFloat); const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kDouble);
return build(features, className, states, weights); return build(features, className, states, weights);
} }
Classifier& Classifier::fit(torch::Tensor& dataset, vector<string>& features, string className, map<string, vector<int>>& states) Classifier& Classifier::fit(torch::Tensor& dataset, vector<string>& features, string className, map<string, vector<int>>& states)
{ {
this->dataset = dataset; this->dataset = dataset;
const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kFloat); const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kDouble);
return build(features, className, states, weights); return build(features, className, states, weights);
} }
Classifier& Classifier::fit(torch::Tensor& dataset, vector<string>& features, string className, map<string, vector<int>>& states, const torch::Tensor& weights) Classifier& Classifier::fit(torch::Tensor& dataset, vector<string>& features, string className, map<string, vector<int>>& states, const torch::Tensor& weights)

View File

@ -65,8 +65,7 @@ namespace bayesnet {
//Update new states of the feature/node //Update new states of the feature/node
states[pFeatures[index]] = xStates; states[pFeatures[index]] = xStates;
} }
// TODO weights can't be ones const torch::Tensor weights = torch::full({ pDataset.size(1) }, 1.0 / pDataset.size(1), torch::kDouble);
const torch::Tensor weights = torch::ones({ pDataset.size(1) }, torch::kFloat);
model.fit(pDataset, weights, pFeatures, pClassName, states); model.fit(pDataset, weights, pFeatures, pClassName, states);
} }
return states; return states;