Begin AdaBoost integration

This commit is contained in:
2025-06-18 11:27:11 +02:00
parent 023d5613b4
commit 415a7ae608
10 changed files with 1001 additions and 56 deletions

View File

@@ -39,6 +39,9 @@ TEST_CASE("DecisionTree Hyperparameter Setting", "[DecisionTree]")
REQUIRE_NOTHROW(dt.setMaxDepth(10));
REQUIRE_NOTHROW(dt.setMinSamplesSplit(5));
REQUIRE_NOTHROW(dt.setMinSamplesLeaf(2));
REQUIRE(dt.getMaxDepth() == 10);
REQUIRE(dt.getMinSamplesSplit() == 5);
REQUIRE(dt.getMinSamplesLeaf() == 2);
}
SECTION("Set hyperparameters via JSON")
@@ -49,6 +52,9 @@ TEST_CASE("DecisionTree Hyperparameter Setting", "[DecisionTree]")
params["min_samples_leaf"] = 2;
REQUIRE_NOTHROW(dt.setHyperparameters(params));
REQUIRE(dt.getMaxDepth() == 7);
REQUIRE(dt.getMinSamplesSplit() == 4);
REQUIRE(dt.getMinSamplesLeaf() == 2);
}
SECTION("Invalid hyperparameters should throw")
@@ -164,7 +170,9 @@ TEST_CASE("DecisionTree on Iris Dataset", "[DecisionTree][iris]")
// Calculate accuracy
auto correct = torch::sum(predictions == raw.yt).item<int>();
double accuracy = static_cast<double>(correct) / raw.yt.size(0);
double acurracy_computed = dt.score(raw.Xt, raw.yt);
REQUIRE(accuracy > 0.97); // Reasonable accuracy for Iris
REQUIRE(acurracy_computed == Catch::Approx(accuracy).epsilon(1e-6));
}
SECTION("Training with vector interface")