Refactor experiment crossvalidation

This commit is contained in:
Ricardo Montañana Gómez 2023-07-29 19:00:39 +02:00
parent cb54f61a69
commit 7222119dfb
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
6 changed files with 37 additions and 31 deletions

View File

@ -3,5 +3,5 @@ include_directories(${BayesNet_SOURCE_DIR}/src/BayesNet)
include_directories(${BayesNet_SOURCE_DIR}/lib/Files)
include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp)
include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include)
add_executable(BayesNetSample sample.cc ${BayesNet_SOURCE_DIR}/src/Platform/Folding.cc)
add_executable(BayesNetSample sample.cc ${BayesNet_SOURCE_DIR}/src/Platform/Folding.cc ${BayesNet_SOURCE_DIR}/src/Platform/Models.cc)
target_link_libraries(BayesNetSample BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}")

View File

@ -4,16 +4,12 @@
#include <thread>
#include <map>
#include <argparse/argparse.hpp>
#include "BaseClassifier.h"
#include "ArffFiles.h"
#include "Network.h"
#include "BayesMetrics.h"
#include "CPPFImdlp.h"
#include "KDB.h"
#include "SPODE.h"
#include "AODE.h"
#include "TAN.h"
#include "Folding.h"
#include "Models.h"
#include "modelRegister.h"
using namespace std;
@ -73,9 +69,8 @@ int main(int argc, char** argv)
{"mfeat-factors", true},
};
auto valid_datasets = vector<string>();
for (auto dataset : datasets) {
valid_datasets.push_back(dataset.first);
}
transform(datasets.begin(), datasets.end(), back_inserter(valid_datasets),
[](const pair<string, bool>& pair) { return pair.first; });
argparse::ArgumentParser program("BayesNetSample");
program.add_argument("-d", "--dataset")
.help("Dataset file name")
@ -91,13 +86,13 @@ int main(int argc, char** argv)
.default_value(string{ PATH }
);
program.add_argument("-m", "--model")
.help("Model to use {AODE, KDB, SPODE, TAN}")
.help("Model to use " + platform::Models::instance()->toString())
.action([](const std::string& value) {
static const vector<string> choices = { "AODE", "KDB", "SPODE", "TAN" };
static const vector<string> choices = platform::Models::instance()->getNames();
if (find(choices.begin(), choices.end(), value) != choices.end()) {
return value;
}
throw runtime_error("Model must be one of {AODE, KDB, SPODE, TAN}");
throw runtime_error("Model must be one of " + platform::Models::instance()->toString());
}
);
program.add_argument("--discretize").help("Discretize input dataset").default_value(false).implicit_value(true);

View File

@ -85,11 +85,25 @@ namespace platform {
file << data;
file.close();
}
void Experiment::show()
{
json data = build_json();
cout << data.dump(4) << endl;
}
void Experiment::go(vector<string> filesToProcess, const string& path)
{
cout << "*** Starting experiment: " << title << " ***" << endl;
for (auto fileName : filesToProcess) {
cout << "- " << setw(20) << left << fileName << " " << right << flush;
auto result = cross_validation(path, fileName);
result.setDataset(fileName);
addResult(result);
cout << endl;
}
}
Result Experiment::cross_validation(const string& path, const string& fileName)
{
auto datasets = platform::Datasets(path, true, platform::ARFF);

View File

@ -106,6 +106,7 @@ namespace platform {
string get_file_name();
void save(string path);
Result cross_validation(const string& path, const string& fileName);
void go(vector<string> filesToProcess, const string& path);
void show();
};
}

View File

@ -5,7 +5,7 @@
#include "Datasets.h"
#include "DotEnv.h"
#include "Models.h"
#include "modelRegister.h"
using namespace std;
const string PATH_RESULTS = "results";
@ -78,19 +78,11 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
}
void registerModels()
{
static platform::Registrar registrarT("TAN",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::TAN();});
static platform::Registrar registrarS("SPODE",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::SPODE(2);});
static platform::Registrar registrarK("KDB",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::KDB(2);});
static platform::Registrar registrarA("AODE",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::AODE();});
}
int main(int argc, char** argv)
{
registerModels();
auto program = manageArguments(argc, argv);
bool saveResults = false;
auto file_name = program.get<string>("dataset");
@ -128,15 +120,8 @@ int main(int argc, char** argv)
experiment.addRandomSeed(seed);
}
platform::Timer timer;
cout << "*** Starting experiment: " << title << " ***" << endl;
timer.start();
for (auto fileName : filesToProcess) {
cout << "- " << setw(20) << left << fileName << " " << right << flush;
auto result = experiment.cross_validation(path, fileName);
result.setDataset(fileName);
experiment.addResult(result);
cout << endl;
}
experiment.go(filesToProcess, path);
experiment.setDuration(timer.getDuration());
if (saveResults)
experiment.save(PATH_RESULTS);

View File

@ -0,0 +1,11 @@
#ifndef MODEL_REGISTER_H
#define MODEL_REGISTER_H
static platform::Registrar registrarT("TAN",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::TAN();});
static platform::Registrar registrarS("SPODE",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::SPODE(2);});
static platform::Registrar registrarK("KDB",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::KDB(2);});
static platform::Registrar registrarA("AODE",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::AODE();});
#endif