From 3691cb4a614b96d0e0145fadb8d365c25f2a23fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Sun, 13 Aug 2023 18:13:00 +0200 Subject: [PATCH] Add totals and filter by scoreName and model --- src/Platform/BestResult.h | 10 ++++++++++ src/Platform/Report.cc | 26 ++++++++++++++++++++++---- src/Platform/Report.h | 2 ++ src/Platform/Results.cc | 24 ++++++++++++++++++------ src/Platform/Results.h | 6 +++++- src/Platform/manage.cc | 8 +++++++- 6 files changed, 64 insertions(+), 12 deletions(-) create mode 100644 src/Platform/BestResult.h diff --git a/src/Platform/BestResult.h b/src/Platform/BestResult.h new file mode 100644 index 0000000..8b3f1cb --- /dev/null +++ b/src/Platform/BestResult.h @@ -0,0 +1,10 @@ +#ifndef BESTRESULT_H +#define BESTRESULT_H +#include +class BestResult { +public: + static std::string title() { return "STree_default (linear-ovo)"; } + static double score() { return 22.109799; } + static std::string scoreName() { return "accuracy"; } +}; +#endif \ No newline at end of file diff --git a/src/Platform/Report.cc b/src/Platform/Report.cc index 3693248..7bd7d69 100644 --- a/src/Platform/Report.cc +++ b/src/Platform/Report.cc @@ -1,4 +1,5 @@ #include "Report.h" +#include "BestResult.h" namespace platform { string headerLine(const string& text) @@ -28,6 +29,7 @@ namespace platform { { header(); body(); + footer(); } void Report::header() { @@ -44,6 +46,8 @@ namespace platform { { cout << "Dataset Sampl. Feat. Cls Nodes Edges States Score Time Hyperparameters" << endl; cout << "============================== ====== ===== === ======= ======= ======= =============== ================= ===============" << endl; + json lastResult; + totalScore = 0; for (const auto& r : data["results"]) { cout << setw(30) << left << r["dataset"].get() << " "; cout << setw(6) << right << r["samples"].get() << " "; @@ -56,12 +60,26 @@ namespace platform { cout << setw(10) << right << setprecision(6) << fixed << r["test_time"].get() << "±" << setw(6) << setprecision(4) << fixed << r["test_time_std"].get() << " "; cout << " " << r["hyperparameters"].get(); cout << endl; + lastResult = r; + totalScore += r["score_test"].get(); + } + if (data["results"].size() == 1) { cout << string(MAXL, '*') << endl; - cout << headerLine("Train scores: " + fVector(r["scores_train"])); - cout << headerLine("Test scores: " + fVector(r["scores_test"])); - cout << headerLine("Train times: " + fVector(r["times_train"])); - cout << headerLine("Test times: " + fVector(r["times_test"])); + cout << headerLine("Train scores: " + fVector(lastResult["scores_train"])); + cout << headerLine("Test scores: " + fVector(lastResult["scores_test"])); + cout << headerLine("Train times: " + fVector(lastResult["times_train"])); + cout << headerLine("Test times: " + fVector(lastResult["times_test"])); cout << string(MAXL, '*') << endl; } } + void Report::footer() + { + cout << string(MAXL, '*') << endl; + auto score = data["score_name"].get(); + if (score == BestResult::scoreName()) { + cout << headerLine(score + " compared to " + BestResult::title() + " .: " + to_string(totalScore / BestResult::score())); + } + cout << string(MAXL, '*') << endl; + + } } \ No newline at end of file diff --git a/src/Platform/Report.h b/src/Platform/Report.h index c6ea8a1..302ac60 100644 --- a/src/Platform/Report.h +++ b/src/Platform/Report.h @@ -16,8 +16,10 @@ namespace platform { private: void header(); void body(); + void footer(); string fromVector(const string& key); json data; + double totalScore; // Total score of all results in a report }; }; #endif \ No newline at end of file diff --git a/src/Platform/Results.cc b/src/Platform/Results.cc index 056f0b1..c33cf37 100644 --- a/src/Platform/Results.cc +++ b/src/Platform/Results.cc @@ -2,8 +2,8 @@ #include "platformUtils.h" #include "Results.h" #include "Report.h" +#include "BestResult.h" namespace platform { - const double REFERENCE_SCORE = 22.109799; Result::Result(const string& path, const string& filename) : path(path) , filename(filename) @@ -14,7 +14,10 @@ namespace platform { for (const auto& result : data["results"]) { score += result["score"].get(); } - score /= REFERENCE_SCORE; + scoreName = data["score_name"]; + if (scoreName == BestResult::scoreName()) { + score /= BestResult::score(); + } title = data["title"]; duration = data["duration"]; model = data["model"]; @@ -35,7 +38,11 @@ namespace platform { auto filename = file.path().filename().string(); if (filename.find(".json") != string::npos && filename.find("results_") == 0) { auto result = Result(path, filename); - files.push_back(result); + bool addResult = true; + if (model != "any" && result.getModel() != model || scoreName != "any" && scoreName != result.getScoreName()) + addResult = false; + if (addResult) + files.push_back(result); } } } @@ -44,7 +51,8 @@ namespace platform { stringstream oss; oss << date << " "; oss << setw(12) << left << model << " "; - oss << right << setw(9) << setprecision(7) << fixed << score << " "; + oss << setw(11) << left << scoreName << " "; + oss << right << setw(11) << setprecision(7) << fixed << score << " "; oss << setw(9) << setprecision(3) << fixed << duration << " "; oss << setw(50) << left << title << " "; return oss.str(); @@ -54,8 +62,8 @@ namespace platform { cout << "Results found: " << files.size() << endl; cout << "-------------------" << endl; auto i = 0; - cout << " # Date Model Score Duration Title" << endl; - cout << "=== ========== ============ ========= ========= =============================================================" << endl; + cout << " # Date Model Score Name Score Duration Title" << endl; + cout << "=== ========== ============ =========== =========== ========= =============================================================" << endl; for (const auto& result : files) { cout << setw(3) << fixed << right << i++ << " "; cout << result.to_string() << endl; @@ -181,6 +189,10 @@ namespace platform { } void Results::manage() { + if (files.size() == 0) { + cout << "No results found!" << endl; + exit(0); + } show(); menu(); } diff --git a/src/Platform/Results.h b/src/Platform/Results.h index 945901f..bd4768b 100644 --- a/src/Platform/Results.h +++ b/src/Platform/Results.h @@ -19,6 +19,7 @@ namespace platform { string getTitle() const { return title; }; double getDuration() const { return duration; }; string getModel() const { return model; }; + string getScoreName() const { return scoreName; }; private: string path; string filename; @@ -27,14 +28,17 @@ namespace platform { string title; double duration; string model; + string scoreName; }; class Results { public: - explicit Results(const string& path, const int max) : path(path), max(max) { load(); }; + Results(const string& path, const int max, const string& model, const string& score) : path(path), max(max), model(model), scoreName(score) { load(); }; void manage(); private: string path; int max; + string model; + string scoreName; vector files; void load(); // Loads the list of results void show() const; diff --git a/src/Platform/manage.cc b/src/Platform/manage.cc index f97dae3..74e4a2c 100644 --- a/src/Platform/manage.cc +++ b/src/Platform/manage.cc @@ -10,12 +10,16 @@ argparse::ArgumentParser manageArguments(int argc, char** argv) { argparse::ArgumentParser program("manage"); program.add_argument("-n", "--number").default_value(0).help("Number of results to show (0 = all)").scan<'i', int>(); + program.add_argument("-m", "--model").default_value("any").help("Filter results of the selected model)"); + program.add_argument("-s", "--score").default_value("any").help("Filter results of the score name supplied"); try { program.parse_args(argc, argv); auto number = program.get("number"); if (number < 0) { throw runtime_error("Number of results must be greater than or equal to 0"); } + auto model = program.get("model"); + auto score = program.get("score"); } catch (const exception& err) { cerr << err.what() << endl; @@ -29,7 +33,9 @@ int main(int argc, char** argv) { auto program = manageArguments(argc, argv); auto number = program.get("number"); - auto results = platform::Results(PATH_RESULTS, number); + auto model = program.get("model"); + auto score = program.get("score"); + auto results = platform::Results(PATH_RESULTS, number, model, score); results.manage(); return 0; }