Refactor experiment crossvalidation

This commit is contained in:
2023-07-29 19:00:39 +02:00
parent cb54f61a69
commit 7222119dfb
6 changed files with 37 additions and 31 deletions

View File

@@ -3,5 +3,5 @@ include_directories(${BayesNet_SOURCE_DIR}/src/BayesNet)
include_directories(${BayesNet_SOURCE_DIR}/lib/Files)
include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp)
include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include)
add_executable(BayesNetSample sample.cc ${BayesNet_SOURCE_DIR}/src/Platform/Folding.cc)
add_executable(BayesNetSample sample.cc ${BayesNet_SOURCE_DIR}/src/Platform/Folding.cc ${BayesNet_SOURCE_DIR}/src/Platform/Models.cc)
target_link_libraries(BayesNetSample BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}")

View File

@@ -4,16 +4,12 @@
#include <thread>
#include <map>
#include <argparse/argparse.hpp>
#include "BaseClassifier.h"
#include "ArffFiles.h"
#include "Network.h"
#include "BayesMetrics.h"
#include "CPPFImdlp.h"
#include "KDB.h"
#include "SPODE.h"
#include "AODE.h"
#include "TAN.h"
#include "Folding.h"
#include "Models.h"
#include "modelRegister.h"
using namespace std;
@@ -73,9 +69,8 @@ int main(int argc, char** argv)
{"mfeat-factors", true},
};
auto valid_datasets = vector<string>();
for (auto dataset : datasets) {
valid_datasets.push_back(dataset.first);
}
transform(datasets.begin(), datasets.end(), back_inserter(valid_datasets),
[](const pair<string, bool>& pair) { return pair.first; });
argparse::ArgumentParser program("BayesNetSample");
program.add_argument("-d", "--dataset")
.help("Dataset file name")
@@ -91,13 +86,13 @@ int main(int argc, char** argv)
.default_value(string{ PATH }
);
program.add_argument("-m", "--model")
.help("Model to use {AODE, KDB, SPODE, TAN}")
.help("Model to use " + platform::Models::instance()->toString())
.action([](const std::string& value) {
static const vector<string> choices = { "AODE", "KDB", "SPODE", "TAN" };
static const vector<string> choices = platform::Models::instance()->getNames();
if (find(choices.begin(), choices.end(), value) != choices.end()) {
return value;
}
throw runtime_error("Model must be one of {AODE, KDB, SPODE, TAN}");
throw runtime_error("Model must be one of " + platform::Models::instance()->toString());
}
);
program.add_argument("--discretize").help("Discretize input dataset").default_value(false).implicit_value(true);