From 68f22a673dc25fc53c33dff5dd541a61cece14b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana?= Date: Wed, 20 Sep 2023 11:40:01 +0200 Subject: [PATCH] Add comparison to report console --- src/Platform/ReportBase.cc | 38 +++++++++++++++++++++++++++++++++++ src/Platform/ReportBase.h | 19 +++++++++++++++++- src/Platform/ReportConsole.cc | 28 +++++++++++++++++++------- src/Platform/ReportConsole.h | 5 +++-- src/Platform/ReportExcel.cc | 38 +---------------------------------- src/Platform/ReportExcel.h | 16 +-------------- 6 files changed, 82 insertions(+), 62 deletions(-) diff --git a/src/Platform/ReportBase.cc b/src/Platform/ReportBase.cc index 24125f8..2289640 100644 --- a/src/Platform/ReportBase.cc +++ b/src/Platform/ReportBase.cc @@ -1,10 +1,22 @@ #include #include +#include "Datasets.h" #include "ReportBase.h" #include "BestResult.h" namespace platform { + ReportBase::ReportBase(json data_) : margin(0.1), data(data_) + { + stringstream oss; + oss << "Better than ZeroR + " << setprecision(1) << fixed << margin * 100 << "%"; + meaning = { + {Symbols::equal_best, "Equal to best"}, + {Symbols::better_best, "Better than best"}, + {Symbols::cross, "Less than or equal to ZeroR"}, + {Symbols::upward_arrow, oss.str()} + }; + } string ReportBase::fromVector(const string& key) { stringstream oss; @@ -34,4 +46,30 @@ namespace platform { header(); body(); } + string ReportBase::compareResult(const string& dataset, double result) + { + string status = " "; + if (data["score_name"].get() == "accuracy") { + auto dt = Datasets(Paths::datasets(), false); + dt.loadDataset(dataset); + auto numClasses = dt.getNClasses(dataset); + if (numClasses == 2) { + vector distribution = dt.getClassesCounts(dataset); + double nSamples = dt.getNSamples(dataset); + vector::iterator maxValue = max_element(distribution.begin(), distribution.end()); + double mark = *maxValue / nSamples * (1 + margin); + if (mark > 1) { + mark = 0.9995; + } + status = result < mark ? Symbols::cross : result > mark ? Symbols::upward_arrow : "="; + auto item = summary.find(status); + if (item != summary.end()) { + summary[status]++; + } else { + summary[status] = 1; + } + } + } + return status; + } } \ No newline at end of file diff --git a/src/Platform/ReportBase.h b/src/Platform/ReportBase.h index 2acbbc7..f587c2e 100644 --- a/src/Platform/ReportBase.h +++ b/src/Platform/ReportBase.h @@ -2,14 +2,26 @@ #define REPORTBASE_H #include #include +#include "Paths.h" #include using json = nlohmann::json; namespace platform { using namespace std; + class Symbols { + public: + inline static const string check_mark{ "\u2714" }; + inline static const string exclamation{ "\u2757" }; + inline static const string black_star{ "\u2605" }; + inline static const string cross{ "\u2717" }; + inline static const string upward_arrow{ "\u27B6" }; + inline static const string down_arrow{ "\u27B4" }; + inline static const string equal_best{ check_mark }; + inline static const string better_best{ black_star }; + }; class ReportBase { public: - explicit ReportBase(json data_) { data = data_; }; + explicit ReportBase(json data_); virtual ~ReportBase() = default; void show(); protected: @@ -18,6 +30,11 @@ namespace platform { string fVector(const string& title, const json& data, const int width, const int precision); virtual void header() = 0; virtual void body() = 0; + virtual void showSummary() = 0; + string compareResult(const string& dataset, double result); + map summary; + double margin; + map meaning; }; }; #endif \ No newline at end of file diff --git a/src/Platform/ReportConsole.cc b/src/Platform/ReportConsole.cc index acbb602..0de1c11 100644 --- a/src/Platform/ReportConsole.cc +++ b/src/Platform/ReportConsole.cc @@ -11,11 +11,11 @@ namespace platform { string do_grouping() const { return "\03"; } }; - string ReportConsole::headerLine(const string& text) + string ReportConsole::headerLine(const string& text, int utf = 0) { int n = MAXL - text.length() - 3; n = n < 0 ? 0 : n; - return "* " + text + string(n, ' ') + "*\n"; + return "* " + text + string(n + utf, ' ') + "*\n"; } void ReportConsole::header() @@ -36,8 +36,8 @@ namespace platform { } void ReportConsole::body() { - cout << Colors::GREEN() << " # Dataset Sampl. Feat. Cls Nodes Edges States Score Time Hyperparameters" << endl; - cout << "=== ============================== ====== ===== === ========= ========= ========= =============== ================== ===============" << endl; + cout << Colors::GREEN() << " # Dataset Sampl. Feat. Cls Nodes Edges States Score Time Hyperparameters" << endl; + cout << "=== ========================= ====== ===== === ========= ========= ========= =============== =================== ====================" << endl; json lastResult; double totalScore = 0.0; bool odd = true; @@ -50,15 +50,17 @@ namespace platform { auto color = odd ? Colors::CYAN() : Colors::BLUE(); cout << color; cout << setw(3) << index++ << " "; - cout << setw(30) << left << r["dataset"].get() << " "; + cout << setw(25) << left << r["dataset"].get() << " "; cout << setw(6) << right << r["samples"].get() << " "; cout << setw(5) << right << r["features"].get() << " "; cout << setw(3) << right << r["classes"].get() << " "; cout << setw(9) << setprecision(2) << fixed << r["nodes"].get() << " "; cout << setw(9) << setprecision(2) << fixed << r["leaves"].get() << " "; cout << setw(9) << setprecision(2) << fixed << r["depth"].get() << " "; - cout << setw(8) << right << setprecision(6) << fixed << r["score"].get() << "±" << setw(6) << setprecision(4) << fixed << r["score_std"].get() << " "; - cout << setw(11) << right << setprecision(6) << fixed << r["time"].get() << "±" << setw(6) << setprecision(4) << fixed << r["time_std"].get() << " "; + cout << setw(8) << right << setprecision(6) << fixed << r["score"].get() << "±" << setw(6) << setprecision(4) << fixed << r["score_std"].get(); + const string status = compareResult(r["dataset"].get(), r["score"].get()); + cout << status; + cout << setw(12) << right << setprecision(6) << fixed << r["time"].get() << "±" << setw(6) << setprecision(4) << fixed << r["time_std"].get() << " "; try { cout << r["hyperparameters"].get(); } @@ -81,9 +83,21 @@ namespace platform { footer(totalScore); } } + void ReportConsole::showSummary() + { + for (const auto& item : summary) { + stringstream oss; + oss << setw(3) << left << item.first; + oss << setw(3) << right << item.second << " "; + oss << left << meaning.at(item.first); + cout << headerLine(oss.str(), 2); + } + } + void ReportConsole::footer(double totalScore) { cout << Colors::MAGENTA() << string(MAXL, '*') << endl; + showSummary(); auto score = data["score_name"].get(); if (score == BestResult::scoreName()) { stringstream oss; diff --git a/src/Platform/ReportConsole.h b/src/Platform/ReportConsole.h index b34e71f..7b3906c 100644 --- a/src/Platform/ReportConsole.h +++ b/src/Platform/ReportConsole.h @@ -7,17 +7,18 @@ namespace platform { using namespace std; - const int MAXL = 132; + const int MAXL = 133; class ReportConsole : public ReportBase { public: explicit ReportConsole(json data_, int index = -1) : ReportBase(data_), selectedIndex(index) {}; virtual ~ReportConsole() = default; private: int selectedIndex; - string headerLine(const string& text); + string headerLine(const string& text, int utf); void header() override; void body() override; void footer(double totalScore); + void showSummary(); }; }; #endif \ No newline at end of file diff --git a/src/Platform/ReportExcel.cc b/src/Platform/ReportExcel.cc index bfda442..c7d6ded 100644 --- a/src/Platform/ReportExcel.cc +++ b/src/Platform/ReportExcel.cc @@ -1,6 +1,5 @@ #include #include -#include "Datasets.h" #include "ReportExcel.h" #include "BestResult.h" @@ -20,7 +19,6 @@ namespace platform { colorTitle = 0xB1A0C7; colorOdd = 0xDCE6F1; colorEven = 0xFDE9D9; - margin = .1; // margin to add to ZeroR comparison createFile(); } @@ -308,43 +306,9 @@ namespace platform { footer(totalScore, row); } } - string ReportExcel::compareResult(const string& dataset, double result) - { - string status = " "; - if (data["score_name"].get() == "accuracy") { - auto dt = Datasets(Paths::datasets(), false); - dt.loadDataset(dataset); - auto numClasses = dt.getNClasses(dataset); - if (numClasses == 2) { - vector distribution = dt.getClassesCounts(dataset); - double nSamples = dt.getNSamples(dataset); - vector::iterator maxValue = max_element(distribution.begin(), distribution.end()); - double mark = *maxValue / nSamples * (1 + margin); - if (mark > 1) { - mark = 0.9995; - } - status = result < mark ? Symbols::cross : result > mark ? Symbols::upward_arrow : "="; - auto item = summary.find(status); - if (item != summary.end()) { - summary[status]++; - } else { - summary[status] = 1; - } - } - } - return status; - } + void ReportExcel::showSummary() { - stringstream oss; - oss << "Better than ZeroR + " << setprecision(1) << fixed << margin * 100 << "%"; - - map meaning = { - {Symbols::equal_best, "Equal to best"}, - {Symbols::better_best, "Better than best"}, - {Symbols::cross, "Less than or equal to ZeroR"}, - {Symbols::upward_arrow, oss.str()} - }; for (const auto& item : summary) { worksheet_write_string(worksheet, row + 2, 1, item.first.c_str(), styles["summaryStyle"]); worksheet_write_number(worksheet, row + 2, 2, item.second, styles["summaryStyle"]); diff --git a/src/Platform/ReportExcel.h b/src/Platform/ReportExcel.h index 6e0f6eb..289c808 100644 --- a/src/Platform/ReportExcel.h +++ b/src/Platform/ReportExcel.h @@ -3,22 +3,11 @@ #include #include "xlsxwriter.h" #include "ReportBase.h" -#include "Paths.h" #include "Colors.h" namespace platform { using namespace std; const int MAXLL = 128; - class Symbols { - public: - inline static const string check_mark{ "\u2714" }; - inline static const string exclamation{ "\u2757" }; - inline static const string black_star{ "\u2605" }; - inline static const string cross{ "\u2717" }; - inline static const string upward_arrow{ "\u27B6" }; - inline static const string down_arrow{ "\u27B4" }; - inline static const string equal_best{ check_mark }; - inline static const string better_best{ black_star }; - }; + class ReportExcel : public ReportBase { public: explicit ReportExcel(json data_, lxw_workbook* workbook); @@ -36,13 +25,11 @@ namespace platform { lxw_workbook* workbook; lxw_worksheet* worksheet; map styles; - map summary; int row; int normalSize; //font size for report body uint32_t colorTitle; uint32_t colorOdd; uint32_t colorEven; - double margin; const string fileName = "some_results.xlsx"; void header() override; void body() override; @@ -50,7 +37,6 @@ namespace platform { void createStyle(const string& name, lxw_format* style, bool odd); void addColor(lxw_format* style, bool odd); lxw_format* efectiveStyle(const string& name); - string compareResult(const string& dataset, double result); }; }; #endif // !REPORTEXCEL_H \ No newline at end of file