diff --git a/bayesnet/utils/BayesMetrics.cc b/bayesnet/utils/BayesMetrics.cc index 61e5845..f3fa20c 100644 --- a/bayesnet/utils/BayesMetrics.cc +++ b/bayesnet/utils/BayesMetrics.cc @@ -177,6 +177,8 @@ namespace bayesnet { // Total weight sum double totalWeight = torch::sum(weights).item(); + if (totalWeight == 0) + return 0; // Compute the conditional entropy double conditionalEntropy = 0.0; @@ -192,63 +194,8 @@ namespace bayesnet { conditionalEntropy -= (jointFreq / totalWeight) * std::log(p_y_given_xc); } } - return conditionalEntropy; } - double Metrics::conditionalEntropy2(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& labels, const torch::Tensor& weights) - { - int numSamples = firstFeature.size(0); - // Get unique values for each variable - auto [uniqueX, countsX] = at::_unique(firstFeature); - auto [uniqueC, countsC] = at::_unique(labels); - - // Compute p(x,c) for each unique value of X and C - std::map, double>> jointCounts; - double totalWeight = 0; - for (auto i = 0; i < numSamples; i++) { - int x = firstFeature[i].item(); - int y = secondFeature[i].item(); - int c = labels[i].item(); - const auto key = std::make_pair(x, c); - jointCounts[y][key] += weights[i].item(); - totalWeight += weights[i].item(); - } - if (totalWeight == 0) - return 0; - double entropyValue = 0; - - // Iterate over unique values of X and C - for (int i = 0; i < uniqueX.size(0); i++) { - int x_val = uniqueX[i].item(); - for (int j = 0; j < uniqueC.size(0); j++) { - int c_val = uniqueC[j].item(); - double p_xc = 0; // Probability of (X=x, C=c) - double entropy_f = 0; - // Find joint counts for this specific (X,C) combination - for (auto& [y, jointCount] : jointCounts) { - auto joint_count_xc = jointCount.find({ x_val, c_val }); - if (joint_count_xc != jointCount.end()) { - p_xc += joint_count_xc->second; - } - } - // Only calculate conditional entropy if p(X=x, C=c) > 0 - if (p_xc > 0) { - p_xc /= totalWeight; - for (auto& [y, jointCount] : jointCounts) { - auto key = std::make_pair(x_val, c_val); - double p_y_xc = jointCount[key] / p_xc; - - if (p_y_xc > 0) { - entropy_f -= p_y_xc * log(p_y_xc); - } - } - } - entropyValue += p_xc * entropy_f; - } - } - return entropyValue; - return 0; - } // I(X;Y) = H(Y) - H(Y|X) double Metrics::mutualInformation(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& weights) { diff --git a/bayesnet/utils/BayesMetrics.h b/bayesnet/utils/BayesMetrics.h index 538e574..6c20852 100644 --- a/bayesnet/utils/BayesMetrics.h +++ b/bayesnet/utils/BayesMetrics.h @@ -25,7 +25,6 @@ namespace bayesnet { // Elements of Information Theory, 2nd Edition, Thomas M. Cover, Joy A. Thomas p. 14 double entropy(const torch::Tensor& feature, const torch::Tensor& weights); double conditionalEntropy(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& labels, const torch::Tensor& weights); - double conditionalEntropy2(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& labels, const torch::Tensor& weights); protected: torch::Tensor samples; // n+1xm torch::Tensor used to fit the model where samples[-1] is the y std::vector std::string className; diff --git a/tests/TestBayesMetrics.cc b/tests/TestBayesMetrics.cc index c05b491..47ddf11 100644 --- a/tests/TestBayesMetrics.cc +++ b/tests/TestBayesMetrics.cc @@ -9,6 +9,7 @@ #include #include "bayesnet/utils/BayesMetrics.h" #include "TestUtils.h" +#include "Timer.h" TEST_CASE("Metrics Test", "[Metrics]") @@ -100,15 +101,32 @@ TEST_CASE("Entropy Test", "[Metrics]") } TEST_CASE("Conditional Entropy", "[Metrics]") { - auto raw = RawDatasets("iris", true); + auto raw = RawDatasets("mfeat-factors", true); bayesnet::Metrics metrics(raw.dataset, raw.features, raw.className, raw.classNumStates); + bayesnet::Metrics metrics2(raw.dataset, raw.features, raw.className, raw.classNumStates); auto feature0 = raw.dataset.index({ 0, "..." }); auto feature1 = raw.dataset.index({ 1, "..." }); auto feature2 = raw.dataset.index({ 2, "..." }); auto feature3 = raw.dataset.index({ 3, "..." }); - auto labels = raw.dataset.index({ 4, "..." }); - auto result = metrics.conditionalEntropy(feature0, feature1, labels, raw.weights); - auto result2 = metrics.conditionalEntropy2(feature0, feature1, labels, raw.weights); - std::cout << "Result=" << result << "\n"; - std::cout << "Result2=" << result2 << "\n"; + platform::Timer timer; + double result, greatest = 0; + int best_i, best_j; + timer.start(); + for (int i = 0; i < raw.features.size() - 1; ++i) { + if (i % 50 == 0) { + std::cout << "i=" << i << " Time=" << timer.getDurationString(true) << std::endl; + } + for (int j = i + 1; j < raw.features.size(); ++j) { + result = metrics.conditionalMutualInformation(raw.dataset.index({ i, "..." }), raw.dataset.index({ j, "..." }), raw.yt, raw.weights); + if (result > greatest) { + greatest = result; + best_i = i; + best_j = j; + } + } + } + timer.stop(); + std::cout << "CMI(" << best_i << "," << best_j << ")=" << greatest << "\n"; + std::cout << "Time=" << timer.getDurationString() << std::endl; + // Se pueden precalcular estos valores y utilizarlos en el algoritmo como entrada } \ No newline at end of file diff --git a/tests/Timer.h b/tests/Timer.h new file mode 100644 index 0000000..5deff5b --- /dev/null +++ b/tests/Timer.h @@ -0,0 +1,41 @@ +#pragma once +#include +#include +#include + +namespace platform { + class Timer { + private: + std::chrono::high_resolution_clock::time_point begin; + std::chrono::high_resolution_clock::time_point end; + public: + Timer() = default; + ~Timer() = default; + void start() { begin = std::chrono::high_resolution_clock::now(); } + void stop() { end = std::chrono::high_resolution_clock::now(); } + double getDuration() + { + stop(); + std::chrono::duration time_span = std::chrono::duration_cast> (end - begin); + return time_span.count(); + } + double getLapse() + { + std::chrono::duration time_span = std::chrono::duration_cast> (std::chrono::high_resolution_clock::now() - begin); + return time_span.count(); + } + std::string getDurationString(bool lapse = false) + { + double duration = lapse ? getLapse() : getDuration(); + return translate2String(duration); + } + std::string translate2String(double duration) + { + double durationShow = duration > 3600 ? duration / 3600 : duration > 60 ? duration / 60 : duration; + std::string durationUnit = duration > 3600 ? "h" : duration > 60 ? "m" : "s"; + std::stringstream ss; + ss << std::setprecision(2) << std::fixed << durationShow << " " << durationUnit; + return ss.str(); + } + }; +} /* namespace platform */ \ No newline at end of file