Add comparison to report console

This commit is contained in:
Ricardo Montañana Gómez 2023-09-20 11:40:01 +02:00
parent b9bc0088f3
commit 68f22a673d
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
6 changed files with 82 additions and 62 deletions

View File

@ -1,10 +1,22 @@
#include <sstream>
#include <locale>
#include "Datasets.h"
#include "ReportBase.h"
#include "BestResult.h"
namespace platform {
ReportBase::ReportBase(json data_) : margin(0.1), data(data_)
{
stringstream oss;
oss << "Better than ZeroR + " << setprecision(1) << fixed << margin * 100 << "%";
meaning = {
{Symbols::equal_best, "Equal to best"},
{Symbols::better_best, "Better than best"},
{Symbols::cross, "Less than or equal to ZeroR"},
{Symbols::upward_arrow, oss.str()}
};
}
string ReportBase::fromVector(const string& key)
{
stringstream oss;
@ -34,4 +46,30 @@ namespace platform {
header();
body();
}
string ReportBase::compareResult(const string& dataset, double result)
{
string status = " ";
if (data["score_name"].get<string>() == "accuracy") {
auto dt = Datasets(Paths::datasets(), false);
dt.loadDataset(dataset);
auto numClasses = dt.getNClasses(dataset);
if (numClasses == 2) {
vector<int> distribution = dt.getClassesCounts(dataset);
double nSamples = dt.getNSamples(dataset);
vector<int>::iterator maxValue = max_element(distribution.begin(), distribution.end());
double mark = *maxValue / nSamples * (1 + margin);
if (mark > 1) {
mark = 0.9995;
}
status = result < mark ? Symbols::cross : result > mark ? Symbols::upward_arrow : "=";
auto item = summary.find(status);
if (item != summary.end()) {
summary[status]++;
} else {
summary[status] = 1;
}
}
}
return status;
}
}

View File

@ -2,14 +2,26 @@
#define REPORTBASE_H
#include <string>
#include <iostream>
#include "Paths.h"
#include <nlohmann/json.hpp>
using json = nlohmann::json;
namespace platform {
using namespace std;
class Symbols {
public:
inline static const string check_mark{ "\u2714" };
inline static const string exclamation{ "\u2757" };
inline static const string black_star{ "\u2605" };
inline static const string cross{ "\u2717" };
inline static const string upward_arrow{ "\u27B6" };
inline static const string down_arrow{ "\u27B4" };
inline static const string equal_best{ check_mark };
inline static const string better_best{ black_star };
};
class ReportBase {
public:
explicit ReportBase(json data_) { data = data_; };
explicit ReportBase(json data_);
virtual ~ReportBase() = default;
void show();
protected:
@ -18,6 +30,11 @@ namespace platform {
string fVector(const string& title, const json& data, const int width, const int precision);
virtual void header() = 0;
virtual void body() = 0;
virtual void showSummary() = 0;
string compareResult(const string& dataset, double result);
map<string, int> summary;
double margin;
map<string, string> meaning;
};
};
#endif

View File

@ -11,11 +11,11 @@ namespace platform {
string do_grouping() const { return "\03"; }
};
string ReportConsole::headerLine(const string& text)
string ReportConsole::headerLine(const string& text, int utf = 0)
{
int n = MAXL - text.length() - 3;
n = n < 0 ? 0 : n;
return "* " + text + string(n, ' ') + "*\n";
return "* " + text + string(n + utf, ' ') + "*\n";
}
void ReportConsole::header()
@ -36,8 +36,8 @@ namespace platform {
}
void ReportConsole::body()
{
cout << Colors::GREEN() << " # Dataset Sampl. Feat. Cls Nodes Edges States Score Time Hyperparameters" << endl;
cout << "=== ============================== ====== ===== === ========= ========= ========= =============== ================== ===============" << endl;
cout << Colors::GREEN() << " # Dataset Sampl. Feat. Cls Nodes Edges States Score Time Hyperparameters" << endl;
cout << "=== ========================= ====== ===== === ========= ========= ========= =============== =================== ====================" << endl;
json lastResult;
double totalScore = 0.0;
bool odd = true;
@ -50,15 +50,17 @@ namespace platform {
auto color = odd ? Colors::CYAN() : Colors::BLUE();
cout << color;
cout << setw(3) << index++ << " ";
cout << setw(30) << left << r["dataset"].get<string>() << " ";
cout << setw(25) << left << r["dataset"].get<string>() << " ";
cout << setw(6) << right << r["samples"].get<int>() << " ";
cout << setw(5) << right << r["features"].get<int>() << " ";
cout << setw(3) << right << r["classes"].get<int>() << " ";
cout << setw(9) << setprecision(2) << fixed << r["nodes"].get<float>() << " ";
cout << setw(9) << setprecision(2) << fixed << r["leaves"].get<float>() << " ";
cout << setw(9) << setprecision(2) << fixed << r["depth"].get<float>() << " ";
cout << setw(8) << right << setprecision(6) << fixed << r["score"].get<double>() << "±" << setw(6) << setprecision(4) << fixed << r["score_std"].get<double>() << " ";
cout << setw(11) << right << setprecision(6) << fixed << r["time"].get<double>() << "±" << setw(6) << setprecision(4) << fixed << r["time_std"].get<double>() << " ";
cout << setw(8) << right << setprecision(6) << fixed << r["score"].get<double>() << "±" << setw(6) << setprecision(4) << fixed << r["score_std"].get<double>();
const string status = compareResult(r["dataset"].get<string>(), r["score"].get<double>());
cout << status;
cout << setw(12) << right << setprecision(6) << fixed << r["time"].get<double>() << "±" << setw(6) << setprecision(4) << fixed << r["time_std"].get<double>() << " ";
try {
cout << r["hyperparameters"].get<string>();
}
@ -81,9 +83,21 @@ namespace platform {
footer(totalScore);
}
}
void ReportConsole::showSummary()
{
for (const auto& item : summary) {
stringstream oss;
oss << setw(3) << left << item.first;
oss << setw(3) << right << item.second << " ";
oss << left << meaning.at(item.first);
cout << headerLine(oss.str(), 2);
}
}
void ReportConsole::footer(double totalScore)
{
cout << Colors::MAGENTA() << string(MAXL, '*') << endl;
showSummary();
auto score = data["score_name"].get<string>();
if (score == BestResult::scoreName()) {
stringstream oss;

View File

@ -7,17 +7,18 @@
namespace platform {
using namespace std;
const int MAXL = 132;
const int MAXL = 133;
class ReportConsole : public ReportBase {
public:
explicit ReportConsole(json data_, int index = -1) : ReportBase(data_), selectedIndex(index) {};
virtual ~ReportConsole() = default;
private:
int selectedIndex;
string headerLine(const string& text);
string headerLine(const string& text, int utf);
void header() override;
void body() override;
void footer(double totalScore);
void showSummary();
};
};
#endif

View File

@ -1,6 +1,5 @@
#include <sstream>
#include <locale>
#include "Datasets.h"
#include "ReportExcel.h"
#include "BestResult.h"
@ -20,7 +19,6 @@ namespace platform {
colorTitle = 0xB1A0C7;
colorOdd = 0xDCE6F1;
colorEven = 0xFDE9D9;
margin = .1; // margin to add to ZeroR comparison
createFile();
}
@ -308,43 +306,9 @@ namespace platform {
footer(totalScore, row);
}
}
string ReportExcel::compareResult(const string& dataset, double result)
{
string status = " ";
if (data["score_name"].get<string>() == "accuracy") {
auto dt = Datasets(Paths::datasets(), false);
dt.loadDataset(dataset);
auto numClasses = dt.getNClasses(dataset);
if (numClasses == 2) {
vector<int> distribution = dt.getClassesCounts(dataset);
double nSamples = dt.getNSamples(dataset);
vector<int>::iterator maxValue = max_element(distribution.begin(), distribution.end());
double mark = *maxValue / nSamples * (1 + margin);
if (mark > 1) {
mark = 0.9995;
}
status = result < mark ? Symbols::cross : result > mark ? Symbols::upward_arrow : "=";
auto item = summary.find(status);
if (item != summary.end()) {
summary[status]++;
} else {
summary[status] = 1;
}
}
}
return status;
}
void ReportExcel::showSummary()
{
stringstream oss;
oss << "Better than ZeroR + " << setprecision(1) << fixed << margin * 100 << "%";
map<string, string> meaning = {
{Symbols::equal_best, "Equal to best"},
{Symbols::better_best, "Better than best"},
{Symbols::cross, "Less than or equal to ZeroR"},
{Symbols::upward_arrow, oss.str()}
};
for (const auto& item : summary) {
worksheet_write_string(worksheet, row + 2, 1, item.first.c_str(), styles["summaryStyle"]);
worksheet_write_number(worksheet, row + 2, 2, item.second, styles["summaryStyle"]);

View File

@ -3,22 +3,11 @@
#include<map>
#include "xlsxwriter.h"
#include "ReportBase.h"
#include "Paths.h"
#include "Colors.h"
namespace platform {
using namespace std;
const int MAXLL = 128;
class Symbols {
public:
inline static const string check_mark{ "\u2714" };
inline static const string exclamation{ "\u2757" };
inline static const string black_star{ "\u2605" };
inline static const string cross{ "\u2717" };
inline static const string upward_arrow{ "\u27B6" };
inline static const string down_arrow{ "\u27B4" };
inline static const string equal_best{ check_mark };
inline static const string better_best{ black_star };
};
class ReportExcel : public ReportBase {
public:
explicit ReportExcel(json data_, lxw_workbook* workbook);
@ -36,13 +25,11 @@ namespace platform {
lxw_workbook* workbook;
lxw_worksheet* worksheet;
map<string, lxw_format*> styles;
map<string, int> summary;
int row;
int normalSize; //font size for report body
uint32_t colorTitle;
uint32_t colorOdd;
uint32_t colorEven;
double margin;
const string fileName = "some_results.xlsx";
void header() override;
void body() override;
@ -50,7 +37,6 @@ namespace platform {
void createStyle(const string& name, lxw_format* style, bool odd);
void addColor(lxw_format* style, bool odd);
lxw_format* efectiveStyle(const string& name);
string compareResult(const string& dataset, double result);
};
};
#endif // !REPORTEXCEL_H