From 03533461c824349e146eed3080b4ec480147d117 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana?= Date: Wed, 20 Sep 2023 12:51:19 +0200 Subject: [PATCH] Add compare to best results in manage --- src/Platform/ReportBase.cc | 70 ++++++++++++++++++++++++++---------- src/Platform/ReportBase.h | 6 +++- src/Platform/ReportConsole.h | 2 +- src/Platform/ReportExcel.cc | 2 +- src/Platform/ReportExcel.h | 2 +- src/Platform/Results.cc | 6 ++-- src/Platform/Results.h | 7 +++- src/Platform/manage.cc | 5 ++- 8 files changed, 73 insertions(+), 27 deletions(-) diff --git a/src/Platform/ReportBase.cc b/src/Platform/ReportBase.cc index 2289640..6a5b885 100644 --- a/src/Platform/ReportBase.cc +++ b/src/Platform/ReportBase.cc @@ -6,7 +6,7 @@ namespace platform { - ReportBase::ReportBase(json data_) : margin(0.1), data(data_) + ReportBase::ReportBase(json data_, bool compare) : data(data_), compare(compare), margin(0.1) { stringstream oss; oss << "Better than ZeroR + " << setprecision(1) << fixed << margin * 100 << "%"; @@ -49,27 +49,59 @@ namespace platform { string ReportBase::compareResult(const string& dataset, double result) { string status = " "; - if (data["score_name"].get() == "accuracy") { - auto dt = Datasets(Paths::datasets(), false); - dt.loadDataset(dataset); - auto numClasses = dt.getNClasses(dataset); - if (numClasses == 2) { - vector distribution = dt.getClassesCounts(dataset); - double nSamples = dt.getNSamples(dataset); - vector::iterator maxValue = max_element(distribution.begin(), distribution.end()); - double mark = *maxValue / nSamples * (1 + margin); - if (mark > 1) { - mark = 0.9995; - } - status = result < mark ? Symbols::cross : result > mark ? Symbols::upward_arrow : "="; - auto item = summary.find(status); - if (item != summary.end()) { - summary[status]++; - } else { - summary[status] = 1; + if (compare) { + double best = bestResult(dataset, data["model"].get()); + if (result == best) { + status = Symbols::equal_best; + } else if (result > best) { + status = Symbols::better_best; + } + } else { + if (data["score_name"].get() == "accuracy") { + auto dt = Datasets(Paths::datasets(), false); + dt.loadDataset(dataset); + auto numClasses = dt.getNClasses(dataset); + if (numClasses == 2) { + vector distribution = dt.getClassesCounts(dataset); + double nSamples = dt.getNSamples(dataset); + vector::iterator maxValue = max_element(distribution.begin(), distribution.end()); + double mark = *maxValue / nSamples * (1 + margin); + if (mark > 1) { + mark = 0.9995; + } + status = result < mark ? Symbols::cross : result > mark ? Symbols::upward_arrow : "="; } } } + if (status != " ") { + auto item = summary.find(status); + if (item != summary.end()) { + summary[status]++; + } else { + summary[status] = 1; + } + } return status; } + double ReportBase::bestResult(const string& dataset, const string& model) + { + double value = 0.0; + if (bestResults.size() == 0) { + // try to load the best results + string score = data["score_name"]; + replace(score.begin(), score.end(), '_', '-'); + string fileName = "best_results_" + score + "_" + model + ".json"; + ifstream resultData(Paths::results() + "/" + fileName); + if (resultData.is_open()) { + bestResults = json::parse(resultData); + } + } + try { + value = bestResults.at(dataset).at(0); + } + catch (exception) { + value = 1.0; + } + return value; + } } \ No newline at end of file diff --git a/src/Platform/ReportBase.h b/src/Platform/ReportBase.h index f587c2e..7695102 100644 --- a/src/Platform/ReportBase.h +++ b/src/Platform/ReportBase.h @@ -21,7 +21,7 @@ namespace platform { }; class ReportBase { public: - explicit ReportBase(json data_); + explicit ReportBase(json data_, bool compare); virtual ~ReportBase() = default; void show(); protected: @@ -35,6 +35,10 @@ namespace platform { map summary; double margin; map meaning; + private: + double bestResult(const string& dataset, const string& model); + bool compare; + json bestResults; }; }; #endif \ No newline at end of file diff --git a/src/Platform/ReportConsole.h b/src/Platform/ReportConsole.h index 7b3906c..3dcc719 100644 --- a/src/Platform/ReportConsole.h +++ b/src/Platform/ReportConsole.h @@ -10,7 +10,7 @@ namespace platform { const int MAXL = 133; class ReportConsole : public ReportBase { public: - explicit ReportConsole(json data_, int index = -1) : ReportBase(data_), selectedIndex(index) {}; + explicit ReportConsole(json data_, bool compare = false, int index = -1) : ReportBase(data_, compare), selectedIndex(index) {}; virtual ~ReportConsole() = default; private: int selectedIndex; diff --git a/src/Platform/ReportExcel.cc b/src/Platform/ReportExcel.cc index c7d6ded..ad816a2 100644 --- a/src/Platform/ReportExcel.cc +++ b/src/Platform/ReportExcel.cc @@ -13,7 +13,7 @@ namespace platform { string do_grouping() const { return "\03"; } }; - ReportExcel::ReportExcel(json data_, lxw_workbook* workbook) : ReportBase(data_), row(0), workbook(workbook) + ReportExcel::ReportExcel(json data_, bool compare, lxw_workbook* workbook) : ReportBase(data_, compare), row(0), workbook(workbook) { normalSize = 14; //font size for report body colorTitle = 0xB1A0C7; diff --git a/src/Platform/ReportExcel.h b/src/Platform/ReportExcel.h index 289c808..c5d462f 100644 --- a/src/Platform/ReportExcel.h +++ b/src/Platform/ReportExcel.h @@ -10,7 +10,7 @@ namespace platform { class ReportExcel : public ReportBase { public: - explicit ReportExcel(json data_, lxw_workbook* workbook); + explicit ReportExcel(json data_, bool compare, lxw_workbook* workbook); lxw_workbook* getWorkbook(); private: void writeString(int row, int col, const string& text, const string& style = ""); diff --git a/src/Platform/Results.cc b/src/Platform/Results.cc index e136a62..8568600 100644 --- a/src/Platform/Results.cc +++ b/src/Platform/Results.cc @@ -109,12 +109,12 @@ namespace platform { cout << Colors::YELLOW() << "Reporting " << files.at(index).getFilename() << endl; auto data = files.at(index).load(); if (excelReport) { - ReportExcel reporter(data, workbook); + ReportExcel reporter(data, compare, workbook); reporter.show(); openExcel = true; workbook = reporter.getWorkbook(); } else { - ReportConsole reporter(data); + ReportConsole reporter(data, compare); reporter.show(); } } @@ -150,6 +150,7 @@ namespace platform { if (all_of(line.begin(), line.end(), ::isdigit)) { int idx = stoi(line); if (indexList) { + // The value is about the files list index = idx; if (index >= 0 && index < files.size()) { report(index, false); @@ -157,6 +158,7 @@ namespace platform { continue; } } else { + // The value is about the result showed on screen showIndex(index, idx); continue; } diff --git a/src/Platform/Results.h b/src/Platform/Results.h index c418135..60748ba 100644 --- a/src/Platform/Results.h +++ b/src/Platform/Results.h @@ -35,7 +35,11 @@ namespace platform { }; class Results { public: - Results(const string& path, const int max, const string& model, const string& score, bool complete, bool partial) : path(path), max(max), model(model), scoreName(score), complete(complete), partial(partial) { load(); }; + Results(const string& path, const int max, const string& model, const string& score, bool complete, bool partial, bool compare) : + path(path), max(max), model(model), scoreName(score), complete(complete), partial(partial), compare(compare) + { + load(); + }; void manage(); private: string path; @@ -46,6 +50,7 @@ namespace platform { bool partial; bool indexList = true; bool openExcel = false; + bool compare; lxw_workbook* workbook = NULL; vector files; void load(); // Loads the list of results diff --git a/src/Platform/manage.cc b/src/Platform/manage.cc index aec19e7..cf699d6 100644 --- a/src/Platform/manage.cc +++ b/src/Platform/manage.cc @@ -14,6 +14,7 @@ argparse::ArgumentParser manageArguments(int argc, char** argv) program.add_argument("-s", "--score").default_value("any").help("Filter results of the score name supplied"); program.add_argument("--complete").help("Show only results with all datasets").default_value(false).implicit_value(true); program.add_argument("--partial").help("Show only partial results").default_value(false).implicit_value(true); + program.add_argument("--compare").help("Compare with best results").default_value(false).implicit_value(true); try { program.parse_args(argc, argv); auto number = program.get("number"); @@ -24,6 +25,7 @@ argparse::ArgumentParser manageArguments(int argc, char** argv) auto score = program.get("score"); auto complete = program.get("complete"); auto partial = program.get("partial"); + auto compare = program.get("compare"); } catch (const exception& err) { cerr << err.what() << endl; @@ -41,9 +43,10 @@ int main(int argc, char** argv) auto score = program.get("score"); auto complete = program.get("complete"); auto partial = program.get("partial"); + auto compare = program.get("compare"); if (complete) partial = false; - auto results = platform::Results(platform::Paths::results(), number, model, score, complete, partial); + auto results = platform::Results(platform::Paths::results(), number, model, score, complete, partial, compare); results.manage(); return 0; }