Refactor Models to be a singleton factory

Add Registrar of models
This commit is contained in:
Ricardo Montañana Gómez 2023-07-29 18:22:15 +02:00
parent 07d572a98c
commit cb54f61a69
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
5 changed files with 95 additions and 35 deletions

View File

@ -1,6 +1,7 @@
#ifndef SPODE_H #ifndef SPODE_H
#define SPODE_H #define SPODE_H
#include "Classifier.h" #include "Classifier.h"
namespace bayesnet { namespace bayesnet {
class SPODE : public Classifier { class SPODE : public Classifier {
private: private:

View File

@ -93,11 +93,6 @@ namespace platform {
Result Experiment::cross_validation(const string& path, const string& fileName) Result Experiment::cross_validation(const string& path, const string& fileName)
{ {
auto datasets = platform::Datasets(path, true, platform::ARFF); auto datasets = platform::Datasets(path, true, platform::ARFF);
auto classifiers = map<string, bayesnet::BaseClassifier*>({
{ "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) },
{ "SPODE", new bayesnet::SPODE(2) }, { "TAN", new bayesnet::TAN() }
}
);
// Get dataset // Get dataset
auto [X, y] = datasets.getTensors(fileName); auto [X, y] = datasets.getTensors(fileName);
auto states = datasets.getStates(fileName); auto states = datasets.getStates(fileName);
@ -127,7 +122,7 @@ namespace platform {
else else
fold = new KFold(nfolds, y.size(0), seed); fold = new KFold(nfolds, y.size(0), seed);
for (int nfold = 0; nfold < nfolds; nfold++) { for (int nfold = 0; nfold < nfolds; nfold++) {
auto clf = Models::createInstance(model); auto clf = Models::instance()->create(model);
setModelVersion(clf->getVersion()); setModelVersion(clf->getVersion());
train_timer.start(); train_timer.start();
auto [train, test] = fold->getFold(nfold); auto [train, test] = fold->getFold(nfold);

View File

@ -1,28 +1,73 @@
#include "Models.h" #include "Models.h"
namespace platform { namespace platform {
using namespace std; using namespace std;
// map<string, bayesnet::BaseClassifier*> Models::classifiers = map<string, bayesnet::BaseClassifier*>({
// { "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) },
// { "SPODE", new bayesnet::SPODE(2) }, { "TAN", new bayesnet::TAN() }
// });
// Idea from: https://www.codeproject.com/Articles/567242/AplusC-2b-2bplusObjectplusFactory // Idea from: https://www.codeproject.com/Articles/567242/AplusC-2b-2bplusObjectplusFactory
shared_ptr<bayesnet::BaseClassifier> Models::createInstance(const string& name) // shared_ptr<bayesnet::BaseClassifier> Models::createInstance(const string& name)
// {
// bayesnet::BaseClassifier* instance = nullptr;
// if (name == "AODE") {
// instance = new bayesnet::AODE();
// } else if (name == "KDB") {
// instance = new bayesnet::KDB(2);
// } else if (name == "SPODE") {
// instance = new bayesnet::SPODE(2);
// } else if (name == "TAN") {
// instance = new bayesnet::TAN();
// } else {
// throw runtime_error("Model " + name + " not found");
// }
// if (instance != nullptr)
// return shared_ptr<bayesnet::BaseClassifier>(instance);
// else
// return nullptr;
// }
Models* Models::factory = nullptr;;
Models* Models::instance()
{
//manages singleton
if (factory == nullptr)
factory = new Models();
return factory;
}
void Models::registerFactoryFunction(const string& name,
function<bayesnet::BaseClassifier* (void)> classFactoryFunction)
{
// register the class factory function
functionRegistry[name] = classFactoryFunction;
}
shared_ptr<bayesnet::BaseClassifier> Models::create(const string& name)
{ {
bayesnet::BaseClassifier* instance = nullptr; bayesnet::BaseClassifier* instance = nullptr;
if (name == "AODE") {
instance = new bayesnet::AODE(); // find name in the registry and call factory method.
} else if (name == "KDB") { auto it = functionRegistry.find(name);
instance = new bayesnet::KDB(2); if (it != functionRegistry.end())
} else if (name == "SPODE") { instance = it->second();
instance = new bayesnet::SPODE(2); // wrap instance in a shared ptr and return
} else if (name == "TAN") {
instance = new bayesnet::TAN();
} else {
throw runtime_error("Model " + name + " not found");
}
if (instance != nullptr) if (instance != nullptr)
return shared_ptr<bayesnet::BaseClassifier>(instance); return shared_ptr<bayesnet::BaseClassifier>(instance);
else else
return nullptr; return nullptr;
} }
vector<string> Models::getNames()
{
vector<string> names;
transform(functionRegistry.begin(), functionRegistry.end(), back_inserter(names),
[](const pair<string, function<bayesnet::BaseClassifier* (void)>>& pair) { return pair.first; });
return names;
}
string Models::toString()
{
string result = "";
for (auto& pair : functionRegistry) {
result += pair.first + ", ";
}
return "{" + result.substr(0, result.size() - 2) + "}";
}
Registrar::Registrar(const string& name, function<bayesnet::BaseClassifier* (void)> classFactoryFunction)
{
// register the class factory function
Models::instance()->registerFactoryFunction(name, classFactoryFunction);
}
} }

View File

@ -8,18 +8,25 @@
#include "SPODE.h" #include "SPODE.h"
namespace platform { namespace platform {
class Models { class Models {
private:
map<string, function<bayesnet::BaseClassifier* (void)>> functionRegistry;
static Models* factory; //singleton
Models() {};
public: public:
Models(Models&) = delete;
void operator=(const Models&) = delete;
// Idea from: https://www.codeproject.com/Articles/567242/AplusC-2b-2bplusObjectplusFactory // Idea from: https://www.codeproject.com/Articles/567242/AplusC-2b-2bplusObjectplusFactory
static shared_ptr<bayesnet::BaseClassifier> createInstance(const string& name); static Models* instance();
static vector<string> getNames() shared_ptr<bayesnet::BaseClassifier> create(const string& name);
{ void registerFactoryFunction(const string& name,
return { "aaaaaAODE", "KDB", "SPODE", "TAN" }; function<bayesnet::BaseClassifier* (void)> classFactoryFunction);
} vector<string> getNames();
static string toString() string toString();
{
return "{aaaaae34223AODE, KDB, SPODE, TAN}"; };
//return "{" + names.substr(0, names.size() - 2) + "}"; class Registrar {
} public:
Registrar(const string& className, function<bayesnet::BaseClassifier* (void)> classFactoryFunction);
}; };
} }
#endif #endif

View File

@ -21,13 +21,13 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
.default_value(string{ PATH_DATASETS } .default_value(string{ PATH_DATASETS }
); );
program.add_argument("-m", "--model") program.add_argument("-m", "--model")
.help("Model to use " + platform::Models::toString()) .help("Model to use " + platform::Models::instance()->toString())
.action([](const std::string& value) { .action([](const std::string& value) {
static const vector<string> choices = platform::Models::getNames(); static const vector<string> choices = platform::Models::instance()->getNames();
if (find(choices.begin(), choices.end(), value) != choices.end()) { if (find(choices.begin(), choices.end(), value) != choices.end()) {
return value; return value;
} }
throw runtime_error("Model must be one of " + platform::Models::toString()); throw runtime_error("Model must be one of " + platform::Models::instance()->toString());
} }
); );
program.add_argument("--title").default_value("").help("Experiment title"); program.add_argument("--title").default_value("").help("Experiment title");
@ -76,9 +76,21 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
} }
return program; return program;
} }
void registerModels()
{
static platform::Registrar registrarT("TAN",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::TAN();});
static platform::Registrar registrarS("SPODE",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::SPODE(2);});
static platform::Registrar registrarK("KDB",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::KDB(2);});
static platform::Registrar registrarA("AODE",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::AODE();});
}
int main(int argc, char** argv) int main(int argc, char** argv)
{ {
registerModels();
auto program = manageArguments(argc, argv); auto program = manageArguments(argc, argv);
bool saveResults = false; bool saveResults = false;
auto file_name = program.get<string>("dataset"); auto file_name = program.get<string>("dataset");