Add BoostA2DE model and fix some report errors

This commit is contained in:
2024-05-17 01:25:27 +02:00
parent 30a6d5e60d
commit 696c0564a7
9 changed files with 25 additions and 22 deletions

View File

@@ -7,6 +7,7 @@
#include "common/DotEnv.h"
#include "common/Datasets.h"
#include "common/Paths.h"
#include "common/Colors.h"
#include "main/Scores.h"
#include "config.h"
@@ -127,7 +128,7 @@ TEST_CASE("Confusion Matrix JSON", "[Scores]")
REQUIRE(res_json_str["Car"][1] == 2);
REQUIRE(res_json_str["Car"][2] == 3);
}
TEST_CASE("Classification Report", "[Scores]")
TEST_CASE("Classification Report", "[Scores]") -
{
std::vector<int> y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 };
std::vector<int> y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 };
@@ -135,19 +136,7 @@ TEST_CASE("Classification Report", "[Scores]")
auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32);
std::vector<std::string> labels = { "Aeroplane", "Boat", "Car" };
platform::Scores scores(y_test_tensor, y_pred_tensor, 3, labels);
std::string expected = R"(Classification Report
=====================
precision recall f1-score support
========= ========= ========= =========
Aeroplane 0.6666667 0.6666667 0.6666667 3
Boat 0.2500000 1.0000000 0.4000000 1
Car 1.0000000 0.5000000 0.6666667 6
accuracy 0.6000000 10
macro avg 0.6388889 0.7222223 0.5777778 10
weighted avg 0.8250000 0.6000000 0.6400000 10
)";
REQUIRE(scores.classification_report() == expected);
auto report = scores.classification_report(Colors::BLUE(), "train");
auto json_matrix = scores.get_confusion_matrix_json(true);
platform::Scores scores2(json_matrix);
REQUIRE(scores.classification_report() == scores2.classification_report());