Refactor Models to be a singleton factory

Add Registrar of models
This commit is contained in:
Ricardo Montañana Gómez 2023-07-29 18:22:15 +02:00
parent 07d572a98c
commit cb54f61a69
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
5 changed files with 95 additions and 35 deletions

View File

@ -1,6 +1,7 @@
#ifndef SPODE_H
#define SPODE_H
#include "Classifier.h"
namespace bayesnet {
class SPODE : public Classifier {
private:

View File

@ -93,11 +93,6 @@ namespace platform {
Result Experiment::cross_validation(const string& path, const string& fileName)
{
auto datasets = platform::Datasets(path, true, platform::ARFF);
auto classifiers = map<string, bayesnet::BaseClassifier*>({
{ "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) },
{ "SPODE", new bayesnet::SPODE(2) }, { "TAN", new bayesnet::TAN() }
}
);
// Get dataset
auto [X, y] = datasets.getTensors(fileName);
auto states = datasets.getStates(fileName);
@ -127,7 +122,7 @@ namespace platform {
else
fold = new KFold(nfolds, y.size(0), seed);
for (int nfold = 0; nfold < nfolds; nfold++) {
auto clf = Models::createInstance(model);
auto clf = Models::instance()->create(model);
setModelVersion(clf->getVersion());
train_timer.start();
auto [train, test] = fold->getFold(nfold);

View File

@ -1,28 +1,73 @@
#include "Models.h"
namespace platform {
using namespace std;
// map<string, bayesnet::BaseClassifier*> Models::classifiers = map<string, bayesnet::BaseClassifier*>({
// { "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) },
// { "SPODE", new bayesnet::SPODE(2) }, { "TAN", new bayesnet::TAN() }
// });
// Idea from: https://www.codeproject.com/Articles/567242/AplusC-2b-2bplusObjectplusFactory
shared_ptr<bayesnet::BaseClassifier> Models::createInstance(const string& name)
// shared_ptr<bayesnet::BaseClassifier> Models::createInstance(const string& name)
// {
// bayesnet::BaseClassifier* instance = nullptr;
// if (name == "AODE") {
// instance = new bayesnet::AODE();
// } else if (name == "KDB") {
// instance = new bayesnet::KDB(2);
// } else if (name == "SPODE") {
// instance = new bayesnet::SPODE(2);
// } else if (name == "TAN") {
// instance = new bayesnet::TAN();
// } else {
// throw runtime_error("Model " + name + " not found");
// }
// if (instance != nullptr)
// return shared_ptr<bayesnet::BaseClassifier>(instance);
// else
// return nullptr;
// }
Models* Models::factory = nullptr;;
Models* Models::instance()
{
//manages singleton
if (factory == nullptr)
factory = new Models();
return factory;
}
void Models::registerFactoryFunction(const string& name,
function<bayesnet::BaseClassifier* (void)> classFactoryFunction)
{
// register the class factory function
functionRegistry[name] = classFactoryFunction;
}
shared_ptr<bayesnet::BaseClassifier> Models::create(const string& name)
{
bayesnet::BaseClassifier* instance = nullptr;
if (name == "AODE") {
instance = new bayesnet::AODE();
} else if (name == "KDB") {
instance = new bayesnet::KDB(2);
} else if (name == "SPODE") {
instance = new bayesnet::SPODE(2);
} else if (name == "TAN") {
instance = new bayesnet::TAN();
} else {
throw runtime_error("Model " + name + " not found");
}
// find name in the registry and call factory method.
auto it = functionRegistry.find(name);
if (it != functionRegistry.end())
instance = it->second();
// wrap instance in a shared ptr and return
if (instance != nullptr)
return shared_ptr<bayesnet::BaseClassifier>(instance);
else
return nullptr;
}
vector<string> Models::getNames()
{
vector<string> names;
transform(functionRegistry.begin(), functionRegistry.end(), back_inserter(names),
[](const pair<string, function<bayesnet::BaseClassifier* (void)>>& pair) { return pair.first; });
return names;
}
string Models::toString()
{
string result = "";
for (auto& pair : functionRegistry) {
result += pair.first + ", ";
}
return "{" + result.substr(0, result.size() - 2) + "}";
}
Registrar::Registrar(const string& name, function<bayesnet::BaseClassifier* (void)> classFactoryFunction)
{
// register the class factory function
Models::instance()->registerFactoryFunction(name, classFactoryFunction);
}
}

View File

@ -8,18 +8,25 @@
#include "SPODE.h"
namespace platform {
class Models {
private:
map<string, function<bayesnet::BaseClassifier* (void)>> functionRegistry;
static Models* factory; //singleton
Models() {};
public:
Models(Models&) = delete;
void operator=(const Models&) = delete;
// Idea from: https://www.codeproject.com/Articles/567242/AplusC-2b-2bplusObjectplusFactory
static shared_ptr<bayesnet::BaseClassifier> createInstance(const string& name);
static vector<string> getNames()
{
return { "aaaaaAODE", "KDB", "SPODE", "TAN" };
}
static string toString()
{
return "{aaaaae34223AODE, KDB, SPODE, TAN}";
//return "{" + names.substr(0, names.size() - 2) + "}";
}
static Models* instance();
shared_ptr<bayesnet::BaseClassifier> create(const string& name);
void registerFactoryFunction(const string& name,
function<bayesnet::BaseClassifier* (void)> classFactoryFunction);
vector<string> getNames();
string toString();
};
class Registrar {
public:
Registrar(const string& className, function<bayesnet::BaseClassifier* (void)> classFactoryFunction);
};
}
#endif

View File

@ -21,13 +21,13 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
.default_value(string{ PATH_DATASETS }
);
program.add_argument("-m", "--model")
.help("Model to use " + platform::Models::toString())
.help("Model to use " + platform::Models::instance()->toString())
.action([](const std::string& value) {
static const vector<string> choices = platform::Models::getNames();
static const vector<string> choices = platform::Models::instance()->getNames();
if (find(choices.begin(), choices.end(), value) != choices.end()) {
return value;
}
throw runtime_error("Model must be one of " + platform::Models::toString());
throw runtime_error("Model must be one of " + platform::Models::instance()->toString());
}
);
program.add_argument("--title").default_value("").help("Experiment title");
@ -76,9 +76,21 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
}
return program;
}
void registerModels()
{
static platform::Registrar registrarT("TAN",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::TAN();});
static platform::Registrar registrarS("SPODE",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::SPODE(2);});
static platform::Registrar registrarK("KDB",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::KDB(2);});
static platform::Registrar registrarA("AODE",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::AODE();});
}
int main(int argc, char** argv)
{
registerModels();
auto program = manageArguments(argc, argv);
bool saveResults = false;
auto file_name = program.get<string>("dataset");