From cb54f61a694e3ef2f5de357b2c5b2abb36ab3e4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana?= Date: Sat, 29 Jul 2023 18:22:15 +0200 Subject: [PATCH] Refactor Models to be a singleton factory Add Registrar of models --- src/BayesNet/SPODE.h | 1 + src/Platform/Experiment.cc | 7 +--- src/Platform/Models.cc | 77 ++++++++++++++++++++++++++++++-------- src/Platform/Models.h | 27 ++++++++----- src/Platform/main.cc | 18 +++++++-- 5 files changed, 95 insertions(+), 35 deletions(-) diff --git a/src/BayesNet/SPODE.h b/src/BayesNet/SPODE.h index 05bf3a5..0f422a7 100644 --- a/src/BayesNet/SPODE.h +++ b/src/BayesNet/SPODE.h @@ -1,6 +1,7 @@ #ifndef SPODE_H #define SPODE_H #include "Classifier.h" + namespace bayesnet { class SPODE : public Classifier { private: diff --git a/src/Platform/Experiment.cc b/src/Platform/Experiment.cc index 08f09f8..ab62ab2 100644 --- a/src/Platform/Experiment.cc +++ b/src/Platform/Experiment.cc @@ -93,11 +93,6 @@ namespace platform { Result Experiment::cross_validation(const string& path, const string& fileName) { auto datasets = platform::Datasets(path, true, platform::ARFF); - auto classifiers = map({ - { "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) }, - { "SPODE", new bayesnet::SPODE(2) }, { "TAN", new bayesnet::TAN() } - } - ); // Get dataset auto [X, y] = datasets.getTensors(fileName); auto states = datasets.getStates(fileName); @@ -127,7 +122,7 @@ namespace platform { else fold = new KFold(nfolds, y.size(0), seed); for (int nfold = 0; nfold < nfolds; nfold++) { - auto clf = Models::createInstance(model); + auto clf = Models::instance()->create(model); setModelVersion(clf->getVersion()); train_timer.start(); auto [train, test] = fold->getFold(nfold); diff --git a/src/Platform/Models.cc b/src/Platform/Models.cc index 7bed6c3..aa23a2d 100644 --- a/src/Platform/Models.cc +++ b/src/Platform/Models.cc @@ -1,28 +1,73 @@ #include "Models.h" namespace platform { using namespace std; - // map Models::classifiers = map({ - // { "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) }, - // { "SPODE", new bayesnet::SPODE(2) }, { "TAN", new bayesnet::TAN() } - // }); // Idea from: https://www.codeproject.com/Articles/567242/AplusC-2b-2bplusObjectplusFactory - shared_ptr Models::createInstance(const string& name) + // shared_ptr Models::createInstance(const string& name) + // { + // bayesnet::BaseClassifier* instance = nullptr; + // if (name == "AODE") { + // instance = new bayesnet::AODE(); + // } else if (name == "KDB") { + // instance = new bayesnet::KDB(2); + // } else if (name == "SPODE") { + // instance = new bayesnet::SPODE(2); + // } else if (name == "TAN") { + // instance = new bayesnet::TAN(); + // } else { + // throw runtime_error("Model " + name + " not found"); + // } + // if (instance != nullptr) + // return shared_ptr(instance); + // else + // return nullptr; + // } + Models* Models::factory = nullptr;; + Models* Models::instance() + { + //manages singleton + if (factory == nullptr) + factory = new Models(); + return factory; + } + void Models::registerFactoryFunction(const string& name, + function classFactoryFunction) + { + // register the class factory function + functionRegistry[name] = classFactoryFunction; + } + shared_ptr Models::create(const string& name) { bayesnet::BaseClassifier* instance = nullptr; - if (name == "AODE") { - instance = new bayesnet::AODE(); - } else if (name == "KDB") { - instance = new bayesnet::KDB(2); - } else if (name == "SPODE") { - instance = new bayesnet::SPODE(2); - } else if (name == "TAN") { - instance = new bayesnet::TAN(); - } else { - throw runtime_error("Model " + name + " not found"); - } + + // find name in the registry and call factory method. + auto it = functionRegistry.find(name); + if (it != functionRegistry.end()) + instance = it->second(); + // wrap instance in a shared ptr and return if (instance != nullptr) return shared_ptr(instance); else return nullptr; } + vector Models::getNames() + { + vector names; + transform(functionRegistry.begin(), functionRegistry.end(), back_inserter(names), + [](const pair>& pair) { return pair.first; }); + return names; + } + string Models::toString() + { + string result = ""; + for (auto& pair : functionRegistry) { + result += pair.first + ", "; + } + return "{" + result.substr(0, result.size() - 2) + "}"; + } + + Registrar::Registrar(const string& name, function classFactoryFunction) + { + // register the class factory function + Models::instance()->registerFactoryFunction(name, classFactoryFunction); + } } \ No newline at end of file diff --git a/src/Platform/Models.h b/src/Platform/Models.h index 379e00f..0bb8d51 100644 --- a/src/Platform/Models.h +++ b/src/Platform/Models.h @@ -8,18 +8,25 @@ #include "SPODE.h" namespace platform { class Models { + private: + map> functionRegistry; + static Models* factory; //singleton + Models() {}; public: + Models(Models&) = delete; + void operator=(const Models&) = delete; // Idea from: https://www.codeproject.com/Articles/567242/AplusC-2b-2bplusObjectplusFactory - static shared_ptr createInstance(const string& name); - static vector getNames() - { - return { "aaaaaAODE", "KDB", "SPODE", "TAN" }; - } - static string toString() - { - return "{aaaaae34223AODE, KDB, SPODE, TAN}"; - //return "{" + names.substr(0, names.size() - 2) + "}"; - } + static Models* instance(); + shared_ptr create(const string& name); + void registerFactoryFunction(const string& name, + function classFactoryFunction); + vector getNames(); + string toString(); + + }; + class Registrar { + public: + Registrar(const string& className, function classFactoryFunction); }; } #endif \ No newline at end of file diff --git a/src/Platform/main.cc b/src/Platform/main.cc index d7d040d..3a4b238 100644 --- a/src/Platform/main.cc +++ b/src/Platform/main.cc @@ -21,13 +21,13 @@ argparse::ArgumentParser manageArguments(int argc, char** argv) .default_value(string{ PATH_DATASETS } ); program.add_argument("-m", "--model") - .help("Model to use " + platform::Models::toString()) + .help("Model to use " + platform::Models::instance()->toString()) .action([](const std::string& value) { - static const vector choices = platform::Models::getNames(); + static const vector choices = platform::Models::instance()->getNames(); if (find(choices.begin(), choices.end(), value) != choices.end()) { return value; } - throw runtime_error("Model must be one of " + platform::Models::toString()); + throw runtime_error("Model must be one of " + platform::Models::instance()->toString()); } ); program.add_argument("--title").default_value("").help("Experiment title"); @@ -76,9 +76,21 @@ argparse::ArgumentParser manageArguments(int argc, char** argv) } return program; } +void registerModels() +{ + static platform::Registrar registrarT("TAN", + [](void) -> bayesnet::BaseClassifier* { return new bayesnet::TAN();}); + static platform::Registrar registrarS("SPODE", + [](void) -> bayesnet::BaseClassifier* { return new bayesnet::SPODE(2);}); + static platform::Registrar registrarK("KDB", + [](void) -> bayesnet::BaseClassifier* { return new bayesnet::KDB(2);}); + static platform::Registrar registrarA("AODE", + [](void) -> bayesnet::BaseClassifier* { return new bayesnet::AODE();}); +} int main(int argc, char** argv) { + registerModels(); auto program = manageArguments(argc, argv); bool saveResults = false; auto file_name = program.get("dataset");