From ec0268c514590680e6bd5a4c42b200accd164201 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Fri, 10 May 2024 13:42:38 +0200 Subject: [PATCH] Add confusion matrix to json results Add Aggregate method to Scores --- src/main/Experiment.cpp | 8 +++++++- src/main/PartialResult.h | 1 + src/main/Scores.cpp | 10 +++++++++- src/main/Scores.h | 1 + tests/TestScores.cpp | 43 ++++++++++++++++++++++++++++++++++++++-- 5 files changed, 59 insertions(+), 4 deletions(-) diff --git a/src/main/Experiment.cpp b/src/main/Experiment.cpp index 60a2401..aea2dce 100644 --- a/src/main/Experiment.cpp +++ b/src/main/Experiment.cpp @@ -2,6 +2,7 @@ #include "reports/ReportConsole.h" #include "common/Paths.h" #include "Models.h" +#include "Scores.h" #include "Experiment.h" namespace platform { using json = nlohmann::json; @@ -96,6 +97,7 @@ namespace platform { auto nodes = torch::zeros({ nResults }, torch::kFloat64); auto edges = torch::zeros({ nResults }, torch::kFloat64); auto num_states = torch::zeros({ nResults }, torch::kFloat64); + json confusion_matrices = json::array(); std::vector notes; Timer train_timer, test_timer; int item = 0; @@ -150,10 +152,13 @@ namespace platform { if (!quiet) showProgress(nfold + 1, getColor(clf->getStatus()), "c"); test_timer.start(); - auto accuracy_test_value = clf->score(X_test, y_test); + auto y_predict = clf->predict(X_test); + Scores scores(y_test, y_predict, states[className].size()); + auto accuracy_test_value = scores.accuracy(); test_time[item] = test_timer.getDuration(); accuracy_train[item] = accuracy_train_value; accuracy_test[item] = accuracy_test_value; + confusion_matrices.push_back(scores.get_confusion_matrix_json()); if (!quiet) std::cout << "\b\b\b, " << flush; // Store results and times in std::vector @@ -173,6 +178,7 @@ namespace platform { partial_result.setTestTimeStd(torch::std(test_time).item()).setTrainTimeStd(torch::std(train_time).item()); partial_result.setNodes(torch::mean(nodes).item()).setLeaves(torch::mean(edges).item()).setDepth(torch::mean(num_states).item()); partial_result.setDataset(fileName).setNotes(notes); + partial_result.setConfusionMatrices(confusion_matrices); addResult(partial_result); } } \ No newline at end of file diff --git a/src/main/PartialResult.h b/src/main/PartialResult.h index 1b1bb4e..c948b0a 100644 --- a/src/main/PartialResult.h +++ b/src/main/PartialResult.h @@ -27,6 +27,7 @@ namespace platform { data["notes"].insert(data["notes"].end(), notes_.begin(), notes_.end()); return *this; } + PartialResult& setConfusionMatrices(const json& confusion_matrices) { data["confusion_matrices"] = confusion_matrices; return *this; } PartialResult& setHyperparameters(const json& hyperparameters) { data["hyperparameters"] = hyperparameters; return *this; } PartialResult& setSamples(int samples) { data["samples"] = samples; return *this; } PartialResult& setFeatures(int features) { data["features"] = features; return *this; } diff --git a/src/main/Scores.cpp b/src/main/Scores.cpp index 2c013d1..a735605 100644 --- a/src/main/Scores.cpp +++ b/src/main/Scores.cpp @@ -25,6 +25,15 @@ namespace platform { labels.push_back("Class " + std::to_string(i)); } } + void Scores::aggregate(const Scores& a) + { + if (a.num_classes != num_classes) + throw std::invalid_argument("The number of classes must be the same"); + confusion_matrix += a.confusion_matrix; + total += a.total; + accuracy_value += a.accuracy_value; + accuracy_value /= 2; + } Scores::Scores(json& confusion_matrix_) { json values; @@ -46,7 +55,6 @@ namespace platform { confusion_matrix[i][j] = value_int; total += value_int; } - std::cout << std::endl; i++; } // Compute accuracy with the confusion matrix diff --git a/src/main/Scores.h b/src/main/Scores.h index e47ffad..9f2990d 100644 --- a/src/main/Scores.h +++ b/src/main/Scores.h @@ -19,6 +19,7 @@ namespace platform { torch::Tensor get_confusion_matrix() { return confusion_matrix; } std::string classification_report(); json get_confusion_matrix_json(bool labels_as_keys = false); + void aggregate(const Scores& a); private: std::string classification_report_line(std::string label, float precision, float recall, float f1_score, int support); void init_confusion_matrix(); diff --git a/tests/TestScores.cpp b/tests/TestScores.cpp index 2bde2cf..a58f75b 100644 --- a/tests/TestScores.cpp +++ b/tests/TestScores.cpp @@ -36,7 +36,7 @@ void make_test_bin(int TP, int TN, int FP, int FN, std::vector& y_test, std } } -TEST_CASE("TestScores binary", "[Scores]") +TEST_CASE("Scores binary", "[Scores]") { std::vector y_test; std::vector y_pred; @@ -59,7 +59,7 @@ TEST_CASE("TestScores binary", "[Scores]") REQUIRE(confusion_matrix[1][0].item() == 41); REQUIRE(confusion_matrix[1][1].item() == 197); } -TEST_CASE("TestScores multiclass", "[Scores]") +TEST_CASE("Scores multiclass", "[Scores]") { std::vector y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 }; std::vector y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 }; @@ -176,4 +176,43 @@ TEST_CASE("JSON constructor", "[Scores]") } REQUIRE(scores.f1_weighted() == scores3.f1_weighted()); REQUIRE(scores.f1_macro() == scores3.f1_macro()); +} +TEST_CASE("Aggregate", "[Scores]") +{ + std::vector y_test; + std::vector y_pred; + make_test_bin(197, 210, 52, 41, y_test, y_pred); + auto y_test_tensor = torch::tensor(y_test, torch::kInt32); + auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32); + platform::Scores scores(y_test_tensor, y_pred_tensor, 2); + y_test.clear(); + y_pred.clear(); + make_test_bin(227, 187, 39, 47, y_test, y_pred); + auto y_test_tensor2 = torch::tensor(y_test, torch::kInt32); + auto y_pred_tensor2 = torch::tensor(y_pred, torch::kInt32); + platform::Scores scores2(y_test_tensor2, y_pred_tensor2, 2); + scores.aggregate(scores2); + REQUIRE(scores.accuracy() == Catch::Approx(0.821).epsilon(epsilon)); + REQUIRE(scores.f1_score(0) == Catch::Approx(0.8160329)); + REQUIRE(scores.f1_score(1) == Catch::Approx(0.8257059)); + REQUIRE(scores.precision(0) == Catch::Approx(0.8185567)); + REQUIRE(scores.precision(1) == Catch::Approx(0.8233010)); + REQUIRE(scores.recall(0) == Catch::Approx(0.8135246)); + REQUIRE(scores.recall(1) == Catch::Approx(0.8281250)); + REQUIRE(scores.f1_weighted() == Catch::Approx(0.8209856)); + REQUIRE(scores.f1_macro() == Catch::Approx(0.8208694)); + y_test.clear(); + y_pred.clear(); + make_test_bin(197 + 227, 210 + 187, 52 + 39, 41 + 47, y_test, y_pred); + y_test_tensor = torch::tensor(y_test, torch::kInt32); + y_pred_tensor = torch::tensor(y_pred, torch::kInt32); + platform::Scores scores3(y_test_tensor, y_pred_tensor, 2); + for (int i = 0; i < 2; ++i) { + REQUIRE(scores3.f1_score(i) == scores.f1_score(i)); + REQUIRE(scores3.precision(i) == scores.precision(i)); + REQUIRE(scores3.recall(i) == scores.recall(i)); + } + REQUIRE(scores3.f1_weighted() == scores.f1_weighted()); + REQUIRE(scores3.f1_macro() == scores.f1_macro()); + REQUIRE(scores3.accuracy() == scores.accuracy()); } \ No newline at end of file