From 24b68f9ae23e42a52c821f5c5ad12f1074bad7fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Tue, 15 Aug 2023 15:04:56 +0200 Subject: [PATCH] Add weigths as parameter --- .vscode/launch.json | 5 ++--- src/BayesNet/AODE.cc | 2 +- src/BayesNet/AODE.h | 2 +- src/BayesNet/AODELd.cc | 4 ++-- src/BayesNet/AODELd.h | 4 ++-- src/BayesNet/BaseClassifier.h | 2 +- src/BayesNet/BayesMetrics.cc | 3 ++- src/BayesNet/Classifier.cc | 9 +++++---- src/BayesNet/Classifier.h | 5 ++--- src/BayesNet/Ensemble.cc | 2 +- src/BayesNet/Ensemble.h | 2 +- src/BayesNet/KDB.cc | 2 +- src/BayesNet/KDB.h | 3 ++- src/BayesNet/Network.cc | 2 +- src/BayesNet/Proposal.cc | 1 + src/BayesNet/SPODE.cc | 2 +- src/BayesNet/SPODE.h | 2 +- src/BayesNet/TAN.cc | 2 +- src/BayesNet/TAN.h | 2 +- src/Platform/Report.h | 2 +- src/Platform/main.cc | 2 +- 21 files changed, 31 insertions(+), 29 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 828604b..e0da5f0 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -25,13 +25,12 @@ "program": "${workspaceFolder}/build/src/Platform/main", "args": [ "-m", - "SPODE", + "TANLd", "-p", "/Users/rmontanana/Code/discretizbench/datasets", "--stratified", - "--discretize", "-d", - "letter" + "iris" ], "cwd": "/Users/rmontanana/Code/discretizbench", }, diff --git a/src/BayesNet/AODE.cc b/src/BayesNet/AODE.cc index 7e6a95f..d90c495 100644 --- a/src/BayesNet/AODE.cc +++ b/src/BayesNet/AODE.cc @@ -2,7 +2,7 @@ namespace bayesnet { AODE::AODE() : Ensemble() {} - void AODE::buildModel() + void AODE::buildModel(const torch::Tensor& weights) { models.clear(); for (int i = 0; i < features.size(); ++i) { diff --git a/src/BayesNet/AODE.h b/src/BayesNet/AODE.h index 3d58851..00965f6 100644 --- a/src/BayesNet/AODE.h +++ b/src/BayesNet/AODE.h @@ -5,7 +5,7 @@ namespace bayesnet { class AODE : public Ensemble { protected: - void buildModel() override; + void buildModel(const torch::Tensor& weights) override; public: AODE(); virtual ~AODE() {}; diff --git a/src/BayesNet/AODELd.cc b/src/BayesNet/AODELd.cc index 9f36ed2..cc842be 100644 --- a/src/BayesNet/AODELd.cc +++ b/src/BayesNet/AODELd.cc @@ -19,7 +19,7 @@ namespace bayesnet { return *this; } - void AODELd::buildModel() + void AODELd::buildModel(const torch::Tensor& weights) { models.clear(); for (int i = 0; i < features.size(); ++i) { @@ -27,7 +27,7 @@ namespace bayesnet { } n_models = models.size(); } - void AODELd::trainModel() + void AODELd::trainModel(const torch::Tensor& weights) { for (const auto& model : models) { model->fit(Xf, y, features, className, states); diff --git a/src/BayesNet/AODELd.h b/src/BayesNet/AODELd.h index 14be0c4..aa67247 100644 --- a/src/BayesNet/AODELd.h +++ b/src/BayesNet/AODELd.h @@ -8,8 +8,8 @@ namespace bayesnet { using namespace std; class AODELd : public Ensemble, public Proposal { protected: - void trainModel() override; - void buildModel() override; + void trainModel(const torch::Tensor& weights) override; + void buildModel(const torch::Tensor& weights) override; public: AODELd(); AODELd& fit(torch::Tensor& X_, torch::Tensor& y_, vector& features_, string className_, map>& states_) override; diff --git a/src/BayesNet/BaseClassifier.h b/src/BayesNet/BaseClassifier.h index ff202e1..527b5c5 100644 --- a/src/BayesNet/BaseClassifier.h +++ b/src/BayesNet/BaseClassifier.h @@ -6,7 +6,7 @@ namespace bayesnet { using namespace std; class BaseClassifier { protected: - virtual void trainModel() = 0; + virtual void trainModel(const torch::Tensor& weights) = 0; public: // X is nxm vector, y is nx1 vector virtual BaseClassifier& fit(vector>& X, vector& y, vector& features, string className, map>& states) = 0; diff --git a/src/BayesNet/BayesMetrics.cc b/src/BayesNet/BayesMetrics.cc index 2f0de11..cb93141 100644 --- a/src/BayesNet/BayesMetrics.cc +++ b/src/BayesNet/BayesMetrics.cc @@ -52,7 +52,8 @@ namespace bayesnet { auto mask = samples.index({ -1, "..." }) == value; auto first_dataset = samples.index({ index_first, mask }); auto second_dataset = samples.index({ index_second, mask }); - auto mi = mutualInformation(first_dataset, second_dataset, weights); + auto weights_dataset = weights.index({ mask }); + auto mi = mutualInformation(first_dataset, second_dataset, weights_dataset); auto pb = margin[value].item(); accumulated += pb * mi; } diff --git a/src/BayesNet/Classifier.cc b/src/BayesNet/Classifier.cc index 87bae91..1fab813 100644 --- a/src/BayesNet/Classifier.cc +++ b/src/BayesNet/Classifier.cc @@ -16,8 +16,10 @@ namespace bayesnet { auto n_classes = states[className].size(); metrics = Metrics(dataset, features, className, n_classes); model.initialize(); - buildModel(); - trainModel(); + // TODO weights can't be ones + const torch::Tensor weights = torch::ones({ m }, torch::kFloat); + buildModel(weights); + trainModel(weights); fitted = true; return *this; } @@ -35,9 +37,8 @@ namespace bayesnet { exit(1); } } - void Classifier::trainModel() + void Classifier::trainModel(const torch::Tensor& weights) { - const torch::Tensor weights = torch::ones({ m }); model.fit(dataset, weights, features, className, states); } // X is nxm where n is the number of features and m the number of samples diff --git a/src/BayesNet/Classifier.h b/src/BayesNet/Classifier.h index 6d00928..3c18295 100644 --- a/src/BayesNet/Classifier.h +++ b/src/BayesNet/Classifier.h @@ -21,10 +21,9 @@ namespace bayesnet { string className; map> states; Tensor dataset; // (n+1)xm tensor - Tensor weights; void checkFitParameters(); - virtual void buildModel() = 0; - void trainModel() override; + virtual void buildModel(const torch::Tensor& weights) = 0; + void trainModel(const torch::Tensor& weights) override; public: Classifier(Network model); virtual ~Classifier() = default; diff --git a/src/BayesNet/Ensemble.cc b/src/BayesNet/Ensemble.cc index 34c6894..926fa5b 100644 --- a/src/BayesNet/Ensemble.cc +++ b/src/BayesNet/Ensemble.cc @@ -5,7 +5,7 @@ namespace bayesnet { Ensemble::Ensemble() : Classifier(Network()) {} - void Ensemble::trainModel() + void Ensemble::trainModel(const torch::Tensor& weights) { n_models = models.size(); for (auto i = 0; i < n_models; ++i) { diff --git a/src/BayesNet/Ensemble.h b/src/BayesNet/Ensemble.h index f0d750b..95c1da6 100644 --- a/src/BayesNet/Ensemble.h +++ b/src/BayesNet/Ensemble.h @@ -14,7 +14,7 @@ namespace bayesnet { protected: unsigned n_models; vector> models; - void trainModel() override; + void trainModel(const torch::Tensor& weights) override; vector voting(Tensor& y_pred); public: Ensemble(); diff --git a/src/BayesNet/KDB.cc b/src/BayesNet/KDB.cc index 874e08a..471f3fd 100644 --- a/src/BayesNet/KDB.cc +++ b/src/BayesNet/KDB.cc @@ -4,7 +4,7 @@ namespace bayesnet { using namespace torch; KDB::KDB(int k, float theta) : Classifier(Network()), k(k), theta(theta) {} - void KDB::buildModel() + void KDB::buildModel(const torch::Tensor& weights) { /* 1. For each feature Xi, compute mutual information, I(X;C), diff --git a/src/BayesNet/KDB.h b/src/BayesNet/KDB.h index e7af8c5..b997cdd 100644 --- a/src/BayesNet/KDB.h +++ b/src/BayesNet/KDB.h @@ -1,5 +1,6 @@ #ifndef KDB_H #define KDB_H +#include #include "Classifier.h" #include "bayesnetUtils.h" namespace bayesnet { @@ -11,7 +12,7 @@ namespace bayesnet { float theta; void add_m_edges(int idx, vector& S, Tensor& weights); protected: - void buildModel() override; + void buildModel(const torch::Tensor& weights) override; public: explicit KDB(int k, float theta = 0.03); virtual ~KDB() {}; diff --git a/src/BayesNet/Network.cc b/src/BayesNet/Network.cc index fbb62cc..b65f570 100644 --- a/src/BayesNet/Network.cc +++ b/src/BayesNet/Network.cc @@ -107,7 +107,7 @@ namespace bayesnet { void Network::checkFitData(int n_samples, int n_features, int n_samples_y, const vector& featureNames, const string& className, const map>& states, const torch::Tensor& weights) { if (weights.size(0) != n_samples) { - throw invalid_argument("Weights must have the same number of elements as samples in Network::fit"); + throw invalid_argument("Weights (" + to_string(weights.size(0)) + ") must have the same number of elements as samples (" + to_string(n_samples) + ") in Network::fit"); } if (n_samples != n_samples_y) { throw invalid_argument("X and y must have the same number of samples in Network::fit (" + to_string(n_samples) + " != " + to_string(n_samples_y) + ")"); diff --git a/src/BayesNet/Proposal.cc b/src/BayesNet/Proposal.cc index d95e701..87767b5 100644 --- a/src/BayesNet/Proposal.cc +++ b/src/BayesNet/Proposal.cc @@ -65,6 +65,7 @@ namespace bayesnet { //Update new states of the feature/node states[pFeatures[index]] = xStates; } + // TODO weights can't be ones const torch::Tensor weights = torch::ones({ pDataset.size(1) }, torch::kFloat); model.fit(pDataset, weights, pFeatures, pClassName, states); } diff --git a/src/BayesNet/SPODE.cc b/src/BayesNet/SPODE.cc index a90e5ef..83c9231 100644 --- a/src/BayesNet/SPODE.cc +++ b/src/BayesNet/SPODE.cc @@ -4,7 +4,7 @@ namespace bayesnet { SPODE::SPODE(int root) : Classifier(Network()), root(root) {} - void SPODE::buildModel() + void SPODE::buildModel(const torch::Tensor& weights) { // 0. Add all nodes to the model addNodes(); diff --git a/src/BayesNet/SPODE.h b/src/BayesNet/SPODE.h index f9b6af0..0a78830 100644 --- a/src/BayesNet/SPODE.h +++ b/src/BayesNet/SPODE.h @@ -7,7 +7,7 @@ namespace bayesnet { private: int root; protected: - void buildModel() override; + void buildModel(const torch::Tensor& weights) override; public: explicit SPODE(int root); virtual ~SPODE() {}; diff --git a/src/BayesNet/TAN.cc b/src/BayesNet/TAN.cc index 843a5e6..f0728be 100644 --- a/src/BayesNet/TAN.cc +++ b/src/BayesNet/TAN.cc @@ -5,7 +5,7 @@ namespace bayesnet { TAN::TAN() : Classifier(Network()) {} - void TAN::buildModel() + void TAN::buildModel(const torch::Tensor& weights) { // 0. Add all nodes to the model addNodes(); diff --git a/src/BayesNet/TAN.h b/src/BayesNet/TAN.h index 4c1c5f5..91b5109 100644 --- a/src/BayesNet/TAN.h +++ b/src/BayesNet/TAN.h @@ -7,7 +7,7 @@ namespace bayesnet { class TAN : public Classifier { private: protected: - void buildModel() override; + void buildModel(const torch::Tensor& weights) override; public: TAN(); virtual ~TAN() {}; diff --git a/src/Platform/Report.h b/src/Platform/Report.h index 5934b2f..2708d4e 100644 --- a/src/Platform/Report.h +++ b/src/Platform/Report.h @@ -6,7 +6,7 @@ #include "Colors.h" using json = nlohmann::json; -const int MAXL = 121; +const int MAXL = 122; namespace platform { using namespace std; class Report { diff --git a/src/Platform/main.cc b/src/Platform/main.cc index 0618c89..6f9ce1c 100644 --- a/src/Platform/main.cc +++ b/src/Platform/main.cc @@ -103,7 +103,7 @@ int main(int argc, char** argv) */ auto env = platform::DotEnv(); auto experiment = platform::Experiment(); - experiment.setTitle(title).setLanguage("cpp").setLanguageVersion("1.0.0"); + experiment.setTitle(title).setLanguage("cpp").setLanguageVersion("14.0.3"); experiment.setDiscretized(discretize_dataset).setModel(model_name).setPlatform(env.get("platform")); experiment.setStratified(stratified).setNFolds(n_folds).setScoreName("accuracy"); for (auto seed : seeds) {