diff --git a/tests/TestBayesMetrics.cc b/tests/TestBayesMetrics.cc index b7d5c76..96fddd9 100644 --- a/tests/TestBayesMetrics.cc +++ b/tests/TestBayesMetrics.cc @@ -100,6 +100,25 @@ TEST_CASE("Entropy Test", "[Metrics]") REQUIRE(result == Catch::Approx(0.693147180559945).epsilon(raw.epsilon)); } TEST_CASE("Conditional Entropy", "[Metrics]") +{ + auto raw = RawDatasets("iris", true); + bayesnet::Metrics metrics(raw.dataset, raw.features, raw.className, raw.classNumStates); + auto expected = std::map, double>{ + { { 0, 1 }, 1.32674 }, + { { 0, 2 }, 0.236253 }, + { { 0, 3 }, 0.1202 }, + { { 1, 2 }, 0.252551 }, + { { 1, 3 }, 0.10515 }, + { { 2, 3 }, 0.108323 }, + }; + for (int i = 0; i < raw.features.size() - 1; ++i) { + for (int j = i + 1; j < raw.features.size(); ++j) { + double result = metrics.conditionalEntropy(raw.dataset.index({ i, "..." }), raw.dataset.index({ j, "..." }), raw.yt, raw.weights); + REQUIRE(result == Catch::Approx(expected.at({ i, j })).epsilon(raw.epsilon)); + } + } +} +TEST_CASE("Conditional Mutual Information", "[Metrics]") { auto raw = RawDatasets("iris", true); bayesnet::Metrics metrics(raw.dataset, raw.features, raw.className, raw.classNumStates);