Add totals and filter by scoreName and model

This commit is contained in:
Ricardo Montañana Gómez 2023-08-13 18:13:00 +02:00
parent 054567c65a
commit 3691cb4a61
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
6 changed files with 64 additions and 12 deletions

10
src/Platform/BestResult.h Normal file
View File

@ -0,0 +1,10 @@
#ifndef BESTRESULT_H
#define BESTRESULT_H
#include <string>
class BestResult {
public:
static std::string title() { return "STree_default (linear-ovo)"; }
static double score() { return 22.109799; }
static std::string scoreName() { return "accuracy"; }
};
#endif

View File

@ -1,4 +1,5 @@
#include "Report.h"
#include "BestResult.h"
namespace platform {
string headerLine(const string& text)
@ -28,6 +29,7 @@ namespace platform {
{
header();
body();
footer();
}
void Report::header()
{
@ -44,6 +46,8 @@ namespace platform {
{
cout << "Dataset Sampl. Feat. Cls Nodes Edges States Score Time Hyperparameters" << endl;
cout << "============================== ====== ===== === ======= ======= ======= =============== ================= ===============" << endl;
json lastResult;
totalScore = 0;
for (const auto& r : data["results"]) {
cout << setw(30) << left << r["dataset"].get<string>() << " ";
cout << setw(6) << right << r["samples"].get<int>() << " ";
@ -56,12 +60,26 @@ namespace platform {
cout << setw(10) << right << setprecision(6) << fixed << r["test_time"].get<double>() << "±" << setw(6) << setprecision(4) << fixed << r["test_time_std"].get<double>() << " ";
cout << " " << r["hyperparameters"].get<string>();
cout << endl;
lastResult = r;
totalScore += r["score_test"].get<double>();
}
if (data["results"].size() == 1) {
cout << string(MAXL, '*') << endl;
cout << headerLine("Train scores: " + fVector(r["scores_train"]));
cout << headerLine("Test scores: " + fVector(r["scores_test"]));
cout << headerLine("Train times: " + fVector(r["times_train"]));
cout << headerLine("Test times: " + fVector(r["times_test"]));
cout << headerLine("Train scores: " + fVector(lastResult["scores_train"]));
cout << headerLine("Test scores: " + fVector(lastResult["scores_test"]));
cout << headerLine("Train times: " + fVector(lastResult["times_train"]));
cout << headerLine("Test times: " + fVector(lastResult["times_test"]));
cout << string(MAXL, '*') << endl;
}
}
void Report::footer()
{
cout << string(MAXL, '*') << endl;
auto score = data["score_name"].get<string>();
if (score == BestResult::scoreName()) {
cout << headerLine(score + " compared to " + BestResult::title() + " .: " + to_string(totalScore / BestResult::score()));
}
cout << string(MAXL, '*') << endl;
}
}

View File

@ -16,8 +16,10 @@ namespace platform {
private:
void header();
void body();
void footer();
string fromVector(const string& key);
json data;
double totalScore; // Total score of all results in a report
};
};
#endif

View File

@ -2,8 +2,8 @@
#include "platformUtils.h"
#include "Results.h"
#include "Report.h"
#include "BestResult.h"
namespace platform {
const double REFERENCE_SCORE = 22.109799;
Result::Result(const string& path, const string& filename)
: path(path)
, filename(filename)
@ -14,7 +14,10 @@ namespace platform {
for (const auto& result : data["results"]) {
score += result["score"].get<double>();
}
score /= REFERENCE_SCORE;
scoreName = data["score_name"];
if (scoreName == BestResult::scoreName()) {
score /= BestResult::score();
}
title = data["title"];
duration = data["duration"];
model = data["model"];
@ -35,7 +38,11 @@ namespace platform {
auto filename = file.path().filename().string();
if (filename.find(".json") != string::npos && filename.find("results_") == 0) {
auto result = Result(path, filename);
files.push_back(result);
bool addResult = true;
if (model != "any" && result.getModel() != model || scoreName != "any" && scoreName != result.getScoreName())
addResult = false;
if (addResult)
files.push_back(result);
}
}
}
@ -44,7 +51,8 @@ namespace platform {
stringstream oss;
oss << date << " ";
oss << setw(12) << left << model << " ";
oss << right << setw(9) << setprecision(7) << fixed << score << " ";
oss << setw(11) << left << scoreName << " ";
oss << right << setw(11) << setprecision(7) << fixed << score << " ";
oss << setw(9) << setprecision(3) << fixed << duration << " ";
oss << setw(50) << left << title << " ";
return oss.str();
@ -54,8 +62,8 @@ namespace platform {
cout << "Results found: " << files.size() << endl;
cout << "-------------------" << endl;
auto i = 0;
cout << " # Date Model Score Duration Title" << endl;
cout << "=== ========== ============ ========= ========= =============================================================" << endl;
cout << " # Date Model Score Name Score Duration Title" << endl;
cout << "=== ========== ============ =========== =========== ========= =============================================================" << endl;
for (const auto& result : files) {
cout << setw(3) << fixed << right << i++ << " ";
cout << result.to_string() << endl;
@ -181,6 +189,10 @@ namespace platform {
}
void Results::manage()
{
if (files.size() == 0) {
cout << "No results found!" << endl;
exit(0);
}
show();
menu();
}

View File

@ -19,6 +19,7 @@ namespace platform {
string getTitle() const { return title; };
double getDuration() const { return duration; };
string getModel() const { return model; };
string getScoreName() const { return scoreName; };
private:
string path;
string filename;
@ -27,14 +28,17 @@ namespace platform {
string title;
double duration;
string model;
string scoreName;
};
class Results {
public:
explicit Results(const string& path, const int max) : path(path), max(max) { load(); };
Results(const string& path, const int max, const string& model, const string& score) : path(path), max(max), model(model), scoreName(score) { load(); };
void manage();
private:
string path;
int max;
string model;
string scoreName;
vector<Result> files;
void load(); // Loads the list of results
void show() const;

View File

@ -10,12 +10,16 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
{
argparse::ArgumentParser program("manage");
program.add_argument("-n", "--number").default_value(0).help("Number of results to show (0 = all)").scan<'i', int>();
program.add_argument("-m", "--model").default_value("any").help("Filter results of the selected model)");
program.add_argument("-s", "--score").default_value("any").help("Filter results of the score name supplied");
try {
program.parse_args(argc, argv);
auto number = program.get<int>("number");
if (number < 0) {
throw runtime_error("Number of results must be greater than or equal to 0");
}
auto model = program.get<string>("model");
auto score = program.get<string>("score");
}
catch (const exception& err) {
cerr << err.what() << endl;
@ -29,7 +33,9 @@ int main(int argc, char** argv)
{
auto program = manageArguments(argc, argv);
auto number = program.get<int>("number");
auto results = platform::Results(PATH_RESULTS, number);
auto model = program.get<string>("model");
auto score = program.get<string>("score");
auto results = platform::Results(PATH_RESULTS, number, model, score);
results.manage();
return 0;
}