From a3e665eed6a2d6afd6d9957cc4895aa6a0cd281d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Wed, 16 Aug 2023 12:46:09 +0200 Subject: [PATCH] make weights double --- src/BayesNet/Classifier.cc | 6 +++--- src/BayesNet/Proposal.cc | 3 +-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/BayesNet/Classifier.cc b/src/BayesNet/Classifier.cc index 4d4ab08..ff25657 100644 --- a/src/BayesNet/Classifier.cc +++ b/src/BayesNet/Classifier.cc @@ -43,7 +43,7 @@ namespace bayesnet { { dataset = X; buildDataset(y); - const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kFloat); + const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kDouble); return build(features, className, states, weights); } // X is nxm where n is the number of features and m the number of samples @@ -55,13 +55,13 @@ namespace bayesnet { } auto ytmp = torch::tensor(y, kInt32); buildDataset(ytmp); - const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kFloat); + const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kDouble); return build(features, className, states, weights); } Classifier& Classifier::fit(torch::Tensor& dataset, vector& features, string className, map>& states) { this->dataset = dataset; - const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kFloat); + const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kDouble); return build(features, className, states, weights); } Classifier& Classifier::fit(torch::Tensor& dataset, vector& features, string className, map>& states, const torch::Tensor& weights) diff --git a/src/BayesNet/Proposal.cc b/src/BayesNet/Proposal.cc index 87767b5..c410289 100644 --- a/src/BayesNet/Proposal.cc +++ b/src/BayesNet/Proposal.cc @@ -65,8 +65,7 @@ namespace bayesnet { //Update new states of the feature/node states[pFeatures[index]] = xStates; } - // TODO weights can't be ones - const torch::Tensor weights = torch::ones({ pDataset.size(1) }, torch::kFloat); + const torch::Tensor weights = torch::full({ pDataset.size(1) }, 1.0 / pDataset.size(1), torch::kDouble); model.fit(pDataset, weights, pFeatures, pClassName, states); } return states;