From 696c0564a73a923a33eef3fe7eeb0cd0b3cc91a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Fri, 17 May 2024 01:25:27 +0200 Subject: [PATCH] Add BoostA2DE model and fix some report errors --- .vscode/launch.json | 4 +++- src/best/BestResults.cpp | 4 ++++ src/commands/b_list.cpp | 7 ++++--- src/main/Models.h | 2 ++ src/main/modelRegister.h | 2 ++ src/reports/ResultsDatasetConsole.cpp | 7 ++++--- src/reports/ResultsDatasetConsole.h | 2 +- src/results/ResultsDataset.cpp | 2 ++ tests/TestScores.cpp | 17 +++-------------- 9 files changed, 25 insertions(+), 22 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 46ef243..6690862 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -99,7 +99,9 @@ "request": "launch", "program": "${workspaceFolder}/build_debug/src/b_list", "args": [ - "datasets" + "results", + "-d", + "mfeat-morphological" ], //"cwd": "/Users/rmontanana/Code/discretizbench", "cwd": "${workspaceFolder}/../discretizbench", diff --git a/src/best/BestResults.cpp b/src/best/BestResults.cpp index 1c97bd3..9f2d779 100644 --- a/src/best/BestResults.cpp +++ b/src/best/BestResults.cpp @@ -55,6 +55,10 @@ namespace platform { } } } + if (bests.empty()) { + std::cerr << Colors::MAGENTA() << "No results found for model " << model << " and score " << score << Colors::RESET() << std::endl; + exit(1); + } std::string bestFileName = path + bestResultFile(); std::ofstream file(bestFileName); file << bests; diff --git a/src/commands/b_list.cpp b/src/commands/b_list.cpp index f9cd901..b044d04 100644 --- a/src/commands/b_list.cpp +++ b/src/commands/b_list.cpp @@ -37,7 +37,8 @@ void list_results(argparse::ArgumentParser& program) auto model = program.get("model"); auto excel = program.get("excel"); auto report = platform::ResultsDatasetsConsole(); - report.report(dataset, score, model); + if (!report.report(dataset, score, model)) + return; std::cout << report.getOutput(); if (excel) { auto data = report.getData(); @@ -73,7 +74,7 @@ int main(int argc, char** argv) } throw std::runtime_error("Dataset must be one of " + datasets.toString()); } - ); + ); results_command.add_argument("-m", "--model") .help("Model to use: " + platform::Models::instance()->toString() + " or any") .default_value("any") @@ -86,7 +87,7 @@ int main(int argc, char** argv) } throw std::runtime_error("Model must be one of " + platform::Models::instance()->toString() + " or any"); } - ); + ); results_command.add_argument("--excel").help("Output in Excel format").default_value(false).implicit_value(true); results_command.add_argument("-s", "--score").default_value("accuracy").help("Filter results of the score name supplied"); diff --git a/src/main/Models.h b/src/main/Models.h index e9b51b7..cd9844f 100644 --- a/src/main/Models.h +++ b/src/main/Models.h @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -13,6 +14,7 @@ #include #include #include +#include #include #include #include diff --git a/src/main/modelRegister.h b/src/main/modelRegister.h index 1120d94..26d8e23 100644 --- a/src/main/modelRegister.h +++ b/src/main/modelRegister.h @@ -22,6 +22,8 @@ static platform::Registrar registrarALD("AODELd", [](void) -> bayesnet::BaseClassifier* { return new bayesnet::AODELd();}); static platform::Registrar registrarBA("BoostAODE", [](void) -> bayesnet::BaseClassifier* { return new bayesnet::BoostAODE();}); +static platform::Registrar registrarBA2("BoostA2DE", + [](void) -> bayesnet::BaseClassifier* { return new bayesnet::BoostA2DE();}); static platform::Registrar registrarSt("STree", [](void) -> bayesnet::BaseClassifier* { return new pywrap::STree();}); static platform::Registrar registrarOdte("Odte", diff --git a/src/reports/ResultsDatasetConsole.cpp b/src/reports/ResultsDatasetConsole.cpp index 470b025..c63c5e0 100644 --- a/src/reports/ResultsDatasetConsole.cpp +++ b/src/reports/ResultsDatasetConsole.cpp @@ -4,15 +4,15 @@ #include "results/ResultsDataset.h" #include "ResultsDatasetConsole.h" namespace platform { - void ResultsDatasetsConsole::report(const std::string& dataset, const std::string& score, const std::string& model) + bool ResultsDatasetsConsole::report(const std::string& dataset, const std::string& score, const std::string& model) { auto results = platform::ResultsDataset(dataset, model, score); results.load(); - results.sortModel(); if (results.empty()) { std::cerr << Colors::RED() << "No results found for dataset " << dataset << " and model " << model << Colors::RESET() << std::endl; - return; + return false; } + results.sortModel(); int maxModel = results.maxModelSize(); int maxHyper = results.maxHyperSize(); double maxResult = results.maxResultScore(); @@ -76,6 +76,7 @@ namespace platform { oss << item["hyperparameters"].get() << std::endl; body.push_back(oss.str()); } + return true; } } diff --git a/src/reports/ResultsDatasetConsole.h b/src/reports/ResultsDatasetConsole.h index 75bcc1d..cc3fe4b 100644 --- a/src/reports/ResultsDatasetConsole.h +++ b/src/reports/ResultsDatasetConsole.h @@ -12,7 +12,7 @@ namespace platform { public: ResultsDatasetsConsole() = default; ~ResultsDatasetsConsole() = default; - void report(const std::string& dataset, const std::string& score, const std::string& model); + bool report(const std::string& dataset, const std::string& score, const std::string& model); }; } diff --git a/src/results/ResultsDataset.cpp b/src/results/ResultsDataset.cpp index 74ed878..0ba8a1c 100644 --- a/src/results/ResultsDataset.cpp +++ b/src/results/ResultsDataset.cpp @@ -31,6 +31,8 @@ namespace platform { } } } + if (files.empty()) + return; maxModel = std::max(size_t(5), (*max_element(files.begin(), files.end(), [](const Result& a, const Result& b) { return a.getModel().size() < b.getModel().size(); })).getModel().size()); maxFile = std::max(size_t(4), (*max_element(files.begin(), files.end(), [](const Result& a, const Result& b) { return a.getFilename().size() < b.getFilename().size(); })).getFilename().size()); } diff --git a/tests/TestScores.cpp b/tests/TestScores.cpp index 4d481bd..2cdd805 100644 --- a/tests/TestScores.cpp +++ b/tests/TestScores.cpp @@ -7,6 +7,7 @@ #include "common/DotEnv.h" #include "common/Datasets.h" #include "common/Paths.h" +#include "common/Colors.h" #include "main/Scores.h" #include "config.h" @@ -127,7 +128,7 @@ TEST_CASE("Confusion Matrix JSON", "[Scores]") REQUIRE(res_json_str["Car"][1] == 2); REQUIRE(res_json_str["Car"][2] == 3); } -TEST_CASE("Classification Report", "[Scores]") +TEST_CASE("Classification Report", "[Scores]") - { std::vector y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 }; std::vector y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 }; @@ -135,19 +136,7 @@ TEST_CASE("Classification Report", "[Scores]") auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32); std::vector labels = { "Aeroplane", "Boat", "Car" }; platform::Scores scores(y_test_tensor, y_pred_tensor, 3, labels); - std::string expected = R"(Classification Report -===================== - precision recall f1-score support - ========= ========= ========= ========= - Aeroplane 0.6666667 0.6666667 0.6666667 3 - Boat 0.2500000 1.0000000 0.4000000 1 - Car 1.0000000 0.5000000 0.6666667 6 - - accuracy 0.6000000 10 - macro avg 0.6388889 0.7222223 0.5777778 10 -weighted avg 0.8250000 0.6000000 0.6400000 10 -)"; - REQUIRE(scores.classification_report() == expected); + auto report = scores.classification_report(Colors::BLUE(), "train"); auto json_matrix = scores.get_confusion_matrix_json(true); platform::Scores scores2(json_matrix); REQUIRE(scores.classification_report() == scores2.classification_report());