Add compare to best results in manage
This commit is contained in:
parent
68f22a673d
commit
03533461c8
@ -6,7 +6,7 @@
|
||||
|
||||
|
||||
namespace platform {
|
||||
ReportBase::ReportBase(json data_) : margin(0.1), data(data_)
|
||||
ReportBase::ReportBase(json data_, bool compare) : data(data_), compare(compare), margin(0.1)
|
||||
{
|
||||
stringstream oss;
|
||||
oss << "Better than ZeroR + " << setprecision(1) << fixed << margin * 100 << "%";
|
||||
@ -49,27 +49,59 @@ namespace platform {
|
||||
string ReportBase::compareResult(const string& dataset, double result)
|
||||
{
|
||||
string status = " ";
|
||||
if (data["score_name"].get<string>() == "accuracy") {
|
||||
auto dt = Datasets(Paths::datasets(), false);
|
||||
dt.loadDataset(dataset);
|
||||
auto numClasses = dt.getNClasses(dataset);
|
||||
if (numClasses == 2) {
|
||||
vector<int> distribution = dt.getClassesCounts(dataset);
|
||||
double nSamples = dt.getNSamples(dataset);
|
||||
vector<int>::iterator maxValue = max_element(distribution.begin(), distribution.end());
|
||||
double mark = *maxValue / nSamples * (1 + margin);
|
||||
if (mark > 1) {
|
||||
mark = 0.9995;
|
||||
}
|
||||
status = result < mark ? Symbols::cross : result > mark ? Symbols::upward_arrow : "=";
|
||||
auto item = summary.find(status);
|
||||
if (item != summary.end()) {
|
||||
summary[status]++;
|
||||
} else {
|
||||
summary[status] = 1;
|
||||
if (compare) {
|
||||
double best = bestResult(dataset, data["model"].get<string>());
|
||||
if (result == best) {
|
||||
status = Symbols::equal_best;
|
||||
} else if (result > best) {
|
||||
status = Symbols::better_best;
|
||||
}
|
||||
} else {
|
||||
if (data["score_name"].get<string>() == "accuracy") {
|
||||
auto dt = Datasets(Paths::datasets(), false);
|
||||
dt.loadDataset(dataset);
|
||||
auto numClasses = dt.getNClasses(dataset);
|
||||
if (numClasses == 2) {
|
||||
vector<int> distribution = dt.getClassesCounts(dataset);
|
||||
double nSamples = dt.getNSamples(dataset);
|
||||
vector<int>::iterator maxValue = max_element(distribution.begin(), distribution.end());
|
||||
double mark = *maxValue / nSamples * (1 + margin);
|
||||
if (mark > 1) {
|
||||
mark = 0.9995;
|
||||
}
|
||||
status = result < mark ? Symbols::cross : result > mark ? Symbols::upward_arrow : "=";
|
||||
}
|
||||
}
|
||||
}
|
||||
if (status != " ") {
|
||||
auto item = summary.find(status);
|
||||
if (item != summary.end()) {
|
||||
summary[status]++;
|
||||
} else {
|
||||
summary[status] = 1;
|
||||
}
|
||||
}
|
||||
return status;
|
||||
}
|
||||
double ReportBase::bestResult(const string& dataset, const string& model)
|
||||
{
|
||||
double value = 0.0;
|
||||
if (bestResults.size() == 0) {
|
||||
// try to load the best results
|
||||
string score = data["score_name"];
|
||||
replace(score.begin(), score.end(), '_', '-');
|
||||
string fileName = "best_results_" + score + "_" + model + ".json";
|
||||
ifstream resultData(Paths::results() + "/" + fileName);
|
||||
if (resultData.is_open()) {
|
||||
bestResults = json::parse(resultData);
|
||||
}
|
||||
}
|
||||
try {
|
||||
value = bestResults.at(dataset).at(0);
|
||||
}
|
||||
catch (exception) {
|
||||
value = 1.0;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
}
|
@ -21,7 +21,7 @@ namespace platform {
|
||||
};
|
||||
class ReportBase {
|
||||
public:
|
||||
explicit ReportBase(json data_);
|
||||
explicit ReportBase(json data_, bool compare);
|
||||
virtual ~ReportBase() = default;
|
||||
void show();
|
||||
protected:
|
||||
@ -35,6 +35,10 @@ namespace platform {
|
||||
map<string, int> summary;
|
||||
double margin;
|
||||
map<string, string> meaning;
|
||||
private:
|
||||
double bestResult(const string& dataset, const string& model);
|
||||
bool compare;
|
||||
json bestResults;
|
||||
};
|
||||
};
|
||||
#endif
|
@ -10,7 +10,7 @@ namespace platform {
|
||||
const int MAXL = 133;
|
||||
class ReportConsole : public ReportBase {
|
||||
public:
|
||||
explicit ReportConsole(json data_, int index = -1) : ReportBase(data_), selectedIndex(index) {};
|
||||
explicit ReportConsole(json data_, bool compare = false, int index = -1) : ReportBase(data_, compare), selectedIndex(index) {};
|
||||
virtual ~ReportConsole() = default;
|
||||
private:
|
||||
int selectedIndex;
|
||||
|
@ -13,7 +13,7 @@ namespace platform {
|
||||
string do_grouping() const { return "\03"; }
|
||||
};
|
||||
|
||||
ReportExcel::ReportExcel(json data_, lxw_workbook* workbook) : ReportBase(data_), row(0), workbook(workbook)
|
||||
ReportExcel::ReportExcel(json data_, bool compare, lxw_workbook* workbook) : ReportBase(data_, compare), row(0), workbook(workbook)
|
||||
{
|
||||
normalSize = 14; //font size for report body
|
||||
colorTitle = 0xB1A0C7;
|
||||
|
@ -10,7 +10,7 @@ namespace platform {
|
||||
|
||||
class ReportExcel : public ReportBase {
|
||||
public:
|
||||
explicit ReportExcel(json data_, lxw_workbook* workbook);
|
||||
explicit ReportExcel(json data_, bool compare, lxw_workbook* workbook);
|
||||
lxw_workbook* getWorkbook();
|
||||
private:
|
||||
void writeString(int row, int col, const string& text, const string& style = "");
|
||||
|
@ -109,12 +109,12 @@ namespace platform {
|
||||
cout << Colors::YELLOW() << "Reporting " << files.at(index).getFilename() << endl;
|
||||
auto data = files.at(index).load();
|
||||
if (excelReport) {
|
||||
ReportExcel reporter(data, workbook);
|
||||
ReportExcel reporter(data, compare, workbook);
|
||||
reporter.show();
|
||||
openExcel = true;
|
||||
workbook = reporter.getWorkbook();
|
||||
} else {
|
||||
ReportConsole reporter(data);
|
||||
ReportConsole reporter(data, compare);
|
||||
reporter.show();
|
||||
}
|
||||
}
|
||||
@ -150,6 +150,7 @@ namespace platform {
|
||||
if (all_of(line.begin(), line.end(), ::isdigit)) {
|
||||
int idx = stoi(line);
|
||||
if (indexList) {
|
||||
// The value is about the files list
|
||||
index = idx;
|
||||
if (index >= 0 && index < files.size()) {
|
||||
report(index, false);
|
||||
@ -157,6 +158,7 @@ namespace platform {
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
// The value is about the result showed on screen
|
||||
showIndex(index, idx);
|
||||
continue;
|
||||
}
|
||||
|
@ -35,7 +35,11 @@ namespace platform {
|
||||
};
|
||||
class Results {
|
||||
public:
|
||||
Results(const string& path, const int max, const string& model, const string& score, bool complete, bool partial) : path(path), max(max), model(model), scoreName(score), complete(complete), partial(partial) { load(); };
|
||||
Results(const string& path, const int max, const string& model, const string& score, bool complete, bool partial, bool compare) :
|
||||
path(path), max(max), model(model), scoreName(score), complete(complete), partial(partial), compare(compare)
|
||||
{
|
||||
load();
|
||||
};
|
||||
void manage();
|
||||
private:
|
||||
string path;
|
||||
@ -46,6 +50,7 @@ namespace platform {
|
||||
bool partial;
|
||||
bool indexList = true;
|
||||
bool openExcel = false;
|
||||
bool compare;
|
||||
lxw_workbook* workbook = NULL;
|
||||
vector<Result> files;
|
||||
void load(); // Loads the list of results
|
||||
|
@ -14,6 +14,7 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
|
||||
program.add_argument("-s", "--score").default_value("any").help("Filter results of the score name supplied");
|
||||
program.add_argument("--complete").help("Show only results with all datasets").default_value(false).implicit_value(true);
|
||||
program.add_argument("--partial").help("Show only partial results").default_value(false).implicit_value(true);
|
||||
program.add_argument("--compare").help("Compare with best results").default_value(false).implicit_value(true);
|
||||
try {
|
||||
program.parse_args(argc, argv);
|
||||
auto number = program.get<int>("number");
|
||||
@ -24,6 +25,7 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
|
||||
auto score = program.get<string>("score");
|
||||
auto complete = program.get<bool>("complete");
|
||||
auto partial = program.get<bool>("partial");
|
||||
auto compare = program.get<bool>("compare");
|
||||
}
|
||||
catch (const exception& err) {
|
||||
cerr << err.what() << endl;
|
||||
@ -41,9 +43,10 @@ int main(int argc, char** argv)
|
||||
auto score = program.get<string>("score");
|
||||
auto complete = program.get<bool>("complete");
|
||||
auto partial = program.get<bool>("partial");
|
||||
auto compare = program.get<bool>("compare");
|
||||
if (complete)
|
||||
partial = false;
|
||||
auto results = platform::Results(platform::Paths::results(), number, model, score, complete, partial);
|
||||
auto results = platform::Results(platform::Paths::results(), number, model, score, complete, partial, compare);
|
||||
results.manage();
|
||||
return 0;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user