Refactor arguments management for Experimentation

This commit is contained in:
2025-01-18 18:26:34 +01:00
parent 7aaf6d1bf8
commit 3397d0962f
10 changed files with 325 additions and 420 deletions

View File

@@ -14,11 +14,11 @@ namespace platform {
result.save();
std::cout << "Result saved in " << Paths::results() << result.getFilename() << std::endl;
}
void Experiment::report(bool classification_report)
void Experiment::report()
{
ReportConsole report(result.getJson());
report.show();
if (classification_report) {
if (filesToTest.size() == 1) {
std::cout << report.showClassificationReport(Colors::BLUE());
}
}
@@ -43,9 +43,9 @@ namespace platform {
}
}
}
void Experiment::go(std::vector<std::string> filesToProcess, bool quiet, bool no_train_score, bool generate_fold_files, bool graph)
void Experiment::go()
{
for (auto fileName : filesToProcess) {
for (auto fileName : filesToTest) {
if (fileName.size() > max_name)
max_name = fileName.size();
}
@@ -64,10 +64,10 @@ namespace platform {
std::cout << " --- " << string(max_name, '-') << " ----- ----- ---- " << string(4 + 3 * nfolds, '-') << " ----------" << Colors::RESET() << std::endl;
}
int num = 0;
for (auto fileName : filesToProcess) {
for (auto fileName : filesToTest) {
if (!quiet)
std::cout << " " << setw(3) << right << num++ << " " << setw(max_name) << left << fileName << right << flush;
cross_validation(fileName, quiet, no_train_score, generate_fold_files, graph);
cross_validation(fileName);
if (!quiet)
std::cout << std::endl;
}
@@ -139,7 +139,7 @@ namespace platform {
file << output.dump(4);
file.close();
}
void Experiment::cross_validation(const std::string& fileName, bool quiet, bool no_train_score, bool generate_fold_files, bool graph)
void Experiment::cross_validation(const std::string& fileName)
{
//
// Load dataset and prepare data