diff --git a/bayesnet/ensembles/BoostA2DE.cc b/bayesnet/ensembles/BoostA2DE.cc index 209173b..4738358 100644 --- a/bayesnet/ensembles/BoostA2DE.cc +++ b/bayesnet/ensembles/BoostA2DE.cc @@ -155,7 +155,7 @@ namespace bayesnet { } } if (pairSelection.size() > 0) { - notes.push_back("Used pairs not used in train: " + std::to_string(pairSelection.size())); + notes.push_back("Pairs not used in train: " + std::to_string(pairSelection.size())); status = WARNING; } notes.push_back("Number of models: " + std::to_string(n_models)); diff --git a/bayesnet/utils/BayesMetrics.cc b/bayesnet/utils/BayesMetrics.cc index 85645c7..b4b12cc 100644 --- a/bayesnet/utils/BayesMetrics.cc +++ b/bayesnet/utils/BayesMetrics.cc @@ -198,24 +198,20 @@ namespace bayesnet { } return entropyValue; } - // H(Y|X,C) = sum_{x in X, c in C} p(x,c) H(Y|X=x,C=c) + // H(X|Y,C) = sum_{y in Y, c in C} p(x,c) H(X|Y=y,C=c) double Metrics::conditionalEntropy(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& labels, const torch::Tensor& weights) { // Ensure the tensors are of the same length assert(firstFeature.size(0) == secondFeature.size(0) && firstFeature.size(0) == labels.size(0) && firstFeature.size(0) == weights.size(0)); - // Convert tensors to vectors for easier processing auto firstFeatureData = firstFeature.accessor(); auto secondFeatureData = secondFeature.accessor(); auto labelsData = labels.accessor(); auto weightsData = weights.accessor(); - int numSamples = firstFeature.size(0); - // Maps for joint and marginal probabilities std::map, double> jointCount; std::map, double> marginalCount; - // Compute joint and marginal counts for (int i = 0; i < numSamples; ++i) { auto keyJoint = std::make_tuple(firstFeatureData[i], labelsData[i], secondFeatureData[i]); @@ -224,34 +220,29 @@ namespace bayesnet { jointCount[keyJoint] += weightsData[i]; marginalCount[keyMarginal] += weightsData[i]; } - // Total weight sum double totalWeight = torch::sum(weights).item(); if (totalWeight == 0) return 0; - // Compute the conditional entropy double conditionalEntropy = 0.0; - for (const auto& [keyJoint, jointFreq] : jointCount) { auto [x, c, y] = keyJoint; auto keyMarginal = std::make_tuple(x, c); - //double p_xc = marginalCount[keyMarginal] / totalWeight; double p_y_given_xc = jointFreq / marginalCount[keyMarginal]; - if (p_y_given_xc > 0) { conditionalEntropy -= (jointFreq / totalWeight) * std::log(p_y_given_xc); } } return conditionalEntropy; } - // I(X;Y) = H(Y) - H(Y|X) + // I(X;Y) = H(Y) - H(Y|X) ; I(X;Y) >= 0 double Metrics::mutualInformation(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& weights) { - return entropy(firstFeature, weights) - conditionalEntropy(firstFeature, secondFeature, weights); + return std::max(entropy(firstFeature, weights) - conditionalEntropy(firstFeature, secondFeature, weights), 0.0); } - // I(X;Y|C) = H(Y|C) - H(Y|X,C) + // I(X;Y|C) = H(X|C) - H(X|Y,C) >= 0 double Metrics::conditionalMutualInformation(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& labels, const torch::Tensor& weights) { return std::max(conditionalEntropy(firstFeature, labels, weights) - conditionalEntropy(firstFeature, secondFeature, labels, weights), 0.0); diff --git a/tests/TestBayesMetrics.cc b/tests/TestBayesMetrics.cc index 755a4a0..f2eeebf 100644 --- a/tests/TestBayesMetrics.cc +++ b/tests/TestBayesMetrics.cc @@ -11,7 +11,6 @@ #include "TestUtils.h" #include "Timer.h" - TEST_CASE("Metrics Test", "[Metrics]") { std::string file_name = GENERATE("glass", "iris", "ecoli", "diabetes"); @@ -28,8 +27,8 @@ TEST_CASE("Metrics Test", "[Metrics]") {"diabetes", 0.0345470614} }; map, std::vector>> resultsMST = { - { {"glass", 0}, { {0, 6}, {0, 5}, {0, 3}, {5, 1}, {5, 8}, {5, 4}, {6, 2}, {6, 7} } }, - { {"glass", 1}, { {1, 5}, {5, 0}, {5, 8}, {5, 4}, {0, 6}, {0, 3}, {6, 2}, {6, 7} } }, + { {"glass", 0}, { {0, 6}, {0, 5}, {0, 3}, {3, 4}, {5, 1}, {5, 8}, {6, 2}, {6, 7} } }, + { {"glass", 1}, { {1, 5}, {5, 0}, {5, 8}, {0, 6}, {0, 3}, {3, 4}, {6, 2}, {6, 7} } }, { {"iris", 0}, { {0, 1}, {0, 2}, {1, 3} } }, { {"iris", 1}, { {1, 0}, {1, 3}, {0, 2} } }, { {"ecoli", 0}, { {0, 1}, {0, 2}, {1, 5}, {1, 3}, {5, 6}, {5, 4} } }, diff --git a/tests/TestBoostA2DE.cc b/tests/TestBoostA2DE.cc index b0f6b4a..b841bc3 100644 --- a/tests/TestBoostA2DE.cc +++ b/tests/TestBoostA2DE.cc @@ -22,7 +22,7 @@ TEST_CASE("Build basic model", "[BoostA2DE]") REQUIRE(clf.getNumberOfEdges() == 684); REQUIRE(clf.getNotes().size() == 3); REQUIRE(clf.getNotes()[0] == "Convergence threshold reached & 15 models eliminated"); - REQUIRE(clf.getNotes()[1] == "Used pairs not used in train: 20"); + REQUIRE(clf.getNotes()[1] == "Pairs not used in train: 20"); REQUIRE(clf.getNotes()[2] == "Number of models: 38"); auto score = clf.score(raw.Xv, raw.yv); REQUIRE(score == Catch::Approx(0.919271).epsilon(raw.epsilon));