Refactor experiment crossvalidation

This commit is contained in:
2023-07-29 19:00:39 +02:00
parent cb54f61a69
commit 7222119dfb
6 changed files with 37 additions and 31 deletions

View File

@@ -5,7 +5,7 @@
#include "Datasets.h"
#include "DotEnv.h"
#include "Models.h"
#include "modelRegister.h"
using namespace std;
const string PATH_RESULTS = "results";
@@ -78,19 +78,11 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
}
void registerModels()
{
static platform::Registrar registrarT("TAN",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::TAN();});
static platform::Registrar registrarS("SPODE",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::SPODE(2);});
static platform::Registrar registrarK("KDB",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::KDB(2);});
static platform::Registrar registrarA("AODE",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::AODE();});
}
int main(int argc, char** argv)
{
registerModels();
auto program = manageArguments(argc, argv);
bool saveResults = false;
auto file_name = program.get<string>("dataset");
@@ -128,15 +120,8 @@ int main(int argc, char** argv)
experiment.addRandomSeed(seed);
}
platform::Timer timer;
cout << "*** Starting experiment: " << title << " ***" << endl;
timer.start();
for (auto fileName : filesToProcess) {
cout << "- " << setw(20) << left << fileName << " " << right << flush;
auto result = experiment.cross_validation(path, fileName);
result.setDataset(fileName);
experiment.addResult(result);
cout << endl;
}
experiment.go(filesToProcess, path);
experiment.setDuration(timer.getDuration());
if (saveResults)
experiment.save(PATH_RESULTS);