Complete Experiment

This commit is contained in:
2023-07-27 15:49:58 +02:00
parent bc214a496c
commit 3d8fea7a37
6 changed files with 80 additions and 67 deletions

View File

@@ -6,12 +6,12 @@
using namespace std;
const string PATH_RESULTS = "results";
argparse::ArgumentParser manageArguments(int argc, char** argv)
{
argparse::ArgumentParser program("BayesNetSample");
program.add_argument("-d", "--dataset")
.help("Dataset file name");
program.add_argument("-d", "--dataset").default_value("").help("Dataset file name");
program.add_argument("-p", "--path")
.help("folder where the data files are located, default")
.default_value(string{ PATH }
@@ -59,9 +59,6 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
complete_file_name = path + file_name + ".arff";
class_last = false;//datasets[file_name];
title = program.get<string>("title");
if (!file_exists(complete_file_name)) {
throw runtime_error("Data File " + path + file_name + ".arff" + " does not exist");
}
}
catch (const exception& err) {
cerr << err.what() << endl;
@@ -98,26 +95,29 @@ int main(int argc, char** argv)
experiment.setDiscretized(discretize_dataset).setModel(model_name).setPlatform("BayesNet");
experiment.setStratified(stratified).setNFolds(n_folds).addRandomSeed(seed).setScoreName("accuracy");
platform::Timer timer;
cout << "*** Starting experiment: " << title << " ***" << endl;
timer.start();
for (auto fileName : filesToProcess) {
cout << "Processing " << fileName << endl;
cout << "- " << fileName << " ";
auto [X, y] = datasets.getTensors(fileName);
// auto states = datasets.getStates(fileName);
// auto features = datasets.getFeatures(fileName);
// auto className = datasets.getDataset(fileName).getClassName();
// Fold* fold;
// if (stratified)
// fold = new StratifiedKFold(n_folds, y, seed);
// else
// fold = new KFold(n_folds, y.numel(), seed);
// auto result = platform::cross_validation(fold, model_name, X, y, features, className, states);
// result.setDataset(file_name);
// experiment.setModelVersion(result.getModelVersion());
// experiment.addResult(result);
// delete fold;
auto states = datasets.getStates(fileName);
auto features = datasets.getFeatures(fileName);
auto samples = datasets.getNSamples(fileName);
auto className = datasets.getClassName(fileName);
cout << " (" << samples << ", " << features.size() << ") " << flush;
Fold* fold;
if (stratified)
fold = new StratifiedKFold(n_folds, y, seed);
else
fold = new KFold(n_folds, samples, seed);
auto result = platform::cross_validation(fold, model_name, X, y, features, className, states);
result.setDataset(file_name);
experiment.setModelVersion(result.getModelVersion());
experiment.addResult(result);
delete fold;
}
experiment.setDuration(timer.getDuration());
experiment.save(path);
experiment.show();
experiment.save(PATH_RESULTS);
cout << "Done!" << endl;
return 0;
}