Add compare to best results in manage

This commit is contained in:
Ricardo Montañana Gómez 2023-09-20 12:51:19 +02:00
parent 68f22a673d
commit 03533461c8
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
8 changed files with 73 additions and 27 deletions

View File

@ -6,7 +6,7 @@
namespace platform { namespace platform {
ReportBase::ReportBase(json data_) : margin(0.1), data(data_) ReportBase::ReportBase(json data_, bool compare) : data(data_), compare(compare), margin(0.1)
{ {
stringstream oss; stringstream oss;
oss << "Better than ZeroR + " << setprecision(1) << fixed << margin * 100 << "%"; oss << "Better than ZeroR + " << setprecision(1) << fixed << margin * 100 << "%";
@ -49,6 +49,14 @@ namespace platform {
string ReportBase::compareResult(const string& dataset, double result) string ReportBase::compareResult(const string& dataset, double result)
{ {
string status = " "; string status = " ";
if (compare) {
double best = bestResult(dataset, data["model"].get<string>());
if (result == best) {
status = Symbols::equal_best;
} else if (result > best) {
status = Symbols::better_best;
}
} else {
if (data["score_name"].get<string>() == "accuracy") { if (data["score_name"].get<string>() == "accuracy") {
auto dt = Datasets(Paths::datasets(), false); auto dt = Datasets(Paths::datasets(), false);
dt.loadDataset(dataset); dt.loadDataset(dataset);
@ -62,6 +70,10 @@ namespace platform {
mark = 0.9995; mark = 0.9995;
} }
status = result < mark ? Symbols::cross : result > mark ? Symbols::upward_arrow : "="; status = result < mark ? Symbols::cross : result > mark ? Symbols::upward_arrow : "=";
}
}
}
if (status != " ") {
auto item = summary.find(status); auto item = summary.find(status);
if (item != summary.end()) { if (item != summary.end()) {
summary[status]++; summary[status]++;
@ -69,7 +81,27 @@ namespace platform {
summary[status] = 1; summary[status] = 1;
} }
} }
}
return status; return status;
} }
double ReportBase::bestResult(const string& dataset, const string& model)
{
double value = 0.0;
if (bestResults.size() == 0) {
// try to load the best results
string score = data["score_name"];
replace(score.begin(), score.end(), '_', '-');
string fileName = "best_results_" + score + "_" + model + ".json";
ifstream resultData(Paths::results() + "/" + fileName);
if (resultData.is_open()) {
bestResults = json::parse(resultData);
}
}
try {
value = bestResults.at(dataset).at(0);
}
catch (exception) {
value = 1.0;
}
return value;
}
} }

View File

@ -21,7 +21,7 @@ namespace platform {
}; };
class ReportBase { class ReportBase {
public: public:
explicit ReportBase(json data_); explicit ReportBase(json data_, bool compare);
virtual ~ReportBase() = default; virtual ~ReportBase() = default;
void show(); void show();
protected: protected:
@ -35,6 +35,10 @@ namespace platform {
map<string, int> summary; map<string, int> summary;
double margin; double margin;
map<string, string> meaning; map<string, string> meaning;
private:
double bestResult(const string& dataset, const string& model);
bool compare;
json bestResults;
}; };
}; };
#endif #endif

View File

