From ad5c3319bd35eece5a90966ee261c8f5a9e983b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Sat, 18 May 2024 22:59:37 +0200 Subject: [PATCH] Complete excel classification report --- src/reports/ReportExcel.cpp | 107 ++++++++++++++++++++++++++++-------- src/reports/ReportExcel.h | 3 +- 2 files changed, 86 insertions(+), 24 deletions(-) diff --git a/src/reports/ReportExcel.cpp b/src/reports/ReportExcel.cpp index ffcfafb..9205d65 100644 --- a/src/reports/ReportExcel.cpp +++ b/src/reports/ReportExcel.cpp @@ -198,11 +198,11 @@ namespace platform { create_classification_report(lastResult); } // Set with of columns to show those totals completely - for (int i = 0; i < 5; ++i) { + worksheet_set_column(worksheet, 1, 1, 12, NULL); + for (int i = 2; i < 7; ++i) { // doesn't work with from col to col, so... - worksheet_set_column(worksheet, i, i, 12, NULL); + worksheet_set_column(worksheet, i, i, 15, NULL); } - worksheet_set_column(worksheet, 5, 5, 7, NULL); } else { footer(totalScore, row); } @@ -216,56 +216,118 @@ namespace platform { throw std::invalid_argument("Couldn't create sheet classif_report"); } worksheet_merge_range(matrix_sheet, 0, 0, 0, 5, "Classification Report", efectiveStyle("bodyHeader")); - int row = 3; + int row = 2; + int col = 0; if (result.find("confusion_matrices_train") != result.end()) { + // Train classification report auto score = Scores::create_aggregate(result, "confusion_matrices_train"); auto train = score.classification_report_json("Train"); - row = write_classification_report(train, row); + std::tie(row, col) = write_classification_report(train, row, 0); + int new_row = 0; + int new_col = col + 1; + for (int i = 0; i < result["confusion_matrices_train"].size(); ++i) { + auto item = result["confusion_matrices_train"][i]; + auto score_item = Scores(item); + auto title = "Train Fold " + std::to_string(i); + std::tie(new_row, new_col) = write_classification_report(score_item.classification_report_json(title), 2, new_col); + new_col++; + } } + // Test classification report auto score = Scores::create_aggregate(result, "confusion_matrices"); auto test = score.classification_report_json("Test"); - write_classification_report(test, ++row); - for (int i = 1; i < 6; ++i) { + int init_row = ++row; + std::tie(row, col) = write_classification_report(test, init_row, 0); + int new_row = 0; + int new_col = col + 1; + for (int i = 0; i < result["confusion_matrices"].size(); ++i) { + auto item = result["confusion_matrices"][i]; + auto score_item = Scores(item); + auto title = "Test Fold " + std::to_string(i); + std::tie(new_row, new_col) = write_classification_report(score_item.classification_report_json(title), init_row, new_col); + new_col++; + } + // Format columns (change size to fit the content) + for (int i = 0; i < new_col; ++i) { // doesn't work with from col to col, so... - worksheet_set_column(worksheet, i, i, 15, NULL); + worksheet_set_column(worksheet, i, i, 12, NULL); } worksheet = tmp; } - int ReportExcel::write_classification_report(const json& result, int row) + std::pair ReportExcel::write_classification_report(const json& result, int init_row, int init_col) { - auto text = result["title"].get().c_str(); - std::cout << "title: " << text << std::endl; - worksheet_merge_range(worksheet, row, 0, row, 5, text, efectiveStyle("bodyHeader")); - int col = 2; - row++; + int row = init_row; + auto text = result["title"].get(); + worksheet_merge_range(worksheet, row++, init_col, row, init_col + 5, text.c_str(), efectiveStyle("bodyHeader")); + int col = init_col + 2; + // Headers bool first_item = true; for (const auto& item : result["headers"]) { - auto text = item.get().c_str(); + auto text = item.get(); if (first_item) { first_item = false; - worksheet_merge_range(worksheet, row, 0, row, 1, text, efectiveStyle("bodyHeader")); + worksheet_merge_range(worksheet, row, init_col, row, init_col + 1, text.c_str(), efectiveStyle("bodyHeader")); } else { writeString(row, col++, text, "bodyHeader"); } } row++; + // Classes f1-score for (const auto& item : result["body"]) { - col = 2; + col = init_col + 2; for (const auto& value : item) { if (value.is_string()) { - worksheet_merge_range(worksheet, row, 0, row, 1, value.get().c_str(), efectiveStyle("text")); + worksheet_merge_range(worksheet, row, init_col, row, init_col + 1, value.get().c_str(), efectiveStyle("text")); } else { if (value.is_number_integer()) { - writeInt(row, col++, value.get(), "result"); + writeInt(row, col++, value.get(), "ints"); } else { writeDouble(row, col++, value.get(), "result"); } } - row++; } + row++; } - return row; - + worksheet_merge_range(worksheet, row, init_col, row, init_col + 5, "", efectiveStyle("text")); + row++; + // Accuracy and average f1-score + for (const auto& item : { "accuracy", "averages", "weighted" }) { + col = init_col + 2; + for (const auto& value : result[item]) { + if (value.is_string()) { + worksheet_merge_range(worksheet, row, init_col, row, init_col + 1, value.get().c_str(), efectiveStyle("text")); + } else { + if (value.is_number_integer()) { + writeInt(row, col++, value.get(), "ints"); + } else { + writeDouble(row, col++, value.get(), "result"); + } + } + } + row++; + } + // Confusion matrix + worksheet_merge_range(worksheet, row, init_col, row, init_col + 5, "", efectiveStyle("bodyHeader")); + row++; + auto n_items = result["confusion_matrix"].size(); + worksheet_merge_range(worksheet, row, init_col, row, init_col + n_items + 1, "Confusion Matrix", efectiveStyle("bodyHeader")); + row++; + for (int i = 0; i < n_items; ++i) { + col = init_col + 2; + auto label = result["body"][i][0].get(); + worksheet_merge_range(worksheet, row, init_col, row, init_col + 1, label.c_str(), efectiveStyle("text")); + for (int j = 0; j < result["confusion_matrix"][i].size(); ++j) { + auto value = result["confusion_matrix"][i][j]; + if (i == j) { + writeInt(row, col++, value.get(), "ints_bold"); + } else { + writeInt(row, col++, value.get(), "ints"); + } + } + row++; + } + int maxcol = std::max(5, int(init_col + n_items + 1)); + return { row, maxcol }; } void ReportExcel::showSummary() { @@ -276,7 +338,6 @@ namespace platform { row += 1; } } - void ReportExcel::footer(double totalScore, int row) { showSummary(); diff --git a/src/reports/ReportExcel.h b/src/reports/ReportExcel.h index 02e90e8..740ac69 100644 --- a/src/reports/ReportExcel.h +++ b/src/reports/ReportExcel.h @@ -1,5 +1,6 @@ #ifndef REPORT_EXCEL_H #define REPORT_EXCEL_H +#include #include "main/Scores.h" #include "common/Colors.h" #include "ReportBase.h" @@ -21,7 +22,7 @@ namespace platform { void footer(double totalScore, int row); void append_notes(const json& r, int row); void create_classification_report(const json& result); - int write_classification_report(const json& result, int row); + std::pair write_classification_report(const json& result, int init_row, int init_col); void header_notes(int row); }; };