Refactor experiment crossvalidation

This commit is contained in:
Ricardo Montañana Gómez 2023-07-29 19:00:39 +02:00
parent cb54f61a69
commit 7222119dfb
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
6 changed files with 37 additions and 31 deletions

View File

@ -3,5 +3,5 @@ include_directories(${BayesNet_SOURCE_DIR}/src/BayesNet)
include_directories(${BayesNet_SOURCE_DIR}/lib/Files) include_directories(${BayesNet_SOURCE_DIR}/lib/Files)
include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp) include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp)
include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include) include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include)
add_executable(BayesNetSample sample.cc ${BayesNet_SOURCE_DIR}/src/Platform/Folding.cc) add_executable(BayesNetSample sample.cc ${BayesNet_SOURCE_DIR}/src/Platform/Folding.cc ${BayesNet_SOURCE_DIR}/src/Platform/Models.cc)
target_link_libraries(BayesNetSample BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}") target_link_libraries(BayesNetSample BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}")

View File

@ -4,16 +4,12 @@
#include <thread> #include <thread>
#include <map> #include <map>
#include <argparse/argparse.hpp> #include <argparse/argparse.hpp>
#include "BaseClassifier.h"
#include "ArffFiles.h" #include "ArffFiles.h"
#include "Network.h"
#include "BayesMetrics.h" #include "BayesMetrics.h"
#include "CPPFImdlp.h" #include "CPPFImdlp.h"
#include "KDB.h"
#include "SPODE.h"
#include "AODE.h"
#include "TAN.h"
#include "Folding.h" #include "Folding.h"
#include "Models.h"
#include "modelRegister.h"
using namespace std; using namespace std;
@ -73,9 +69,8 @@ int main(int argc, char** argv)
{"mfeat-factors", true}, {"mfeat-factors", true},
}; };
auto valid_datasets = vector<string>(); auto valid_datasets = vector<string>();
for (auto dataset : datasets) { transform(datasets.begin(), datasets.end(), back_inserter(valid_datasets),
valid_datasets.push_back(dataset.first); [](const pair<string, bool>& pair) { return pair.first; });
}
argparse::ArgumentParser program("BayesNetSample"); argparse::ArgumentParser program("BayesNetSample");
program.add_argument("-d", "--dataset") program.add_argument("-d", "--dataset")
.help("Dataset file name") .help("Dataset file name")
@ -91,13 +86,13 @@ int main(int argc, char** argv)
.default_value(string{ PATH } .default_value(string{ PATH }
); );
program.add_argument("-m", "--model") program.add_argument("-m", "--model")
.help("Model to use {AODE, KDB, SPODE, TAN}") .help("Model to use " + platform::Models::instance()->toString())
.action([](const std::string& value) { .action([](const std::string& value) {
static const vector<string> choices = { "AODE", "KDB", "SPODE", "TAN" }; static const vector<string> choices = platform::Models::instance()->getNames();
if (find(choices.begin(), choices.end(), value) != choices.end()) { if (find(choices.begin(), choices.end(), value) != choices.end()) {
return value; return value;
} }
throw runtime_error("Model must be one of {AODE, KDB, SPODE, TAN}"); throw runtime_error("Model must be one of " + platform::Models::instance()->toString());
} }
); );
program.add_argument("--discretize").help("Discretize input dataset").default_value(false).implicit_value(true); program.add_argument("--discretize").help("Discretize input dataset").default_value(false).implicit_value(true);

View File

@ -85,11 +85,25 @@ namespace platform {
file << data; file << data;
file.close(); file.close();
} }
void Experiment::show() void Experiment::show()
{ {
json data = build_json(); json data = build_json();
cout << data.dump(4) << endl; cout << data.dump(4) << endl;
} }
void Experiment::go(vector<string> filesToProcess, const string& path)
{
cout << "*** Starting experiment: " << title << " ***" << endl;
for (auto fileName : filesToProcess) {
cout << "- " << setw(20) << left << fileName << " " << right << flush;
auto result = cross_validation(path, fileName);
result.setDataset(fileName);
addResult(result);
cout << endl;
}
}
Result Experiment::cross_validation(const string& path, const string& fileName) Result Experiment::cross_validation(const string& path, const string& fileName)
{ {
auto datasets = platform::Datasets(path, true, platform::ARFF); auto datasets = platform::Datasets(path, true, platform::ARFF);

View File

@ -106,6 +106,7 @@ namespace platform {
string get_file_name(); string get_file_name();
void save(string path); void save(string path);
Result cross_validation(const string& path, const string& fileName); Result cross_validation(const string& path, const string& fileName);
void go(vector<string> filesToProcess, const string& path);
void show(); void show();
}; };
} }

View File

@ -5,7 +5,7 @@
#include "Datasets.h" #include "Datasets.h"
#include "DotEnv.h" #include "DotEnv.h"
#include "Models.h" #include "Models.h"
#include "modelRegister.h"
using namespace std; using namespace std;
const string PATH_RESULTS = "results"; const string PATH_RESULTS = "results";
@ -78,19 +78,11 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
} }
void registerModels() void registerModels()
{ {
static platform::Registrar registrarT("TAN",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::TAN();});
static platform::Registrar registrarS("SPODE",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::SPODE(2);});
static platform::Registrar registrarK("KDB",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::KDB(2);});
static platform::Registrar registrarA("AODE",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::AODE();});
} }
int main(int argc, char** argv) int main(int argc, char** argv)
{ {
registerModels();
auto program = manageArguments(argc, argv); auto program = manageArguments(argc, argv);
bool saveResults = false; bool saveResults = false;
auto file_name = program.get<string>("dataset"); auto file_name = program.get<string>("dataset");
@ -128,15 +120,8 @@ int main(int argc, char** argv)
experiment.addRandomSeed(seed); experiment.addRandomSeed(seed);
} }
platform::Timer timer; platform::Timer timer;
cout << "*** Starting experiment: " << title << " ***" << endl;
timer.start(); timer.start();
for (auto fileName : filesToProcess) { experiment.go(filesToProcess, path);
cout << "- " << setw(20) << left << fileName << " " << right << flush;
auto result = experiment.cross_validation(path, fileName);
result.setDataset(fileName);
experiment.addResult(result);
cout << endl;
}
experiment.setDuration(timer.getDuration()); experiment.setDuration(timer.getDuration());
if (saveResults) if (saveResults)
experiment.save(PATH_RESULTS); experiment.save(PATH_RESULTS);

View File

@ -0,0 +1,11 @@
#ifndef MODEL_REGISTER_H
#define MODEL_REGISTER_H
static platform::Registrar registrarT("TAN",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::TAN();});
static platform::Registrar registrarS("SPODE",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::SPODE(2);});
static platform::Registrar registrarK("KDB",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::KDB(2);});
static platform::Registrar registrarA("AODE",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::AODE();});
#endif