Refactor Models to be a singleton factory

Add Registrar of models
This commit is contained in:
2023-07-29 18:22:15 +02:00
parent 07d572a98c
commit cb54f61a69
5 changed files with 95 additions and 35 deletions

View File

@@ -1,28 +1,73 @@
#include "Models.h"
namespace platform {
using namespace std;
// map<string, bayesnet::BaseClassifier*> Models::classifiers = map<string, bayesnet::BaseClassifier*>({
// { "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) },
// { "SPODE", new bayesnet::SPODE(2) }, { "TAN", new bayesnet::TAN() }
// });
// Idea from: https://www.codeproject.com/Articles/567242/AplusC-2b-2bplusObjectplusFactory
shared_ptr<bayesnet::BaseClassifier> Models::createInstance(const string& name)
// shared_ptr<bayesnet::BaseClassifier> Models::createInstance(const string& name)
// {
// bayesnet::BaseClassifier* instance = nullptr;
// if (name == "AODE") {
// instance = new bayesnet::AODE();
// } else if (name == "KDB") {
// instance = new bayesnet::KDB(2);
// } else if (name == "SPODE") {
// instance = new bayesnet::SPODE(2);
// } else if (name == "TAN") {
// instance = new bayesnet::TAN();
// } else {
// throw runtime_error("Model " + name + " not found");
// }
// if (instance != nullptr)
// return shared_ptr<bayesnet::BaseClassifier>(instance);
// else
// return nullptr;
// }
Models* Models::factory = nullptr;;
Models* Models::instance()
{
//manages singleton
if (factory == nullptr)
factory = new Models();
return factory;
}
void Models::registerFactoryFunction(const string& name,
function<bayesnet::BaseClassifier* (void)> classFactoryFunction)
{
// register the class factory function
functionRegistry[name] = classFactoryFunction;
}
shared_ptr<bayesnet::BaseClassifier> Models::create(const string& name)
{
bayesnet::BaseClassifier* instance = nullptr;
if (name == "AODE") {
instance = new bayesnet::AODE();
} else if (name == "KDB") {
instance = new bayesnet::KDB(2);
} else if (name == "SPODE") {
instance = new bayesnet::SPODE(2);
} else if (name == "TAN") {
instance = new bayesnet::TAN();
} else {
throw runtime_error("Model " + name + " not found");
}
// find name in the registry and call factory method.
auto it = functionRegistry.find(name);
if (it != functionRegistry.end())
instance = it->second();
// wrap instance in a shared ptr and return
if (instance != nullptr)
return shared_ptr<bayesnet::BaseClassifier>(instance);
else
return nullptr;
}
vector<string> Models::getNames()
{
vector<string> names;
transform(functionRegistry.begin(), functionRegistry.end(), back_inserter(names),
[](const pair<string, function<bayesnet::BaseClassifier* (void)>>& pair) { return pair.first; });
return names;
}
string Models::toString()
{
string result = "";
for (auto& pair : functionRegistry) {
result += pair.first + ", ";
}
return "{" + result.substr(0, result.size() - 2) + "}";
}
Registrar::Registrar(const string& name, function<bayesnet::BaseClassifier* (void)> classFactoryFunction)
{
// register the class factory function
Models::instance()->registerFactoryFunction(name, classFactoryFunction);
}
}