Force mutual information methods to be at least 0
There were cases where a tiny negative number was returned (less than -1e-7) Fix mst glass test that is affected with this change
This commit is contained in:
parent
291ba0fb0e
commit
2584e8294d
@ -155,7 +155,7 @@ namespace bayesnet {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (pairSelection.size() > 0) {
|
if (pairSelection.size() > 0) {
|
||||||
notes.push_back("Used pairs not used in train: " + std::to_string(pairSelection.size()));
|
notes.push_back("Pairs not used in train: " + std::to_string(pairSelection.size()));
|
||||||
status = WARNING;
|
status = WARNING;
|
||||||
}
|
}
|
||||||
notes.push_back("Number of models: " + std::to_string(n_models));
|
notes.push_back("Number of models: " + std::to_string(n_models));
|
||||||
|
@ -198,24 +198,20 @@ namespace bayesnet {
|
|||||||
}
|
}
|
||||||
return entropyValue;
|
return entropyValue;
|
||||||
}
|
}
|
||||||
// H(Y|X,C) = sum_{x in X, c in C} p(x,c) H(Y|X=x,C=c)
|
// H(X|Y,C) = sum_{y in Y, c in C} p(x,c) H(X|Y=y,C=c)
|
||||||
double Metrics::conditionalEntropy(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& labels, const torch::Tensor& weights)
|
double Metrics::conditionalEntropy(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& labels, const torch::Tensor& weights)
|
||||||
{
|
{
|
||||||
// Ensure the tensors are of the same length
|
// Ensure the tensors are of the same length
|
||||||
assert(firstFeature.size(0) == secondFeature.size(0) && firstFeature.size(0) == labels.size(0) && firstFeature.size(0) == weights.size(0));
|
assert(firstFeature.size(0) == secondFeature.size(0) && firstFeature.size(0) == labels.size(0) && firstFeature.size(0) == weights.size(0));
|
||||||
|
|
||||||
// Convert tensors to vectors for easier processing
|
// Convert tensors to vectors for easier processing
|
||||||
auto firstFeatureData = firstFeature.accessor<int, 1>();
|
auto firstFeatureData = firstFeature.accessor<int, 1>();
|
||||||
auto secondFeatureData = secondFeature.accessor<int, 1>();
|
auto secondFeatureData = secondFeature.accessor<int, 1>();
|
||||||
auto labelsData = labels.accessor<int, 1>();
|
auto labelsData = labels.accessor<int, 1>();
|
||||||
auto weightsData = weights.accessor<double, 1>();
|
auto weightsData = weights.accessor<double, 1>();
|
||||||
|
|
||||||
int numSamples = firstFeature.size(0);
|
int numSamples = firstFeature.size(0);
|
||||||
|
|
||||||
// Maps for joint and marginal probabilities
|
// Maps for joint and marginal probabilities
|
||||||
std::map<std::tuple<int, int, int>, double> jointCount;
|
std::map<std::tuple<int, int, int>, double> jointCount;
|
||||||
std::map<std::tuple<int, int>, double> marginalCount;
|
std::map<std::tuple<int, int>, double> marginalCount;
|
||||||
|
|
||||||
// Compute joint and marginal counts
|
// Compute joint and marginal counts
|
||||||
for (int i = 0; i < numSamples; ++i) {
|
for (int i = 0; i < numSamples; ++i) {
|
||||||
auto keyJoint = std::make_tuple(firstFeatureData[i], labelsData[i], secondFeatureData[i]);
|
auto keyJoint = std::make_tuple(firstFeatureData[i], labelsData[i], secondFeatureData[i]);
|
||||||
@ -224,34 +220,29 @@ namespace bayesnet {
|
|||||||
jointCount[keyJoint] += weightsData[i];
|
jointCount[keyJoint] += weightsData[i];
|
||||||
marginalCount[keyMarginal] += weightsData[i];
|
marginalCount[keyMarginal] += weightsData[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
// Total weight sum
|
// Total weight sum
|
||||||
double totalWeight = torch::sum(weights).item<double>();
|
double totalWeight = torch::sum(weights).item<double>();
|
||||||
if (totalWeight == 0)
|
if (totalWeight == 0)
|
||||||
return 0;
|
return 0;
|
||||||
|
|
||||||
// Compute the conditional entropy
|
// Compute the conditional entropy
|
||||||
double conditionalEntropy = 0.0;
|
double conditionalEntropy = 0.0;
|
||||||
|
|
||||||
for (const auto& [keyJoint, jointFreq] : jointCount) {
|
for (const auto& [keyJoint, jointFreq] : jointCount) {
|
||||||
auto [x, c, y] = keyJoint;
|
auto [x, c, y] = keyJoint;
|
||||||
auto keyMarginal = std::make_tuple(x, c);
|
auto keyMarginal = std::make_tuple(x, c);
|
||||||
|
|
||||||
//double p_xc = marginalCount[keyMarginal] / totalWeight;
|
//double p_xc = marginalCount[keyMarginal] / totalWeight;
|
||||||
double p_y_given_xc = jointFreq / marginalCount[keyMarginal];
|
double p_y_given_xc = jointFreq / marginalCount[keyMarginal];
|
||||||
|
|
||||||
if (p_y_given_xc > 0) {
|
if (p_y_given_xc > 0) {
|
||||||
conditionalEntropy -= (jointFreq / totalWeight) * std::log(p_y_given_xc);
|
conditionalEntropy -= (jointFreq / totalWeight) * std::log(p_y_given_xc);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return conditionalEntropy;
|
return conditionalEntropy;
|
||||||
}
|
}
|
||||||
// I(X;Y) = H(Y) - H(Y|X)
|
// I(X;Y) = H(Y) - H(Y|X) ; I(X;Y) >= 0
|
||||||
double Metrics::mutualInformation(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& weights)
|
double Metrics::mutualInformation(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& weights)
|
||||||
{
|
{
|
||||||
return entropy(firstFeature, weights) - conditionalEntropy(firstFeature, secondFeature, weights);
|
return std::max(entropy(firstFeature, weights) - conditionalEntropy(firstFeature, secondFeature, weights), 0.0);
|
||||||
}
|
}
|
||||||
// I(X;Y|C) = H(Y|C) - H(Y|X,C)
|
// I(X;Y|C) = H(X|C) - H(X|Y,C) >= 0
|
||||||
double Metrics::conditionalMutualInformation(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& labels, const torch::Tensor& weights)
|
double Metrics::conditionalMutualInformation(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& labels, const torch::Tensor& weights)
|
||||||
{
|
{
|
||||||
return std::max(conditionalEntropy(firstFeature, labels, weights) - conditionalEntropy(firstFeature, secondFeature, labels, weights), 0.0);
|
return std::max(conditionalEntropy(firstFeature, labels, weights) - conditionalEntropy(firstFeature, secondFeature, labels, weights), 0.0);
|
||||||
|
@ -11,7 +11,6 @@
|
|||||||
#include "TestUtils.h"
|
#include "TestUtils.h"
|
||||||
#include "Timer.h"
|
#include "Timer.h"
|
||||||
|
|
||||||
|
|
||||||
TEST_CASE("Metrics Test", "[Metrics]")
|
TEST_CASE("Metrics Test", "[Metrics]")
|
||||||
{
|
{
|
||||||
std::string file_name = GENERATE("glass", "iris", "ecoli", "diabetes");
|
std::string file_name = GENERATE("glass", "iris", "ecoli", "diabetes");
|
||||||
@ -28,8 +27,8 @@ TEST_CASE("Metrics Test", "[Metrics]")
|
|||||||
{"diabetes", 0.0345470614}
|
{"diabetes", 0.0345470614}
|
||||||
};
|
};
|
||||||
map<pair<std::string, int>, std::vector<pair<int, int>>> resultsMST = {
|
map<pair<std::string, int>, std::vector<pair<int, int>>> resultsMST = {
|
||||||
{ {"glass", 0}, { {0, 6}, {0, 5}, {0, 3}, {5, 1}, {5, 8}, {5, 4}, {6, 2}, {6, 7} } },
|
{ {"glass", 0}, { {0, 6}, {0, 5}, {0, 3}, {3, 4}, {5, 1}, {5, 8}, {6, 2}, {6, 7} } },
|
||||||
{ {"glass", 1}, { {1, 5}, {5, 0}, {5, 8}, {5, 4}, {0, 6}, {0, 3}, {6, 2}, {6, 7} } },
|
{ {"glass", 1}, { {1, 5}, {5, 0}, {5, 8}, {0, 6}, {0, 3}, {3, 4}, {6, 2}, {6, 7} } },
|
||||||
{ {"iris", 0}, { {0, 1}, {0, 2}, {1, 3} } },
|
{ {"iris", 0}, { {0, 1}, {0, 2}, {1, 3} } },
|
||||||
{ {"iris", 1}, { {1, 0}, {1, 3}, {0, 2} } },
|
{ {"iris", 1}, { {1, 0}, {1, 3}, {0, 2} } },
|
||||||
{ {"ecoli", 0}, { {0, 1}, {0, 2}, {1, 5}, {1, 3}, {5, 6}, {5, 4} } },
|
{ {"ecoli", 0}, { {0, 1}, {0, 2}, {1, 5}, {1, 3}, {5, 6}, {5, 4} } },
|
||||||
|
@ -22,7 +22,7 @@ TEST_CASE("Build basic model", "[BoostA2DE]")
|
|||||||
REQUIRE(clf.getNumberOfEdges() == 684);
|
REQUIRE(clf.getNumberOfEdges() == 684);
|
||||||
REQUIRE(clf.getNotes().size() == 3);
|
REQUIRE(clf.getNotes().size() == 3);
|
||||||
REQUIRE(clf.getNotes()[0] == "Convergence threshold reached & 15 models eliminated");
|
REQUIRE(clf.getNotes()[0] == "Convergence threshold reached & 15 models eliminated");
|
||||||
REQUIRE(clf.getNotes()[1] == "Used pairs not used in train: 20");
|
REQUIRE(clf.getNotes()[1] == "Pairs not used in train: 20");
|
||||||
REQUIRE(clf.getNotes()[2] == "Number of models: 38");
|
REQUIRE(clf.getNotes()[2] == "Number of models: 38");
|
||||||
auto score = clf.score(raw.Xv, raw.yv);
|
auto score = clf.score(raw.Xv, raw.yv);
|
||||||
REQUIRE(score == Catch::Approx(0.919271).epsilon(raw.epsilon));
|
REQUIRE(score == Catch::Approx(0.919271).epsilon(raw.epsilon));
|
||||||
|
Loading…
Reference in New Issue
Block a user