Complete comparison with ZeroR

This commit is contained in:
2023-09-19 17:55:03 +02:00
parent 1bdfbd1620
commit f69f415b92
4 changed files with 99 additions and 15 deletions

View File

@@ -1,5 +1,6 @@
#include <sstream>
#include <locale>
#include "Datasets.h"
#include "ReportExcel.h"
#include "BestResult.h"
@@ -12,6 +13,15 @@ namespace platform {
string do_grouping() const { return "\03"; }
};
ReportExcel::ReportExcel(json data_) : ReportBase(data_), row(0)
{
normalSize = 14; //font size for report body
colorTitle = 0xB1A0C7;
colorOdd = 0xDCE6F1;
colorEven = 0xFDE9D9;
margin = .1; // margin to add to ZeroR comparison
createFile();
}
lxw_format* ReportExcel::efectiveStyle(const string& style)
{
@@ -41,7 +51,7 @@ namespace platform {
void ReportExcel::formatColumns()
{
worksheet_freeze_panes(worksheet, 6, 1);
vector<int> columns_sizes = { 22, 10, 9, 7, 12, 12, 12, 12, 12, 3, 14, 12, 50 };
vector<int> columns_sizes = { 22, 10, 9, 7, 12, 12, 12, 12, 12, 3, 15, 12, 23 };
for (int i = 0; i < columns_sizes.size(); ++i) {
worksheet_set_column(worksheet, i, i, columns_sizes.at(i), NULL);
}
@@ -65,9 +75,9 @@ namespace platform {
} else if (name == "bodyHeader") {
format_set_bold(style);
format_set_font_size(style, normalSize);
format_set_align(style, LXW_ALIGN_VERTICAL_CENTER);
format_set_align(style, LXW_ALIGN_CENTER);
format_set_align(style, LXW_ALIGN_VERTICAL_CENTER);
format_set_border(style, LXW_BORDER_THIN);
format_set_bg_color(style, lxw_color_t(colorTitle));
} else if (name == "result") {
format_set_font_size(style, normalSize);
@@ -129,9 +139,17 @@ namespace platform {
format_set_align(headerSmall, LXW_ALIGN_VERTICAL_CENTER);
format_set_bg_color(headerSmall, lxw_color_t(colorOdd));
// Summary style
lxw_format* summaryStyle = workbook_add_format(workbook);
format_set_bold(summaryStyle);
format_set_font_size(summaryStyle, 16);
format_set_border(summaryStyle, LXW_BORDER_THIN);
format_set_align(summaryStyle, LXW_ALIGN_VERTICAL_CENTER);
styles["headerFirst"] = headerFirst;
styles["headerRest"] = headerRest;
styles["headerSmall"] = headerSmall;
styles["summaryStyle"] = summaryStyle;
}
void ReportExcel::setProperties()
@@ -173,7 +191,7 @@ namespace platform {
locale::global(mylocale);
cout.imbue(mylocale);
stringstream oss;
string message = data["model"].get<string>() + " ver. " + data["version"].get<string>() +
string message = data["model"].get<string>() + " ver. " + data["version"].get<string>() + " " +
data["language"].get<string>() + " ver. " + data["language_version"].get<string>() +
" with " + to_string(data["folds"].get<int>()) + " Folds cross validation and " + to_string(data["seeds"].size()) +
" random seeds. " + data["date"].get<string>() + " " + data["time"].get<string>();
@@ -211,6 +229,7 @@ namespace platform {
}
row = 6;
col = 0;
int hypSize = 22;
json lastResult;
double totalScore = 0.0;
string hyperparameters;
@@ -224,7 +243,7 @@ namespace platform {
writeDouble(row, col + 6, r["depth"].get<double>(), "floats");
writeDouble(row, col + 7, r["score"].get<double>(), "result");
writeDouble(row, col + 8, r["score_std"].get<double>(), "result");
const string status = "X";
const string status = compareResult(r["dataset"].get<string>(), r["score"].get<double>());
writeString(row, col + 9, status, "textCentered");
writeDouble(row, col + 10, r["time"].get<double>(), "time");
writeDouble(row, col + 11, r["time_std"].get<double>(), "time");
@@ -236,11 +255,18 @@ namespace platform {
oss << r["hyperparameters"];
hyperparameters = oss.str();
}
if (hyperparameters.size() > hypSize) {
hypSize = hyperparameters.size();
}
writeString(row, col + 12, hyperparameters, "text");
lastResult = r;
totalScore += r["score"].get<double>();
row++;
}
// Set the right column width of hyperparameters with the maximum length
worksheet_set_column(worksheet, 12, 12, hypSize + 1, NULL);
// Show totals if only one dataset is present in the result
if (data["results"].size() == 1) {
for (const string& group : { "scores_train", "scores_test", "times_train", "times_test" }) {
row++;
@@ -254,9 +280,52 @@ namespace platform {
footer(totalScore, row);
}
}
string ReportExcel::compareResult(const string& dataset, double result)
{
string status = " ";
if (data["score_name"].get<string>() == "accuracy") {
auto dt = Datasets(Paths::datasets(), false);
dt.loadDataset(dataset);
auto numClasses = dt.getNClasses(dataset);
if (numClasses == 2) {
vector<int> distribution = dt.getClassesCounts(dataset);
vector<int>::iterator maxValue = max_element(distribution.begin(), distribution.end());
int maxCategory = distance(distribution.begin(), maxValue);
double mark = maxCategory * (1 + margin);
status = result < mark ? Symbols::cross : result > mark ? Symbols::upward_arrow : "=";
auto item = summary.find(status);
if (item != summary.end()) {
summary[status]++;
} else {
summary[status] = 1;
}
}
}
return status;
}
void ReportExcel::showSummary()
{
stringstream oss;
oss << "Better than ZeroR + " << setprecision(1) << fixed << margin * 100 << "%";
map<string, string> meaning = {
{Symbols::equal_best, "Equal to best"},
{Symbols::better_best, "Better than best"},
{Symbols::cross, "Less than or equal to ZeroR"},
{Symbols::upward_arrow, oss.str()}
};
for (const auto& item : summary) {
worksheet_write_string(worksheet, row + 2, 1, item.first.c_str(), styles["summaryStyle"]);
worksheet_write_number(worksheet, row + 2, 2, item.second, styles["summaryStyle"]);
worksheet_merge_range(worksheet, row + 2, 3, row + 2, 5, meaning.at(item.first).c_str(), styles["summaryStyle"]);
row += 1;
}
}
void ReportExcel::footer(double totalScore, int row)
{
showSummary();
row += 2 + summary.size();
auto score = data["score_name"].get<string>();
if (score == BestResult::scoreName()) {
worksheet_merge_range(worksheet, row + 2, 1, row + 2, 5, (score + " compared to " + BestResult::title() + " .:").c_str(), styles["text_even"]);