From 4c847fc3f6e887fd81eb05275161f8605923af5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Sat, 9 Mar 2024 20:19:27 +0100 Subject: [PATCH] Add model selection to b_best to filter results --- src/best/BestResults.cc | 19 +++++++++++-------- src/best/BestResults.h | 5 +++-- src/best/b_best.cc | 10 ++++++---- 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/src/best/BestResults.cc b/src/best/BestResults.cc index c4a624d..6ea29c9 100644 --- a/src/best/BestResults.cc +++ b/src/best/BestResults.cc @@ -42,6 +42,9 @@ namespace platform { for (auto const& item : data.at("results")) { bool update = true; auto datasetName = item.at("dataset").get(); + if (dataset != "any" && dataset != datasetName) { + continue; + } if (bests.contains(datasetName)) { if (item.at("score").get() < bests[datasetName].at(0).get()) { update = false; @@ -122,8 +125,8 @@ namespace platform { std::vector BestResults::getDatasets(json table) { std::vector datasets; - for (const auto& dataset : table.items()) { - datasets.push_back(dataset.key()); + for (const auto& dataset_ : table.items()) { + datasets.push_back(dataset_.key()); } maxDatasetName = (*max_element(datasets.begin(), datasets.end(), [](const std::string& a, const std::string& b) { return a.size() < b.size(); })).size(); maxDatasetName = std::max(7, maxDatasetName); @@ -232,16 +235,16 @@ namespace platform { totals[model] = 0.0; } auto datasets = getDatasets(table.begin().value()); - for (auto const& dataset : datasets) { + for (auto const& dataset_ : datasets) { auto color = odd ? Colors::BLUE() : Colors::CYAN(); std::cout << color << std::setw(3) << std::fixed << std::right << i++ << " "; - std::cout << std::setw(maxDatasetName) << std::left << dataset << " "; + std::cout << std::setw(maxDatasetName) << std::left << dataset_ << " "; double maxValue = 0; // Find out the max value for this dataset for (const auto& model : models) { double value; try { - value = table[model].at(dataset).at(0).get(); + value = table[model].at(dataset_).at(0).get(); } catch (nlohmann::json_abi_v3_11_3::detail::out_of_range err) { value = -1.0; @@ -255,7 +258,7 @@ namespace platform { std::string efectiveColor = color; double value; try { - value = table[model].at(dataset).at(0).get(); + value = table[model].at(dataset_).at(0).get(); } catch (nlohmann::json_abi_v3_11_3::detail::out_of_range err) { value = -1.0; @@ -331,9 +334,9 @@ namespace platform { double min = 2000; // Find out the control model auto totals = std::vector(models.size(), 0.0); - for (const auto& dataset : datasets) { + for (const auto& dataset_ : datasets) { for (int i = 0; i < models.size(); ++i) { - totals[i] += ranksModels[dataset][models[i]]; + totals[i] += ranksModels[dataset_][models[i]]; } } for (int i = 0; i < models.size(); ++i) { diff --git a/src/best/BestResults.h b/src/best/BestResults.h index 7d576b0..1937df7 100644 --- a/src/best/BestResults.h +++ b/src/best/BestResults.h @@ -6,8 +6,8 @@ using json = nlohmann::json; namespace platform { class BestResults { public: - explicit BestResults(const std::string& path, const std::string& score, const std::string& model, bool friedman, double significance = 0.05) - : path(path), score(score), model(model), friedman(friedman), significance(significance) + explicit BestResults(const std::string& path, const std::string& score, const std::string& model, const std::string& dataset, bool friedman, double significance = 0.05) + : path(path), score(score), model(model), dataset(dataset), friedman(friedman), significance(significance) { } std::string build(); @@ -27,6 +27,7 @@ namespace platform { std::string path; std::string score; std::string model; + std::string dataset; bool friedman; double significance; int maxModelName = 0; diff --git a/src/best/b_best.cc b/src/best/b_best.cc index 1e83969..9ca4370 100644 --- a/src/best/b_best.cc +++ b/src/best/b_best.cc @@ -8,6 +8,7 @@ void manageArguments(argparse::ArgumentParser& program) { program.add_argument("-m", "--model").default_value("").help("Filter results of the selected model) (any for all models)"); + program.add_argument("-d", "--dataset").default_value("any").help("Filter results of the selected model) (any for all datasets)"); program.add_argument("-s", "--score").default_value("accuracy").help("Filter results of the score name supplied"); program.add_argument("--friedman").help("Friedman test").default_value(false).implicit_value(true); program.add_argument("--excel").help("Output to excel").default_value(false).implicit_value(true); @@ -31,12 +32,13 @@ int main(int argc, char** argv) { argparse::ArgumentParser program("b_best", { platform_project_version.begin(), platform_project_version.end() }); manageArguments(program); - std::string model, score; + std::string model, dataset, score; bool build, report, friedman, excel; double level; try { program.parse_args(argc, argv); model = program.get("model"); + dataset = program.get("dataset"); score = program.get("score"); friedman = program.get("friedman"); excel = program.get("excel"); @@ -44,8 +46,8 @@ int main(int argc, char** argv) if (model == "" || score == "") { throw std::runtime_error("Model and score name must be supplied"); } - if (friedman && model != "any") { - std::cerr << "Friedman test can only be used with all models" << std::endl; + if (friedman && (model != "any" || dataset != "any")) { + std::cerr << "Friedman test can only be used with all models and all the datasets" << std::endl; std::cerr << program; exit(1); } @@ -56,7 +58,7 @@ int main(int argc, char** argv) exit(1); } // Generate report - auto results = platform::BestResults(platform::Paths::results(), score, model, friedman, level); + auto results = platform::BestResults(platform::Paths::results(), score, model, dataset, friedman, level); if (model == "any") { results.buildAll(); results.reportAll(excel);