diff --git a/bayesnet/utils/BayesMetrics.cc b/bayesnet/utils/BayesMetrics.cc index f3fa20c..3e63038 100644 --- a/bayesnet/utils/BayesMetrics.cc +++ b/bayesnet/utils/BayesMetrics.cc @@ -204,7 +204,7 @@ namespace bayesnet { // I(X;Y|C) = H(Y|C) - H(Y|X,C) double Metrics::conditionalMutualInformation(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& labels, const torch::Tensor& weights) { - return conditionalEntropy(firstFeature, labels, weights) - conditionalEntropy(firstFeature, secondFeature, labels, weights); + return std::max(conditionalEntropy(firstFeature, labels, weights) - conditionalEntropy(firstFeature, secondFeature, labels, weights), 0.0); } /* Compute the maximum spanning tree considering the weights as distances diff --git a/tests/TestBayesMetrics.cc b/tests/TestBayesMetrics.cc index 47ddf11..b7d5c76 100644 --- a/tests/TestBayesMetrics.cc +++ b/tests/TestBayesMetrics.cc @@ -101,32 +101,20 @@ TEST_CASE("Entropy Test", "[Metrics]") } TEST_CASE("Conditional Entropy", "[Metrics]") { - auto raw = RawDatasets("mfeat-factors", true); + auto raw = RawDatasets("iris", true); bayesnet::Metrics metrics(raw.dataset, raw.features, raw.className, raw.classNumStates); - bayesnet::Metrics metrics2(raw.dataset, raw.features, raw.className, raw.classNumStates); - auto feature0 = raw.dataset.index({ 0, "..." }); - auto feature1 = raw.dataset.index({ 1, "..." }); - auto feature2 = raw.dataset.index({ 2, "..." }); - auto feature3 = raw.dataset.index({ 3, "..." }); - platform::Timer timer; - double result, greatest = 0; - int best_i, best_j; - timer.start(); + auto expected = std::map, double>{ + { { 0, 1 }, 0.0 }, + { { 0, 2 }, 0.287696 }, + { { 0, 3 }, 0.403749 }, + { { 1, 2 }, 1.17112 }, + { { 1, 3 }, 1.31852 }, + { { 2, 3 }, 0.210068 }, + }; for (int i = 0; i < raw.features.size() - 1; ++i) { - if (i % 50 == 0) { - std::cout << "i=" << i << " Time=" << timer.getDurationString(true) << std::endl; - } for (int j = i + 1; j < raw.features.size(); ++j) { - result = metrics.conditionalMutualInformation(raw.dataset.index({ i, "..." }), raw.dataset.index({ j, "..." }), raw.yt, raw.weights); - if (result > greatest) { - greatest = result; - best_i = i; - best_j = j; - } + double result = metrics.conditionalMutualInformation(raw.dataset.index({ i, "..." }), raw.dataset.index({ j, "..." }), raw.yt, raw.weights); + REQUIRE(result == Catch::Approx(expected.at({ i, j })).epsilon(raw.epsilon)); } } - timer.stop(); - std::cout << "CMI(" << best_i << "," << best_j << ")=" << greatest << "\n"; - std::cout << "Time=" << timer.getDurationString() << std::endl; - // Se pueden precalcular estos valores y utilizarlos en el algoritmo como entrada } \ No newline at end of file