Add confusion matrix to json results
Add Aggregate method to Scores
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
#include "reports/ReportConsole.h"
|
||||
#include "common/Paths.h"
|
||||
#include "Models.h"
|
||||
#include "Scores.h"
|
||||
#include "Experiment.h"
|
||||
namespace platform {
|
||||
using json = nlohmann::json;
|
||||
@@ -96,6 +97,7 @@ namespace platform {
|
||||
auto nodes = torch::zeros({ nResults }, torch::kFloat64);
|
||||
auto edges = torch::zeros({ nResults }, torch::kFloat64);
|
||||
auto num_states = torch::zeros({ nResults }, torch::kFloat64);
|
||||
json confusion_matrices = json::array();
|
||||
std::vector<std::string> notes;
|
||||
Timer train_timer, test_timer;
|
||||
int item = 0;
|
||||
@@ -150,10 +152,13 @@ namespace platform {
|
||||
if (!quiet)
|
||||
showProgress(nfold + 1, getColor(clf->getStatus()), "c");
|
||||
test_timer.start();
|
||||
auto accuracy_test_value = clf->score(X_test, y_test);
|
||||
auto y_predict = clf->predict(X_test);
|
||||
Scores scores(y_test, y_predict, states[className].size());
|
||||
auto accuracy_test_value = scores.accuracy();
|
||||
test_time[item] = test_timer.getDuration();
|
||||
accuracy_train[item] = accuracy_train_value;
|
||||
accuracy_test[item] = accuracy_test_value;
|
||||
confusion_matrices.push_back(scores.get_confusion_matrix_json());
|
||||
if (!quiet)
|
||||
std::cout << "\b\b\b, " << flush;
|
||||
// Store results and times in std::vector
|
||||
@@ -173,6 +178,7 @@ namespace platform {
|
||||
partial_result.setTestTimeStd(torch::std(test_time).item<double>()).setTrainTimeStd(torch::std(train_time).item<double>());
|
||||
partial_result.setNodes(torch::mean(nodes).item<double>()).setLeaves(torch::mean(edges).item<double>()).setDepth(torch::mean(num_states).item<double>());
|
||||
partial_result.setDataset(fileName).setNotes(notes);
|
||||
partial_result.setConfusionMatrices(confusion_matrices);
|
||||
addResult(partial_result);
|
||||
}
|
||||
}
|
@@ -27,6 +27,7 @@ namespace platform {
|
||||
data["notes"].insert(data["notes"].end(), notes_.begin(), notes_.end());
|
||||
return *this;
|
||||
}
|
||||
PartialResult& setConfusionMatrices(const json& confusion_matrices) { data["confusion_matrices"] = confusion_matrices; return *this; }
|
||||
PartialResult& setHyperparameters(const json& hyperparameters) { data["hyperparameters"] = hyperparameters; return *this; }
|
||||
PartialResult& setSamples(int samples) { data["samples"] = samples; return *this; }
|
||||
PartialResult& setFeatures(int features) { data["features"] = features; return *this; }
|
||||
|
@@ -25,6 +25,15 @@ namespace platform {
|
||||
labels.push_back("Class " + std::to_string(i));
|
||||
}
|
||||
}
|
||||
void Scores::aggregate(const Scores& a)
|
||||
{
|
||||
if (a.num_classes != num_classes)
|
||||
throw std::invalid_argument("The number of classes must be the same");
|
||||
confusion_matrix += a.confusion_matrix;
|
||||
total += a.total;
|
||||
accuracy_value += a.accuracy_value;
|
||||
accuracy_value /= 2;
|
||||
}
|
||||
Scores::Scores(json& confusion_matrix_)
|
||||
{
|
||||
json values;
|
||||
@@ -46,7 +55,6 @@ namespace platform {
|
||||
confusion_matrix[i][j] = value_int;
|
||||
total += value_int;
|
||||
}
|
||||
std::cout << std::endl;
|
||||
i++;
|
||||
}
|
||||
// Compute accuracy with the confusion matrix
|
||||
|
@@ -19,6 +19,7 @@ namespace platform {
|
||||
torch::Tensor get_confusion_matrix() { return confusion_matrix; }
|
||||
std::string classification_report();
|
||||
json get_confusion_matrix_json(bool labels_as_keys = false);
|
||||
void aggregate(const Scores& a);
|
||||
private:
|
||||
std::string classification_report_line(std::string label, float precision, float recall, float f1_score, int support);
|
||||
void init_confusion_matrix();
|
||||
|
@@ -36,7 +36,7 @@ void make_test_bin(int TP, int TN, int FP, int FN, std::vector<int>& y_test, std
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("TestScores binary", "[Scores]")
|
||||
TEST_CASE("Scores binary", "[Scores]")
|
||||
{
|
||||
std::vector<int> y_test;
|
||||
std::vector<int> y_pred;
|
||||
@@ -59,7 +59,7 @@ TEST_CASE("TestScores binary", "[Scores]")
|
||||
REQUIRE(confusion_matrix[1][0].item<int>() == 41);
|
||||
REQUIRE(confusion_matrix[1][1].item<int>() == 197);
|
||||
}
|
||||
TEST_CASE("TestScores multiclass", "[Scores]")
|
||||
TEST_CASE("Scores multiclass", "[Scores]")
|
||||
{
|
||||
std::vector<int> y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 };
|
||||
std::vector<int> y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 };
|
||||
@@ -177,3 +177,42 @@ TEST_CASE("JSON constructor", "[Scores]")
|
||||
REQUIRE(scores.f1_weighted() == scores3.f1_weighted());
|
||||
REQUIRE(scores.f1_macro() == scores3.f1_macro());
|
||||
}
|
||||
TEST_CASE("Aggregate", "[Scores]")
|
||||
{
|
||||
std::vector<int> y_test;
|
||||
std::vector<int> y_pred;
|
||||
make_test_bin(197, 210, 52, 41, y_test, y_pred);
|
||||
auto y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
||||
auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32);
|
||||
platform::Scores scores(y_test_tensor, y_pred_tensor, 2);
|
||||
y_test.clear();
|
||||
y_pred.clear();
|
||||
make_test_bin(227, 187, 39, 47, y_test, y_pred);
|
||||
auto y_test_tensor2 = torch::tensor(y_test, torch::kInt32);
|
||||
auto y_pred_tensor2 = torch::tensor(y_pred, torch::kInt32);
|
||||
platform::Scores scores2(y_test_tensor2, y_pred_tensor2, 2);
|
||||
scores.aggregate(scores2);
|
||||
REQUIRE(scores.accuracy() == Catch::Approx(0.821).epsilon(epsilon));
|
||||
REQUIRE(scores.f1_score(0) == Catch::Approx(0.8160329));
|
||||
REQUIRE(scores.f1_score(1) == Catch::Approx(0.8257059));
|
||||
REQUIRE(scores.precision(0) == Catch::Approx(0.8185567));
|
||||
REQUIRE(scores.precision(1) == Catch::Approx(0.8233010));
|
||||
REQUIRE(scores.recall(0) == Catch::Approx(0.8135246));
|
||||
REQUIRE(scores.recall(1) == Catch::Approx(0.8281250));
|
||||
REQUIRE(scores.f1_weighted() == Catch::Approx(0.8209856));
|
||||
REQUIRE(scores.f1_macro() == Catch::Approx(0.8208694));
|
||||
y_test.clear();
|
||||
y_pred.clear();
|
||||
make_test_bin(197 + 227, 210 + 187, 52 + 39, 41 + 47, y_test, y_pred);
|
||||
y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
||||
y_pred_tensor = torch::tensor(y_pred, torch::kInt32);
|
||||
platform::Scores scores3(y_test_tensor, y_pred_tensor, 2);
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
REQUIRE(scores3.f1_score(i) == scores.f1_score(i));
|
||||
REQUIRE(scores3.precision(i) == scores.precision(i));
|
||||
REQUIRE(scores3.recall(i) == scores.recall(i));
|
||||
}
|
||||
REQUIRE(scores3.f1_weighted() == scores.f1_weighted());
|
||||
REQUIRE(scores3.f1_macro() == scores.f1_macro());
|
||||
REQUIRE(scores3.accuracy() == scores.accuracy());
|
||||
}
|
Reference in New Issue
Block a user