From 291ba0fb0e329eb12297d1c7dce9f3fc5928fcdf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Thu, 16 May 2024 16:33:33 +0200 Subject: [PATCH] First functional BoostA2DE with its 1st test --- tests/TestBoostA2DE.cc | 31 ++++++++++--------------------- 1 file changed, 10 insertions(+), 21 deletions(-) diff --git a/tests/TestBoostA2DE.cc b/tests/TestBoostA2DE.cc index 4b870f7..b0f6b4a 100644 --- a/tests/TestBoostA2DE.cc +++ b/tests/TestBoostA2DE.cc @@ -16,27 +16,16 @@ TEST_CASE("Build basic model", "[BoostA2DE]") { auto raw = RawDatasets("diabetes", true); - bayesnet::Metrics metrics(raw.dataset, raw.features, raw.className, raw.classNumStates); - auto expected = std::map, double>{ - { { 0, 1 }, 0.0 }, - { { 0, 2 }, 0.287696 }, - { { 0, 3 }, 0.403749 }, - { { 1, 2 }, 1.17112 }, - { { 1, 3 }, 1.31852 }, - { { 2, 3 }, 0.210068 }, - }; - for (int i = 0; i < raw.features.size() - 1; ++i) { - for (int j = i + 1; j < raw.features.size(); ++j) { - double result = metrics.conditionalMutualInformation(raw.dataset.index({ i, "..." }), raw.dataset.index({ j, "..." }), raw.yt, raw.weights); - // REQUIRE(result == Catch::Approx(expected.at({ i, j })).epsilon(raw.epsilon)); - auto clf = bayesnet::SPnDE({ i, j }); - clf.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states); - auto score = clf.score(raw.Xt, raw.yt); - std::cout << " i " << i << " j " << j << " cmi " - << std::setw(8) << std::setprecision(6) << fixed << result - << " score = " << score << std::endl; - } - } + auto clf = bayesnet::BoostA2DE(); + clf.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states); + REQUIRE(clf.getNumberOfNodes() == 342); + REQUIRE(clf.getNumberOfEdges() == 684); + REQUIRE(clf.getNotes().size() == 3); + REQUIRE(clf.getNotes()[0] == "Convergence threshold reached & 15 models eliminated"); + REQUIRE(clf.getNotes()[1] == "Used pairs not used in train: 20"); + REQUIRE(clf.getNotes()[2] == "Number of models: 38"); + auto score = clf.score(raw.Xv, raw.yv); + REQUIRE(score == Catch::Approx(0.919271).epsilon(raw.epsilon)); } // TEST_CASE("Feature_select IWSS", "[BoostAODE]") // {