Files
Platform/tests/TestDecisionTree.cpp

311 lines
9.7 KiB
C++

// ***************************************************************
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
// SPDX-FileType: SOURCE
// SPDX-License-Identifier: MIT
// ***************************************************************
#include <catch2/catch_test_macros.hpp>
#include <catch2/catch_approx.hpp>
#include <catch2/matchers/catch_matchers_string.hpp>
#include <catch2/matchers/catch_matchers_vector.hpp>
#include <torch/torch.h>
#include <memory>
#include <stdexcept>
#include "experimental_clfs/DecisionTree.h"
#include "TestUtils.h"
using namespace bayesnet;
using namespace Catch::Matchers;
TEST_CASE("DecisionTree Construction", "[DecisionTree]")
{
SECTION("Default constructor")
{
REQUIRE_NOTHROW(DecisionTree());
}
SECTION("Constructor with parameters")
{
REQUIRE_NOTHROW(DecisionTree(5, 10, 3));
}
}
TEST_CASE("DecisionTree Hyperparameter Setting", "[DecisionTree]")
{
DecisionTree dt;
SECTION("Set individual hyperparameters")
{
REQUIRE_NOTHROW(dt.setMaxDepth(10));
REQUIRE_NOTHROW(dt.setMinSamplesSplit(5));
REQUIRE_NOTHROW(dt.setMinSamplesLeaf(2));
REQUIRE(dt.getMaxDepth() == 10);
REQUIRE(dt.getMinSamplesSplit() == 5);
REQUIRE(dt.getMinSamplesLeaf() == 2);
}
SECTION("Set hyperparameters via JSON")
{
nlohmann::json params;
params["max_depth"] = 7;
params["min_samples_split"] = 4;
params["min_samples_leaf"] = 2;
REQUIRE_NOTHROW(dt.setHyperparameters(params));
REQUIRE(dt.getMaxDepth() == 7);
REQUIRE(dt.getMinSamplesSplit() == 4);
REQUIRE(dt.getMinSamplesLeaf() == 2);
}
SECTION("Invalid hyperparameters should throw")
{
nlohmann::json params;
// Negative max_depth
params["max_depth"] = -1;
REQUIRE_THROWS_AS(dt.setHyperparameters(params), std::invalid_argument);
// Zero min_samples_split
params["max_depth"] = 5;
params["min_samples_split"] = 0;
REQUIRE_THROWS_AS(dt.setHyperparameters(params), std::invalid_argument);
// Negative min_samples_leaf
params["min_samples_split"] = 2;
params["min_samples_leaf"] = -5;
REQUIRE_THROWS_AS(dt.setHyperparameters(params), std::invalid_argument);
}
}
TEST_CASE("DecisionTree Basic Functionality", "[DecisionTree]")
{
// Create a simple dataset
int n_samples = 20;
int n_features = 2;
std::vector<std::vector<int>> X(n_features, std::vector<int>(n_samples));
std::vector<int> y(n_samples);
// Simple pattern: class depends on first feature
for (int i = 0; i < n_samples; i++) {
X[0][i] = i < 10 ? 0 : 1;
X[1][i] = i % 2;
y[i] = X[0][i]; // Class equals first feature
}
std::vector<std::string> features = { "f1", "f2" };
std::string className = "class";
std::map<std::string, std::vector<int>> states;
states["f1"] = { 0, 1 };
states["f2"] = { 0, 1 };
states["class"] = { 0, 1 };
SECTION("Training with vector interface")
{
DecisionTree dt(3, 2, 1);
REQUIRE_NOTHROW(dt.fit(X, y, features, className, states, Smoothing_t::NONE));
auto predictions = dt.predict(X);
REQUIRE(predictions.size() == static_cast<size_t>(n_samples));
// Should achieve perfect accuracy on this simple dataset
int correct = 0;
for (size_t i = 0; i < predictions.size(); i++) {
if (predictions[i] == y[i]) correct++;
}
REQUIRE(correct == n_samples);
}
SECTION("Prediction before fitting")
{
DecisionTree dt;
REQUIRE_THROWS_WITH(dt.predict(X),
ContainsSubstring("Classifier has not been fitted"));
}
SECTION("Probability predictions")
{
DecisionTree dt(3, 2, 1);
dt.fit(X, y, features, className, states, Smoothing_t::NONE);
auto proba = dt.predict_proba(X);
REQUIRE(proba.size() == static_cast<size_t>(n_samples));
REQUIRE(proba[0].size() == 2); // Two classes
// Check probabilities sum to 1 and probabilities are valid
auto predictions = dt.predict(X);
for (size_t i = 0; i < proba.size(); i++) {
auto p = proba[i];
auto pred = predictions[i];
REQUIRE(p.size() == 2);
REQUIRE(p[0] >= 0.0);
REQUIRE(p[1] >= 0.0);
double sum = p[0] + p[1];
//Check that prodict_proba matches the expected predict value
REQUIRE(pred == (p[0] > p[1] ? 0 : 1));
REQUIRE(sum == Catch::Approx(1.0).epsilon(1e-6));
}
}
}
TEST_CASE("DecisionTree on Iris Dataset", "[DecisionTree][iris]")
{
auto raw = RawDatasets("iris", true);
SECTION("Training with dataset format")
{
DecisionTree dt(5, 2, 1);
INFO("Dataset shape: " << raw.dataset.sizes());
INFO("Features: " << raw.featurest.size());
INFO("Samples: " << raw.nSamples);
// DecisionTree expects dataset in format: features x samples, with labels as last row
REQUIRE_NOTHROW(dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE));
// Test prediction
auto predictions = dt.predict(raw.Xt);
REQUIRE(predictions.size(0) == raw.yt.size(0));
// Calculate accuracy
auto correct = torch::sum(predictions == raw.yt).item<int>();
double accuracy = static_cast<double>(correct) / raw.yt.size(0);
double acurracy_computed = dt.score(raw.Xt, raw.yt);
REQUIRE(accuracy > 0.97); // Reasonable accuracy for Iris
REQUIRE(acurracy_computed == Catch::Approx(accuracy).epsilon(1e-6));
}
SECTION("Training with vector interface")
{
DecisionTree dt(5, 2, 1);
REQUIRE_NOTHROW(dt.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv, Smoothing_t::NONE));
// std::cout << "Tree structure:\n";
// auto graph_lines = dt.graph("Iris Decision Tree");
// for (const auto& line : graph_lines) {
// std::cout << line << "\n";
// }
auto predictions = dt.predict(raw.Xv);
REQUIRE(predictions.size() == raw.yv.size());
}
SECTION("Different tree depths")
{
std::vector<int> depths = { 1, 3, 5 };
for (int depth : depths) {
DecisionTree dt(depth, 2, 1);
dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE);
auto predictions = dt.predict(raw.Xt);
REQUIRE(predictions.size(0) == raw.yt.size(0));
}
}
}
TEST_CASE("DecisionTree Edge Cases", "[DecisionTree]")
{
auto raw = RawDatasets("iris", true);
SECTION("Very shallow tree")
{
DecisionTree dt(1, 2, 1); // depth = 1
dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE);
auto predictions = dt.predict(raw.Xt);
REQUIRE(predictions.size(0) == raw.yt.size(0));
// With depth 1, should have at most 2 unique predictions
auto unique_vals = at::_unique(predictions);
REQUIRE(std::get<0>(unique_vals).size(0) <= 2);
}
SECTION("High min_samples_split")
{
DecisionTree dt(10, 50, 1);
dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE);
auto predictions = dt.predict(raw.Xt);
REQUIRE(predictions.size(0) == raw.yt.size(0));
}
}
TEST_CASE("DecisionTree Graph Visualization", "[DecisionTree]")
{
// Simple dataset
std::vector<std::vector<int>> X = { {0,0,0,1}, {0,1,1,1} }; // XOR pattern
std::vector<int> y = { 0, 1, 1, 0 }; // XOR pattern
std::vector<std::string> features = { "x1", "x2" };
std::string className = "xor";
std::map<std::string, std::vector<int>> states;
states["x1"] = { 0, 1 };
states["x2"] = { 0, 1 };
states["xor"] = { 0, 1 };
SECTION("Graph generation")
{
DecisionTree dt(2, 1, 1);
dt.fit(X, y, features, className, states, Smoothing_t::NONE);
auto graph_lines = dt.graph();
REQUIRE(graph_lines.size() > 2);
REQUIRE(graph_lines.front() == "digraph DecisionTree {");
REQUIRE(graph_lines.back() == "}");
// Should contain node definitions
bool has_nodes = false;
for (const auto& line : graph_lines) {
if (line.find("node") != std::string::npos) {
has_nodes = true;
break;
}
}
REQUIRE(has_nodes);
}
SECTION("Graph with title")
{
DecisionTree dt(2, 1, 1);
dt.fit(X, y, features, className, states, Smoothing_t::NONE);
auto graph_lines = dt.graph("XOR Tree");
bool has_title = false;
for (const auto& line : graph_lines) {
if (line.find("label=\"XOR Tree\"") != std::string::npos) {
has_title = true;
break;
}
}
REQUIRE(has_title);
}
}
TEST_CASE("DecisionTree with Weights", "[DecisionTree]")
{
auto raw = RawDatasets("iris", true);
SECTION("Uniform weights")
{
DecisionTree dt(5, 2, 1);
dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, raw.weights, Smoothing_t::NONE);
auto predictions = dt.predict(raw.Xt);
REQUIRE(predictions.size(0) == raw.yt.size(0));
}
SECTION("Non-uniform weights")
{
auto weights = torch::ones({ raw.nSamples });
weights.index({ torch::indexing::Slice(0, 50) }) *= 2.0; // Emphasize first class
weights = weights / weights.sum();
DecisionTree dt(5, 2, 1);
dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, weights, Smoothing_t::NONE);
auto predictions = dt.predict(raw.Xt);
REQUIRE(predictions.size(0) == raw.yt.size(0));
}
}