Remove unoptimized implementation of conditionalEntropy
This commit is contained in:
parent
e2e0fb0c40
commit
521bfd2a8e
@ -177,6 +177,8 @@ namespace bayesnet {
|
|||||||
|
|
||||||
// Total weight sum
|
// Total weight sum
|
||||||
double totalWeight = torch::sum(weights).item<double>();
|
double totalWeight = torch::sum(weights).item<double>();
|
||||||
|
if (totalWeight == 0)
|
||||||
|
return 0;
|
||||||
|
|
||||||
// Compute the conditional entropy
|
// Compute the conditional entropy
|
||||||
double conditionalEntropy = 0.0;
|
double conditionalEntropy = 0.0;
|
||||||
@ -192,63 +194,8 @@ namespace bayesnet {
|
|||||||
conditionalEntropy -= (jointFreq / totalWeight) * std::log(p_y_given_xc);
|
conditionalEntropy -= (jointFreq / totalWeight) * std::log(p_y_given_xc);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return conditionalEntropy;
|
return conditionalEntropy;
|
||||||
}
|
}
|
||||||
double Metrics::conditionalEntropy2(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& labels, const torch::Tensor& weights)
|
|
||||||
{
|
|
||||||
int numSamples = firstFeature.size(0);
|
|
||||||
// Get unique values for each variable
|
|
||||||
auto [uniqueX, countsX] = at::_unique(firstFeature);
|
|
||||||
auto [uniqueC, countsC] = at::_unique(labels);
|
|
||||||
|
|
||||||
// Compute p(x,c) for each unique value of X and C
|
|
||||||
std::map<int, std::map<std::pair<int, int>, double>> jointCounts;
|
|
||||||
double totalWeight = 0;
|
|
||||||
for (auto i = 0; i < numSamples; i++) {
|
|
||||||
int x = firstFeature[i].item<int>();
|
|
||||||
int y = secondFeature[i].item<int>();
|
|
||||||
int c = labels[i].item<int>();
|
|
||||||
const auto key = std::make_pair(x, c);
|
|
||||||
jointCounts[y][key] += weights[i].item<double>();
|
|
||||||
totalWeight += weights[i].item<float>();
|
|
||||||
}
|
|
||||||
if (totalWeight == 0)
|
|
||||||
return 0;
|
|
||||||
double entropyValue = 0;
|
|
||||||
|
|
||||||
// Iterate over unique values of X and C
|
|
||||||
for (int i = 0; i < uniqueX.size(0); i++) {
|
|
||||||
int x_val = uniqueX[i].item<int>();
|
|
||||||
for (int j = 0; j < uniqueC.size(0); j++) {
|
|
||||||
int c_val = uniqueC[j].item<int>();
|
|
||||||
double p_xc = 0; // Probability of (X=x, C=c)
|
|
||||||
double entropy_f = 0;
|
|
||||||
// Find joint counts for this specific (X,C) combination
|
|
||||||
for (auto& [y, jointCount] : jointCounts) {
|
|
||||||
auto joint_count_xc = jointCount.find({ x_val, c_val });
|
|
||||||
if (joint_count_xc != jointCount.end()) {
|
|
||||||
p_xc += joint_count_xc->second;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Only calculate conditional entropy if p(X=x, C=c) > 0
|
|
||||||
if (p_xc > 0) {
|
|
||||||
p_xc /= totalWeight;
|
|
||||||
for (auto& [y, jointCount] : jointCounts) {
|
|
||||||
auto key = std::make_pair(x_val, c_val);
|
|
||||||
double p_y_xc = jointCount[key] / p_xc;
|
|
||||||
|
|
||||||
if (p_y_xc > 0) {
|
|
||||||
entropy_f -= p_y_xc * log(p_y_xc);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
entropyValue += p_xc * entropy_f;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return entropyValue;
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
// I(X;Y) = H(Y) - H(Y|X)
|
// I(X;Y) = H(Y) - H(Y|X)
|
||||||
double Metrics::mutualInformation(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& weights)
|
double Metrics::mutualInformation(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& weights)
|
||||||
{
|
{
|
||||||
|
@ -25,7 +25,6 @@ namespace bayesnet {
|
|||||||
// Elements of Information Theory, 2nd Edition, Thomas M. Cover, Joy A. Thomas p. 14
|
// Elements of Information Theory, 2nd Edition, Thomas M. Cover, Joy A. Thomas p. 14
|
||||||
double entropy(const torch::Tensor& feature, const torch::Tensor& weights);
|
double entropy(const torch::Tensor& feature, const torch::Tensor& weights);
|
||||||
double conditionalEntropy(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& labels, const torch::Tensor& weights);
|
double conditionalEntropy(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& labels, const torch::Tensor& weights);
|
||||||
double conditionalEntropy2(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& labels, const torch::Tensor& weights);
|
|
||||||
protected:
|
protected:
|
||||||
torch::Tensor samples; // n+1xm torch::Tensor used to fit the model where samples[-1] is the y std::vector
|
torch::Tensor samples; // n+1xm torch::Tensor used to fit the model where samples[-1] is the y std::vector
|
||||||
std::string className;
|
std::string className;
|
||||||
|
@ -9,6 +9,7 @@
|
|||||||
#include <catch2/generators/catch_generators.hpp>
|
#include <catch2/generators/catch_generators.hpp>
|
||||||
#include "bayesnet/utils/BayesMetrics.h"
|
#include "bayesnet/utils/BayesMetrics.h"
|
||||||
#include "TestUtils.h"
|
#include "TestUtils.h"
|
||||||
|
#include "Timer.h"
|
||||||
|
|
||||||
|
|
||||||
TEST_CASE("Metrics Test", "[Metrics]")
|
TEST_CASE("Metrics Test", "[Metrics]")
|
||||||
@ -100,15 +101,32 @@ TEST_CASE("Entropy Test", "[Metrics]")
|
|||||||
}
|
}
|
||||||
TEST_CASE("Conditional Entropy", "[Metrics]")
|
TEST_CASE("Conditional Entropy", "[Metrics]")
|
||||||
{
|
{
|
||||||
auto raw = RawDatasets("iris", true);
|
auto raw = RawDatasets("mfeat-factors", true);
|
||||||
bayesnet::Metrics metrics(raw.dataset, raw.features, raw.className, raw.classNumStates);
|
bayesnet::Metrics metrics(raw.dataset, raw.features, raw.className, raw.classNumStates);
|
||||||
|
bayesnet::Metrics metrics2(raw.dataset, raw.features, raw.className, raw.classNumStates);
|
||||||
auto feature0 = raw.dataset.index({ 0, "..." });
|
auto feature0 = raw.dataset.index({ 0, "..." });
|
||||||
auto feature1 = raw.dataset.index({ 1, "..." });
|
auto feature1 = raw.dataset.index({ 1, "..." });
|
||||||
auto feature2 = raw.dataset.index({ 2, "..." });
|
auto feature2 = raw.dataset.index({ 2, "..." });
|
||||||
auto feature3 = raw.dataset.index({ 3, "..." });
|
auto feature3 = raw.dataset.index({ 3, "..." });
|
||||||
auto labels = raw.dataset.index({ 4, "..." });
|
platform::Timer timer;
|
||||||
auto result = metrics.conditionalEntropy(feature0, feature1, labels, raw.weights);
|
double result, greatest = 0;
|
||||||
auto result2 = metrics.conditionalEntropy2(feature0, feature1, labels, raw.weights);
|
int best_i, best_j;
|
||||||
std::cout << "Result=" << result << "\n";
|
timer.start();
|
||||||
std::cout << "Result2=" << result2 << "\n";
|
for (int i = 0; i < raw.features.size() - 1; ++i) {
|
||||||
|
if (i % 50 == 0) {
|
||||||
|
std::cout << "i=" << i << " Time=" << timer.getDurationString(true) << std::endl;
|
||||||
|
}
|
||||||
|
for (int j = i + 1; j < raw.features.size(); ++j) {
|
||||||
|
result = metrics.conditionalMutualInformation(raw.dataset.index({ i, "..." }), raw.dataset.index({ j, "..." }), raw.yt, raw.weights);
|
||||||
|
if (result > greatest) {
|
||||||
|
greatest = result;
|
||||||
|
best_i = i;
|
||||||
|
best_j = j;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
timer.stop();
|
||||||
|
std::cout << "CMI(" << best_i << "," << best_j << ")=" << greatest << "\n";
|
||||||
|
std::cout << "Time=" << timer.getDurationString() << std::endl;
|
||||||
|
// Se pueden precalcular estos valores y utilizarlos en el algoritmo como entrada
|
||||||
}
|
}
|
41
tests/Timer.h
Normal file
41
tests/Timer.h
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
#pragma once
|
||||||
|
#include <chrono>
|
||||||
|
#include <string>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
namespace platform {
|
||||||
|
class Timer {
|
||||||
|
private:
|
||||||
|
std::chrono::high_resolution_clock::time_point begin;
|
||||||
|
std::chrono::high_resolution_clock::time_point end;
|
||||||
|
public:
|
||||||
|
Timer() = default;
|
||||||
|
~Timer() = default;
|
||||||
|
void start() { begin = std::chrono::high_resolution_clock::now(); }
|
||||||
|
void stop() { end = std::chrono::high_resolution_clock::now(); }
|
||||||
|
double getDuration()
|
||||||
|
{
|
||||||
|
stop();
|
||||||
|
std::chrono::duration<double> time_span = std::chrono::duration_cast<std::chrono::duration<double >> (end - begin);
|
||||||
|
return time_span.count();
|
||||||
|
}
|
||||||
|
double getLapse()
|
||||||
|
{
|
||||||
|
std::chrono::duration<double> time_span = std::chrono::duration_cast<std::chrono::duration<double >> (std::chrono::high_resolution_clock::now() - begin);
|
||||||
|
return time_span.count();
|
||||||
|
}
|
||||||
|
std::string getDurationString(bool lapse = false)
|
||||||
|
{
|
||||||
|
double duration = lapse ? getLapse() : getDuration();
|
||||||
|
return translate2String(duration);
|
||||||
|
}
|
||||||
|
std::string translate2String(double duration)
|
||||||
|
{
|
||||||
|
double durationShow = duration > 3600 ? duration / 3600 : duration > 60 ? duration / 60 : duration;
|
||||||
|
std::string durationUnit = duration > 3600 ? "h" : duration > 60 ? "m" : "s";
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << std::setprecision(2) << std::fixed << durationShow << " " << durationUnit;
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} /* namespace platform */
|
Loading…
Reference in New Issue
Block a user