From 8d20545fd211ad1df0246c4b54e05e73582b982a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Mon, 13 May 2024 10:40:25 +0200 Subject: [PATCH] Git add Confusion Matrix to console report --- lib/Files/ArffFiles.cc | 8 +++++++- src/main/Scores.cpp | 10 ++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/lib/Files/ArffFiles.cc b/lib/Files/ArffFiles.cc index 1460bfa..826bd86 100644 --- a/lib/Files/ArffFiles.cc +++ b/lib/Files/ArffFiles.cc @@ -2,6 +2,8 @@ #include #include #include +#include // std::isdigit +#include // std::all_of #include ArffFiles::ArffFiles() = default; @@ -162,7 +164,11 @@ std::vector ArffFiles::factorize(const std::vector& labels_t) for (const std::string& label : labels_t) { if (labelMap.find(label) == labelMap.end()) { labelMap[label] = i++; - labels.push_back(label); + bool allDigits = std::all_of(label.begin(), label.end(), isdigit); + if (allDigits) + labels.push_back("Class " + label); + else + labels.push_back(label); } yy.push_back(labelMap[label]); } diff --git a/src/main/Scores.cpp b/src/main/Scores.cpp index e1fb9a4..72095be 100644 --- a/src/main/Scores.cpp +++ b/src/main/Scores.cpp @@ -160,6 +160,16 @@ namespace platform { recall_avg /= num_classes; oss << classification_report_line("macro avg", precision_avg, recall_avg, f1_macro(), total); oss << classification_report_line("weighted avg", precision_wavg, recall_wavg, f1_weighted(), total); + oss << std::endl << "Confusion Matrix" << std::endl; + oss << "================" << std::endl; + auto number = total > 1000 ? 4 : 3; + for (int i = 0; i < num_classes; i++) { + oss << std::right << std::setw(label_len) << labels[i] << " "; + for (int j = 0; j < num_classes; j++) { + oss << std::setw(number) << confusion_matrix[i][j].item() << " "; + } + oss << std::endl; + } return oss.str(); } json Scores::get_confusion_matrix_json(bool labels_as_keys)