From 1f705f6018894fd9abc627c0b5522f33b8c208d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Mon, 23 Oct 2023 16:12:52 +0200 Subject: [PATCH] Refactor BestScore and add experiment to .env --- src/Platform/BestScore.h | 30 ++++++++++++++++++++++++------ src/Platform/ReportBase.cc | 1 - src/Platform/ReportConsole.cc | 6 ++++-- src/Platform/ReportConsole.h | 1 - src/Platform/ReportExcel.cc | 7 ++++--- src/Platform/Result.cc | 10 ++++++---- src/Platform/Results.cc | 1 - 7 files changed, 38 insertions(+), 18 deletions(-) diff --git a/src/Platform/BestScore.h b/src/Platform/BestScore.h index 4e649b2..bc3a719 100644 --- a/src/Platform/BestScore.h +++ b/src/Platform/BestScore.h @@ -1,10 +1,28 @@ #ifndef BESTSCORE_H #define BESTSCORE_H #include -class BestScore { -public: - static std::string title() { return "STree_default (linear-ovo)"; } - static double score() { return 22.109799; } - static std::string scoreName() { return "accuracy"; } -}; +#include +#include +#include "DotEnv.h" +namespace platform { + class BestScore { + public: + static pair getScore(const std::string& metric) + { + static map, pair> data = { + {{"discretiz", "accuracy"}, {"STree_default (linear-ovo)", 22.109799}}, + //{{"odte", "accuracy"}, {"STree_default (linear-ovo)", 22.109799}}, + }; + auto env = platform::DotEnv(); + string experiment = env.get("experiment"); + try { + return data[{experiment, metric}]; + } + catch (...) { + return { "", 0.0 }; + } + } + }; +} + #endif \ No newline at end of file diff --git a/src/Platform/ReportBase.cc b/src/Platform/ReportBase.cc index acb5581..2be08a5 100644 --- a/src/Platform/ReportBase.cc +++ b/src/Platform/ReportBase.cc @@ -2,7 +2,6 @@ #include #include "Datasets.h" #include "ReportBase.h" -#include "BestScore.h" #include "DotEnv.h" namespace platform { diff --git a/src/Platform/ReportConsole.cc b/src/Platform/ReportConsole.cc index e6144cc..c336d50 100644 --- a/src/Platform/ReportConsole.cc +++ b/src/Platform/ReportConsole.cc @@ -1,3 +1,4 @@ +#include #include #include #include "ReportConsole.h" @@ -94,9 +95,10 @@ namespace platform { cout << Colors::MAGENTA() << string(MAXL, '*') << endl; showSummary(); auto score = data["score_name"].get(); - if (score == BestScore::scoreName()) { + auto best = BestScore::getScore(score); + if (best.first != "") { stringstream oss; - oss << score << " compared to " << BestScore::title() << " .: " << totalScore / BestScore::score(); + oss << score << " compared to " << best.first << " .: " << totalScore / best.second; cout << headerLine(oss.str()); } if (!getExistBestFile() && compare) { diff --git a/src/Platform/ReportConsole.h b/src/Platform/ReportConsole.h index 127e269..a36dd03 100644 --- a/src/Platform/ReportConsole.h +++ b/src/Platform/ReportConsole.h @@ -1,7 +1,6 @@ #ifndef REPORTCONSOLE_H #define REPORTCONSOLE_H #include -#include #include "ReportBase.h" #include "Colors.h" diff --git a/src/Platform/ReportExcel.cc b/src/Platform/ReportExcel.cc index b2a900a..d091b10 100644 --- a/src/Platform/ReportExcel.cc +++ b/src/Platform/ReportExcel.cc @@ -163,9 +163,10 @@ namespace platform { showSummary(); row += 4 + summary.size(); auto score = data["score_name"].get(); - if (score == BestScore::scoreName()) { - worksheet_merge_range(worksheet, row, 1, row, 5, (score + " compared to " + BestScore::title() + " .:").c_str(), efectiveStyle("text")); - writeDouble(row, 6, totalScore / BestScore::score(), "result"); + auto best = BestScore::getScore(score); + if (best.first != "") { + worksheet_merge_range(worksheet, row, 1, row, 5, (score + " compared to " + best.first + " .:").c_str(), efectiveStyle("text")); + writeDouble(row, 6, totalScore / best.second, "result"); } if (!getExistBestFile() && compare) { worksheet_write_string(worksheet, row + 1, 0, "*** Best Results File not found. Couldn't compare any result!", styles["summaryStyle"]); diff --git a/src/Platform/Result.cc b/src/Platform/Result.cc index a185b56..a444877 100644 --- a/src/Platform/Result.cc +++ b/src/Platform/Result.cc @@ -1,9 +1,10 @@ +#include "Result.h" +#include "BestScore.h" #include #include #include -#include "Result.h" #include "Colors.h" -#include "BestScore.h" +#include "DotEnv.h" #include "CLocale.h" namespace platform { @@ -18,8 +19,9 @@ namespace platform { score += result["score"].get(); } scoreName = data["score_name"]; - if (scoreName == BestScore::scoreName()) { - score /= BestScore::score(); + auto best = BestScore::getScore(scoreName); + if (best.first != "") { + score /= best.second; } title = data["title"]; duration = data["duration"]; diff --git a/src/Platform/Results.cc b/src/Platform/Results.cc index 6f73a4c..e15dd8e 100644 --- a/src/Platform/Results.cc +++ b/src/Platform/Results.cc @@ -1,6 +1,5 @@ #include "Results.h" #include -#include "BestScore.h" namespace platform { Results::Results(const string& path, const string& model, const string& score, bool complete, bool partial, bool compare) :