Refactor experiment crossvalidation
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
#include "Datasets.h"
|
||||
#include "DotEnv.h"
|
||||
#include "Models.h"
|
||||
|
||||
#include "modelRegister.h"
|
||||
|
||||
using namespace std;
|
||||
const string PATH_RESULTS = "results";
|
||||
@@ -78,19 +78,11 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
|
||||
}
|
||||
void registerModels()
|
||||
{
|
||||
static platform::Registrar registrarT("TAN",
|
||||
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::TAN();});
|
||||
static platform::Registrar registrarS("SPODE",
|
||||
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::SPODE(2);});
|
||||
static platform::Registrar registrarK("KDB",
|
||||
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::KDB(2);});
|
||||
static platform::Registrar registrarA("AODE",
|
||||
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::AODE();});
|
||||
|
||||
}
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
registerModels();
|
||||
auto program = manageArguments(argc, argv);
|
||||
bool saveResults = false;
|
||||
auto file_name = program.get<string>("dataset");
|
||||
@@ -128,15 +120,8 @@ int main(int argc, char** argv)
|
||||
experiment.addRandomSeed(seed);
|
||||
}
|
||||
platform::Timer timer;
|
||||
cout << "*** Starting experiment: " << title << " ***" << endl;
|
||||
timer.start();
|
||||
for (auto fileName : filesToProcess) {
|
||||
cout << "- " << setw(20) << left << fileName << " " << right << flush;
|
||||
auto result = experiment.cross_validation(path, fileName);
|
||||
result.setDataset(fileName);
|
||||
experiment.addResult(result);
|
||||
cout << endl;
|
||||
}
|
||||
experiment.go(filesToProcess, path);
|
||||
experiment.setDuration(timer.getDuration());
|
||||
if (saveResults)
|
||||
experiment.save(PATH_RESULTS);
|
||||
|
Reference in New Issue
Block a user