Begin AdaBoost integration
This commit is contained in:
@@ -39,6 +39,9 @@ TEST_CASE("DecisionTree Hyperparameter Setting", "[DecisionTree]")
|
||||
REQUIRE_NOTHROW(dt.setMaxDepth(10));
|
||||
REQUIRE_NOTHROW(dt.setMinSamplesSplit(5));
|
||||
REQUIRE_NOTHROW(dt.setMinSamplesLeaf(2));
|
||||
REQUIRE(dt.getMaxDepth() == 10);
|
||||
REQUIRE(dt.getMinSamplesSplit() == 5);
|
||||
REQUIRE(dt.getMinSamplesLeaf() == 2);
|
||||
}
|
||||
|
||||
SECTION("Set hyperparameters via JSON")
|
||||
@@ -49,6 +52,9 @@ TEST_CASE("DecisionTree Hyperparameter Setting", "[DecisionTree]")
|
||||
params["min_samples_leaf"] = 2;
|
||||
|
||||
REQUIRE_NOTHROW(dt.setHyperparameters(params));
|
||||
REQUIRE(dt.getMaxDepth() == 7);
|
||||
REQUIRE(dt.getMinSamplesSplit() == 4);
|
||||
REQUIRE(dt.getMinSamplesLeaf() == 2);
|
||||
}
|
||||
|
||||
SECTION("Invalid hyperparameters should throw")
|
||||
@@ -164,7 +170,9 @@ TEST_CASE("DecisionTree on Iris Dataset", "[DecisionTree][iris]")
|
||||
// Calculate accuracy
|
||||
auto correct = torch::sum(predictions == raw.yt).item<int>();
|
||||
double accuracy = static_cast<double>(correct) / raw.yt.size(0);
|
||||
double acurracy_computed = dt.score(raw.Xt, raw.yt);
|
||||
REQUIRE(accuracy > 0.97); // Reasonable accuracy for Iris
|
||||
REQUIRE(acurracy_computed == Catch::Approx(accuracy).epsilon(1e-6));
|
||||
}
|
||||
|
||||
SECTION("Training with vector interface")
|
||||
|
Reference in New Issue
Block a user