From ef1bffcac314bdccebf888dff1067125cd68ee47 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana?= Date: Mon, 7 Aug 2023 13:50:11 +0200 Subject: [PATCH] Fixed normal classifiers --- .vscode/launch.json | 3 ++- src/BayesNet/AODELd.h | 2 +- src/BayesNet/Classifier.cc | 25 +++++++++++++++++-------- src/BayesNet/Classifier.h | 2 +- src/BayesNet/Ensemble.h | 2 +- 5 files changed, 22 insertions(+), 12 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 7241ae2..8eeff68 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -25,7 +25,8 @@ "program": "${workspaceFolder}/build/src/Platform/main", "args": [ "-m", - "AODELd", + "AODE", + "--discretize", "-p", "/Users/rmontanana/Code/discretizbench/datasets", "--stratified", diff --git a/src/BayesNet/AODELd.h b/src/BayesNet/AODELd.h index c8db41d..74b74b1 100644 --- a/src/BayesNet/AODELd.h +++ b/src/BayesNet/AODELd.h @@ -8,7 +8,7 @@ namespace bayesnet { using namespace std; class AODELd : public Ensemble, public Proposal { private: - void trainModel(); + void trainModel() override; void buildModel() override; public: AODELd(); diff --git a/src/BayesNet/Classifier.cc b/src/BayesNet/Classifier.cc index c84ebe6..c0f1895 100644 --- a/src/BayesNet/Classifier.cc +++ b/src/BayesNet/Classifier.cc @@ -10,26 +10,35 @@ namespace bayesnet { this->features = features; this->className = className; this->states = states; + m = dataset.size(1); + n = dataset.size(0) - 1; checkFitParameters(); auto n_classes = states[className].size(); metrics = Metrics(dataset, features, className, n_classes); model.initialize(); buildModel(); - m = dataset.size(1); - n = dataset.size(0); trainModel(); fitted = true; return *this; } + + void Classifier::buildDataset(Tensor& ytmp) + { + try { + auto yresized = torch::transpose(ytmp.view({ ytmp.size(0), 1 }), 0, 1); + dataset = torch::cat({ dataset, yresized }, 0); + } + catch (const std::exception& e) { + std::cerr << e.what() << '\n'; + cout << "X dimensions: " << dataset.sizes() << "\n"; + cout << "y dimensions: " << ytmp.sizes() << "\n"; + exit(1); + } + } void Classifier::trainModel() { model.fit(dataset, features, className); } - void Classifier::buildDataset(Tensor& ytmp) - { - ytmp = torch::transpose(ytmp.view({ ytmp.size(0), 1 }), 0, 1); - dataset = torch::cat({ dataset, ytmp }, 0); - } // X is nxm where n is the number of features and m the number of samples Classifier& Classifier::fit(torch::Tensor& X, torch::Tensor& y, vector& features, string className, map>& states) { @@ -56,7 +65,7 @@ namespace bayesnet { void Classifier::checkFitParameters() { if (n != features.size()) { - throw invalid_argument("X and features must have the same number of features"); + throw invalid_argument("X " + to_string(n) + " and features " + to_string(features.size()) + " must have the same number of features"); } if (states.find(className) == states.end()) { throw invalid_argument("className not found in states"); diff --git a/src/BayesNet/Classifier.h b/src/BayesNet/Classifier.h index d492d81..7e88bd3 100644 --- a/src/BayesNet/Classifier.h +++ b/src/BayesNet/Classifier.h @@ -23,7 +23,7 @@ namespace bayesnet { map> states; void checkFitParameters(); virtual void buildModel() = 0; - void trainModel(); + virtual void trainModel(); public: Classifier(Network model); virtual ~Classifier() = default; diff --git a/src/BayesNet/Ensemble.h b/src/BayesNet/Ensemble.h index 8efa0b7..f36d1ad 100644 --- a/src/BayesNet/Ensemble.h +++ b/src/BayesNet/Ensemble.h @@ -14,7 +14,7 @@ namespace bayesnet { protected: unsigned n_models; vector> models; - void trainModel(); + void trainModel() override; vector voting(Tensor& y_pred); public: Ensemble();