Add confusion matrix to json results

Add Aggregate method to Scores
This commit is contained in:
2024-05-10 13:42:38 +02:00
parent dd94fd51f7
commit ec0268c514
5 changed files with 59 additions and 4 deletions

View File

@@ -36,7 +36,7 @@ void make_test_bin(int TP, int TN, int FP, int FN, std::vector<int>& y_test, std
}
}
TEST_CASE("TestScores binary", "[Scores]")
TEST_CASE("Scores binary", "[Scores]")
{
std::vector<int> y_test;
std::vector<int> y_pred;
@@ -59,7 +59,7 @@ TEST_CASE("TestScores binary", "[Scores]")
REQUIRE(confusion_matrix[1][0].item<int>() == 41);
REQUIRE(confusion_matrix[1][1].item<int>() == 197);
}
TEST_CASE("TestScores multiclass", "[Scores]")
TEST_CASE("Scores multiclass", "[Scores]")
{
std::vector<int> y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 };
std::vector<int> y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 };
@@ -176,4 +176,43 @@ TEST_CASE("JSON constructor", "[Scores]")
}
REQUIRE(scores.f1_weighted() == scores3.f1_weighted());
REQUIRE(scores.f1_macro() == scores3.f1_macro());
}
TEST_CASE("Aggregate", "[Scores]")
{
std::vector<int> y_test;
std::vector<int> y_pred;
make_test_bin(197, 210, 52, 41, y_test, y_pred);
auto y_test_tensor = torch::tensor(y_test, torch::kInt32);
auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32);
platform::Scores scores(y_test_tensor, y_pred_tensor, 2);
y_test.clear();
y_pred.clear();
make_test_bin(227, 187, 39, 47, y_test, y_pred);
auto y_test_tensor2 = torch::tensor(y_test, torch::kInt32);
auto y_pred_tensor2 = torch::tensor(y_pred, torch::kInt32);
platform::Scores scores2(y_test_tensor2, y_pred_tensor2, 2);
scores.aggregate(scores2);
REQUIRE(scores.accuracy() == Catch::Approx(0.821).epsilon(epsilon));
REQUIRE(scores.f1_score(0) == Catch::Approx(0.8160329));
REQUIRE(scores.f1_score(1) == Catch::Approx(0.8257059));
REQUIRE(scores.precision(0) == Catch::Approx(0.8185567));
REQUIRE(scores.precision(1) == Catch::Approx(0.8233010));
REQUIRE(scores.recall(0) == Catch::Approx(0.8135246));
REQUIRE(scores.recall(1) == Catch::Approx(0.8281250));
REQUIRE(scores.f1_weighted() == Catch::Approx(0.8209856));
REQUIRE(scores.f1_macro() == Catch::Approx(0.8208694));
y_test.clear();
y_pred.clear();
make_test_bin(197 + 227, 210 + 187, 52 + 39, 41 + 47, y_test, y_pred);
y_test_tensor = torch::tensor(y_test, torch::kInt32);
y_pred_tensor = torch::tensor(y_pred, torch::kInt32);
platform::Scores scores3(y_test_tensor, y_pred_tensor, 2);
for (int i = 0; i < 2; ++i) {
REQUIRE(scores3.f1_score(i) == scores.f1_score(i));
REQUIRE(scores3.precision(i) == scores.precision(i));
REQUIRE(scores3.recall(i) == scores.recall(i));
}
REQUIRE(scores3.f1_weighted() == scores.f1_weighted());
REQUIRE(scores3.f1_macro() == scores.f1_macro());
REQUIRE(scores3.accuracy() == scores.accuracy());
}