Make some boostAODE tests
This commit is contained in:
@@ -102,5 +102,50 @@ TEST_CASE("Order asc, desc & random", "[BoostAODE]")
|
||||
}
|
||||
TEST_CASE("Oddities", "[BoostAODE]")
|
||||
{
|
||||
auto clf = bayesnet::BoostAODE();
|
||||
auto raw = RawDatasets("iris", true);
|
||||
auto bad_hyper = nlohmann::json{
|
||||
{ { "order", "duck" } },
|
||||
{ { "select_features", "duck" } },
|
||||
{ { "maxTolerance", 0 } },
|
||||
{ { "maxTolerance", 5 } },
|
||||
};
|
||||
for (const auto& hyper : bad_hyper.items()) {
|
||||
INFO("BoostAODE hyper: " + hyper.value().dump());
|
||||
REQUIRE_THROWS_AS(clf.setHyperparameters(hyper.value()), std::invalid_argument);
|
||||
}
|
||||
REQUIRE_THROWS_AS(clf.setHyperparameters({ {"maxTolerance", 0 } }), std::invalid_argument);
|
||||
auto bad_hyper_fit = nlohmann::json{
|
||||
{ { "select_features","IWSS" }, { "threshold", -0.01 } },
|
||||
{ { "select_features","IWSS" }, { "threshold", 0.51 } },
|
||||
{ { "select_features","FCBF" }, { "threshold", 1e-8 } },
|
||||
{ { "select_features","FCBF" }, { "threshold", 1.01 } },
|
||||
};
|
||||
for (const auto& hyper : bad_hyper_fit.items()) {
|
||||
INFO("BoostAODE hyper: " + hyper.value().dump());
|
||||
clf.setHyperparameters(hyper.value());
|
||||
REQUIRE_THROWS_AS(clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv), std::invalid_argument);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("Bisection", "[BoostAODE]")
|
||||
{
|
||||
auto clf = bayesnet::BoostAODE();
|
||||
auto raw = RawDatasets("mfeat-factors", true);
|
||||
clf.setHyperparameters({
|
||||
{"bisection", true},
|
||||
{"maxTolerance", 3},
|
||||
{"convergence", true},
|
||||
});
|
||||
clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv);
|
||||
REQUIRE(clf.getNumberOfNodes() == 217);
|
||||
REQUIRE(clf.getNumberOfEdges() == 431);
|
||||
REQUIRE(clf.getNotes().size() == 3);
|
||||
REQUIRE(clf.getNotes()[0] == "Convergence threshold reached & 15 models eliminated");
|
||||
REQUIRE(clf.getNotes()[1] == "Used features in train: 16 of 216");
|
||||
REQUIRE(clf.getNotes()[2] == "Number of models: 1");
|
||||
auto score = clf.score(raw.Xv, raw.yv);
|
||||
auto scoret = clf.score(raw.Xt, raw.yt);
|
||||
REQUIRE(score == Catch::Approx(1.0f).epsilon(raw.epsilon));
|
||||
REQUIRE(scoret == Catch::Approx(1.0f).epsilon(raw.epsilon));
|
||||
}
|
Reference in New Issue
Block a user