diff --git a/sample/CMakeLists.txt b/sample/CMakeLists.txt index 4f9d087..000a88b 100644 --- a/sample/CMakeLists.txt +++ b/sample/CMakeLists.txt @@ -3,5 +3,5 @@ include_directories(${BayesNet_SOURCE_DIR}/src/BayesNet) include_directories(${BayesNet_SOURCE_DIR}/lib/Files) include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp) include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include) -add_executable(BayesNetSample sample.cc ${BayesNet_SOURCE_DIR}/src/Platform/Folding.cc) +add_executable(BayesNetSample sample.cc ${BayesNet_SOURCE_DIR}/src/Platform/Folding.cc ${BayesNet_SOURCE_DIR}/src/Platform/Models.cc) target_link_libraries(BayesNetSample BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}") \ No newline at end of file diff --git a/sample/sample.cc b/sample/sample.cc index f515405..502dcfc 100644 --- a/sample/sample.cc +++ b/sample/sample.cc @@ -4,16 +4,12 @@ #include #include #include -#include "BaseClassifier.h" #include "ArffFiles.h" -#include "Network.h" #include "BayesMetrics.h" #include "CPPFImdlp.h" -#include "KDB.h" -#include "SPODE.h" -#include "AODE.h" -#include "TAN.h" #include "Folding.h" +#include "Models.h" +#include "modelRegister.h" using namespace std; @@ -73,9 +69,8 @@ int main(int argc, char** argv) {"mfeat-factors", true}, }; auto valid_datasets = vector(); - for (auto dataset : datasets) { - valid_datasets.push_back(dataset.first); - } + transform(datasets.begin(), datasets.end(), back_inserter(valid_datasets), + [](const pair& pair) { return pair.first; }); argparse::ArgumentParser program("BayesNetSample"); program.add_argument("-d", "--dataset") .help("Dataset file name") @@ -91,13 +86,13 @@ int main(int argc, char** argv) .default_value(string{ PATH } ); program.add_argument("-m", "--model") - .help("Model to use {AODE, KDB, SPODE, TAN}") + .help("Model to use " + platform::Models::instance()->toString()) .action([](const std::string& value) { - static const vector choices = { "AODE", "KDB", "SPODE", "TAN" }; + static const vector choices = platform::Models::instance()->getNames(); if (find(choices.begin(), choices.end(), value) != choices.end()) { return value; } - throw runtime_error("Model must be one of {AODE, KDB, SPODE, TAN}"); + throw runtime_error("Model must be one of " + platform::Models::instance()->toString()); } ); program.add_argument("--discretize").help("Discretize input dataset").default_value(false).implicit_value(true); diff --git a/src/Platform/Experiment.cc b/src/Platform/Experiment.cc index ab62ab2..71e8cf4 100644 --- a/src/Platform/Experiment.cc +++ b/src/Platform/Experiment.cc @@ -85,11 +85,25 @@ namespace platform { file << data; file.close(); } + void Experiment::show() { json data = build_json(); cout << data.dump(4) << endl; } + + void Experiment::go(vector filesToProcess, const string& path) + { + cout << "*** Starting experiment: " << title << " ***" << endl; + for (auto fileName : filesToProcess) { + cout << "- " << setw(20) << left << fileName << " " << right << flush; + auto result = cross_validation(path, fileName); + result.setDataset(fileName); + addResult(result); + cout << endl; + } + } + Result Experiment::cross_validation(const string& path, const string& fileName) { auto datasets = platform::Datasets(path, true, platform::ARFF); diff --git a/src/Platform/Experiment.h b/src/Platform/Experiment.h index 84b1627..951ac4a 100644 --- a/src/Platform/Experiment.h +++ b/src/Platform/Experiment.h @@ -106,6 +106,7 @@ namespace platform { string get_file_name(); void save(string path); Result cross_validation(const string& path, const string& fileName); + void go(vector filesToProcess, const string& path); void show(); }; } diff --git a/src/Platform/main.cc b/src/Platform/main.cc index 3a4b238..29c8505 100644 --- a/src/Platform/main.cc +++ b/src/Platform/main.cc @@ -5,7 +5,7 @@ #include "Datasets.h" #include "DotEnv.h" #include "Models.h" - +#include "modelRegister.h" using namespace std; const string PATH_RESULTS = "results"; @@ -78,19 +78,11 @@ argparse::ArgumentParser manageArguments(int argc, char** argv) } void registerModels() { - static platform::Registrar registrarT("TAN", - [](void) -> bayesnet::BaseClassifier* { return new bayesnet::TAN();}); - static platform::Registrar registrarS("SPODE", - [](void) -> bayesnet::BaseClassifier* { return new bayesnet::SPODE(2);}); - static platform::Registrar registrarK("KDB", - [](void) -> bayesnet::BaseClassifier* { return new bayesnet::KDB(2);}); - static platform::Registrar registrarA("AODE", - [](void) -> bayesnet::BaseClassifier* { return new bayesnet::AODE();}); + } int main(int argc, char** argv) { - registerModels(); auto program = manageArguments(argc, argv); bool saveResults = false; auto file_name = program.get("dataset"); @@ -128,15 +120,8 @@ int main(int argc, char** argv) experiment.addRandomSeed(seed); } platform::Timer timer; - cout << "*** Starting experiment: " << title << " ***" << endl; timer.start(); - for (auto fileName : filesToProcess) { - cout << "- " << setw(20) << left << fileName << " " << right << flush; - auto result = experiment.cross_validation(path, fileName); - result.setDataset(fileName); - experiment.addResult(result); - cout << endl; - } + experiment.go(filesToProcess, path); experiment.setDuration(timer.getDuration()); if (saveResults) experiment.save(PATH_RESULTS); diff --git a/src/Platform/modelRegister.h b/src/Platform/modelRegister.h new file mode 100644 index 0000000..a4188bc --- /dev/null +++ b/src/Platform/modelRegister.h @@ -0,0 +1,11 @@ +#ifndef MODEL_REGISTER_H +#define MODEL_REGISTER_H +static platform::Registrar registrarT("TAN", + [](void) -> bayesnet::BaseClassifier* { return new bayesnet::TAN();}); +static platform::Registrar registrarS("SPODE", + [](void) -> bayesnet::BaseClassifier* { return new bayesnet::SPODE(2);}); +static platform::Registrar registrarK("KDB", + [](void) -> bayesnet::BaseClassifier* { return new bayesnet::KDB(2);}); +static platform::Registrar registrarA("AODE", + [](void) -> bayesnet::BaseClassifier* { return new bayesnet::AODE();}); +#endif \ No newline at end of file