Fix CFS merit computation error

This commit is contained in:
2025-06-01 13:54:18 +02:00
parent da357ac5ba
commit ad72bb355b
6 changed files with 164 additions and 137 deletions

View File

@@ -33,13 +33,11 @@ TEST_CASE("Feature_select IWSS", "[BoostA2DE]")
auto clf = bayesnet::BoostA2DE();
clf.setHyperparameters({ {"select_features", "IWSS"}, {"threshold", 0.5 } });
clf.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing);
REQUIRE(clf.getNumberOfNodes() == 140);
REQUIRE(clf.getNumberOfEdges() == 294);
REQUIRE(clf.getNotes().size() == 4);
REQUIRE(clf.getNotes()[0] == "Used features in initialization: 4 of 9 with IWSS");
REQUIRE(clf.getNotes()[1] == "Convergence threshold reached & 15 models eliminated");
REQUIRE(clf.getNotes()[2] == "Pairs not used in train: 2");
REQUIRE(clf.getNotes()[3] == "Number of models: 14");
REQUIRE(clf.getNumberOfNodes() == 360);
REQUIRE(clf.getNumberOfEdges() == 756);
REQUIRE(clf.getNotes().size() == 2);
REQUIRE(clf.getNotes()[0] == "Used features in initialization: 9 of 9 with IWSS");
REQUIRE(clf.getNotes()[1] == "Number of models: 36");
}
TEST_CASE("Feature_select FCBF", "[BoostA2DE]")
{
@@ -64,15 +62,15 @@ TEST_CASE("Test used features in train note and score", "[BoostA2DE]")
{"select_features","CFS"},
});
clf.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing);
REQUIRE(clf.getNumberOfNodes() == 144);
REQUIRE(clf.getNumberOfEdges() == 288);
REQUIRE(clf.getNumberOfNodes() == 189);
REQUIRE(clf.getNumberOfEdges() == 378);
REQUIRE(clf.getNotes().size() == 2);
REQUIRE(clf.getNotes()[0] == "Used features in initialization: 6 of 8 with CFS");
REQUIRE(clf.getNotes()[1] == "Number of models: 16");
REQUIRE(clf.getNotes()[0] == "Used features in initialization: 7 of 8 with CFS");
REQUIRE(clf.getNotes()[1] == "Number of models: 21");
auto score = clf.score(raw.Xv, raw.yv);
auto scoret = clf.score(raw.Xt, raw.yt);
REQUIRE(score == Catch::Approx(0.856771).epsilon(raw.epsilon));
REQUIRE(scoret == Catch::Approx(0.856771).epsilon(raw.epsilon));
REQUIRE(score == Catch::Approx(0.85546875f).epsilon(raw.epsilon));
REQUIRE(scoret == Catch::Approx(0.85546875f).epsilon(raw.epsilon));
}
TEST_CASE("Voting vs proba", "[BoostA2DE]")
{