@ -10,7 +10,7 @@ namespace platform {
const int MAXL = 133; const int MAXL = 133;
class ReportConsole : public ReportBase { class ReportConsole : public ReportBase {
public: public:
explicit ReportConsole(json data_, int index = -1) : ReportBase(data_), selectedIndex(index) {}; explicit ReportConsole(json data_, bool compare = false, int index = -1) : ReportBase(data_, compare), selectedIndex(index) {};
virtual ~ReportConsole() = default; virtual ~ReportConsole() = default;
private: private:
int selectedIndex; int selectedIndex;

View File

@ -13,7 +13,7 @@ namespace platform {
string do_grouping() const { return "\03"; } string do_grouping() const { return "\03"; }
}; };
ReportExcel::ReportExcel(json data_, lxw_workbook* workbook) : ReportBase(data_), row(0), workbook(workbook) ReportExcel::ReportExcel(json data_, bool compare, lxw_workbook* workbook) : ReportBase(data_, compare), row(0), workbook(workbook)
{ {
normalSize = 14; //font size for report body normalSize = 14; //font size for report body
colorTitle = 0xB1A0C7; colorTitle = 0xB1A0C7;

View File

@ -10,7 +10,7 @@ namespace platform {
class ReportExcel : public ReportBase { class ReportExcel : public ReportBase {
public: public:
explicit ReportExcel(json data_, lxw_workbook* workbook); explicit ReportExcel(json data_, bool compare, lxw_workbook* workbook);
lxw_workbook* getWorkbook(); lxw_workbook* getWorkbook();
private: private:
void writeString(int row, int col, const string& text, const string& style = ""); void writeString(int row, int col, const string& text, const string& style = "");

View File

@ -109,12 +109,12 @@ namespace platform {
cout << Colors::YELLOW() << "Reporting " << files.at(index).getFilename() << endl; cout << Colors::YELLOW() << "Reporting " << files.at(index).getFilename() << endl;
auto data = files.at(index).load(); auto data = files.at(index).load();
if (excelReport) { if (excelReport) {
ReportExcel reporter(data, workbook); ReportExcel reporter(data, compare, workbook);
reporter.show(); reporter.show();
openExcel = true; openExcel = true;
workbook = reporter.getWorkbook(); workbook = reporter.getWorkbook();
} else { } else {
ReportConsole reporter(data); ReportConsole reporter(data, compare);
reporter.show(); reporter.show();
} }
} }
@ -150,6 +150,7 @@ namespace platform {
if (all_of(line.begin(), line.end(), ::isdigit)) { if (all_of(line.begin(), line.end(), ::isdigit)) {
int idx = stoi(line); int idx = stoi(line);
if (indexList) { if (indexList) {
// The value is about the files list
index = idx; index = idx;
if (index >= 0 && index < files.size()) { if (index >= 0 && index < files.size()) {
report(index, false); report(index, false);
@ -157,6 +158,7 @@ namespace platform {
continue; continue;
} }
} else { } else {
// The value is about the result showed on screen
showIndex(index, idx); showIndex(index, idx);
continue; continue;
} }

View File

@ -35,7 +35,11 @@ namespace platform {
}; };
class Results { class Results {
public: public:
Results(const string& path, const int max, const string& model, const string& score, bool complete, bool partial) : path(path), max(max), model(model), scoreName(score), complete(complete), partial(partial) { load(); }; Results(const string& path, const int max, const string& model, const string& score, bool complete, bool partial, bool compare) :
path(path), max(max), model(model), scoreName(score), complete(complete), partial(partial), compare(compare)
{
load();
};
void manage(); void manage();
private: private:
string path; string path;
@ -46,6 +50,7 @@ namespace platform {
bool partial; bool partial;
bool indexList = true; bool indexList = true;
bool openExcel = false; bool openExcel = false;
bool compare;
lxw_workbook* workbook = NULL; lxw_workbook* workbook = NULL;
vector<Result> files; vector<Result> files;
void load(); // Loads the list of results void load(); // Loads the list of results

View File

@ -14,6 +14,7 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
program.add_argument("-s", "--score").default_value("any").help("Filter results of the score name supplied"); program.add_argument("-s", "--score").default_value("any").help("Filter results of the score name supplied");
program.add_argument("--complete").help("Show only results with all datasets").default_value(false).implicit_value(true); program.add_argument("--complete").help("Show only results with all datasets").default_value(false).implicit_value(true);
program.add_argument("--partial").help("Show only partial results").default_value(false).implicit_value(true); program.add_argument("--partial").help("Show only partial results").default_value(false).implicit_value(true);
program.add_argument("--compare").help("Compare with best results").default_value(false).implicit_value(true);
try { try {
program.parse_args(argc, argv); program.parse_args(argc, argv);
auto number = program.get<int>("number"); auto number = program.get<int>("number");
@ -24,6 +25,7 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
auto score = program.get<string>("score"); auto score = program.get<string>("score");
auto complete = program.get<bool>("complete"); auto complete = program.get<bool>("complete");
auto partial = program.get<bool>("partial"); auto partial = program.get<bool>("partial");
auto compare = program.get<bool>("compare");
} }
catch (const exception& err) { catch (const exception& err) {
cerr << err.what() << endl; cerr << err.what() << endl;
@ -41,9 +43,10 @@ int main(int argc, char** argv)
auto score = program.get<string>("score"); auto score = program.get<string>("score");
auto complete = program.get<bool>("complete"); auto complete = program.get<bool>("complete");
auto partial = program.get<bool>("partial"); auto partial = program.get<bool>("partial");
auto compare = program.get<bool>("compare");
if (complete) if (complete)
partial = false; partial = false;
auto results = platform::Results(platform::Paths::results(), number, model, score, complete, partial); auto results = platform::Results(platform::Paths::results(), number, model, score, complete, partial, compare);
results.manage(); results.manage();
return 0; return 0;
} }