Add Model factory
This commit is contained in:
parent
c4f3e6f19a
commit
07d572a98c
@ -8,6 +8,7 @@ namespace bayesnet {
|
||||
void train() override;
|
||||
public:
|
||||
AODE();
|
||||
virtual ~AODE() {};
|
||||
vector<string> graph(string title = "AODE") override;
|
||||
};
|
||||
}
|
||||
|
@ -14,6 +14,7 @@ namespace bayesnet {
|
||||
void train() override;
|
||||
public:
|
||||
KDB(int k, float theta = 0.03);
|
||||
virtual ~KDB() {};
|
||||
vector<string> graph(string name = "KDB") override;
|
||||
};
|
||||
}
|
||||
|
@ -9,6 +9,7 @@ namespace bayesnet {
|
||||
void train() override;
|
||||
public:
|
||||
SPODE(int root);
|
||||
virtual ~SPODE() {};
|
||||
vector<string> graph(string name = "SPODE") override;
|
||||
};
|
||||
}
|
||||
|
@ -10,6 +10,7 @@ namespace bayesnet {
|
||||
void train() override;
|
||||
public:
|
||||
TAN();
|
||||
virtual ~TAN() {};
|
||||
vector<string> graph(string name = "TAN") override;
|
||||
};
|
||||
}
|
||||
|
@ -1,5 +1,6 @@
|
||||
#include "Experiment.h"
|
||||
#include "Datasets.h"
|
||||
#include "Models.h"
|
||||
|
||||
namespace platform {
|
||||
using json = nlohmann::json;
|
||||
@ -91,12 +92,12 @@ namespace platform {
|
||||
}
|
||||
Result Experiment::cross_validation(const string& path, const string& fileName)
|
||||
{
|
||||
auto datasets = platform::Datasets(path, true, platform::ARFF);
|
||||
auto classifiers = map<string, bayesnet::BaseClassifier*>({
|
||||
{ "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) },
|
||||
{ "SPODE", new bayesnet::SPODE(2) }, { "TAN", new bayesnet::TAN() }
|
||||
{ "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) },
|
||||
{ "SPODE", new bayesnet::SPODE(2) }, { "TAN", new bayesnet::TAN() }
|
||||
}
|
||||
);
|
||||
auto datasets = platform::Datasets(path, true, platform::ARFF);
|
||||
// Get dataset
|
||||
auto [X, y] = datasets.getTensors(fileName);
|
||||
auto states = datasets.getStates(fileName);
|
||||
@ -119,15 +120,14 @@ namespace platform {
|
||||
Timer train_timer, test_timer;
|
||||
int item = 0;
|
||||
for (auto seed : randomSeeds) {
|
||||
cout << "(" << seed << ") " << flush;
|
||||
cout << "(" << seed << ") doing Fold: " << flush;
|
||||
Fold* fold;
|
||||
if (stratified)
|
||||
fold = new StratifiedKFold(nfolds, y, seed);
|
||||
else
|
||||
fold = new KFold(nfolds, y.size(0), seed);
|
||||
cout << "doing Fold: " << flush;
|
||||
for (int nfold = 0; nfold < nfolds; nfold++) {
|
||||
bayesnet::BaseClassifier* clf = classifiers[model];
|
||||
auto clf = Models::createInstance(model);
|
||||
setModelVersion(clf->getVersion());
|
||||
train_timer.start();
|
||||
auto [train, test] = fold->getFold(nfold);
|
||||
|
@ -1,8 +1,28 @@
|
||||
#include "Models.h"
|
||||
namespace platform {
|
||||
using namespace std;
|
||||
map<string, bayesnet::BaseClassifier*> Models::classifiers = map<string, bayesnet::BaseClassifier*>({
|
||||
{ "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) },
|
||||
{ "SPODE", new bayesnet::SPODE(2) }, { "TAN", new bayesnet::TAN() }
|
||||
});
|
||||
// map<string, bayesnet::BaseClassifier*> Models::classifiers = map<string, bayesnet::BaseClassifier*>({
|
||||
// { "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) },
|
||||
// { "SPODE", new bayesnet::SPODE(2) }, { "TAN", new bayesnet::TAN() }
|
||||
// });
|
||||
// Idea from: https://www.codeproject.com/Articles/567242/AplusC-2b-2bplusObjectplusFactory
|
||||
shared_ptr<bayesnet::BaseClassifier> Models::createInstance(const string& name)
|
||||
{
|
||||
bayesnet::BaseClassifier* instance = nullptr;
|
||||
if (name == "AODE") {
|
||||
instance = new bayesnet::AODE();
|
||||
} else if (name == "KDB") {
|
||||
instance = new bayesnet::KDB(2);
|
||||
} else if (name == "SPODE") {
|
||||
instance = new bayesnet::SPODE(2);
|
||||
} else if (name == "TAN") {
|
||||
instance = new bayesnet::TAN();
|
||||
} else {
|
||||
throw runtime_error("Model " + name + " not found");
|
||||
}
|
||||
if (instance != nullptr)
|
||||
return shared_ptr<bayesnet::BaseClassifier>(instance);
|
||||
else
|
||||
return nullptr;
|
||||
}
|
||||
}
|
@ -8,25 +8,17 @@
|
||||
#include "SPODE.h"
|
||||
namespace platform {
|
||||
class Models {
|
||||
private:
|
||||
static map<string, bayesnet::BaseClassifier*> classifiers;
|
||||
public:
|
||||
static bayesnet::BaseClassifier* get(string name) { return classifiers[name]; }
|
||||
// Idea from: https://www.codeproject.com/Articles/567242/AplusC-2b-2bplusObjectplusFactory
|
||||
static shared_ptr<bayesnet::BaseClassifier> createInstance(const string& name);
|
||||
static vector<string> getNames()
|
||||
{
|
||||
vector<string> names;
|
||||
for (auto& [name, classifier] : classifiers) {
|
||||
names.push_back(name);
|
||||
}
|
||||
return names;
|
||||
return { "aaaaaAODE", "KDB", "SPODE", "TAN" };
|
||||
}
|
||||
static string toString()
|
||||
{
|
||||
string names = "";
|
||||
for (auto& [name, classifier] : classifiers) {
|
||||
names += name + ", ";
|
||||
}
|
||||
return "{" + names.substr(0, names.size() - 2) + "}";
|
||||
return "{aaaaae34223AODE, KDB, SPODE, TAN}";
|
||||
//return "{" + names.substr(0, names.size() - 2) + "}";
|
||||
}
|
||||
};
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user