Add compare to best results in manage

This commit is contained in:
Ricardo Montañana Gómez 2023-09-20 12:51:19 +02:00
parent 68f22a673d
commit 03533461c8
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
8 changed files with 73 additions and 27 deletions

View File

@ -6,7 +6,7 @@
namespace platform {
ReportBase::ReportBase(json data_) : margin(0.1), data(data_)
ReportBase::ReportBase(json data_, bool compare) : data(data_), compare(compare), margin(0.1)
{
stringstream oss;
oss << "Better than ZeroR + " << setprecision(1) << fixed << margin * 100 << "%";
@ -49,27 +49,59 @@ namespace platform {
string ReportBase::compareResult(const string& dataset, double result)
{
string status = " ";
if (data["score_name"].get<string>() == "accuracy") {
auto dt = Datasets(Paths::datasets(), false);
dt.loadDataset(dataset);
auto numClasses = dt.getNClasses(dataset);
if (numClasses == 2) {
vector<int> distribution = dt.getClassesCounts(dataset);
double nSamples = dt.getNSamples(dataset);
vector<int>::iterator maxValue = max_element(distribution.begin(), distribution.end());
double mark = *maxValue / nSamples * (1 + margin);
if (mark > 1) {
mark = 0.9995;
}
status = result < mark ? Symbols::cross : result > mark ? Symbols::upward_arrow : "=";
auto item = summary.find(status);
if (item != summary.end()) {
summary[status]++;
} else {
summary[status] = 1;
if (compare) {
double best = bestResult(dataset, data["model"].get<string>());
if (result == best) {
status = Symbols::equal_best;
} else if (result > best) {
status = Symbols::better_best;
}
} else {
if (data["score_name"].get<string>() == "accuracy") {
auto dt = Datasets(Paths::datasets(), false);
dt.loadDataset(dataset);
auto numClasses = dt.getNClasses(dataset);
if (numClasses == 2) {
vector<int> distribution = dt.getClassesCounts(dataset);
double nSamples = dt.getNSamples(dataset);
vector<int>::iterator maxValue = max_element(distribution.begin(), distribution.end());
double mark = *maxValue / nSamples * (1 + margin);
if (mark > 1) {
mark = 0.9995;
}
status = result < mark ? Symbols::cross : result > mark ? Symbols::upward_arrow : "=";
}
}
}
if (status != " ") {
auto item = summary.find(status);
if (item != summary.end()) {
summary[status]++;
} else {
summary[status] = 1;
}
}
return status;
}
double ReportBase::bestResult(const string& dataset, const string& model)
{
double value = 0.0;
if (bestResults.size() == 0) {
// try to load the best results
string score = data["score_name"];
replace(score.begin(), score.end(), '_', '-');
string fileName = "best_results_" + score + "_" + model + ".json";
ifstream resultData(Paths::results() + "/" + fileName);
if (resultData.is_open()) {
bestResults = json::parse(resultData);
}
}
try {
value = bestResults.at(dataset).at(0);
}
catch (exception) {
value = 1.0;
}
return value;
}
}

View File

@ -21,7 +21,7 @@ namespace platform {
};
class ReportBase {
public:
explicit ReportBase(json data_);
explicit ReportBase(json data_, bool compare);
virtual ~ReportBase() = default;
void show();
protected:
@ -35,6 +35,10 @@ namespace platform {
map<string, int> summary;
double margin;
map<string, string> meaning;
private:
double bestResult(const string& dataset, const string& model);
bool compare;
json bestResults;
};
};
#endif

View File

@ -10,7 +10,7 @@ namespace platform {
const int MAXL = 133;
class ReportConsole : public ReportBase {
public:
explicit ReportConsole(json data_, int index = -1) : ReportBase(data_), selectedIndex(index) {};
explicit ReportConsole(json data_, bool compare = false, int index = -1) : ReportBase(data_, compare), selectedIndex(index) {};
virtual ~ReportConsole() = default;
private:
int selectedIndex;

View File

@ -13,7 +13,7 @@ namespace platform {
string do_grouping() const { return "\03"; }
};
ReportExcel::ReportExcel(json data_, lxw_workbook* workbook) : ReportBase(data_), row(0), workbook(workbook)
ReportExcel::ReportExcel(json data_, bool compare, lxw_workbook* workbook) : ReportBase(data_, compare), row(0), workbook(workbook)
{
normalSize = 14; //font size for report body
colorTitle = 0xB1A0C7;

View File

@ -10,7 +10,7 @@ namespace platform {
class ReportExcel : public ReportBase {
public:
explicit ReportExcel(json data_, lxw_workbook* workbook);
explicit ReportExcel(json data_, bool compare, lxw_workbook* workbook);
lxw_workbook* getWorkbook();
private:
void writeString(int row, int col, const string& text, const string& style = "");

View File

@ -109,12 +109,12 @@ namespace platform {
cout << Colors::YELLOW() << "Reporting " << files.at(index).getFilename() << endl;
auto data = files.at(index).load();
if (excelReport) {
ReportExcel reporter(data, workbook);
ReportExcel reporter(data, compare, workbook);
reporter.show();
openExcel = true;
workbook = reporter.getWorkbook();
} else {
ReportConsole reporter(data);
ReportConsole reporter(data, compare);
reporter.show();
}
}
@ -150,6 +150,7 @@ namespace platform {
if (all_of(line.begin(), line.end(), ::isdigit)) {
int idx = stoi(line);
if (indexList) {
// The value is about the files list
index = idx;
if (index >= 0 && index < files.size()) {
report(index, false);
@ -157,6 +158,7 @@ namespace platform {
continue;
}
} else {
// The value is about the result showed on screen
showIndex(index, idx);
continue;
}

View File

@ -35,7 +35,11 @@ namespace platform {
};
class Results {
public:
Results(const string& path, const int max, const string& model, const string& score, bool complete, bool partial) : path(path), max(max), model(model), scoreName(score), complete(complete), partial(partial) { load(); };
Results(const string& path, const int max, const string& model, const string& score, bool complete, bool partial, bool compare) :
path(path), max(max), model(model), scoreName(score), complete(complete), partial(partial), compare(compare)
{
load();
};
void manage();
private:
string path;
@ -46,6 +50,7 @@ namespace platform {
bool partial;
bool indexList = true;
bool openExcel = false;
bool compare;
lxw_workbook* workbook = NULL;
vector<Result> files;
void load(); // Loads the list of results

View File

@ -14,6 +14,7 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
program.add_argument("-s", "--score").default_value("any").help("Filter results of the score name supplied");
program.add_argument("--complete").help("Show only results with all datasets").default_value(false).implicit_value(true);
program.add_argument("--partial").help("Show only partial results").default_value(false).implicit_value(true);
program.add_argument("--compare").help("Compare with best results").default_value(false).implicit_value(true);
try {
program.parse_args(argc, argv);
auto number = program.get<int>("number");
@ -24,6 +25,7 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
auto score = program.get<string>("score");
auto complete = program.get<bool>("complete");
auto partial = program.get<bool>("partial");
auto compare = program.get<bool>("compare");
}
catch (const exception& err) {
cerr << err.what() << endl;
@ -41,9 +43,10 @@ int main(int argc, char** argv)
auto score = program.get<string>("score");
auto complete = program.get<bool>("complete");
auto partial = program.get<bool>("partial");
auto compare = program.get<bool>("compare");
if (complete)
partial = false;
auto results = platform::Results(platform::Paths::results(), number, model, score, complete, partial);
auto results = platform::Results(platform::Paths::results(), number, model, score, complete, partial, compare);
results.manage();
return 0;
}