From 594adb053421783f25801fa7d1d4890126fd28ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Sat, 18 May 2024 21:37:34 +0200 Subject: [PATCH] Begin classification report in excel --- src/CMakeLists.txt | 2 +- src/main/Scores.cpp | 63 ++++++++++++++++++++----------- src/main/Scores.h | 7 ++-- src/reports/ReportExcel.cpp | 75 +++++++++++++++++++++++++++++++------ src/reports/ReportExcel.h | 3 ++ 5 files changed, 113 insertions(+), 37 deletions(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 47de3a3..6313e48 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -22,7 +22,7 @@ add_executable( b_best commands/b_best.cpp best/Statistics.cpp best/BestResultsExcel.cpp best/BestResults.cpp common/Datasets.cpp common/Dataset.cpp - main/Models.cpp + main/Models.cpp main/Scores.cpp reports/ReportExcel.cpp reports/ReportBase.cpp reports/ExcelFile.cpp results/Result.cpp ) diff --git a/src/main/Scores.cpp b/src/main/Scores.cpp index 6014504..9a4935d 100644 --- a/src/main/Scores.cpp +++ b/src/main/Scores.cpp @@ -16,7 +16,7 @@ namespace platform { confusion_matrix[actual][predicted] += 1; } } - Scores::Scores(json& confusion_matrix_) + Scores::Scores(const json& confusion_matrix_) { json values; total = 0; @@ -40,7 +40,7 @@ namespace platform { } compute_accuracy_value(); } - Scores Scores::create_aggregate(json& data, std::string key) + Scores Scores::create_aggregate(const json& data, const std::string key) { auto scores = Scores(data[key][0]); for (int i = 1; i < data[key].size(); i++) { @@ -138,6 +138,25 @@ namespace platform { << std::setw(dlen) << std::right << support; return oss.str(); } + std::tuple Scores::compute_averages() + { + float precision_avg = 0; + float recall_avg = 0; + float precision_wavg = 0; + float recall_wavg = 0; + for (int i = 0; i < num_classes; i++) { + int support = confusion_matrix[i].sum().item(); + precision_avg += precision(i); + precision_wavg += precision(i) * support; + recall_avg += recall(i); + recall_wavg += recall(i) * support; + } + precision_wavg /= total; + recall_wavg /= total; + precision_avg /= num_classes; + recall_avg /= num_classes; + return { precision_avg, recall_avg, precision_wavg, recall_wavg }; + } std::vector Scores::classification_report(std::string color, std::string title) { std::stringstream oss; @@ -157,21 +176,7 @@ namespace platform { report.push_back(" "); oss << classification_report_line("accuracy", 0, 0, accuracy(), total); report.push_back(oss.str()); oss.str(""); - float precision_avg = 0; - float recall_avg = 0; - float precision_wavg = 0; - float recall_wavg = 0; - for (int i = 0; i < num_classes; i++) { - int support = confusion_matrix[i].sum().item(); - precision_avg += precision(i); - precision_wavg += precision(i) * support; - recall_avg += recall(i); - recall_wavg += recall(i) * support; - } - precision_wavg /= total; - recall_wavg /= total; - precision_avg /= num_classes; - recall_avg /= num_classes; + auto [precision_avg, recall_avg, precision_wavg, recall_wavg] = compute_averages(); report.push_back(classification_report_line("macro avg", precision_avg, recall_avg, f1_macro(), total)); report.push_back(classification_report_line("weighted avg", precision_wavg, recall_wavg, f1_weighted(), total)); report.push_back(""); @@ -189,17 +194,33 @@ namespace platform { } return report; } + json Scores::classification_report_json(std::string title) + { + json output; + output["title"] = "Classification Report using " + title + " dataset"; + output["headers"] = { " ", "precision", "recall", "f1-score", "support" }; + output["body"] = {}; + for (int i = 0; i < num_classes; i++) { + output["body"].push_back({ labels[i], precision(i), recall(i), f1_score(i), confusion_matrix[i].sum().item() }); + } + output["accuracy"] = { "accuracy", 0, 0, accuracy(), total }; + auto [precision_avg, recall_avg, precision_wavg, recall_wavg] = compute_averages(); + output["averages"] = { "macro avg", precision_avg, recall_avg, f1_macro(), total }; + output["weighted"] = { "weighted avg", precision_wavg, recall_wavg, f1_weighted(), total }; + output["confusion_matrix"] = get_confusion_matrix_json(); + return output; + } json Scores::get_confusion_matrix_json(bool labels_as_keys) { - json j; + json output; for (int i = 0; i < num_classes; i++) { auto r_ptr = confusion_matrix[i].data_ptr(); if (labels_as_keys) { - j[labels[i]] = std::vector(r_ptr, r_ptr + num_classes); + output[labels[i]] = std::vector(r_ptr, r_ptr + num_classes); } else { - j[i] = std::vector(r_ptr, r_ptr + num_classes); + output[i] = std::vector(r_ptr, r_ptr + num_classes); } } - return j; + return output; } } \ No newline at end of file diff --git a/src/main/Scores.h b/src/main/Scores.h index c6fff97..d0f32f0 100644 --- a/src/main/Scores.h +++ b/src/main/Scores.h @@ -4,15 +4,14 @@ #include #include #include -#include namespace platform { using json = nlohmann::ordered_json; class Scores { public: Scores(torch::Tensor& y_test, torch::Tensor& y_pred, int num_classes, std::vector labels = {}); - explicit Scores(json& confusion_matrix_); - static Scores create_aggregate(json& data, std::string key); + explicit Scores(const json& confusion_matrix_); + static Scores create_aggregate(const json& data, const std::string key); float accuracy(); float f1_score(int num_class); float f1_weighted(); @@ -21,6 +20,7 @@ namespace platform { float recall(int num_class); torch::Tensor get_confusion_matrix() { return confusion_matrix; } std::vector classification_report(std::string color = "", std::string title = ""); + json classification_report_json(std::string title = ""); json get_confusion_matrix_json(bool labels_as_keys = false); void aggregate(const Scores& a); private: @@ -28,6 +28,7 @@ namespace platform { void init_confusion_matrix(); void init_default_labels(); void compute_accuracy_value(); + std::tuple compute_averages(); int num_classes; float accuracy_value; int total; diff --git a/src/reports/ReportExcel.cpp b/src/reports/ReportExcel.cpp index 77dd2de..ffcfafb 100644 --- a/src/reports/ReportExcel.cpp +++ b/src/reports/ReportExcel.cpp @@ -195,27 +195,78 @@ namespace platform { } // Classificacion report if (lastResult.find("confusion_matrices") != lastResult.end()) { - // auto score = platform2::Scores::create_aggregate(lastResult, "confusion_matrices"); - // row++; - // writeString(row, 1, "Classification Report", "bodyHeader"); - // row++; - // auto output = platform2::Scores::classification_report("", "test"); - // for (const auto& item : output) { - // writeString(row, 1, item, "text"); - // row++; - // } + create_classification_report(lastResult); } // Set with of columns to show those totals completely - worksheet_set_column(worksheet, 1, 1, 12, NULL); - for (int i = 2; i < 7; ++i) { + for (int i = 0; i < 5; ++i) { // doesn't work with from col to col, so... - worksheet_set_column(worksheet, i, i, 15, NULL); + worksheet_set_column(worksheet, i, i, 12, NULL); } + worksheet_set_column(worksheet, 5, 5, 7, NULL); } else { footer(totalScore, row); } } + void ReportExcel::create_classification_report(const json& result) + { + auto matrix_sheet = workbook_add_worksheet(workbook, "classif_report"); + lxw_worksheet* tmp = worksheet; + worksheet = matrix_sheet; + if (matrix_sheet == NULL) { + throw std::invalid_argument("Couldn't create sheet classif_report"); + } + worksheet_merge_range(matrix_sheet, 0, 0, 0, 5, "Classification Report", efectiveStyle("bodyHeader")); + int row = 3; + if (result.find("confusion_matrices_train") != result.end()) { + auto score = Scores::create_aggregate(result, "confusion_matrices_train"); + auto train = score.classification_report_json("Train"); + row = write_classification_report(train, row); + } + auto score = Scores::create_aggregate(result, "confusion_matrices"); + auto test = score.classification_report_json("Test"); + write_classification_report(test, ++row); + for (int i = 1; i < 6; ++i) { + // doesn't work with from col to col, so... + worksheet_set_column(worksheet, i, i, 15, NULL); + } + worksheet = tmp; + } + int ReportExcel::write_classification_report(const json& result, int row) + { + auto text = result["title"].get().c_str(); + std::cout << "title: " << text << std::endl; + worksheet_merge_range(worksheet, row, 0, row, 5, text, efectiveStyle("bodyHeader")); + int col = 2; + row++; + bool first_item = true; + for (const auto& item : result["headers"]) { + auto text = item.get().c_str(); + if (first_item) { + first_item = false; + worksheet_merge_range(worksheet, row, 0, row, 1, text, efectiveStyle("bodyHeader")); + } else { + writeString(row, col++, text, "bodyHeader"); + } + } + row++; + for (const auto& item : result["body"]) { + col = 2; + for (const auto& value : item) { + if (value.is_string()) { + worksheet_merge_range(worksheet, row, 0, row, 1, value.get().c_str(), efectiveStyle("text")); + } else { + if (value.is_number_integer()) { + writeInt(row, col++, value.get(), "result"); + } else { + writeDouble(row, col++, value.get(), "result"); + } + } + row++; + } + } + return row; + } void ReportExcel::showSummary() { for (const auto& item : summary) { diff --git a/src/reports/ReportExcel.h b/src/reports/ReportExcel.h index e2755a2..02e90e8 100644 --- a/src/reports/ReportExcel.h +++ b/src/reports/ReportExcel.h @@ -1,5 +1,6 @@ #ifndef REPORT_EXCEL_H #define REPORT_EXCEL_H +#include "main/Scores.h" #include "common/Colors.h" #include "ReportBase.h" #include "ExcelFile.h" @@ -19,6 +20,8 @@ namespace platform { void showSummary() override; void footer(double totalScore, int row); void append_notes(const json& r, int row); + void create_classification_report(const json& result); + int write_classification_report(const json& result, int row); void header_notes(int row); }; };