Add Model factory
This commit is contained in:
parent
c4f3e6f19a
commit
07d572a98c
@ -8,6 +8,7 @@ namespace bayesnet {
|
|||||||
void train() override;
|
void train() override;
|
||||||
public:
|
public:
|
||||||
AODE();
|
AODE();
|
||||||
|
virtual ~AODE() {};
|
||||||
vector<string> graph(string title = "AODE") override;
|
vector<string> graph(string title = "AODE") override;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -14,6 +14,7 @@ namespace bayesnet {
|
|||||||
void train() override;
|
void train() override;
|
||||||
public:
|
public:
|
||||||
KDB(int k, float theta = 0.03);
|
KDB(int k, float theta = 0.03);
|
||||||
|
virtual ~KDB() {};
|
||||||
vector<string> graph(string name = "KDB") override;
|
vector<string> graph(string name = "KDB") override;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -9,6 +9,7 @@ namespace bayesnet {
|
|||||||
void train() override;
|
void train() override;
|
||||||
public:
|
public:
|
||||||
SPODE(int root);
|
SPODE(int root);
|
||||||
|
virtual ~SPODE() {};
|
||||||
vector<string> graph(string name = "SPODE") override;
|
vector<string> graph(string name = "SPODE") override;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -10,6 +10,7 @@ namespace bayesnet {
|
|||||||
void train() override;
|
void train() override;
|
||||||
public:
|
public:
|
||||||
TAN();
|
TAN();
|
||||||
|
virtual ~TAN() {};
|
||||||
vector<string> graph(string name = "TAN") override;
|
vector<string> graph(string name = "TAN") override;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
#include "Experiment.h"
|
#include "Experiment.h"
|
||||||
#include "Datasets.h"
|
#include "Datasets.h"
|
||||||
|
#include "Models.h"
|
||||||
|
|
||||||
namespace platform {
|
namespace platform {
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::json;
|
||||||
@ -91,12 +92,12 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
Result Experiment::cross_validation(const string& path, const string& fileName)
|
Result Experiment::cross_validation(const string& path, const string& fileName)
|
||||||
{
|
{
|
||||||
|
auto datasets = platform::Datasets(path, true, platform::ARFF);
|
||||||
auto classifiers = map<string, bayesnet::BaseClassifier*>({
|
auto classifiers = map<string, bayesnet::BaseClassifier*>({
|
||||||
{ "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) },
|
{ "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) },
|
||||||
{ "SPODE", new bayesnet::SPODE(2) }, { "TAN", new bayesnet::TAN() }
|
{ "SPODE", new bayesnet::SPODE(2) }, { "TAN", new bayesnet::TAN() }
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
auto datasets = platform::Datasets(path, true, platform::ARFF);
|
|
||||||
// Get dataset
|
// Get dataset
|
||||||
auto [X, y] = datasets.getTensors(fileName);
|
auto [X, y] = datasets.getTensors(fileName);
|
||||||
auto states = datasets.getStates(fileName);
|
auto states = datasets.getStates(fileName);
|
||||||
@ -119,15 +120,14 @@ namespace platform {
|
|||||||
Timer train_timer, test_timer;
|
Timer train_timer, test_timer;
|
||||||
int item = 0;
|
int item = 0;
|
||||||
for (auto seed : randomSeeds) {
|
for (auto seed : randomSeeds) {
|
||||||
cout << "(" << seed << ") " << flush;
|
cout << "(" << seed << ") doing Fold: " << flush;
|
||||||
Fold* fold;
|
Fold* fold;
|
||||||
if (stratified)
|
if (stratified)
|
||||||
fold = new StratifiedKFold(nfolds, y, seed);
|
fold = new StratifiedKFold(nfolds, y, seed);
|
||||||
else
|
else
|
||||||
fold = new KFold(nfolds, y.size(0), seed);
|
fold = new KFold(nfolds, y.size(0), seed);
|
||||||
cout << "doing Fold: " << flush;
|
|
||||||
for (int nfold = 0; nfold < nfolds; nfold++) {
|
for (int nfold = 0; nfold < nfolds; nfold++) {
|
||||||
bayesnet::BaseClassifier* clf = classifiers[model];
|
auto clf = Models::createInstance(model);
|
||||||
setModelVersion(clf->getVersion());
|
setModelVersion(clf->getVersion());
|
||||||
train_timer.start();
|
train_timer.start();
|
||||||
auto [train, test] = fold->getFold(nfold);
|
auto [train, test] = fold->getFold(nfold);
|
||||||
|
@ -1,8 +1,28 @@
|
|||||||
#include "Models.h"
|
#include "Models.h"
|
||||||
namespace platform {
|
namespace platform {
|
||||||
using namespace std;
|
using namespace std;
|
||||||
map<string, bayesnet::BaseClassifier*> Models::classifiers = map<string, bayesnet::BaseClassifier*>({
|
// map<string, bayesnet::BaseClassifier*> Models::classifiers = map<string, bayesnet::BaseClassifier*>({
|
||||||
{ "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) },
|
// { "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) },
|
||||||
{ "SPODE", new bayesnet::SPODE(2) }, { "TAN", new bayesnet::TAN() }
|
// { "SPODE", new bayesnet::SPODE(2) }, { "TAN", new bayesnet::TAN() }
|
||||||
});
|
// });
|
||||||
|
// Idea from: https://www.codeproject.com/Articles/567242/AplusC-2b-2bplusObjectplusFactory
|
||||||
|
shared_ptr<bayesnet::BaseClassifier> Models::createInstance(const string& name)
|
||||||
|
{
|
||||||
|
bayesnet::BaseClassifier* instance = nullptr;
|
||||||
|
if (name == "AODE") {
|
||||||
|
instance = new bayesnet::AODE();
|
||||||
|
} else if (name == "KDB") {
|
||||||
|
instance = new bayesnet::KDB(2);
|
||||||
|
} else if (name == "SPODE") {
|
||||||
|
instance = new bayesnet::SPODE(2);
|
||||||
|
} else if (name == "TAN") {
|
||||||
|
instance = new bayesnet::TAN();
|
||||||
|
} else {
|
||||||
|
throw runtime_error("Model " + name + " not found");
|
||||||
|
}
|
||||||
|
if (instance != nullptr)
|
||||||
|
return shared_ptr<bayesnet::BaseClassifier>(instance);
|
||||||
|
else
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
}
|
}
|
@ -8,25 +8,17 @@
|
|||||||
#include "SPODE.h"
|
#include "SPODE.h"
|
||||||
namespace platform {
|
namespace platform {
|
||||||
class Models {
|
class Models {
|
||||||
private:
|
|
||||||
static map<string, bayesnet::BaseClassifier*> classifiers;
|
|
||||||
public:
|
public:
|
||||||
static bayesnet::BaseClassifier* get(string name) { return classifiers[name]; }
|
// Idea from: https://www.codeproject.com/Articles/567242/AplusC-2b-2bplusObjectplusFactory
|
||||||
|
static shared_ptr<bayesnet::BaseClassifier> createInstance(const string& name);
|
||||||
static vector<string> getNames()
|
static vector<string> getNames()
|
||||||
{
|
{
|
||||||
vector<string> names;
|
return { "aaaaaAODE", "KDB", "SPODE", "TAN" };
|
||||||
for (auto& [name, classifier] : classifiers) {
|
|
||||||
names.push_back(name);
|
|
||||||
}
|
|
||||||
return names;
|
|
||||||
}
|
}
|
||||||
static string toString()
|
static string toString()
|
||||||
{
|
{
|
||||||
string names = "";
|
return "{aaaaae34223AODE, KDB, SPODE, TAN}";
|
||||||
for (auto& [name, classifier] : classifiers) {
|
//return "{" + names.substr(0, names.size() - 2) + "}";
|
||||||
names += name + ", ";
|
|
||||||
}
|
|
||||||
return "{" + names.substr(0, names.size() - 2) + "}";
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user