Add Model factory

This commit is contained in:
Ricardo Montañana Gómez 2023-07-29 17:27:43 +02:00
parent c4f3e6f19a
commit 07d572a98c
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
7 changed files with 39 additions and 23 deletions

View File

@ -8,6 +8,7 @@ namespace bayesnet {
void train() override;
public:
AODE();
virtual ~AODE() {};
vector<string> graph(string title = "AODE") override;
};
}

View File

@ -14,6 +14,7 @@ namespace bayesnet {
void train() override;
public:
KDB(int k, float theta = 0.03);
virtual ~KDB() {};
vector<string> graph(string name = "KDB") override;
};
}

View File

@ -9,6 +9,7 @@ namespace bayesnet {
void train() override;
public:
SPODE(int root);
virtual ~SPODE() {};
vector<string> graph(string name = "SPODE") override;
};
}

View File

@ -10,6 +10,7 @@ namespace bayesnet {
void train() override;
public:
TAN();
virtual ~TAN() {};
vector<string> graph(string name = "TAN") override;
};
}

View File

@ -1,5 +1,6 @@
#include "Experiment.h"
#include "Datasets.h"
#include "Models.h"
namespace platform {
using json = nlohmann::json;
@ -91,12 +92,12 @@ namespace platform {
}
Result Experiment::cross_validation(const string& path, const string& fileName)
{
auto datasets = platform::Datasets(path, true, platform::ARFF);
auto classifiers = map<string, bayesnet::BaseClassifier*>({
{ "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) },
{ "SPODE", new bayesnet::SPODE(2) }, { "TAN", new bayesnet::TAN() }
{ "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) },
{ "SPODE", new bayesnet::SPODE(2) }, { "TAN", new bayesnet::TAN() }
}
);
auto datasets = platform::Datasets(path, true, platform::ARFF);
// Get dataset
auto [X, y] = datasets.getTensors(fileName);
auto states = datasets.getStates(fileName);
@ -119,15 +120,14 @@ namespace platform {
Timer train_timer, test_timer;
int item = 0;
for (auto seed : randomSeeds) {
cout << "(" << seed << ") " << flush;
cout << "(" << seed << ") doing Fold: " << flush;
Fold* fold;
if (stratified)
fold = new StratifiedKFold(nfolds, y, seed);
else
fold = new KFold(nfolds, y.size(0), seed);
cout << "doing Fold: " << flush;
for (int nfold = 0; nfold < nfolds; nfold++) {
bayesnet::BaseClassifier* clf = classifiers[model];
auto clf = Models::createInstance(model);
setModelVersion(clf->getVersion());
train_timer.start();
auto [train, test] = fold->getFold(nfold);

View File

@ -1,8 +1,28 @@
#include "Models.h"
namespace platform {
using namespace std;
map<string, bayesnet::BaseClassifier*> Models::classifiers = map<string, bayesnet::BaseClassifier*>({
{ "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) },
{ "SPODE", new bayesnet::SPODE(2) }, { "TAN", new bayesnet::TAN() }
});
// map<string, bayesnet::BaseClassifier*> Models::classifiers = map<string, bayesnet::BaseClassifier*>({
// { "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) },
// { "SPODE", new bayesnet::SPODE(2) }, { "TAN", new bayesnet::TAN() }
// });
// Idea from: https://www.codeproject.com/Articles/567242/AplusC-2b-2bplusObjectplusFactory
shared_ptr<bayesnet::BaseClassifier> Models::createInstance(const string& name)
{
bayesnet::BaseClassifier* instance = nullptr;
if (name == "AODE") {
instance = new bayesnet::AODE();
} else if (name == "KDB") {
instance = new bayesnet::KDB(2);
} else if (name == "SPODE") {
instance = new bayesnet::SPODE(2);
} else if (name == "TAN") {
instance = new bayesnet::TAN();
} else {
throw runtime_error("Model " + name + " not found");
}
if (instance != nullptr)
return shared_ptr<bayesnet::BaseClassifier>(instance);
else
return nullptr;
}
}

View File

@ -8,25 +8,17 @@
#include "SPODE.h"
namespace platform {
class Models {
private:
static map<string, bayesnet::BaseClassifier*> classifiers;
public:
static bayesnet::BaseClassifier* get(string name) { return classifiers[name]; }
// Idea from: https://www.codeproject.com/Articles/567242/AplusC-2b-2bplusObjectplusFactory
static shared_ptr<bayesnet::BaseClassifier> createInstance(const string& name);
static vector<string> getNames()
{
vector<string> names;
for (auto& [name, classifier] : classifiers) {
names.push_back(name);
}
return names;
return { "aaaaaAODE", "KDB", "SPODE", "TAN" };
}
static string toString()
{
string names = "";
for (auto& [name, classifier] : classifiers) {
names += name + ", ";
}
return "{" + names.substr(0, names.size() - 2) + "}";
return "{aaaaae34223AODE, KDB, SPODE, TAN}";
//return "{" + names.substr(0, names.size() - 2) + "}";
}
};
}