Complete Conditional Mutual Information and test
This commit is contained in:
parent
521bfd2a8e
commit
0e24135d46
@ -204,7 +204,7 @@ namespace bayesnet {
|
|||||||
// I(X;Y|C) = H(Y|C) - H(Y|X,C)
|
// I(X;Y|C) = H(Y|C) - H(Y|X,C)
|
||||||
double Metrics::conditionalMutualInformation(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& labels, const torch::Tensor& weights)
|
double Metrics::conditionalMutualInformation(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& labels, const torch::Tensor& weights)
|
||||||
{
|
{
|
||||||
return conditionalEntropy(firstFeature, labels, weights) - conditionalEntropy(firstFeature, secondFeature, labels, weights);
|
return std::max(conditionalEntropy(firstFeature, labels, weights) - conditionalEntropy(firstFeature, secondFeature, labels, weights), 0.0);
|
||||||
}
|
}
|
||||||
/*
|
/*
|
||||||
Compute the maximum spanning tree considering the weights as distances
|
Compute the maximum spanning tree considering the weights as distances
|
||||||
|
@ -101,32 +101,20 @@ TEST_CASE("Entropy Test", "[Metrics]")
|
|||||||
}
|
}
|
||||||
TEST_CASE("Conditional Entropy", "[Metrics]")
|
TEST_CASE("Conditional Entropy", "[Metrics]")
|
||||||
{
|
{
|
||||||
auto raw = RawDatasets("mfeat-factors", true);
|
auto raw = RawDatasets("iris", true);
|
||||||
bayesnet::Metrics metrics(raw.dataset, raw.features, raw.className, raw.classNumStates);
|
bayesnet::Metrics metrics(raw.dataset, raw.features, raw.className, raw.classNumStates);
|
||||||
bayesnet::Metrics metrics2(raw.dataset, raw.features, raw.className, raw.classNumStates);
|
auto expected = std::map<std::pair<int, int>, double>{
|
||||||
auto feature0 = raw.dataset.index({ 0, "..." });
|
{ { 0, 1 }, 0.0 },
|
||||||
auto feature1 = raw.dataset.index({ 1, "..." });
|
{ { 0, 2 }, 0.287696 },
|
||||||
auto feature2 = raw.dataset.index({ 2, "..." });
|
{ { 0, 3 }, 0.403749 },
|
||||||
auto feature3 = raw.dataset.index({ 3, "..." });
|
{ { 1, 2 }, 1.17112 },
|
||||||
platform::Timer timer;
|
{ { 1, 3 }, 1.31852 },
|
||||||
double result, greatest = 0;
|
{ { 2, 3 }, 0.210068 },
|
||||||
int best_i, best_j;
|
};
|
||||||
timer.start();
|
|
||||||
for (int i = 0; i < raw.features.size() - 1; ++i) {
|
for (int i = 0; i < raw.features.size() - 1; ++i) {
|
||||||
if (i % 50 == 0) {
|
|
||||||
std::cout << "i=" << i << " Time=" << timer.getDurationString(true) << std::endl;
|
|
||||||
}
|
|
||||||
for (int j = i + 1; j < raw.features.size(); ++j) {
|
for (int j = i + 1; j < raw.features.size(); ++j) {
|
||||||
result = metrics.conditionalMutualInformation(raw.dataset.index({ i, "..." }), raw.dataset.index({ j, "..." }), raw.yt, raw.weights);
|
double result = metrics.conditionalMutualInformation(raw.dataset.index({ i, "..." }), raw.dataset.index({ j, "..." }), raw.yt, raw.weights);
|
||||||
if (result > greatest) {
|
REQUIRE(result == Catch::Approx(expected.at({ i, j })).epsilon(raw.epsilon));
|
||||||
greatest = result;
|
|
||||||
best_i = i;
|
|
||||||
best_j = j;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
timer.stop();
|
|
||||||
std::cout << "CMI(" << best_i << "," << best_j << ")=" << greatest << "\n";
|
|
||||||
std::cout << "Time=" << timer.getDurationString() << std::endl;
|
|
||||||
// Se pueden precalcular estos valores y utilizarlos en el algoritmo como entrada
|
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user