Add json constructor to Scores
This commit is contained in:
1
.vscode/launch.json
vendored
1
.vscode/launch.json
vendored
@@ -110,6 +110,7 @@
|
|||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "${workspaceFolder}/build_debug/tests/unit_tests_platform",
|
"program": "${workspaceFolder}/build_debug/tests/unit_tests_platform",
|
||||||
"args": [
|
"args": [
|
||||||
|
"[Scores]",
|
||||||
// "-c=\"Metrics Test\"",
|
// "-c=\"Metrics Test\"",
|
||||||
// "-s",
|
// "-s",
|
||||||
],
|
],
|
||||||
|
@@ -4,19 +4,57 @@ namespace platform {
|
|||||||
Scores::Scores(torch::Tensor& y_test, torch::Tensor& y_pred, int num_classes, std::vector<std::string> labels) : num_classes(num_classes), labels(labels)
|
Scores::Scores(torch::Tensor& y_test, torch::Tensor& y_pred, int num_classes, std::vector<std::string> labels) : num_classes(num_classes), labels(labels)
|
||||||
{
|
{
|
||||||
if (labels.size() == 0) {
|
if (labels.size() == 0) {
|
||||||
for (int i = 0; i < num_classes; i++) {
|
init_default_labels();
|
||||||
this->labels.push_back("Class " + std::to_string(i));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
total = y_test.size(0);
|
total = y_test.size(0);
|
||||||
accuracy_value = (y_pred == y_test).sum().item<float>() / total;
|
accuracy_value = (y_pred == y_test).sum().item<float>() / total;
|
||||||
confusion_matrix = torch::zeros({ num_classes, num_classes }, torch::kInt32);
|
init_confusion_matrix();
|
||||||
for (int i = 0; i < total; i++) {
|
for (int i = 0; i < total; i++) {
|
||||||
int actual = y_test[i].item<int>();
|
int actual = y_test[i].item<int>();
|
||||||
int predicted = y_pred[i].item<int>();
|
int predicted = y_pred[i].item<int>();
|
||||||
confusion_matrix[actual][predicted] += 1;
|
confusion_matrix[actual][predicted] += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
void Scores::init_confusion_matrix()
|
||||||
|
{
|
||||||
|
confusion_matrix = torch::zeros({ num_classes, num_classes }, torch::kInt32);
|
||||||
|
}
|
||||||
|
void Scores::init_default_labels()
|
||||||
|
{
|
||||||
|
for (int i = 0; i < num_classes; i++) {
|
||||||
|
labels.push_back("Class " + std::to_string(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Scores::Scores(json& confusion_matrix_)
|
||||||
|
{
|
||||||
|
json values;
|
||||||
|
total = 0;
|
||||||
|
num_classes = confusion_matrix_.size();
|
||||||
|
init_confusion_matrix();
|
||||||
|
init_default_labels();
|
||||||
|
int i = 0;
|
||||||
|
for (const auto& item : confusion_matrix_) {
|
||||||
|
if (item.is_array()) {
|
||||||
|
values = item;
|
||||||
|
} else {
|
||||||
|
auto it = item.begin();
|
||||||
|
values = it.value();
|
||||||
|
labels.push_back(it.key());
|
||||||
|
}
|
||||||
|
for (int j = 0; j < num_classes; ++j) {
|
||||||
|
int value_int = values[j].get<int>();
|
||||||
|
confusion_matrix[i][j] = value_int;
|
||||||
|
total += value_int;
|
||||||
|
}
|
||||||
|
std::cout << std::endl;
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
// Compute accuracy with the confusion matrix
|
||||||
|
for (int i = 0; i < num_classes; i++) {
|
||||||
|
accuracy_value += confusion_matrix[i][i].item<int>();
|
||||||
|
}
|
||||||
|
accuracy_value /= total;
|
||||||
|
}
|
||||||
float Scores::accuracy()
|
float Scores::accuracy()
|
||||||
{
|
{
|
||||||
return accuracy_value;
|
return accuracy_value;
|
||||||
|
@@ -9,6 +9,7 @@ namespace platform {
|
|||||||
class Scores {
|
class Scores {
|
||||||
public:
|
public:
|
||||||
Scores(torch::Tensor& y_test, torch::Tensor& y_pred, int num_classes, std::vector<std::string> labels = {});
|
Scores(torch::Tensor& y_test, torch::Tensor& y_pred, int num_classes, std::vector<std::string> labels = {});
|
||||||
|
explicit Scores(json& confusion_matrix_);
|
||||||
float accuracy();
|
float accuracy();
|
||||||
float f1_score(int num_class);
|
float f1_score(int num_class);
|
||||||
float f1_weighted();
|
float f1_weighted();
|
||||||
@@ -20,6 +21,8 @@ namespace platform {
|
|||||||
json get_confusion_matrix_json(bool labels_as_keys = false);
|
json get_confusion_matrix_json(bool labels_as_keys = false);
|
||||||
private:
|
private:
|
||||||
std::string classification_report_line(std::string label, float precision, float recall, float f1_score, int support);
|
std::string classification_report_line(std::string label, float precision, float recall, float f1_score, int support);
|
||||||
|
void init_confusion_matrix();
|
||||||
|
void init_default_labels();
|
||||||
int num_classes;
|
int num_classes;
|
||||||
float accuracy_value;
|
float accuracy_value;
|
||||||
int total;
|
int total;
|
||||||
|
@@ -147,4 +147,33 @@ TEST_CASE("Classification Report", "[Scores]")
|
|||||||
weighted avg 0.8250000 0.6000000 0.6400000 10
|
weighted avg 0.8250000 0.6000000 0.6400000 10
|
||||||
)";
|
)";
|
||||||
REQUIRE(scores.classification_report() == expected);
|
REQUIRE(scores.classification_report() == expected);
|
||||||
|
}
|
||||||
|
TEST_CASE("JSON constructor", "[Scores]")
|
||||||
|
{
|
||||||
|
std::vector<int> y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 };
|
||||||
|
std::vector<int> y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 };
|
||||||
|
auto y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
||||||
|
auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32);
|
||||||
|
std::vector<std::string> labels = { "Aeroplane", "Boat", "Car" };
|
||||||
|
platform::Scores scores(y_test_tensor, y_pred_tensor, 3, labels);
|
||||||
|
auto res_json_int = scores.get_confusion_matrix_json();
|
||||||
|
platform::Scores scores2(res_json_int);
|
||||||
|
REQUIRE(scores.accuracy() == scores2.accuracy());
|
||||||
|
for (int i = 0; i < 2; ++i) {
|
||||||
|
REQUIRE(scores.f1_score(i) == scores2.f1_score(i));
|
||||||
|
REQUIRE(scores.precision(i) == scores2.precision(i));
|
||||||
|
REQUIRE(scores.recall(i) == scores2.recall(i));
|
||||||
|
}
|
||||||
|
REQUIRE(scores.f1_weighted() == scores2.f1_weighted());
|
||||||
|
REQUIRE(scores.f1_macro() == scores2.f1_macro());
|
||||||
|
auto res_json_key = scores.get_confusion_matrix_json(true);
|
||||||
|
platform::Scores scores3(res_json_key);
|
||||||
|
REQUIRE(scores.accuracy() == scores3.accuracy());
|
||||||
|
for (int i = 0; i < 2; ++i) {
|
||||||
|
REQUIRE(scores.f1_score(i) == scores3.f1_score(i));
|
||||||
|
REQUIRE(scores.precision(i) == scores3.precision(i));
|
||||||
|
REQUIRE(scores.recall(i) == scores3.recall(i));
|
||||||
|
}
|
||||||
|
REQUIRE(scores.f1_weighted() == scores3.f1_weighted());
|
||||||
|
REQUIRE(scores.f1_macro() == scores3.f1_macro());
|
||||||
}
|
}
|
Reference in New Issue
Block a user