Add Model factory

This commit is contained in:
2023-07-29 17:27:43 +02:00
parent c4f3e6f19a
commit 07d572a98c
7 changed files with 39 additions and 23 deletions

View File

@@ -1,5 +1,6 @@
#include "Experiment.h"
#include "Datasets.h"
#include "Models.h"
namespace platform {
using json = nlohmann::json;
@@ -91,12 +92,12 @@ namespace platform {
}
Result Experiment::cross_validation(const string& path, const string& fileName)
{
auto datasets = platform::Datasets(path, true, platform::ARFF);
auto classifiers = map<string, bayesnet::BaseClassifier*>({
{ "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) },
{ "SPODE", new bayesnet::SPODE(2) }, { "TAN", new bayesnet::TAN() }
{ "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) },
{ "SPODE", new bayesnet::SPODE(2) }, { "TAN", new bayesnet::TAN() }
}
);
auto datasets = platform::Datasets(path, true, platform::ARFF);
// Get dataset
auto [X, y] = datasets.getTensors(fileName);
auto states = datasets.getStates(fileName);
@@ -119,15 +120,14 @@ namespace platform {
Timer train_timer, test_timer;
int item = 0;
for (auto seed : randomSeeds) {
cout << "(" << seed << ") " << flush;
cout << "(" << seed << ") doing Fold: " << flush;
Fold* fold;
if (stratified)
fold = new StratifiedKFold(nfolds, y, seed);
else
fold = new KFold(nfolds, y.size(0), seed);
cout << "doing Fold: " << flush;
for (int nfold = 0; nfold < nfolds; nfold++) {
bayesnet::BaseClassifier* clf = classifiers[model];
auto clf = Models::createInstance(model);
setModelVersion(clf->getVersion());
train_timer.start();
auto [train, test] = fold->getFold(nfold);