boostAode #5

Merged
rmontanana merged 21 commits from boostAode into main 2023-08-20 09:02:08 +00:00
2 changed files with 4 additions and 5 deletions
Showing only changes of commit a3e665eed6 - Show all commits

View File

@ -43,7 +43,7 @@ namespace bayesnet {
{
dataset = X;
buildDataset(y);
const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kFloat);
const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kDouble);
return build(features, className, states, weights);
}
// X is nxm where n is the number of features and m the number of samples
@ -55,13 +55,13 @@ namespace bayesnet {
}
auto ytmp = torch::tensor(y, kInt32);
buildDataset(ytmp);
const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kFloat);
const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kDouble);
return build(features, className, states, weights);
}
Classifier& Classifier::fit(torch::Tensor& dataset, vector<string>& features, string className, map<string, vector<int>>& states)
{
this->dataset = dataset;
const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kFloat);
const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kDouble);
return build(features, className, states, weights);
}
Classifier& Classifier::fit(torch::Tensor& dataset, vector<string>& features, string className, map<string, vector<int>>& states, const torch::Tensor& weights)

View File

@ -65,8 +65,7 @@ namespace bayesnet {
//Update new states of the feature/node
states[pFeatures[index]] = xStates;
}
// TODO weights can't be ones
const torch::Tensor weights = torch::ones({ pDataset.size(1) }, torch::kFloat);
const torch::Tensor weights = torch::full({ pDataset.size(1) }, 1.0 / pDataset.size(1), torch::kDouble);
model.fit(pDataset, weights, pFeatures, pClassName, states);
}
return states;