From fa612c531e1a8b9957ac4ef4bca475e0699a8e83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Tue, 15 Aug 2023 15:59:56 +0200 Subject: [PATCH] Complete Adding weights to Models --- src/BayesNet/BaseClassifier.h | 1 + src/BayesNet/Classifier.cc | 19 ++++++++++++------- src/BayesNet/Classifier.h | 3 ++- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/src/BayesNet/BaseClassifier.h b/src/BayesNet/BaseClassifier.h index 527b5c5..5f1cbaa 100644 --- a/src/BayesNet/BaseClassifier.h +++ b/src/BayesNet/BaseClassifier.h @@ -13,6 +13,7 @@ namespace bayesnet { // X is nxm tensor, y is nx1 tensor virtual BaseClassifier& fit(torch::Tensor& X, torch::Tensor& y, vector& features, string className, map>& states) = 0; virtual BaseClassifier& fit(torch::Tensor& dataset, vector& features, string className, map>& states) = 0; + virtual BaseClassifier& fit(torch::Tensor& dataset, vector& features, string className, map>& states, const torch::Tensor& weights) = 0; virtual ~BaseClassifier() = default; torch::Tensor virtual predict(torch::Tensor& X) = 0; vector virtual predict(vector>& X) = 0; diff --git a/src/BayesNet/Classifier.cc b/src/BayesNet/Classifier.cc index 1fab813..154f1df 100644 --- a/src/BayesNet/Classifier.cc +++ b/src/BayesNet/Classifier.cc @@ -5,7 +5,7 @@ namespace bayesnet { using namespace torch; Classifier::Classifier(Network model) : model(model), m(0), n(0), metrics(Metrics()), fitted(false) {} - Classifier& Classifier::build(vector& features, string className, map>& states) + Classifier& Classifier::build(vector& features, string className, map>& states, const torch::Tensor& weights) { this->features = features; this->className = className; @@ -16,14 +16,11 @@ namespace bayesnet { auto n_classes = states[className].size(); metrics = Metrics(dataset, features, className, n_classes); model.initialize(); - // TODO weights can't be ones - const torch::Tensor weights = torch::ones({ m }, torch::kFloat); buildModel(weights); trainModel(weights); fitted = true; return *this; } - void Classifier::buildDataset(Tensor& ytmp) { try { @@ -46,7 +43,8 @@ namespace bayesnet { { dataset = X; buildDataset(y); - return build(features, className, states); + const torch::Tensor weights = torch::ones({ dataset.size(1) }, torch::kFloat); + return build(features, className, states, weights); } // X is nxm where n is the number of features and m the number of samples Classifier& Classifier::fit(vector>& X, vector& y, vector& features, string className, map>& states) @@ -57,12 +55,19 @@ namespace bayesnet { } auto ytmp = torch::tensor(y, kInt32); buildDataset(ytmp); - return build(features, className, states); + const torch::Tensor weights = torch::ones({ dataset.size(1) }, torch::kFloat); + return build(features, className, states, weights); } Classifier& Classifier::fit(torch::Tensor& dataset, vector& features, string className, map>& states) { this->dataset = dataset; - return build(features, className, states); + const torch::Tensor weights = torch::ones({ dataset.size(1) }, torch::kFloat); + return build(features, className, states, weights); + } + Classifier& Classifier::fit(torch::Tensor& dataset, vector& features, string className, map>& states, const torch::Tensor& weights) + { + this->dataset = dataset; + return build(features, className, states, weights); } void Classifier::checkFitParameters() { diff --git a/src/BayesNet/Classifier.h b/src/BayesNet/Classifier.h index 3c18295..0c2940b 100644 --- a/src/BayesNet/Classifier.h +++ b/src/BayesNet/Classifier.h @@ -11,7 +11,7 @@ namespace bayesnet { class Classifier : public BaseClassifier { private: void buildDataset(torch::Tensor& y); - Classifier& build(vector& features, string className, map>& states); + Classifier& build(vector& features, string className, map>& states, const torch::Tensor& weights); protected: bool fitted; int m, n; // m: number of samples, n: number of features @@ -30,6 +30,7 @@ namespace bayesnet { Classifier& fit(vector>& X, vector& y, vector& features, string className, map>& states) override; Classifier& fit(torch::Tensor& X, torch::Tensor& y, vector& features, string className, map>& states) override; Classifier& fit(torch::Tensor& dataset, vector& features, string className, map>& states) override; + Classifier& fit(torch::Tensor& dataset, vector& features, string className, map>& states, const torch::Tensor& weights) override; void addNodes(); int getNumberOfNodes() const override; int getNumberOfEdges() const override;