Fix json key automatic ordering error when creating Score from json
This commit is contained in:
@@ -10,6 +10,7 @@
|
||||
#include "main/Scores.h"
|
||||
#include "config.h"
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
auto epsilon = 1e-4;
|
||||
|
||||
void make_test_bin(int TP, int TN, int FP, int FN, std::vector<int>& y_test, std::vector<int>& y_pred)
|
||||
@@ -157,7 +158,7 @@ TEST_CASE("JSON constructor", "[Scores]")
|
||||
std::vector<int> y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 };
|
||||
auto y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
||||
auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32);
|
||||
std::vector<std::string> labels = { "Aeroplane", "Boat", "Car" };
|
||||
std::vector<std::string> labels = { "Car", "Boat", "Aeroplane" };
|
||||
platform::Scores scores(y_test_tensor, y_pred_tensor, 3, labels);
|
||||
auto res_json_int = scores.get_confusion_matrix_json();
|
||||
platform::Scores scores2(res_json_int);
|
||||
@@ -218,4 +219,33 @@ TEST_CASE("Aggregate", "[Scores]")
|
||||
REQUIRE(scores3.f1_weighted() == scores.f1_weighted());
|
||||
REQUIRE(scores3.f1_macro() == scores.f1_macro());
|
||||
REQUIRE(scores3.accuracy() == scores.accuracy());
|
||||
}
|
||||
TEST_CASE("Order of keys", "[Scores]")
|
||||
{
|
||||
std::vector<int> y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 };
|
||||
std::vector<int> y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 };
|
||||
auto y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
||||
auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32);
|
||||
std::vector<std::string> labels = { "Car", "Boat", "Aeroplane" };
|
||||
platform::Scores scores(y_test_tensor, y_pred_tensor, 3, labels);
|
||||
auto res_json_int = scores.get_confusion_matrix_json(true);
|
||||
// Make a temp file and store the json
|
||||
std::string filename = "temp.json";
|
||||
std::ofstream file(filename);
|
||||
file << res_json_int;
|
||||
file.close();
|
||||
// Read the json from the file
|
||||
std::ifstream file2(filename);
|
||||
json res_json_int2;
|
||||
file2 >> res_json_int2;
|
||||
file2.close();
|
||||
// Remove the temp file
|
||||
std::remove(filename.c_str());
|
||||
platform::Scores scores2(res_json_int2);
|
||||
REQUIRE(scores.accuracy() == scores2.accuracy());
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
REQUIRE(scores.f1_score(i) == scores2.f1_score(i));
|
||||
REQUIRE(scores.precision(i) == scores2.precision(i));
|
||||
REQUIRE(scores.recall(i) == scores2.recall(i));
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user