// *************************************************************** // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez // SPDX-FileType: SOURCE // SPDX-License-Identifier: MIT // *************************************************************** #include #include #include #include #include #include #include #include "experimental_clfs/DecisionTree.h" #include "TestUtils.h" using namespace bayesnet; using namespace Catch::Matchers; TEST_CASE("DecisionTree Construction", "[DecisionTree]") { SECTION("Default constructor") { REQUIRE_NOTHROW(DecisionTree()); } SECTION("Constructor with parameters") { REQUIRE_NOTHROW(DecisionTree(5, 10, 3)); } } TEST_CASE("DecisionTree Hyperparameter Setting", "[DecisionTree]") { DecisionTree dt; SECTION("Set individual hyperparameters") { REQUIRE_NOTHROW(dt.setMaxDepth(10)); REQUIRE_NOTHROW(dt.setMinSamplesSplit(5)); REQUIRE_NOTHROW(dt.setMinSamplesLeaf(2)); REQUIRE(dt.getMaxDepth() == 10); REQUIRE(dt.getMinSamplesSplit() == 5); REQUIRE(dt.getMinSamplesLeaf() == 2); } SECTION("Set hyperparameters via JSON") { nlohmann::json params; params["max_depth"] = 7; params["min_samples_split"] = 4; params["min_samples_leaf"] = 2; REQUIRE_NOTHROW(dt.setHyperparameters(params)); REQUIRE(dt.getMaxDepth() == 7); REQUIRE(dt.getMinSamplesSplit() == 4); REQUIRE(dt.getMinSamplesLeaf() == 2); } SECTION("Invalid hyperparameters should throw") { nlohmann::json params; // Negative max_depth params["max_depth"] = -1; REQUIRE_THROWS_AS(dt.setHyperparameters(params), std::invalid_argument); // Zero min_samples_split params["max_depth"] = 5; params["min_samples_split"] = 0; REQUIRE_THROWS_AS(dt.setHyperparameters(params), std::invalid_argument); // Negative min_samples_leaf params["min_samples_split"] = 2; params["min_samples_leaf"] = -5; REQUIRE_THROWS_AS(dt.setHyperparameters(params), std::invalid_argument); } } TEST_CASE("DecisionTree Basic Functionality", "[DecisionTree]") { // Create a simple dataset int n_samples = 20; int n_features = 2; std::vector> X(n_features, std::vector(n_samples)); std::vector y(n_samples); // Simple pattern: class depends on first feature for (int i = 0; i < n_samples; i++) { X[0][i] = i < 10 ? 0 : 1; X[1][i] = i % 2; y[i] = X[0][i]; // Class equals first feature } std::vector features = { "f1", "f2" }; std::string className = "class"; std::map> states; states["f1"] = { 0, 1 }; states["f2"] = { 0, 1 }; states["class"] = { 0, 1 }; SECTION("Training with vector interface") { DecisionTree dt(3, 2, 1); REQUIRE_NOTHROW(dt.fit(X, y, features, className, states, Smoothing_t::NONE)); auto predictions = dt.predict(X); REQUIRE(predictions.size() == static_cast(n_samples)); // Should achieve perfect accuracy on this simple dataset int correct = 0; for (size_t i = 0; i < predictions.size(); i++) { if (predictions[i] == y[i]) correct++; } REQUIRE(correct == n_samples); } SECTION("Prediction before fitting") { DecisionTree dt; REQUIRE_THROWS_WITH(dt.predict(X), ContainsSubstring("Classifier has not been fitted")); } SECTION("Probability predictions") { DecisionTree dt(3, 2, 1); dt.fit(X, y, features, className, states, Smoothing_t::NONE); auto proba = dt.predict_proba(X); REQUIRE(proba.size() == static_cast(n_samples)); REQUIRE(proba[0].size() == 2); // Two classes // Check probabilities sum to 1 and probabilities are valid auto predictions = dt.predict(X); for (size_t i = 0; i < proba.size(); i++) { auto p = proba[i]; auto pred = predictions[i]; REQUIRE(p.size() == 2); REQUIRE(p[0] >= 0.0); REQUIRE(p[1] >= 0.0); double sum = p[0] + p[1]; //Check that prodict_proba matches the expected predict value REQUIRE(pred == (p[0] > p[1] ? 0 : 1)); REQUIRE(sum == Catch::Approx(1.0).epsilon(1e-6)); } } } TEST_CASE("DecisionTree on Iris Dataset", "[DecisionTree][iris]") { auto raw = RawDatasets("iris", true); SECTION("Training with dataset format") { DecisionTree dt(5, 2, 1); INFO("Dataset shape: " << raw.dataset.sizes()); INFO("Features: " << raw.featurest.size()); INFO("Samples: " << raw.nSamples); // DecisionTree expects dataset in format: features x samples, with labels as last row REQUIRE_NOTHROW(dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE)); // Test prediction auto predictions = dt.predict(raw.Xt); REQUIRE(predictions.size(0) == raw.yt.size(0)); // Calculate accuracy auto correct = torch::sum(predictions == raw.yt).item(); double accuracy = static_cast(correct) / raw.yt.size(0); double acurracy_computed = dt.score(raw.Xt, raw.yt); REQUIRE(accuracy > 0.97); // Reasonable accuracy for Iris REQUIRE(acurracy_computed == Catch::Approx(accuracy).epsilon(1e-6)); } SECTION("Training with vector interface") { DecisionTree dt(5, 2, 1); REQUIRE_NOTHROW(dt.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv, Smoothing_t::NONE)); // std::cout << "Tree structure:\n"; // auto graph_lines = dt.graph("Iris Decision Tree"); // for (const auto& line : graph_lines) { // std::cout << line << "\n"; // } auto predictions = dt.predict(raw.Xv); REQUIRE(predictions.size() == raw.yv.size()); } SECTION("Different tree depths") { std::vector depths = { 1, 3, 5 }; for (int depth : depths) { DecisionTree dt(depth, 2, 1); dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE); auto predictions = dt.predict(raw.Xt); REQUIRE(predictions.size(0) == raw.yt.size(0)); } } } TEST_CASE("DecisionTree Edge Cases", "[DecisionTree]") { auto raw = RawDatasets("iris", true); SECTION("Very shallow tree") { DecisionTree dt(1, 2, 1); // depth = 1 dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE); auto predictions = dt.predict(raw.Xt); REQUIRE(predictions.size(0) == raw.yt.size(0)); // With depth 1, should have at most 2 unique predictions auto unique_vals = at::_unique(predictions); REQUIRE(std::get<0>(unique_vals).size(0) <= 2); } SECTION("High min_samples_split") { DecisionTree dt(10, 50, 1); dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE); auto predictions = dt.predict(raw.Xt); REQUIRE(predictions.size(0) == raw.yt.size(0)); } } TEST_CASE("DecisionTree Graph Visualization", "[DecisionTree]") { // Simple dataset std::vector> X = { {0,0,0,1}, {0,1,1,1} }; // XOR pattern std::vector y = { 0, 1, 1, 0 }; // XOR pattern std::vector features = { "x1", "x2" }; std::string className = "xor"; std::map> states; states["x1"] = { 0, 1 }; states["x2"] = { 0, 1 }; states["xor"] = { 0, 1 }; SECTION("Graph generation") { DecisionTree dt(2, 1, 1); dt.fit(X, y, features, className, states, Smoothing_t::NONE); auto graph_lines = dt.graph(); REQUIRE(graph_lines.size() > 2); REQUIRE(graph_lines.front() == "digraph DecisionTree {"); REQUIRE(graph_lines.back() == "}"); // Should contain node definitions bool has_nodes = false; for (const auto& line : graph_lines) { if (line.find("node") != std::string::npos) { has_nodes = true; break; } } REQUIRE(has_nodes); } SECTION("Graph with title") { DecisionTree dt(2, 1, 1); dt.fit(X, y, features, className, states, Smoothing_t::NONE); auto graph_lines = dt.graph("XOR Tree"); bool has_title = false; for (const auto& line : graph_lines) { if (line.find("label=\"XOR Tree\"") != std::string::npos) { has_title = true; break; } } REQUIRE(has_title); } } TEST_CASE("DecisionTree with Weights", "[DecisionTree]") { auto raw = RawDatasets("iris", true); SECTION("Uniform weights") { DecisionTree dt(5, 2, 1); dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, raw.weights, Smoothing_t::NONE); auto predictions = dt.predict(raw.Xt); REQUIRE(predictions.size(0) == raw.yt.size(0)); } SECTION("Non-uniform weights") { auto weights = torch::ones({ raw.nSamples }); weights.index({ torch::indexing::Slice(0, 50) }) *= 2.0; // Emphasize first class weights = weights / weights.sum(); DecisionTree dt(5, 2, 1); dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, weights, Smoothing_t::NONE); auto predictions = dt.predict(raw.Xt); REQUIRE(predictions.size(0) == raw.yt.size(0)); } }