From f69f415b92f8807fad5502f9aa7089352891899b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana?= Date: Tue, 19 Sep 2023 17:55:03 +0200 Subject: [PATCH] Complete comparison with ZeroR --- src/Platform/CMakeLists.txt | 6 +-- src/Platform/ReportExcel.cc | 77 +++++++++++++++++++++++++++++++++++-- src/Platform/ReportExcel.h | 27 ++++++++++--- src/Platform/main.cc | 4 +- 4 files changed, 99 insertions(+), 15 deletions(-) diff --git a/src/Platform/CMakeLists.txt b/src/Platform/CMakeLists.txt index b885792..2a506b8 100644 --- a/src/Platform/CMakeLists.txt +++ b/src/Platform/CMakeLists.txt @@ -5,12 +5,12 @@ include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp) include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include) include_directories(${BayesNet_SOURCE_DIR}/lib/json/include) add_executable(main main.cc Folding.cc platformUtils.cc Experiment.cc Datasets.cc Models.cc ReportConsole.cc ReportBase.cc) -add_executable(manage manage.cc Results.cc ReportConsole.cc ReportExcel.cc ReportBase.cc) +add_executable(manage manage.cc Results.cc ReportConsole.cc ReportExcel.cc ReportBase.cc Datasets.cc platformUtils.cc) add_executable(list list.cc platformUtils Datasets.cc) target_link_libraries(main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}") if (${CMAKE_HOST_SYSTEM_NAME} MATCHES "Linux") - target_link_libraries(manage "${TORCH_LIBRARIES}" libxlsxwriter.so stdc++fs) - target_link_libraries(manage "${TORCH_LIBRARIES}" libxlsxwriter.so) + target_link_libraries(manage "${TORCH_LIBRARIES}" libxlsxwriter.so ArffFiles mdlp stdc++fs) else() + target_link_libraries(manage "${TORCH_LIBRARIES}" libxlsxwriter.so ArffFiles mdlp) endif() target_link_libraries(list ArffFiles mdlp "${TORCH_LIBRARIES}") \ No newline at end of file diff --git a/src/Platform/ReportExcel.cc b/src/Platform/ReportExcel.cc index 37f9cb5..acb2eaa 100644 --- a/src/Platform/ReportExcel.cc +++ b/src/Platform/ReportExcel.cc @@ -1,5 +1,6 @@ #include #include +#include "Datasets.h" #include "ReportExcel.h" #include "BestResult.h" @@ -12,6 +13,15 @@ namespace platform { string do_grouping() const { return "\03"; } }; + ReportExcel::ReportExcel(json data_) : ReportBase(data_), row(0) + { + normalSize = 14; //font size for report body + colorTitle = 0xB1A0C7; + colorOdd = 0xDCE6F1; + colorEven = 0xFDE9D9; + margin = .1; // margin to add to ZeroR comparison + createFile(); + } lxw_format* ReportExcel::efectiveStyle(const string& style) { @@ -41,7 +51,7 @@ namespace platform { void ReportExcel::formatColumns() { worksheet_freeze_panes(worksheet, 6, 1); - vector columns_sizes = { 22, 10, 9, 7, 12, 12, 12, 12, 12, 3, 14, 12, 50 }; + vector columns_sizes = { 22, 10, 9, 7, 12, 12, 12, 12, 12, 3, 15, 12, 23 }; for (int i = 0; i < columns_sizes.size(); ++i) { worksheet_set_column(worksheet, i, i, columns_sizes.at(i), NULL); } @@ -65,9 +75,9 @@ namespace platform { } else if (name == "bodyHeader") { format_set_bold(style); format_set_font_size(style, normalSize); - format_set_align(style, LXW_ALIGN_VERTICAL_CENTER); format_set_align(style, LXW_ALIGN_CENTER); format_set_align(style, LXW_ALIGN_VERTICAL_CENTER); + format_set_border(style, LXW_BORDER_THIN); format_set_bg_color(style, lxw_color_t(colorTitle)); } else if (name == "result") { format_set_font_size(style, normalSize); @@ -129,9 +139,17 @@ namespace platform { format_set_align(headerSmall, LXW_ALIGN_VERTICAL_CENTER); format_set_bg_color(headerSmall, lxw_color_t(colorOdd)); + // Summary style + lxw_format* summaryStyle = workbook_add_format(workbook); + format_set_bold(summaryStyle); + format_set_font_size(summaryStyle, 16); + format_set_border(summaryStyle, LXW_BORDER_THIN); + format_set_align(summaryStyle, LXW_ALIGN_VERTICAL_CENTER); + styles["headerFirst"] = headerFirst; styles["headerRest"] = headerRest; styles["headerSmall"] = headerSmall; + styles["summaryStyle"] = summaryStyle; } void ReportExcel::setProperties() @@ -173,7 +191,7 @@ namespace platform { locale::global(mylocale); cout.imbue(mylocale); stringstream oss; - string message = data["model"].get() + " ver. " + data["version"].get() + + string message = data["model"].get() + " ver. " + data["version"].get() + " " + data["language"].get() + " ver. " + data["language_version"].get() + " with " + to_string(data["folds"].get()) + " Folds cross validation and " + to_string(data["seeds"].size()) + " random seeds. " + data["date"].get() + " " + data["time"].get(); @@ -211,6 +229,7 @@ namespace platform { } row = 6; col = 0; + int hypSize = 22; json lastResult; double totalScore = 0.0; string hyperparameters; @@ -224,7 +243,7 @@ namespace platform { writeDouble(row, col + 6, r["depth"].get(), "floats"); writeDouble(row, col + 7, r["score"].get(), "result"); writeDouble(row, col + 8, r["score_std"].get(), "result"); - const string status = "X"; + const string status = compareResult(r["dataset"].get(), r["score"].get()); writeString(row, col + 9, status, "textCentered"); writeDouble(row, col + 10, r["time"].get(), "time"); writeDouble(row, col + 11, r["time_std"].get(), "time"); @@ -236,11 +255,18 @@ namespace platform { oss << r["hyperparameters"]; hyperparameters = oss.str(); } + if (hyperparameters.size() > hypSize) { + hypSize = hyperparameters.size(); + } writeString(row, col + 12, hyperparameters, "text"); lastResult = r; totalScore += r["score"].get(); row++; + } + // Set the right column width of hyperparameters with the maximum length + worksheet_set_column(worksheet, 12, 12, hypSize + 1, NULL); + // Show totals if only one dataset is present in the result if (data["results"].size() == 1) { for (const string& group : { "scores_train", "scores_test", "times_train", "times_test" }) { row++; @@ -254,9 +280,52 @@ namespace platform { footer(totalScore, row); } } + string ReportExcel::compareResult(const string& dataset, double result) + { + string status = " "; + if (data["score_name"].get() == "accuracy") { + auto dt = Datasets(Paths::datasets(), false); + dt.loadDataset(dataset); + auto numClasses = dt.getNClasses(dataset); + if (numClasses == 2) { + vector distribution = dt.getClassesCounts(dataset); + vector::iterator maxValue = max_element(distribution.begin(), distribution.end()); + int maxCategory = distance(distribution.begin(), maxValue); + double mark = maxCategory * (1 + margin); + status = result < mark ? Symbols::cross : result > mark ? Symbols::upward_arrow : "="; + auto item = summary.find(status); + if (item != summary.end()) { + summary[status]++; + } else { + summary[status] = 1; + } + } + } + return status; + } + void ReportExcel::showSummary() + { + stringstream oss; + oss << "Better than ZeroR + " << setprecision(1) << fixed << margin * 100 << "%"; + + map meaning = { + {Symbols::equal_best, "Equal to best"}, + {Symbols::better_best, "Better than best"}, + {Symbols::cross, "Less than or equal to ZeroR"}, + {Symbols::upward_arrow, oss.str()} + }; + for (const auto& item : summary) { + worksheet_write_string(worksheet, row + 2, 1, item.first.c_str(), styles["summaryStyle"]); + worksheet_write_number(worksheet, row + 2, 2, item.second, styles["summaryStyle"]); + worksheet_merge_range(worksheet, row + 2, 3, row + 2, 5, meaning.at(item.first).c_str(), styles["summaryStyle"]); + row += 1; + } + } void ReportExcel::footer(double totalScore, int row) { + showSummary(); + row += 2 + summary.size(); auto score = data["score_name"].get(); if (score == BestResult::scoreName()) { worksheet_merge_range(worksheet, row + 2, 1, row + 2, 5, (score + " compared to " + BestResult::title() + " .:").c_str(), styles["text_even"]); diff --git a/src/Platform/ReportExcel.h b/src/Platform/ReportExcel.h index 9dc3aa3..8c86788 100644 --- a/src/Platform/ReportExcel.h +++ b/src/Platform/ReportExcel.h @@ -8,9 +8,20 @@ namespace platform { using namespace std; const int MAXLL = 128; + class Symbols { + public: + inline static const string check_mark{ "\u2714" }; + inline static const string exclamation{ "\u2757" }; + inline static const string black_star{ "\u2605" }; + inline static const string cross{ "\u2717" }; + inline static const string upward_arrow{ "\u27B6" }; + inline static const string down_arrow{ "\u27B4" }; + inline static const string equal_best{ check_mark }; + inline static const string better_best{ black_star }; + }; class ReportExcel : public ReportBase { public: - explicit ReportExcel(json data_) : ReportBase(data_) { createFile(); }; + explicit ReportExcel(json data_); virtual ~ReportExcel() { closeFile(); }; private: void writeString(int row, int col, const string& text, const string& style = ""); @@ -21,14 +32,17 @@ namespace platform { void setProperties(); void createFile(); void closeFile(); + void showSummary(); lxw_workbook* workbook; lxw_worksheet* worksheet; map styles; - int row = 0; - int normalSize = 14; //font size for report body - uint32_t colorTitle = 0xB1A0C7; - uint32_t colorOdd = 0xDCE6F1; - uint32_t colorEven = 0xFDE9D9; + map summary; + int row; + int normalSize; //font size for report body + uint32_t colorTitle; + uint32_t colorOdd; + uint32_t colorEven; + double margin; const string fileName = "some_results.xlsx"; void header() override; void body() override; @@ -36,6 +50,7 @@ namespace platform { void createStyle(const string& name, lxw_format* style, bool odd); void addColor(lxw_format* style, bool odd); lxw_format* efectiveStyle(const string& name); + string compareResult(const string& dataset, double result); }; }; #endif // !REPORTEXCEL_H \ No newline at end of file diff --git a/src/Platform/main.cc b/src/Platform/main.cc index 9f8e00b..a122ad2 100644 --- a/src/Platform/main.cc +++ b/src/Platform/main.cc @@ -87,7 +87,7 @@ int main(int argc, char** argv) auto stratified = program.get("stratified"); auto n_folds = program.get("folds"); auto seeds = program.get>("seeds"); - auto hyperparameters =program.get("hyperparameters"); + auto hyperparameters = program.get("hyperparameters"); vector filesToTest; auto datasets = platform::Datasets(path, true, platform::ARFF); auto title = program.get("title"); @@ -102,7 +102,7 @@ int main(int argc, char** argv) } filesToTest.push_back(file_name); } else { - filesToTest = platform::Datasets(path, true, platform::ARFF).getNames(); + filesToTest = datasets.getNames(); saveResults = true; } /*