Add DecisionTree with tests
This commit is contained in:
@@ -12,11 +12,11 @@ if(ENABLE_TESTING)
|
||||
${Bayesnet_INCLUDE_DIRS}
|
||||
)
|
||||
set(TEST_SOURCES_PLATFORM
|
||||
TestUtils.cpp TestPlatform.cpp TestResult.cpp TestScores.cpp
|
||||
TestUtils.cpp TestPlatform.cpp TestResult.cpp TestScores.cpp TestDecisionTree.cpp
|
||||
${Platform_SOURCE_DIR}/src/common/Datasets.cpp ${Platform_SOURCE_DIR}/src/common/Dataset.cpp ${Platform_SOURCE_DIR}/src/common/Discretization.cpp
|
||||
${Platform_SOURCE_DIR}/src/main/Scores.cpp
|
||||
${Platform_SOURCE_DIR}/src/main/Scores.cpp ${Platform_SOURCE_DIR}/src/experimental_clfs/DecisionTree.cpp
|
||||
)
|
||||
add_executable(${TEST_PLATFORM} ${TEST_SOURCES_PLATFORM})
|
||||
target_link_libraries(${TEST_PLATFORM} PUBLIC "${TORCH_LIBRARIES}" mdlp Catch2::Catch2WithMain BayesNet)
|
||||
target_link_libraries(${TEST_PLATFORM} PUBLIC "${TORCH_LIBRARIES}" fimdlp Catch2::Catch2WithMain bayesnet)
|
||||
add_test(NAME ${TEST_PLATFORM} COMMAND ${TEST_PLATFORM})
|
||||
endif(ENABLE_TESTING)
|
||||
|
303
tests/TestDecisionTree.cpp
Normal file
303
tests/TestDecisionTree.cpp
Normal file
@@ -0,0 +1,303 @@
|
||||
// ***************************************************************
|
||||
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
|
||||
// SPDX-FileType: SOURCE
|
||||
// SPDX-License-Identifier: MIT
|
||||
// ***************************************************************
|
||||
|
||||
#include <catch2/catch_test_macros.hpp>
|
||||
#include <catch2/catch_approx.hpp>
|
||||
#include <catch2/matchers/catch_matchers_string.hpp>
|
||||
#include <catch2/matchers/catch_matchers_vector.hpp>
|
||||
#include <torch/torch.h>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include "experimental_clfs/DecisionTree.h"
|
||||
#include "TestUtils.h"
|
||||
|
||||
using namespace bayesnet;
|
||||
using namespace Catch::Matchers;
|
||||
|
||||
TEST_CASE("DecisionTree Construction", "[DecisionTree]")
|
||||
{
|
||||
SECTION("Default constructor")
|
||||
{
|
||||
REQUIRE_NOTHROW(DecisionTree());
|
||||
}
|
||||
|
||||
SECTION("Constructor with parameters")
|
||||
{
|
||||
REQUIRE_NOTHROW(DecisionTree(5, 10, 3));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("DecisionTree Hyperparameter Setting", "[DecisionTree]")
|
||||
{
|
||||
DecisionTree dt;
|
||||
|
||||
SECTION("Set individual hyperparameters")
|
||||
{
|
||||
REQUIRE_NOTHROW(dt.setMaxDepth(10));
|
||||
REQUIRE_NOTHROW(dt.setMinSamplesSplit(5));
|
||||
REQUIRE_NOTHROW(dt.setMinSamplesLeaf(2));
|
||||
}
|
||||
|
||||
SECTION("Set hyperparameters via JSON")
|
||||
{
|
||||
nlohmann::json params;
|
||||
params["max_depth"] = 7;
|
||||
params["min_samples_split"] = 4;
|
||||
params["min_samples_leaf"] = 2;
|
||||
|
||||
REQUIRE_NOTHROW(dt.setHyperparameters(params));
|
||||
}
|
||||
|
||||
SECTION("Invalid hyperparameters should throw")
|
||||
{
|
||||
nlohmann::json params;
|
||||
|
||||
// Negative max_depth
|
||||
params["max_depth"] = -1;
|
||||
REQUIRE_THROWS_AS(dt.setHyperparameters(params), std::invalid_argument);
|
||||
|
||||
// Zero min_samples_split
|
||||
params["max_depth"] = 5;
|
||||
params["min_samples_split"] = 0;
|
||||
REQUIRE_THROWS_AS(dt.setHyperparameters(params), std::invalid_argument);
|
||||
|
||||
// Negative min_samples_leaf
|
||||
params["min_samples_split"] = 2;
|
||||
params["min_samples_leaf"] = -5;
|
||||
REQUIRE_THROWS_AS(dt.setHyperparameters(params), std::invalid_argument);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("DecisionTree Basic Functionality", "[DecisionTree]")
|
||||
{
|
||||
// Create a simple dataset
|
||||
int n_samples = 20;
|
||||
int n_features = 2;
|
||||
|
||||
std::vector<std::vector<int>> X(n_features, std::vector<int>(n_samples));
|
||||
std::vector<int> y(n_samples);
|
||||
|
||||
// Simple pattern: class depends on first feature
|
||||
for (int i = 0; i < n_samples; i++) {
|
||||
X[0][i] = i < 10 ? 0 : 1;
|
||||
X[1][i] = i % 2;
|
||||
y[i] = X[0][i]; // Class equals first feature
|
||||
}
|
||||
|
||||
std::vector<std::string> features = { "f1", "f2" };
|
||||
std::string className = "class";
|
||||
std::map<std::string, std::vector<int>> states;
|
||||
states["f1"] = { 0, 1 };
|
||||
states["f2"] = { 0, 1 };
|
||||
states["class"] = { 0, 1 };
|
||||
|
||||
SECTION("Training with vector interface")
|
||||
{
|
||||
DecisionTree dt(3, 2, 1);
|
||||
REQUIRE_NOTHROW(dt.fit(X, y, features, className, states, Smoothing_t::NONE));
|
||||
|
||||
auto predictions = dt.predict(X);
|
||||
REQUIRE(predictions.size() == static_cast<size_t>(n_samples));
|
||||
|
||||
// Should achieve perfect accuracy on this simple dataset
|
||||
int correct = 0;
|
||||
for (size_t i = 0; i < predictions.size(); i++) {
|
||||
if (predictions[i] == y[i]) correct++;
|
||||
}
|
||||
REQUIRE(correct == n_samples);
|
||||
}
|
||||
|
||||
SECTION("Prediction before fitting")
|
||||
{
|
||||
DecisionTree dt;
|
||||
REQUIRE_THROWS_WITH(dt.predict(X),
|
||||
ContainsSubstring("Classifier has not been fitted"));
|
||||
}
|
||||
|
||||
SECTION("Probability predictions")
|
||||
{
|
||||
DecisionTree dt(3, 2, 1);
|
||||
dt.fit(X, y, features, className, states, Smoothing_t::NONE);
|
||||
|
||||
auto proba = dt.predict_proba(X);
|
||||
REQUIRE(proba.size() == static_cast<size_t>(n_samples));
|
||||
REQUIRE(proba[0].size() == 2); // Two classes
|
||||
|
||||
// Check probabilities sum to 1 and probabilities are valid
|
||||
auto predictions = dt.predict(X);
|
||||
for (size_t i = 0; i < proba.size(); i++) {
|
||||
auto p = proba[i];
|
||||
auto pred = predictions[i];
|
||||
REQUIRE(p.size() == 2);
|
||||
REQUIRE(p[0] >= 0.0);
|
||||
REQUIRE(p[1] >= 0.0);
|
||||
double sum = p[0] + p[1];
|
||||
//Check that prodict_proba matches the expected predict value
|
||||
REQUIRE(pred == (p[0] > p[1] ? 0 : 1));
|
||||
REQUIRE(sum == Catch::Approx(1.0).epsilon(1e-6));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("DecisionTree on Iris Dataset", "[DecisionTree][iris]")
|
||||
{
|
||||
auto raw = RawDatasets("iris", true);
|
||||
|
||||
SECTION("Training with dataset format")
|
||||
{
|
||||
DecisionTree dt(5, 2, 1);
|
||||
|
||||
INFO("Dataset shape: " << raw.dataset.sizes());
|
||||
INFO("Features: " << raw.featurest.size());
|
||||
INFO("Samples: " << raw.nSamples);
|
||||
|
||||
// DecisionTree expects dataset in format: features x samples, with labels as last row
|
||||
REQUIRE_NOTHROW(dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE));
|
||||
|
||||
// Test prediction
|
||||
auto predictions = dt.predict(raw.Xt);
|
||||
REQUIRE(predictions.size(0) == raw.yt.size(0));
|
||||
|
||||
// Calculate accuracy
|
||||
auto correct = torch::sum(predictions == raw.yt).item<int>();
|
||||
double accuracy = static_cast<double>(correct) / raw.yt.size(0);
|
||||
REQUIRE(accuracy > 0.97); // Reasonable accuracy for Iris
|
||||
}
|
||||
|
||||
SECTION("Training with vector interface")
|
||||
{
|
||||
DecisionTree dt(5, 2, 1);
|
||||
|
||||
REQUIRE_NOTHROW(dt.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv, Smoothing_t::NONE));
|
||||
|
||||
// std::cout << "Tree structure:\n";
|
||||
// auto graph_lines = dt.graph("Iris Decision Tree");
|
||||
// for (const auto& line : graph_lines) {
|
||||
// std::cout << line << "\n";
|
||||
// }
|
||||
auto predictions = dt.predict(raw.Xv);
|
||||
REQUIRE(predictions.size() == raw.yv.size());
|
||||
}
|
||||
|
||||
SECTION("Different tree depths")
|
||||
{
|
||||
std::vector<int> depths = { 1, 3, 5 };
|
||||
|
||||
for (int depth : depths) {
|
||||
DecisionTree dt(depth, 2, 1);
|
||||
dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE);
|
||||
|
||||
auto predictions = dt.predict(raw.Xt);
|
||||
REQUIRE(predictions.size(0) == raw.yt.size(0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("DecisionTree Edge Cases", "[DecisionTree]")
|
||||
{
|
||||
auto raw = RawDatasets("iris", true);
|
||||
|
||||
SECTION("Very shallow tree")
|
||||
{
|
||||
DecisionTree dt(1, 2, 1); // depth = 1
|
||||
dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE);
|
||||
|
||||
auto predictions = dt.predict(raw.Xt);
|
||||
REQUIRE(predictions.size(0) == raw.yt.size(0));
|
||||
|
||||
// With depth 1, should have at most 2 unique predictions
|
||||
auto unique_vals = at::_unique(predictions);
|
||||
REQUIRE(std::get<0>(unique_vals).size(0) <= 2);
|
||||
}
|
||||
|
||||
SECTION("High min_samples_split")
|
||||
{
|
||||
DecisionTree dt(10, 50, 1);
|
||||
dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE);
|
||||
|
||||
auto predictions = dt.predict(raw.Xt);
|
||||
REQUIRE(predictions.size(0) == raw.yt.size(0));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("DecisionTree Graph Visualization", "[DecisionTree]")
|
||||
{
|
||||
// Simple dataset
|
||||
std::vector<std::vector<int>> X = { {0,0,0,1}, {0,1,1,1} }; // XOR pattern
|
||||
std::vector<int> y = { 0, 1, 1, 0 }; // XOR pattern
|
||||
std::vector<std::string> features = { "x1", "x2" };
|
||||
std::string className = "xor";
|
||||
std::map<std::string, std::vector<int>> states;
|
||||
states["x1"] = { 0, 1 };
|
||||
states["x2"] = { 0, 1 };
|
||||
states["xor"] = { 0, 1 };
|
||||
|
||||
SECTION("Graph generation")
|
||||
{
|
||||
DecisionTree dt(2, 1, 1);
|
||||
dt.fit(X, y, features, className, states, Smoothing_t::NONE);
|
||||
|
||||
auto graph_lines = dt.graph();
|
||||
|
||||
REQUIRE(graph_lines.size() > 2);
|
||||
REQUIRE(graph_lines.front() == "digraph DecisionTree {");
|
||||
REQUIRE(graph_lines.back() == "}");
|
||||
|
||||
// Should contain node definitions
|
||||
bool has_nodes = false;
|
||||
for (const auto& line : graph_lines) {
|
||||
if (line.find("node") != std::string::npos) {
|
||||
has_nodes = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
REQUIRE(has_nodes);
|
||||
}
|
||||
|
||||
SECTION("Graph with title")
|
||||
{
|
||||
DecisionTree dt(2, 1, 1);
|
||||
dt.fit(X, y, features, className, states, Smoothing_t::NONE);
|
||||
|
||||
auto graph_lines = dt.graph("XOR Tree");
|
||||
|
||||
bool has_title = false;
|
||||
for (const auto& line : graph_lines) {
|
||||
if (line.find("label=\"XOR Tree\"") != std::string::npos) {
|
||||
has_title = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
REQUIRE(has_title);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("DecisionTree with Weights", "[DecisionTree]")
|
||||
{
|
||||
auto raw = RawDatasets("iris", true);
|
||||
|
||||
SECTION("Uniform weights")
|
||||
{
|
||||
DecisionTree dt(5, 2, 1);
|
||||
dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, raw.weights, Smoothing_t::NONE);
|
||||
|
||||
auto predictions = dt.predict(raw.Xt);
|
||||
REQUIRE(predictions.size(0) == raw.yt.size(0));
|
||||
}
|
||||
|
||||
SECTION("Non-uniform weights")
|
||||
{
|
||||
auto weights = torch::ones({ raw.nSamples });
|
||||
weights.index({ torch::indexing::Slice(0, 50) }) *= 2.0; // Emphasize first class
|
||||
weights = weights / weights.sum();
|
||||
|
||||
DecisionTree dt(5, 2, 1);
|
||||
dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, weights, Smoothing_t::NONE);
|
||||
|
||||
auto predictions = dt.predict(raw.Xt);
|
||||
REQUIRE(predictions.size(0) == raw.yt.size(0));
|
||||
}
|
||||
}
|
@@ -7,7 +7,7 @@
|
||||
#include <string>
|
||||
#include "TestUtils.h"
|
||||
#include "folding.hpp"
|
||||
#include <ArffFiles.hpp>
|
||||
#include <ArffFiles/ArffFiles.hpp>
|
||||
#include <bayesnet/classifiers/TAN.h>
|
||||
#include "config_platform.h"
|
||||
|
||||
@@ -20,17 +20,17 @@ TEST_CASE("Test Platform version", "[Platform]")
|
||||
TEST_CASE("Test Folding library version", "[Folding]")
|
||||
{
|
||||
std::string version = folding::KFold(5, 100).version();
|
||||
REQUIRE(version == "1.1.0");
|
||||
REQUIRE(version == "1.1.1");
|
||||
}
|
||||
TEST_CASE("Test BayesNet version", "[BayesNet]")
|
||||
{
|
||||
std::string version = bayesnet::TAN().getVersion();
|
||||
REQUIRE(version == "1.0.6");
|
||||
REQUIRE(version == "1.1.2");
|
||||
}
|
||||
TEST_CASE("Test mdlp version", "[mdlp]")
|
||||
{
|
||||
std::string version = mdlp::CPPFImdlp::version();
|
||||
REQUIRE(version == "2.0.0");
|
||||
REQUIRE(version == "2.0.1");
|
||||
}
|
||||
TEST_CASE("Test Arff version", "[Arff]")
|
||||
{
|
||||
|
@@ -14,38 +14,40 @@
|
||||
using json = nlohmann::ordered_json;
|
||||
auto epsilon = 1e-4;
|
||||
|
||||
void make_test_bin(int TP, int TN, int FP, int FN, std::vector<int>& y_test, std::vector<int>& y_pred)
|
||||
void make_test_bin(int TP, int TN, int FP, int FN, std::vector<int>& y_test, torch::Tensor& y_pred)
|
||||
{
|
||||
// TP
|
||||
std::vector<std::array<double, 2>> probs;
|
||||
// TP: true positive (label 1, predicted 1)
|
||||
for (int i = 0; i < TP; i++) {
|
||||
y_test.push_back(1);
|
||||
y_pred.push_back(1);
|
||||
probs.push_back({ 0.0, 1.0 }); // P(class 0)=0, P(class 1)=1
|
||||
}
|
||||
// TN
|
||||
// TN: true negative (label 0, predicted 0)
|
||||
for (int i = 0; i < TN; i++) {
|
||||
y_test.push_back(0);
|
||||
y_pred.push_back(0);
|
||||
probs.push_back({ 1.0, 0.0 }); // P(class 0)=1, P(class 1)=0
|
||||
}
|
||||
// FP
|
||||
// FP: false positive (label 0, predicted 1)
|
||||
for (int i = 0; i < FP; i++) {
|
||||
y_test.push_back(0);
|
||||
y_pred.push_back(1);
|
||||
probs.push_back({ 0.0, 1.0 }); // P(class 0)=0, P(class 1)=1
|
||||
}
|
||||
// FN
|
||||
// FN: false negative (label 1, predicted 0)
|
||||
for (int i = 0; i < FN; i++) {
|
||||
y_test.push_back(1);
|
||||
y_pred.push_back(0);
|
||||
probs.push_back({ 1.0, 0.0 }); // P(class 0)=1, P(class 1)=0
|
||||
}
|
||||
// Convert to torch::Tensor of double, shape [N,2]
|
||||
y_pred = torch::from_blob(probs.data(), { (long)probs.size(), 2 }, torch::kFloat64).clone();
|
||||
}
|
||||
|
||||
TEST_CASE("Scores binary", "[Scores]")
|
||||
{
|
||||
std::vector<int> y_test;
|
||||
std::vector<int> y_pred;
|
||||
torch::Tensor y_pred;
|
||||
make_test_bin(197, 210, 52, 41, y_test, y_pred);
|
||||
auto y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
||||
auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32);
|
||||
platform::Scores scores(y_test_tensor, y_pred_tensor, 2);
|
||||
platform::Scores scores(y_test_tensor, y_pred, 2);
|
||||
REQUIRE(scores.accuracy() == Catch::Approx(0.814).epsilon(epsilon));
|
||||
REQUIRE(scores.f1_score(0) == Catch::Approx(0.818713));
|
||||
REQUIRE(scores.f1_score(1) == Catch::Approx(0.809035));
|
||||
@@ -64,10 +66,23 @@ TEST_CASE("Scores binary", "[Scores]")
|
||||
TEST_CASE("Scores multiclass", "[Scores]")
|
||||
{
|
||||
std::vector<int> y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 };
|
||||
std::vector<int> y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 };
|
||||
// Refactor y_pred to a tensor of shape [10, 3] with probabilities
|
||||
std::vector<std::array<double, 3>> probs = {
|
||||
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||
{ 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||
{ 0.0, 0.0, 1.0 } // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||
};
|
||||
torch::Tensor y_pred = torch::from_blob(probs.data(), { (long)probs.size(), 3 }, torch::kFloat64).clone();
|
||||
// Convert y_test to a tensor
|
||||
auto y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
||||
auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32);
|
||||
platform::Scores scores(y_test_tensor, y_pred_tensor, 3);
|
||||
platform::Scores scores(y_test_tensor, y_pred, 3);
|
||||
REQUIRE(scores.accuracy() == Catch::Approx(0.6).epsilon(epsilon));
|
||||
REQUIRE(scores.f1_score(0) == Catch::Approx(0.666667));
|
||||
REQUIRE(scores.f1_score(1) == Catch::Approx(0.4));
|
||||
@@ -84,10 +99,21 @@ TEST_CASE("Scores multiclass", "[Scores]")
|
||||
TEST_CASE("Test Confusion Matrix Values", "[Scores]")
|
||||
{
|
||||
std::vector<int> y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 };
|
||||
std::vector<int> y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 };
|
||||
std::vector<std::array<double, 3>> probs = {
|
||||
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||
{ 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||
{ 0.0, 0.0, 1.0 } // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||
};
|
||||
torch::Tensor y_pred = torch::from_blob(probs.data(), { (long)probs.size(), 3 }, torch::kFloat64).clone();
|
||||
auto y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
||||
auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32);
|
||||
platform::Scores scores(y_test_tensor, y_pred_tensor, 3);
|
||||
platform::Scores scores(y_test_tensor, y_pred, 3);
|
||||
auto confusion_matrix = scores.get_confusion_matrix();
|
||||
REQUIRE(confusion_matrix[0][0].item<int>() == 2);
|
||||
REQUIRE(confusion_matrix[0][1].item<int>() == 1);
|
||||
@@ -102,11 +128,22 @@ TEST_CASE("Test Confusion Matrix Values", "[Scores]")
|
||||
TEST_CASE("Confusion Matrix JSON", "[Scores]")
|
||||
{
|
||||
std::vector<int> y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 };
|
||||
std::vector<int> y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 };
|
||||
std::vector<std::array<double, 3>> probs = {
|
||||
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||
{ 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||
{ 0.0, 0.0, 1.0 } // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||
};
|
||||
torch::Tensor y_pred = torch::from_blob(probs.data(), { (long)probs.size(), 3 }, torch::kFloat64).clone();
|
||||
auto y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
||||
auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32);
|
||||
std::vector<std::string> labels = { "Aeroplane", "Boat", "Car" };
|
||||
platform::Scores scores(y_test_tensor, y_pred_tensor, 3, labels);
|
||||
platform::Scores scores(y_test_tensor, y_pred, 3, labels);
|
||||
auto res_json_int = scores.get_confusion_matrix_json();
|
||||
REQUIRE(res_json_int[0][0] == 2);
|
||||
REQUIRE(res_json_int[0][1] == 1);
|
||||
@@ -131,11 +168,22 @@ TEST_CASE("Confusion Matrix JSON", "[Scores]")
|
||||
TEST_CASE("Classification Report", "[Scores]")
|
||||
{
|
||||
std::vector<int> y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 };
|
||||
std::vector<int> y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 };
|
||||
std::vector<std::array<double, 3>> probs = {
|
||||
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||
{ 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||
{ 0.0, 0.0, 1.0 } // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||
};
|
||||
torch::Tensor y_pred = torch::from_blob(probs.data(), { (long)probs.size(), 3 }, torch::kFloat64).clone();
|
||||
auto y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
||||
auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32);
|
||||
std::vector<std::string> labels = { "Aeroplane", "Boat", "Car" };
|
||||
platform::Scores scores(y_test_tensor, y_pred_tensor, 3, labels);
|
||||
platform::Scores scores(y_test_tensor, y_pred, 3, labels);
|
||||
auto report = scores.classification_report(Colors::BLUE(), "train");
|
||||
auto json_matrix = scores.get_confusion_matrix_json(true);
|
||||
platform::Scores scores2(json_matrix);
|
||||
@@ -144,11 +192,22 @@ TEST_CASE("Classification Report", "[Scores]")
|
||||
TEST_CASE("JSON constructor", "[Scores]")
|
||||
{
|
||||
std::vector<int> y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 };
|
||||
std::vector<int> y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 };
|
||||
std::vector<std::array<double, 3>> probs = {
|
||||
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||
{ 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||
{ 0.0, 0.0, 1.0 } // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||
};
|
||||
torch::Tensor y_pred = torch::from_blob(probs.data(), { (long)probs.size(), 3 }, torch::kFloat64).clone();
|
||||
auto y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
||||
auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32);
|
||||
std::vector<std::string> labels = { "Car", "Boat", "Aeroplane" };
|
||||
platform::Scores scores(y_test_tensor, y_pred_tensor, 3, labels);
|
||||
platform::Scores scores(y_test_tensor, y_pred, 3, labels);
|
||||
auto res_json_int = scores.get_confusion_matrix_json();
|
||||
platform::Scores scores2(res_json_int);
|
||||
REQUIRE(scores.accuracy() == scores2.accuracy());
|
||||
@@ -173,17 +232,14 @@ TEST_CASE("JSON constructor", "[Scores]")
|
||||
TEST_CASE("Aggregate", "[Scores]")
|
||||
{
|
||||
std::vector<int> y_test;
|
||||
std::vector<int> y_pred;
|
||||
torch::Tensor y_pred;
|
||||
make_test_bin(197, 210, 52, 41, y_test, y_pred);
|
||||
auto y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
||||
auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32);
|
||||
platform::Scores scores(y_test_tensor, y_pred_tensor, 2);
|
||||
platform::Scores scores(y_test_tensor, y_pred, 2);
|
||||
y_test.clear();
|
||||
y_pred.clear();
|
||||
make_test_bin(227, 187, 39, 47, y_test, y_pred);
|
||||
auto y_test_tensor2 = torch::tensor(y_test, torch::kInt32);
|
||||
auto y_pred_tensor2 = torch::tensor(y_pred, torch::kInt32);
|
||||
platform::Scores scores2(y_test_tensor2, y_pred_tensor2, 2);
|
||||
platform::Scores scores2(y_test_tensor2, y_pred, 2);
|
||||
scores.aggregate(scores2);
|
||||
REQUIRE(scores.accuracy() == Catch::Approx(0.821).epsilon(epsilon));
|
||||
REQUIRE(scores.f1_score(0) == Catch::Approx(0.8160329));
|
||||
@@ -195,11 +251,9 @@ TEST_CASE("Aggregate", "[Scores]")
|
||||
REQUIRE(scores.f1_weighted() == Catch::Approx(0.8209856));
|
||||
REQUIRE(scores.f1_macro() == Catch::Approx(0.8208694));
|
||||
y_test.clear();
|
||||
y_pred.clear();
|
||||
make_test_bin(197 + 227, 210 + 187, 52 + 39, 41 + 47, y_test, y_pred);
|
||||
y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
||||
y_pred_tensor = torch::tensor(y_pred, torch::kInt32);
|
||||
platform::Scores scores3(y_test_tensor, y_pred_tensor, 2);
|
||||
platform::Scores scores3(y_test_tensor, y_pred, 2);
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
REQUIRE(scores3.f1_score(i) == scores.f1_score(i));
|
||||
REQUIRE(scores3.precision(i) == scores.precision(i));
|
||||
@@ -212,11 +266,22 @@ TEST_CASE("Aggregate", "[Scores]")
|
||||
TEST_CASE("Order of keys", "[Scores]")
|
||||
{
|
||||
std::vector<int> y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 };
|
||||
std::vector<int> y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 };
|
||||
std::vector<std::array<double, 3>> probs = {
|
||||
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||
{ 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||
{ 0.0, 0.0, 1.0 } // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||
};
|
||||
torch::Tensor y_pred = torch::from_blob(probs.data(), { (long)probs.size(), 3 }, torch::kFloat64).clone();
|
||||
auto y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
||||
auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32);
|
||||
std::vector<std::string> labels = { "Car", "Boat", "Aeroplane" };
|
||||
platform::Scores scores(y_test_tensor, y_pred_tensor, 3, labels);
|
||||
platform::Scores scores(y_test_tensor, y_pred, 3, labels);
|
||||
auto res_json_int = scores.get_confusion_matrix_json(true);
|
||||
// Make a temp file and store the json
|
||||
std::string filename = "temp.json";
|
||||
|
@@ -5,7 +5,7 @@
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <tuple>
|
||||
#include <ArffFiles.hpp>
|
||||
#include <ArffFiles/ArffFiles.hpp>
|
||||
#include <fimdlp/CPPFImdlp.h>
|
||||
|
||||
bool file_exists(const std::string& name);
|
||||
|
Reference in New Issue
Block a user