From 6d4117d18858a0f6e26d22d678efc1501ea4bb10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Fri, 10 May 2024 14:11:51 +0200 Subject: [PATCH] Add Classification report to end of experiment if only one dataset is tested --- src/CMakeLists.txt | 1 + src/commands/b_main.cpp | 6 ++++-- src/main/Experiment.cpp | 5 ++++- src/main/Experiment.h | 2 +- src/reports/ReportConsole.cpp | 13 +++++++++++++ src/reports/ReportConsole.h | 1 + 6 files changed, 24 insertions(+), 4 deletions(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 7d23ba3..90244b7 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -64,5 +64,6 @@ add_executable( common/Datasets.cpp common/Dataset.cpp reports/ReportConsole.cpp reports/ReportExcel.cpp reports/ReportExcelCompared.cpp reports/ReportBase.cpp reports/ExcelFile.cpp reports/DatasetsConsole.cpp reports/ResultsDatasetConsole.cpp reports/ReportsPaged.cpp results/Result.cpp results/ResultsDataset.cpp + main/Scores.cpp ) target_link_libraries(b_manage "${TORCH_LIBRARIES}" "${XLSXWRITER_LIB}" ArffFiles mdlp "${BayesNet}") diff --git a/src/commands/b_main.cpp b/src/commands/b_main.cpp index 1aaa792..c8fd606 100644 --- a/src/commands/b_main.cpp +++ b/src/commands/b_main.cpp @@ -189,8 +189,10 @@ int main(int argc, char** argv) if (saveResults) { experiment.saveResult(); } - if (!quiet) - experiment.report(); + if (!quiet) { + // Classification report if only one dataset is tested + experiment.report(filesToTest.size() == 1); + } std::cout << "Done!" << std::endl; return 0; } diff --git a/src/main/Experiment.cpp b/src/main/Experiment.cpp index aea2dce..1f67592 100644 --- a/src/main/Experiment.cpp +++ b/src/main/Experiment.cpp @@ -11,10 +11,13 @@ namespace platform { { result.save(); } - void Experiment::report() + void Experiment::report(bool classification_report) { ReportConsole report(result.getJson()); report.show(); + if (classification_report) { + report.showClassificationReport(); + } } void Experiment::show() { diff --git a/src/main/Experiment.h b/src/main/Experiment.h index a24ef5d..1863aa5 100644 --- a/src/main/Experiment.h +++ b/src/main/Experiment.h @@ -32,7 +32,7 @@ namespace platform { void go(std::vector filesToProcess, bool quiet, bool no_train_score); void saveResult(); void show(); - void report(); + void report(bool classification_report = false); private: Result result; bool discretized{ false }, stratified{ false }; diff --git a/src/reports/ReportConsole.cpp b/src/reports/ReportConsole.cpp index 21328b2..c3d5c82 100644 --- a/src/reports/ReportConsole.cpp +++ b/src/reports/ReportConsole.cpp @@ -3,6 +3,7 @@ #include "best/BestScore.h" #include "common/CLocale.h" #include "ReportConsole.h" +#include "main/Scores.h" namespace platform { std::string ReportConsole::headerLine(const std::string& text, int utf = 0) @@ -164,4 +165,16 @@ namespace platform { std::cout << headerLine("*** Best Results File not found. Couldn't compare any result!"); } } + void ReportConsole::showClassificationReport() + { + if (data["results"].size() > 1) + return; + auto item = data["results"][0]; + auto scores = Scores(item["confusion_matrices"][0]); + for (int i = 1; i < item["confusion_matrices"].size(); i++) { + auto score = Scores(item["confusion_matrices"][i]); + scores.aggregate(score); + } + std::cout << Colors::BLUE() << scores.classification_report() << Colors::RESET(); + } } \ No newline at end of file diff --git a/src/reports/ReportConsole.h b/src/reports/ReportConsole.h index 01859a1..c6ba6d6 100644 --- a/src/reports/ReportConsole.h +++ b/src/reports/ReportConsole.h @@ -14,6 +14,7 @@ namespace platform { std::string fileReport(); std::string getHeader() { do_header(); do_body(); return sheader.str(); } std::vector& getBody() { return vbody; } + void showClassificationReport(); private: int selectedIndex; std::string headerLine(const std::string& text, int utf);