From 889668bf006e3035b16a0a1f2cfd29eaab1ad01d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Mon, 29 Jan 2024 18:50:53 +0100 Subject: [PATCH] Enhance b_main experiment output --- src/Platform/modules/BestResults.cc | 9 +++----- src/Platform/modules/Experiment.cc | 35 ++++++++++++++++++++++------- src/Platform/modules/Experiment.h | 1 + 3 files changed, 31 insertions(+), 14 deletions(-) diff --git a/src/Platform/modules/BestResults.cc b/src/Platform/modules/BestResults.cc index d6a3991..c361955 100644 --- a/src/Platform/modules/BestResults.cc +++ b/src/Platform/modules/BestResults.cc @@ -39,15 +39,12 @@ namespace platform { auto result = Result(path, file); auto data = result.load(); for (auto const& item : data.at("results")) { - bool update = false; - // Check if results file contains only one dataset + bool update = true; auto datasetName = item.at("dataset").get(); if (bests.contains(datasetName)) { - if (item.at("score").get() > bests[datasetName].at(0).get()) { - update = true; + if (item.at("score").get() < bests[datasetName].at(0).get()) { + update = false; } - } else { - update = true; } if (update) { bests[datasetName] = { item.at("score").get(), item.at("hyperparameters"), file }; diff --git a/src/Platform/modules/Experiment.cc b/src/Platform/modules/Experiment.cc index 86b138f..a48a290 100644 --- a/src/Platform/modules/Experiment.cc +++ b/src/Platform/modules/Experiment.cc @@ -103,12 +103,31 @@ namespace platform { void Experiment::go(std::vector filesToProcess, bool quiet) { - std::cout << "*** Starting experiment: " << title << " ***" << std::endl; for (auto fileName : filesToProcess) { - std::cout << "- " << setw(20) << left << fileName << " " << right << flush; - cross_validation(fileName, quiet); - std::cout << std::endl; + if (fileName.size() > max_name) + max_name = fileName.size(); } + std::cout << Colors::MAGENTA() << "*** Starting experiment: " << title << " ***" << Colors::RESET() << std::endl << std::endl; + if (!quiet) { + std::cout << Colors::GREEN() << " Status Meaning" << std::endl; + std::cout << " ------ -----------------------------" << Colors::RESET() << std::endl; + std::cout << " ( " << Colors::GREEN() << "a" << Colors::RESET() << " ) Fitting model with train dataset" << std::endl; + std::cout << " ( " << Colors::GREEN() << "b" << Colors::RESET() << " ) Scoring train dataset" << std::endl; + std::cout << " ( " << Colors::GREEN() << "c" << Colors::RESET() << " ) Scoring test dataset" << std::endl << std::endl; + std::cout << Colors::YELLOW() << "Note: fold number in this color means fitting had issues such as not using all features in BoostAODE classifier" << std::endl << std::endl; + std::cout << Colors::GREEN() << left << " # " << setw(max_name) << "Dataset" << " #Samp #Feat Seed Status" << std::endl; + std::cout << " --- " << string(max_name, '-') << " ----- ----- ---- " << string(4 + 3 * nfolds, '-') << Colors::RESET() << std::endl; + } + int num = 0; + for (auto fileName : filesToProcess) { + if (!quiet) + std::cout << " " << setw(3) << right << num++ << " " << setw(max_name) << left << fileName << right << flush; + cross_validation(fileName, quiet); + if (!quiet) + std::cout << std::endl; + } + if (!quiet) + std::cout << std::endl; } std::string getColor(bayesnet::status_t status) @@ -141,7 +160,7 @@ namespace platform { auto samples = datasets.getNSamples(fileName); auto className = datasets.getClassName(fileName); if (!quiet) { - std::cout << " (" << setw(5) << samples << "," << setw(3) << features.size() << ") " << flush; + std::cout << " " << setw(5) << samples << " " << setw(5) << features.size() << flush; } // Prepare Result auto result = Result(); @@ -162,11 +181,11 @@ namespace platform { bool first_seed = true; for (auto seed : randomSeeds) { if (!quiet) { - string prefix = ""; + string prefix = " "; if (!first_seed) { - prefix = "\n" + string(36, ' '); + prefix = "\n" + string(18 + max_name, ' '); } - std::cout << prefix << "(" << setw(4) << seed << ") doing Fold: " << flush; + std::cout << prefix << setw(4) << right << seed << " " << flush; first_seed = false; } folding::Fold* fold; diff --git a/src/Platform/modules/Experiment.h b/src/Platform/modules/Experiment.h index 440f711..c0a1570 100644 --- a/src/Platform/modules/Experiment.h +++ b/src/Platform/modules/Experiment.h @@ -96,6 +96,7 @@ namespace platform { std::vector randomSeeds; HyperParameters hyperparameters; int nfolds{ 0 }; + int max_name{ 7 }; // max length of dataset name for formatting (default 7) float duration{ 0 }; json build_json(); };