From 5c190d7c66506370dc08d199d9becbe191b57c11 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Tue, 14 May 2024 11:45:54 +0200 Subject: [PATCH] Add train classification report --- src/main/Experiment.cpp | 11 +++++-- src/main/PartialResult.h | 1 + src/main/Scores.cpp | 34 +++++++++++--------- src/main/Scores.h | 4 +-- src/reports/ReportConsole.cpp | 60 +++++++++++++++++++++++++++++------ src/reports/ReportConsole.h | 3 ++ 6 files changed, 85 insertions(+), 28 deletions(-) diff --git a/src/main/Experiment.cpp b/src/main/Experiment.cpp index 07e095e..8657c12 100644 --- a/src/main/Experiment.cpp +++ b/src/main/Experiment.cpp @@ -102,6 +102,7 @@ namespace platform { auto edges = torch::zeros({ nResults }, torch::kFloat64); auto num_states = torch::zeros({ nResults }, torch::kFloat64); json confusion_matrices = json::array(); + json confusion_matrices_train = json::array(); std::vector notes; Timer train_timer, test_timer; int item = 0; @@ -150,8 +151,12 @@ namespace platform { train_time[item] = train_timer.getDuration(); double accuracy_train_value = 0.0; // Score train - if (!no_train_score) - accuracy_train_value = clf->score(X_train, y_train); + if (!no_train_score) { + auto y_predict = clf->predict(X_train); + Scores scores(y_train, y_predict, states[className].size(), labels); + accuracy_train_value = scores.accuracy(); + confusion_matrices_train.push_back(scores.get_confusion_matrix_json(true)); + } // Test model if (!quiet) showProgress(nfold + 1, getColor(clf->getStatus()), "c"); @@ -183,6 +188,8 @@ namespace platform { partial_result.setNodes(torch::mean(nodes).item()).setLeaves(torch::mean(edges).item()).setDepth(torch::mean(num_states).item()); partial_result.setDataset(fileName).setNotes(notes); partial_result.setConfusionMatrices(confusion_matrices); + if (!no_train_score) + partial_result.setConfusionMatricesTrain(confusion_matrices_train); addResult(partial_result); } } \ No newline at end of file diff --git a/src/main/PartialResult.h b/src/main/PartialResult.h index f5f0b2c..c141046 100644 --- a/src/main/PartialResult.h +++ b/src/main/PartialResult.h @@ -28,6 +28,7 @@ namespace platform { return *this; } PartialResult& setConfusionMatrices(const json& confusion_matrices) { data["confusion_matrices"] = confusion_matrices; return *this; } + PartialResult& setConfusionMatricesTrain(const json& confusion_matrices) { data["confusion_matrices_train"] = confusion_matrices; return *this; } PartialResult& setHyperparameters(const json& hyperparameters) { data["hyperparameters"] = hyperparameters; return *this; } PartialResult& setSamples(int samples) { data["samples"] = samples; return *this; } PartialResult& setFeatures(int features) { data["features"] = features; return *this; } diff --git a/src/main/Scores.cpp b/src/main/Scores.cpp index 737af29..ebf0914 100644 --- a/src/main/Scores.cpp +++ b/src/main/Scores.cpp @@ -126,24 +126,28 @@ namespace platform { oss << std::setw(dlen) << std::setprecision(ndec) << std::fixed << recall << " "; } oss << std::setw(dlen) << std::setprecision(ndec) << std::fixed << f1_score << " " - << std::setw(dlen) << std::right << support << std::endl; + << std::setw(dlen) << std::right << support; return oss.str(); } - std::string Scores::classification_report(std::string color) + std::vector Scores::classification_report(std::string color, std::string title) { std::stringstream oss; + std::vector report; for (int i = 0; i < num_classes; i++) { label_len = std::max(label_len, (int)labels[i].size()); } - oss << Colors::GREEN() << "Classification Report" << std::endl; - oss << "=====================" << std::endl << color; - oss << std::string(label_len, ' ') << " precision recall f1-score support" << std::endl; - oss << std::string(label_len, ' ') << " ========= ========= ========= =========" << std::endl; + report.push_back("Classification Report using " + title + " dataset"); + report.push_back("========================================="); + oss << std::string(label_len, ' ') << " precision recall f1-score support"; + report.push_back(oss.str()); oss.str(""); + oss << std::string(label_len, ' ') << " ========= ========= ========= ========="; + report.push_back(oss.str()); oss.str(""); for (int i = 0; i < num_classes; i++) { - oss << classification_report_line(labels[i], precision(i), recall(i), f1_score(i), confusion_matrix[i].sum().item()); + report.push_back(classification_report_line(labels[i], precision(i), recall(i), f1_score(i), confusion_matrix[i].sum().item())); } - oss << std::endl; + report.push_back(" "); oss << classification_report_line("accuracy", 0, 0, accuracy(), total); + report.push_back(oss.str()); oss.str(""); float precision_avg = 0; float recall_avg = 0; float precision_wavg = 0; @@ -159,10 +163,11 @@ namespace platform { recall_wavg /= total; precision_avg /= num_classes; recall_avg /= num_classes; - oss << classification_report_line("macro avg", precision_avg, recall_avg, f1_macro(), total); - oss << classification_report_line("weighted avg", precision_wavg, recall_wavg, f1_weighted(), total); - oss << std::endl << Colors::GREEN() << "Confusion Matrix" << std::endl; - oss << "================" << std::endl << color; + report.push_back(classification_report_line("macro avg", precision_avg, recall_avg, f1_macro(), total)); + report.push_back(classification_report_line("weighted avg", precision_wavg, recall_wavg, f1_weighted(), total)); + report.push_back(""); + report.push_back("Confusion Matrix"); + report.push_back("================"); auto number = total > 1000 ? 4 : 3; for (int i = 0; i < num_classes; i++) { oss << std::right << std::setw(label_len) << labels[i] << " "; @@ -171,10 +176,9 @@ namespace platform { oss << std::setw(number) << confusion_matrix[i][j].item() << " "; if (i == j) oss << color; } - oss << std::endl; + report.push_back(oss.str()); oss.str(""); } - oss << Colors::RESET(); - return oss.str(); + return report; } json Scores::get_confusion_matrix_json(bool labels_as_keys) { diff --git a/src/main/Scores.h b/src/main/Scores.h index 1f5bb8f..3053d4f 100644 --- a/src/main/Scores.h +++ b/src/main/Scores.h @@ -17,7 +17,7 @@ namespace platform { float precision(int num_class); float recall(int num_class); torch::Tensor get_confusion_matrix() { return confusion_matrix; } - std::string classification_report(std::string color = ""); + std::vector classification_report(std::string color = "", std::string title = ""); json get_confusion_matrix_json(bool labels_as_keys = false); void aggregate(const Scores& a); private: @@ -30,7 +30,7 @@ namespace platform { int total; std::vector labels; torch::Tensor confusion_matrix; // Rows ar actual, columns are predicted - int label_len = 12; + int label_len = 16; int dlen = 9; int ndec = 7; }; diff --git a/src/reports/ReportConsole.cpp b/src/reports/ReportConsole.cpp index a44be15..e903a53 100644 --- a/src/reports/ReportConsole.cpp +++ b/src/reports/ReportConsole.cpp @@ -3,7 +3,6 @@ #include "best/BestScore.h" #include "common/CLocale.h" #include "ReportConsole.h" -#include "main/Scores.h" namespace platform { std::string ReportConsole::headerLine(const std::string& text, int utf = 0) @@ -135,7 +134,7 @@ namespace platform { } sbody << std::string(MAXL, '*') << Colors::RESET() << std::endl; vbody.push_back(std::string(MAXL, '*') + Colors::RESET() + "\n"); - if (lastResult.find("confusion_matrices") != lastResult.end() && (data["results"].size() == 1 || selectedIndex != -1)) { + if (data["results"].size() == 1 || selectedIndex != -1) { vbody.push_back(showClassificationReport(Colors::BLUE())); } } @@ -169,17 +168,60 @@ namespace platform { std::cout << headerLine("*** Best Results File not found. Couldn't compare any result!"); } } - std::string ReportConsole::showClassificationReport(std::string color) + Scores ReportConsole::aggregateScore(std::string key) { auto lastResult = data["results"][0]; - if (data["results"].size() > 1 || lastResult.find("confusion_matrices") == lastResult.end()) - return ""; auto item = data["results"][0]; - auto scores = Scores(item["confusion_matrices"][0]); - for (int i = 1; i < item["confusion_matrices"].size(); i++) { - auto score = Scores(item["confusion_matrices"][i]); + auto scores = Scores(item[key][0]); + for (int i = 1; i < item[key].size(); i++) { + auto score = Scores(item[key][i]); scores.aggregate(score); } - return scores.classification_report(color); + return scores; + } + std::string ReportConsole::showClassificationReport(std::string color) + { + std::stringstream oss; + auto result = data["results"][0]; + if (result.find("confusion_matrices") == result.end()) + return ""; + auto scores = aggregateScore("confusion_matrices"); + auto output_test = scores.classification_report(color, "Test"); + oss << Colors::BLUE(); + if (result.find("confusion_matrices_train") == result.end()) { + for (auto& line : output_test) { + + oss << line << std::endl; + } + oss << Colors::RESET(); + return oss.str(); + } + auto scores_train = aggregateScore("confusion_matrices_train"); + auto output_train = scores_train.classification_report(color, "Train"); + int maxLine = (*std::max_element(output_train.begin(), output_train.end(), [](const std::string& a, const std::string& b) { return a.size() < b.size(); })).size(); + bool second_header = false; + int lines_header = 0; + std::string color_line; + std::string suffix = ""; + for (int i = 0; i < output_train.size(); i++) { + if (i < 2 || second_header) { + color_line = Colors::GREEN(); + } else { + color_line = Colors::BLUE(); + if (lines_header > 1) + suffix = std::string(14, ' '); // compensate for the color + } + oss << color_line << std::left << std::setw(maxLine) << output_train[i] + << suffix << Colors::BLUE() << " | " << color_line << std::left << std::setw(maxLine) + << output_test[i] << std::endl; + if (output_train[i] == "" || (second_header && lines_header < 2)) { + lines_header++; + second_header = true; + } else { + second_header = false; + } + } + oss << Colors::RESET(); + return oss.str(); } } \ No newline at end of file diff --git a/src/reports/ReportConsole.h b/src/reports/ReportConsole.h index a332e31..a7ec851 100644 --- a/src/reports/ReportConsole.h +++ b/src/reports/ReportConsole.h @@ -4,6 +4,8 @@ #include "common/Colors.h" #include #include "ReportBase.h" +#include "main/Scores.h" + namespace platform { const int MAXL = 133; @@ -24,6 +26,7 @@ namespace platform { void do_body(); void footer(double totalScore); void showSummary() override; + Scores aggregateScore(std::string key); std::stringstream sheader; std::stringstream sbody; std::vector vbody;