From dd94fd51f7c503777bc4124371b28662d940652c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Fri, 10 May 2024 11:35:07 +0200 Subject: [PATCH] Add json constructor to Scores --- .vscode/launch.json | 1 + src/main/Scores.cpp | 46 ++++++++++++++++++++++++++++++++++++++++---- src/main/Scores.h | 3 +++ tests/TestScores.cpp | 29 ++++++++++++++++++++++++++++ 4 files changed, 75 insertions(+), 4 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index a587132..46ef243 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -110,6 +110,7 @@ "request": "launch", "program": "${workspaceFolder}/build_debug/tests/unit_tests_platform", "args": [ + "[Scores]", // "-c=\"Metrics Test\"", // "-s", ], diff --git a/src/main/Scores.cpp b/src/main/Scores.cpp index 9417e54..2c013d1 100644 --- a/src/main/Scores.cpp +++ b/src/main/Scores.cpp @@ -4,19 +4,57 @@ namespace platform { Scores::Scores(torch::Tensor& y_test, torch::Tensor& y_pred, int num_classes, std::vector labels) : num_classes(num_classes), labels(labels) { if (labels.size() == 0) { - for (int i = 0; i < num_classes; i++) { - this->labels.push_back("Class " + std::to_string(i)); - } + init_default_labels(); } total = y_test.size(0); accuracy_value = (y_pred == y_test).sum().item() / total; - confusion_matrix = torch::zeros({ num_classes, num_classes }, torch::kInt32); + init_confusion_matrix(); for (int i = 0; i < total; i++) { int actual = y_test[i].item(); int predicted = y_pred[i].item(); confusion_matrix[actual][predicted] += 1; } } + void Scores::init_confusion_matrix() + { + confusion_matrix = torch::zeros({ num_classes, num_classes }, torch::kInt32); + } + void Scores::init_default_labels() + { + for (int i = 0; i < num_classes; i++) { + labels.push_back("Class " + std::to_string(i)); + } + } + Scores::Scores(json& confusion_matrix_) + { + json values; + total = 0; + num_classes = confusion_matrix_.size(); + init_confusion_matrix(); + init_default_labels(); + int i = 0; + for (const auto& item : confusion_matrix_) { + if (item.is_array()) { + values = item; + } else { + auto it = item.begin(); + values = it.value(); + labels.push_back(it.key()); + } + for (int j = 0; j < num_classes; ++j) { + int value_int = values[j].get(); + confusion_matrix[i][j] = value_int; + total += value_int; + } + std::cout << std::endl; + i++; + } + // Compute accuracy with the confusion matrix + for (int i = 0; i < num_classes; i++) { + accuracy_value += confusion_matrix[i][i].item(); + } + accuracy_value /= total; + } float Scores::accuracy() { return accuracy_value; diff --git a/src/main/Scores.h b/src/main/Scores.h index 9097fd0..e47ffad 100644 --- a/src/main/Scores.h +++ b/src/main/Scores.h @@ -9,6 +9,7 @@ namespace platform { class Scores { public: Scores(torch::Tensor& y_test, torch::Tensor& y_pred, int num_classes, std::vector labels = {}); + explicit Scores(json& confusion_matrix_); float accuracy(); float f1_score(int num_class); float f1_weighted(); @@ -20,6 +21,8 @@ namespace platform { json get_confusion_matrix_json(bool labels_as_keys = false); private: std::string classification_report_line(std::string label, float precision, float recall, float f1_score, int support); + void init_confusion_matrix(); + void init_default_labels(); int num_classes; float accuracy_value; int total; diff --git a/tests/TestScores.cpp b/tests/TestScores.cpp index 03654f9..2bde2cf 100644 --- a/tests/TestScores.cpp +++ b/tests/TestScores.cpp @@ -147,4 +147,33 @@ TEST_CASE("Classification Report", "[Scores]") weighted avg 0.8250000 0.6000000 0.6400000 10 )"; REQUIRE(scores.classification_report() == expected); +} +TEST_CASE("JSON constructor", "[Scores]") +{ + std::vector y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 }; + std::vector y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 }; + auto y_test_tensor = torch::tensor(y_test, torch::kInt32); + auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32); + std::vector labels = { "Aeroplane", "Boat", "Car" }; + platform::Scores scores(y_test_tensor, y_pred_tensor, 3, labels); + auto res_json_int = scores.get_confusion_matrix_json(); + platform::Scores scores2(res_json_int); + REQUIRE(scores.accuracy() == scores2.accuracy()); + for (int i = 0; i < 2; ++i) { + REQUIRE(scores.f1_score(i) == scores2.f1_score(i)); + REQUIRE(scores.precision(i) == scores2.precision(i)); + REQUIRE(scores.recall(i) == scores2.recall(i)); + } + REQUIRE(scores.f1_weighted() == scores2.f1_weighted()); + REQUIRE(scores.f1_macro() == scores2.f1_macro()); + auto res_json_key = scores.get_confusion_matrix_json(true); + platform::Scores scores3(res_json_key); + REQUIRE(scores.accuracy() == scores3.accuracy()); + for (int i = 0; i < 2; ++i) { + REQUIRE(scores.f1_score(i) == scores3.f1_score(i)); + REQUIRE(scores.precision(i) == scores3.precision(i)); + REQUIRE(scores.recall(i) == scores3.recall(i)); + } + REQUIRE(scores.f1_weighted() == scores3.f1_weighted()); + REQUIRE(scores.f1_macro() == scores3.f1_macro()); } \ No newline at end of file