Refactor experiment crossvalidation
This commit is contained in:
parent
cb54f61a69
commit
7222119dfb
@ -3,5 +3,5 @@ include_directories(${BayesNet_SOURCE_DIR}/src/BayesNet)
|
||||
include_directories(${BayesNet_SOURCE_DIR}/lib/Files)
|
||||
include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp)
|
||||
include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include)
|
||||
add_executable(BayesNetSample sample.cc ${BayesNet_SOURCE_DIR}/src/Platform/Folding.cc)
|
||||
add_executable(BayesNetSample sample.cc ${BayesNet_SOURCE_DIR}/src/Platform/Folding.cc ${BayesNet_SOURCE_DIR}/src/Platform/Models.cc)
|
||||
target_link_libraries(BayesNetSample BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}")
|
@ -4,16 +4,12 @@
|
||||
#include <thread>
|
||||
#include <map>
|
||||
#include <argparse/argparse.hpp>
|
||||
#include "BaseClassifier.h"
|
||||
#include "ArffFiles.h"
|
||||
#include "Network.h"
|
||||
#include "BayesMetrics.h"
|
||||
#include "CPPFImdlp.h"
|
||||
#include "KDB.h"
|
||||
#include "SPODE.h"
|
||||
#include "AODE.h"
|
||||
#include "TAN.h"
|
||||
#include "Folding.h"
|
||||
#include "Models.h"
|
||||
#include "modelRegister.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
@ -73,9 +69,8 @@ int main(int argc, char** argv)
|
||||
{"mfeat-factors", true},
|
||||
};
|
||||
auto valid_datasets = vector<string>();
|
||||
for (auto dataset : datasets) {
|
||||
valid_datasets.push_back(dataset.first);
|
||||
}
|
||||
transform(datasets.begin(), datasets.end(), back_inserter(valid_datasets),
|
||||
[](const pair<string, bool>& pair) { return pair.first; });
|
||||
argparse::ArgumentParser program("BayesNetSample");
|
||||
program.add_argument("-d", "--dataset")
|
||||
.help("Dataset file name")
|
||||
@ -91,13 +86,13 @@ int main(int argc, char** argv)
|
||||
.default_value(string{ PATH }
|
||||
);
|
||||
program.add_argument("-m", "--model")
|
||||
.help("Model to use {AODE, KDB, SPODE, TAN}")
|
||||
.help("Model to use " + platform::Models::instance()->toString())
|
||||
.action([](const std::string& value) {
|
||||
static const vector<string> choices = { "AODE", "KDB", "SPODE", "TAN" };
|
||||
static const vector<string> choices = platform::Models::instance()->getNames();
|
||||
if (find(choices.begin(), choices.end(), value) != choices.end()) {
|
||||
return value;
|
||||
}
|
||||
throw runtime_error("Model must be one of {AODE, KDB, SPODE, TAN}");
|
||||
throw runtime_error("Model must be one of " + platform::Models::instance()->toString());
|
||||
}
|
||||
);
|
||||
program.add_argument("--discretize").help("Discretize input dataset").default_value(false).implicit_value(true);
|
||||
|
@ -85,11 +85,25 @@ namespace platform {
|
||||
file << data;
|
||||
file.close();
|
||||
}
|
||||
|
||||
void Experiment::show()
|
||||
{
|
||||
json data = build_json();
|
||||
cout << data.dump(4) << endl;
|
||||
}
|
||||
|
||||
void Experiment::go(vector<string> filesToProcess, const string& path)
|
||||
{
|
||||
cout << "*** Starting experiment: " << title << " ***" << endl;
|
||||
for (auto fileName : filesToProcess) {
|
||||
cout << "- " << setw(20) << left << fileName << " " << right << flush;
|
||||
auto result = cross_validation(path, fileName);
|
||||
result.setDataset(fileName);
|
||||
addResult(result);
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
|
||||
Result Experiment::cross_validation(const string& path, const string& fileName)
|
||||
{
|
||||
auto datasets = platform::Datasets(path, true, platform::ARFF);
|
||||
|
@ -106,6 +106,7 @@ namespace platform {
|
||||
string get_file_name();
|
||||
void save(string path);
|
||||
Result cross_validation(const string& path, const string& fileName);
|
||||
void go(vector<string> filesToProcess, const string& path);
|
||||
void show();
|
||||
};
|
||||
}
|
||||
|
@ -5,7 +5,7 @@
|
||||
#include "Datasets.h"
|
||||
#include "DotEnv.h"
|
||||
#include "Models.h"
|
||||
|
||||
#include "modelRegister.h"
|
||||
|
||||
using namespace std;
|
||||
const string PATH_RESULTS = "results";
|
||||
@ -78,19 +78,11 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
|
||||
}
|
||||
void registerModels()
|
||||
{
|
||||
static platform::Registrar registrarT("TAN",
|
||||
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::TAN();});
|
||||
static platform::Registrar registrarS("SPODE",
|
||||
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::SPODE(2);});
|
||||
static platform::Registrar registrarK("KDB",
|
||||
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::KDB(2);});
|
||||
static platform::Registrar registrarA("AODE",
|
||||
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::AODE();});
|
||||
|
||||
}
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
registerModels();
|
||||
auto program = manageArguments(argc, argv);
|
||||
bool saveResults = false;
|
||||
auto file_name = program.get<string>("dataset");
|
||||
@ -128,15 +120,8 @@ int main(int argc, char** argv)
|
||||
experiment.addRandomSeed(seed);
|
||||
}
|
||||
platform::Timer timer;
|
||||
cout << "*** Starting experiment: " << title << " ***" << endl;
|
||||
timer.start();
|
||||
for (auto fileName : filesToProcess) {
|
||||
cout << "- " << setw(20) << left << fileName << " " << right << flush;
|
||||
auto result = experiment.cross_validation(path, fileName);
|
||||
result.setDataset(fileName);
|
||||
experiment.addResult(result);
|
||||
cout << endl;
|
||||
}
|
||||
experiment.go(filesToProcess, path);
|
||||
experiment.setDuration(timer.getDuration());
|
||||
if (saveResults)
|
||||
experiment.save(PATH_RESULTS);
|
||||
|
11
src/Platform/modelRegister.h
Normal file
11
src/Platform/modelRegister.h
Normal file
@ -0,0 +1,11 @@
|
||||
#ifndef MODEL_REGISTER_H
|
||||
#define MODEL_REGISTER_H
|
||||
static platform::Registrar registrarT("TAN",
|
||||
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::TAN();});
|
||||
static platform::Registrar registrarS("SPODE",
|
||||
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::SPODE(2);});
|
||||
static platform::Registrar registrarK("KDB",
|
||||
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::KDB(2);});
|
||||
static platform::Registrar registrarA("AODE",
|
||||
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::AODE();});
|
||||
#endif
|
Loading…
Reference in New Issue
Block a user