Refactor Models to be a singleton factory
Add Registrar of models
This commit is contained in:
parent
07d572a98c
commit
cb54f61a69
@ -1,6 +1,7 @@
|
|||||||
#ifndef SPODE_H
|
#ifndef SPODE_H
|
||||||
#define SPODE_H
|
#define SPODE_H
|
||||||
#include "Classifier.h"
|
#include "Classifier.h"
|
||||||
|
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
class SPODE : public Classifier {
|
class SPODE : public Classifier {
|
||||||
private:
|
private:
|
||||||
|
@ -93,11 +93,6 @@ namespace platform {
|
|||||||
Result Experiment::cross_validation(const string& path, const string& fileName)
|
Result Experiment::cross_validation(const string& path, const string& fileName)
|
||||||
{
|
{
|
||||||
auto datasets = platform::Datasets(path, true, platform::ARFF);
|
auto datasets = platform::Datasets(path, true, platform::ARFF);
|
||||||
auto classifiers = map<string, bayesnet::BaseClassifier*>({
|
|
||||||
{ "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) },
|
|
||||||
{ "SPODE", new bayesnet::SPODE(2) }, { "TAN", new bayesnet::TAN() }
|
|
||||||
}
|
|
||||||
);
|
|
||||||
// Get dataset
|
// Get dataset
|
||||||
auto [X, y] = datasets.getTensors(fileName);
|
auto [X, y] = datasets.getTensors(fileName);
|
||||||
auto states = datasets.getStates(fileName);
|
auto states = datasets.getStates(fileName);
|
||||||
@ -127,7 +122,7 @@ namespace platform {
|
|||||||
else
|
else
|
||||||
fold = new KFold(nfolds, y.size(0), seed);
|
fold = new KFold(nfolds, y.size(0), seed);
|
||||||
for (int nfold = 0; nfold < nfolds; nfold++) {
|
for (int nfold = 0; nfold < nfolds; nfold++) {
|
||||||
auto clf = Models::createInstance(model);
|
auto clf = Models::instance()->create(model);
|
||||||
setModelVersion(clf->getVersion());
|
setModelVersion(clf->getVersion());
|
||||||
train_timer.start();
|
train_timer.start();
|
||||||
auto [train, test] = fold->getFold(nfold);
|
auto [train, test] = fold->getFold(nfold);
|
||||||
|
@ -1,28 +1,73 @@
|
|||||||
#include "Models.h"
|
#include "Models.h"
|
||||||
namespace platform {
|
namespace platform {
|
||||||
using namespace std;
|
using namespace std;
|
||||||
// map<string, bayesnet::BaseClassifier*> Models::classifiers = map<string, bayesnet::BaseClassifier*>({
|
|
||||||
// { "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) },
|
|
||||||
// { "SPODE", new bayesnet::SPODE(2) }, { "TAN", new bayesnet::TAN() }
|
|
||||||
// });
|
|
||||||
// Idea from: https://www.codeproject.com/Articles/567242/AplusC-2b-2bplusObjectplusFactory
|
// Idea from: https://www.codeproject.com/Articles/567242/AplusC-2b-2bplusObjectplusFactory
|
||||||
shared_ptr<bayesnet::BaseClassifier> Models::createInstance(const string& name)
|
// shared_ptr<bayesnet::BaseClassifier> Models::createInstance(const string& name)
|
||||||
|
// {
|
||||||
|
// bayesnet::BaseClassifier* instance = nullptr;
|
||||||
|
// if (name == "AODE") {
|
||||||
|
// instance = new bayesnet::AODE();
|
||||||
|
// } else if (name == "KDB") {
|
||||||
|
// instance = new bayesnet::KDB(2);
|
||||||
|
// } else if (name == "SPODE") {
|
||||||
|
// instance = new bayesnet::SPODE(2);
|
||||||
|
// } else if (name == "TAN") {
|
||||||
|
// instance = new bayesnet::TAN();
|
||||||
|
// } else {
|
||||||
|
// throw runtime_error("Model " + name + " not found");
|
||||||
|
// }
|
||||||
|
// if (instance != nullptr)
|
||||||
|
// return shared_ptr<bayesnet::BaseClassifier>(instance);
|
||||||
|
// else
|
||||||
|
// return nullptr;
|
||||||
|
// }
|
||||||
|
Models* Models::factory = nullptr;;
|
||||||
|
Models* Models::instance()
|
||||||
|
{
|
||||||
|
//manages singleton
|
||||||
|
if (factory == nullptr)
|
||||||
|
factory = new Models();
|
||||||
|
return factory;
|
||||||
|
}
|
||||||
|
void Models::registerFactoryFunction(const string& name,
|
||||||
|
function<bayesnet::BaseClassifier* (void)> classFactoryFunction)
|
||||||
|
{
|
||||||
|
// register the class factory function
|
||||||
|
functionRegistry[name] = classFactoryFunction;
|
||||||
|
}
|
||||||
|
shared_ptr<bayesnet::BaseClassifier> Models::create(const string& name)
|
||||||
{
|
{
|
||||||
bayesnet::BaseClassifier* instance = nullptr;
|
bayesnet::BaseClassifier* instance = nullptr;
|
||||||
if (name == "AODE") {
|
|
||||||
instance = new bayesnet::AODE();
|
// find name in the registry and call factory method.
|
||||||
} else if (name == "KDB") {
|
auto it = functionRegistry.find(name);
|
||||||
instance = new bayesnet::KDB(2);
|
if (it != functionRegistry.end())
|
||||||
} else if (name == "SPODE") {
|
instance = it->second();
|
||||||
instance = new bayesnet::SPODE(2);
|
// wrap instance in a shared ptr and return
|
||||||
} else if (name == "TAN") {
|
|
||||||
instance = new bayesnet::TAN();
|
|
||||||
} else {
|
|
||||||
throw runtime_error("Model " + name + " not found");
|
|
||||||
}
|
|
||||||
if (instance != nullptr)
|
if (instance != nullptr)
|
||||||
return shared_ptr<bayesnet::BaseClassifier>(instance);
|
return shared_ptr<bayesnet::BaseClassifier>(instance);
|
||||||
else
|
else
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
vector<string> Models::getNames()
|
||||||
|
{
|
||||||
|
vector<string> names;
|
||||||
|
transform(functionRegistry.begin(), functionRegistry.end(), back_inserter(names),
|
||||||
|
[](const pair<string, function<bayesnet::BaseClassifier* (void)>>& pair) { return pair.first; });
|
||||||
|
return names;
|
||||||
|
}
|
||||||
|
string Models::toString()
|
||||||
|
{
|
||||||
|
string result = "";
|
||||||
|
for (auto& pair : functionRegistry) {
|
||||||
|
result += pair.first + ", ";
|
||||||
|
}
|
||||||
|
return "{" + result.substr(0, result.size() - 2) + "}";
|
||||||
|
}
|
||||||
|
|
||||||
|
Registrar::Registrar(const string& name, function<bayesnet::BaseClassifier* (void)> classFactoryFunction)
|
||||||
|
{
|
||||||
|
// register the class factory function
|
||||||
|
Models::instance()->registerFactoryFunction(name, classFactoryFunction);
|
||||||
|
}
|
||||||
}
|
}
|
@ -8,18 +8,25 @@
|
|||||||
#include "SPODE.h"
|
#include "SPODE.h"
|
||||||
namespace platform {
|
namespace platform {
|
||||||
class Models {
|
class Models {
|
||||||
|
private:
|
||||||
|
map<string, function<bayesnet::BaseClassifier* (void)>> functionRegistry;
|
||||||
|
static Models* factory; //singleton
|
||||||
|
Models() {};
|
||||||
public:
|
public:
|
||||||
|
Models(Models&) = delete;
|
||||||
|
void operator=(const Models&) = delete;
|
||||||
// Idea from: https://www.codeproject.com/Articles/567242/AplusC-2b-2bplusObjectplusFactory
|
// Idea from: https://www.codeproject.com/Articles/567242/AplusC-2b-2bplusObjectplusFactory
|
||||||
static shared_ptr<bayesnet::BaseClassifier> createInstance(const string& name);
|
static Models* instance();
|
||||||
static vector<string> getNames()
|
shared_ptr<bayesnet::BaseClassifier> create(const string& name);
|
||||||
{
|
void registerFactoryFunction(const string& name,
|
||||||
return { "aaaaaAODE", "KDB", "SPODE", "TAN" };
|
function<bayesnet::BaseClassifier* (void)> classFactoryFunction);
|
||||||
}
|
vector<string> getNames();
|
||||||
static string toString()
|
string toString();
|
||||||
{
|
|
||||||
return "{aaaaae34223AODE, KDB, SPODE, TAN}";
|
};
|
||||||
//return "{" + names.substr(0, names.size() - 2) + "}";
|
class Registrar {
|
||||||
}
|
public:
|
||||||
|
Registrar(const string& className, function<bayesnet::BaseClassifier* (void)> classFactoryFunction);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
@ -21,13 +21,13 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
|
|||||||
.default_value(string{ PATH_DATASETS }
|
.default_value(string{ PATH_DATASETS }
|
||||||
);
|
);
|
||||||
program.add_argument("-m", "--model")
|
program.add_argument("-m", "--model")
|
||||||
.help("Model to use " + platform::Models::toString())
|
.help("Model to use " + platform::Models::instance()->toString())
|
||||||
.action([](const std::string& value) {
|
.action([](const std::string& value) {
|
||||||
static const vector<string> choices = platform::Models::getNames();
|
static const vector<string> choices = platform::Models::instance()->getNames();
|
||||||
if (find(choices.begin(), choices.end(), value) != choices.end()) {
|
if (find(choices.begin(), choices.end(), value) != choices.end()) {
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
throw runtime_error("Model must be one of " + platform::Models::toString());
|
throw runtime_error("Model must be one of " + platform::Models::instance()->toString());
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
program.add_argument("--title").default_value("").help("Experiment title");
|
program.add_argument("--title").default_value("").help("Experiment title");
|
||||||
@ -76,9 +76,21 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
|
|||||||
}
|
}
|
||||||
return program;
|
return program;
|
||||||
}
|
}
|
||||||
|
void registerModels()
|
||||||
|
{
|
||||||
|
static platform::Registrar registrarT("TAN",
|
||||||
|
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::TAN();});
|
||||||
|
static platform::Registrar registrarS("SPODE",
|
||||||
|
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::SPODE(2);});
|
||||||
|
static platform::Registrar registrarK("KDB",
|
||||||
|
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::KDB(2);});
|
||||||
|
static platform::Registrar registrarA("AODE",
|
||||||
|
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::AODE();});
|
||||||
|
}
|
||||||
|
|
||||||
int main(int argc, char** argv)
|
int main(int argc, char** argv)
|
||||||
{
|
{
|
||||||
|
registerModels();
|
||||||
auto program = manageArguments(argc, argv);
|
auto program = manageArguments(argc, argv);
|
||||||
bool saveResults = false;
|
bool saveResults = false;
|
||||||
auto file_name = program.get<string>("dataset");
|
auto file_name = program.get<string>("dataset");
|
||||||
|
Loading…
Reference in New Issue
Block a user