From 07d572a98c813fd274fb9143d674c9dc41c216cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana?= Date: Sat, 29 Jul 2023 17:27:43 +0200 Subject: [PATCH] Add Model factory --- src/BayesNet/AODE.h | 1 + src/BayesNet/KDB.h | 1 + src/BayesNet/SPODE.h | 1 + src/BayesNet/TAN.h | 1 + src/Platform/Experiment.cc | 12 ++++++------ src/Platform/Models.cc | 28 ++++++++++++++++++++++++---- src/Platform/Models.h | 18 +++++------------- 7 files changed, 39 insertions(+), 23 deletions(-) diff --git a/src/BayesNet/AODE.h b/src/BayesNet/AODE.h index 84386d3..bc859e7 100644 --- a/src/BayesNet/AODE.h +++ b/src/BayesNet/AODE.h @@ -8,6 +8,7 @@ namespace bayesnet { void train() override; public: AODE(); + virtual ~AODE() {}; vector graph(string title = "AODE") override; }; } diff --git a/src/BayesNet/KDB.h b/src/BayesNet/KDB.h index 9683955..b0790da 100644 --- a/src/BayesNet/KDB.h +++ b/src/BayesNet/KDB.h @@ -14,6 +14,7 @@ namespace bayesnet { void train() override; public: KDB(int k, float theta = 0.03); + virtual ~KDB() {}; vector graph(string name = "KDB") override; }; } diff --git a/src/BayesNet/SPODE.h b/src/BayesNet/SPODE.h index 668bbca..05bf3a5 100644 --- a/src/BayesNet/SPODE.h +++ b/src/BayesNet/SPODE.h @@ -9,6 +9,7 @@ namespace bayesnet { void train() override; public: SPODE(int root); + virtual ~SPODE() {}; vector graph(string name = "SPODE") override; }; } diff --git a/src/BayesNet/TAN.h b/src/BayesNet/TAN.h index 11e7421..ce9b10a 100644 --- a/src/BayesNet/TAN.h +++ b/src/BayesNet/TAN.h @@ -10,6 +10,7 @@ namespace bayesnet { void train() override; public: TAN(); + virtual ~TAN() {}; vector graph(string name = "TAN") override; }; } diff --git a/src/Platform/Experiment.cc b/src/Platform/Experiment.cc index 8592019..08f09f8 100644 --- a/src/Platform/Experiment.cc +++ b/src/Platform/Experiment.cc @@ -1,5 +1,6 @@ #include "Experiment.h" #include "Datasets.h" +#include "Models.h" namespace platform { using json = nlohmann::json; @@ -91,12 +92,12 @@ namespace platform { } Result Experiment::cross_validation(const string& path, const string& fileName) { + auto datasets = platform::Datasets(path, true, platform::ARFF); auto classifiers = map({ - { "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) }, - { "SPODE", new bayesnet::SPODE(2) }, { "TAN", new bayesnet::TAN() } + { "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) }, + { "SPODE", new bayesnet::SPODE(2) }, { "TAN", new bayesnet::TAN() } } ); - auto datasets = platform::Datasets(path, true, platform::ARFF); // Get dataset auto [X, y] = datasets.getTensors(fileName); auto states = datasets.getStates(fileName); @@ -119,15 +120,14 @@ namespace platform { Timer train_timer, test_timer; int item = 0; for (auto seed : randomSeeds) { - cout << "(" << seed << ") " << flush; + cout << "(" << seed << ") doing Fold: " << flush; Fold* fold; if (stratified) fold = new StratifiedKFold(nfolds, y, seed); else fold = new KFold(nfolds, y.size(0), seed); - cout << "doing Fold: " << flush; for (int nfold = 0; nfold < nfolds; nfold++) { - bayesnet::BaseClassifier* clf = classifiers[model]; + auto clf = Models::createInstance(model); setModelVersion(clf->getVersion()); train_timer.start(); auto [train, test] = fold->getFold(nfold); diff --git a/src/Platform/Models.cc b/src/Platform/Models.cc index b6aaa66..7bed6c3 100644 --- a/src/Platform/Models.cc +++ b/src/Platform/Models.cc @@ -1,8 +1,28 @@ #include "Models.h" namespace platform { using namespace std; - map Models::classifiers = map({ - { "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) }, - { "SPODE", new bayesnet::SPODE(2) }, { "TAN", new bayesnet::TAN() } - }); + // map Models::classifiers = map({ + // { "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) }, + // { "SPODE", new bayesnet::SPODE(2) }, { "TAN", new bayesnet::TAN() } + // }); + // Idea from: https://www.codeproject.com/Articles/567242/AplusC-2b-2bplusObjectplusFactory + shared_ptr Models::createInstance(const string& name) + { + bayesnet::BaseClassifier* instance = nullptr; + if (name == "AODE") { + instance = new bayesnet::AODE(); + } else if (name == "KDB") { + instance = new bayesnet::KDB(2); + } else if (name == "SPODE") { + instance = new bayesnet::SPODE(2); + } else if (name == "TAN") { + instance = new bayesnet::TAN(); + } else { + throw runtime_error("Model " + name + " not found"); + } + if (instance != nullptr) + return shared_ptr(instance); + else + return nullptr; + } } \ No newline at end of file diff --git a/src/Platform/Models.h b/src/Platform/Models.h index 2851036..379e00f 100644 --- a/src/Platform/Models.h +++ b/src/Platform/Models.h @@ -8,25 +8,17 @@ #include "SPODE.h" namespace platform { class Models { - private: - static map classifiers; public: - static bayesnet::BaseClassifier* get(string name) { return classifiers[name]; } + // Idea from: https://www.codeproject.com/Articles/567242/AplusC-2b-2bplusObjectplusFactory + static shared_ptr createInstance(const string& name); static vector getNames() { - vector names; - for (auto& [name, classifier] : classifiers) { - names.push_back(name); - } - return names; + return { "aaaaaAODE", "KDB", "SPODE", "TAN" }; } static string toString() { - string names = ""; - for (auto& [name, classifier] : classifiers) { - names += name + ", "; - } - return "{" + names.substr(0, names.size() - 2) + "}"; + return "{aaaaae34223AODE, KDB, SPODE, TAN}"; + //return "{" + names.substr(0, names.size() - 2) + "}"; } }; }