Refactor Models to be a singleton factory

Add Registrar of models
This commit is contained in:
2023-07-29 18:22:15 +02:00
parent 07d572a98c
commit cb54f61a69
5 changed files with 95 additions and 35 deletions

View File

@@ -21,13 +21,13 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
.default_value(string{ PATH_DATASETS }
);
program.add_argument("-m", "--model")
.help("Model to use " + platform::Models::toString())
.help("Model to use " + platform::Models::instance()->toString())
.action([](const std::string& value) {
static const vector<string> choices = platform::Models::getNames();
static const vector<string> choices = platform::Models::instance()->getNames();
if (find(choices.begin(), choices.end(), value) != choices.end()) {
return value;
}
throw runtime_error("Model must be one of " + platform::Models::toString());
throw runtime_error("Model must be one of " + platform::Models::instance()->toString());
}
);
program.add_argument("--title").default_value("").help("Experiment title");
@@ -76,9 +76,21 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
}
return program;
}
void registerModels()
{
static platform::Registrar registrarT("TAN",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::TAN();});
static platform::Registrar registrarS("SPODE",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::SPODE(2);});
static platform::Registrar registrarK("KDB",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::KDB(2);});
static platform::Registrar registrarA("AODE",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::AODE();});
}
int main(int argc, char** argv)
{
registerModels();
auto program = manageArguments(argc, argv);
bool saveResults = false;
auto file_name = program.get<string>("dataset");