Compare commits

...

5 Commits

6 changed files with 154 additions and 13 deletions

5
.gitmodules vendored
View File

@@ -13,3 +13,8 @@
url = https://github.com/rmontanana/folding
main = main
update = merge
[submodule "tests/lib/catch2"]
path = tests/lib/catch2
url = https://github.com/catchorg/Catch2.git
main = main
update = merge

View File

@@ -4,6 +4,9 @@
// SPDX-License-Identifier: MIT
// ***************************************************************
#include <map>
#include <unordered_map>
#include <tuple>
#include "Mst.h"
#include "BayesMetrics.h"
namespace bayesnet {
@@ -105,6 +108,8 @@ namespace bayesnet {
}
return matrix;
}
// Measured in nats (natural logarithm (log) base e)
// Elements of Information Theory, 2nd Edition, Thomas M. Cover, Joy A. Thomas p. 14
double Metrics::entropy(const torch::Tensor& feature, const torch::Tensor& weights)
{
torch::Tensor counts = feature.bincount(weights);
@@ -143,11 +148,64 @@ namespace bayesnet {
}
return entropyValue;
}
// H(Y|X,C) = sum_{x in X, c in C} p(x,c) H(Y|X=x,C=c)
double Metrics::conditionalEntropy(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& labels, const torch::Tensor& weights)
{
// Ensure the tensors are of the same length
assert(firstFeature.size(0) == secondFeature.size(0) && firstFeature.size(0) == labels.size(0) && firstFeature.size(0) == weights.size(0));
// Convert tensors to vectors for easier processing
auto firstFeatureData = firstFeature.accessor<int, 1>();
auto secondFeatureData = secondFeature.accessor<int, 1>();
auto labelsData = labels.accessor<int, 1>();
auto weightsData = weights.accessor<double, 1>();
int numSamples = firstFeature.size(0);
// Maps for joint and marginal probabilities
std::map<std::tuple<int, int, int>, double> jointCount;
std::map<std::tuple<int, int>, double> marginalCount;
// Compute joint and marginal counts
for (int i = 0; i < numSamples; ++i) {
auto keyJoint = std::make_tuple(firstFeatureData[i], labelsData[i], secondFeatureData[i]);
auto keyMarginal = std::make_tuple(firstFeatureData[i], labelsData[i]);
jointCount[keyJoint] += weightsData[i];
marginalCount[keyMarginal] += weightsData[i];
}
// Total weight sum
double totalWeight = torch::sum(weights).item<double>();
if (totalWeight == 0)
return 0;
// Compute the conditional entropy
double conditionalEntropy = 0.0;
for (const auto& [keyJoint, jointFreq] : jointCount) {
auto [x, c, y] = keyJoint;
auto keyMarginal = std::make_tuple(x, c);
double p_xc = marginalCount[keyMarginal] / totalWeight;
double p_y_given_xc = jointFreq / marginalCount[keyMarginal];
if (p_y_given_xc > 0) {
conditionalEntropy -= (jointFreq / totalWeight) * std::log(p_y_given_xc);
}
}
return conditionalEntropy;
}
// I(X;Y) = H(Y) - H(Y|X)
double Metrics::mutualInformation(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& weights)
{
return entropy(firstFeature, weights) - conditionalEntropy(firstFeature, secondFeature, weights);
}
// I(X;Y|C) = H(Y|C) - H(Y|X,C)
double Metrics::conditionalMutualInformation(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& labels, const torch::Tensor& weights)
{
return std::max(conditionalEntropy(firstFeature, labels, weights) - conditionalEntropy(firstFeature, secondFeature, labels, weights), 0.0);
}
/*
Compute the maximum spanning tree considering the weights as distances
and the indices of the weights as nodes of this square matrix using

View File

@@ -18,12 +18,16 @@ namespace bayesnet {
std::vector<int> SelectKBestWeighted(const torch::Tensor& weights, bool ascending = false, unsigned k = 0);
std::vector<double> getScoresKBest() const;
double mutualInformation(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& weights);
double conditionalMutualInformation(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& labels, const torch::Tensor& weights);
torch::Tensor conditionalEdge(const torch::Tensor& weights);
std::vector<std::pair<int, int>> maximumSpanningTree(const std::vector<std::string>& features, const torch::Tensor& weights, const int root);
// Measured in nats (natural logarithm (log) base e)
// Elements of Information Theory, 2nd Edition, Thomas M. Cover, Joy A. Thomas p. 14
double entropy(const torch::Tensor& feature, const torch::Tensor& weights);
double conditionalEntropy(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& labels, const torch::Tensor& weights);
protected:
torch::Tensor samples; // n+1xm torch::Tensor used to fit the model where samples[-1] is the y std::vector
std::string className;
double entropy(const torch::Tensor& feature, const torch::Tensor& weights);
std::vector<std::string> features;
template <class T>
std::vector<std::pair<T, T>> doCombinations(const std::vector<T>& source)

View File

@@ -9,6 +9,7 @@
#include <catch2/generators/catch_generators.hpp>
#include "bayesnet/utils/BayesMetrics.h"
#include "TestUtils.h"
#include "Timer.h"
TEST_CASE("Metrics Test", "[Metrics]")
@@ -83,4 +84,37 @@ TEST_CASE("Select all features ordered by Mutual Information", "[Metrics]")
auto kBest = metrics.SelectKBestWeighted(raw.weights, true, 0);
REQUIRE(kBest.size() == raw.features.size());
REQUIRE(kBest == std::vector<int>({ 1, 0, 3, 2 }));
}
TEST_CASE("Entropy Test", "[Metrics]")
{
auto raw = RawDatasets("iris", true);
bayesnet::Metrics metrics(raw.dataset, raw.features, raw.className, raw.classNumStates);
auto result = metrics.entropy(raw.dataset.index({ 0, "..." }), raw.weights);
REQUIRE(result == Catch::Approx(0.9848175048828125).epsilon(raw.epsilon));
auto data = torch::tensor({ 0, 0, 0, 0, 0, 0, 0, 1, 1, 1 }, torch::kInt32);
auto weights = torch::tensor({ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 }, torch::kFloat32);
result = metrics.entropy(data, weights);
REQUIRE(result == Catch::Approx(0.61086434125900269).epsilon(raw.epsilon));
data = torch::tensor({ 0, 0, 0, 0, 0, 1, 1, 1, 1, 1 }, torch::kInt32);
result = metrics.entropy(data, weights);
REQUIRE(result == Catch::Approx(0.693147180559945).epsilon(raw.epsilon));
}
TEST_CASE("Conditional Entropy", "[Metrics]")
{
auto raw = RawDatasets("iris", true);
bayesnet::Metrics metrics(raw.dataset, raw.features, raw.className, raw.classNumStates);
auto expected = std::map<std::pair<int, int>, double>{
{ { 0, 1 }, 0.0 },
{ { 0, 2 }, 0.287696 },
{ { 0, 3 }, 0.403749 },
{ { 1, 2 }, 1.17112 },
{ { 1, 3 }, 1.31852 },
{ { 2, 3 }, 0.210068 },
};
for (int i = 0; i < raw.features.size() - 1; ++i) {
for (int j = i + 1; j < raw.features.size(); ++j) {
double result = metrics.conditionalMutualInformation(raw.dataset.index({ i, "..." }), raw.dataset.index({ j, "..." }), raw.yt, raw.weights);
REQUIRE(result == Catch::Approx(expected.at({ i, j })).epsilon(raw.epsilon));
}
}
}

View File

@@ -45,7 +45,7 @@ TEST_CASE("Feature_select FCBF", "[BoostAODE]")
REQUIRE(clf.getNumberOfNodes() == 90);
REQUIRE(clf.getNumberOfEdges() == 153);
REQUIRE(clf.getNotes().size() == 2);
REQUIRE(clf.getNotes()[0] == "Used features in initialization: 5 of 9 with FCBF");
REQUIRE(clf.getNotes()[0] == "Used features in initialization: 4 of 9 with FCBF");
REQUIRE(clf.getNotes()[1] == "Number of models: 9");
}
TEST_CASE("Test used features in train note and score", "[BoostAODE]")
@@ -65,8 +65,8 @@ TEST_CASE("Test used features in train note and score", "[BoostAODE]")
REQUIRE(clf.getNotes()[1] == "Number of models: 8");
auto score = clf.score(raw.Xv, raw.yv);
auto scoret = clf.score(raw.Xt, raw.yt);
REQUIRE(score == Catch::Approx(0.80078).epsilon(raw.epsilon));
REQUIRE(scoret == Catch::Approx(0.80078).epsilon(raw.epsilon));
REQUIRE(score == Catch::Approx(0.809895813).epsilon(raw.epsilon));
REQUIRE(scoret == Catch::Approx(0.809895813).epsilon(raw.epsilon));
}
TEST_CASE("Voting vs proba", "[BoostAODE]")
{
@@ -149,15 +149,14 @@ TEST_CASE("Bisection Best", "[BoostAODE]")
{"convergence_best", false},
});
clf.fit(raw.X_train, raw.y_train, raw.features, raw.className, raw.states);
REQUIRE(clf.getNumberOfNodes() == 75);
REQUIRE(clf.getNumberOfEdges() == 135);
REQUIRE(clf.getNotes().size() == 2);
REQUIRE(clf.getNotes().at(0) == "Convergence threshold reached & 9 models eliminated");
REQUIRE(clf.getNotes().at(1) == "Number of models: 5");
REQUIRE(clf.getNumberOfNodes() == 210);
REQUIRE(clf.getNumberOfEdges() == 378);
REQUIRE(clf.getNotes().size() == 1);
REQUIRE(clf.getNotes().at(0) == "Number of models: 14");
auto score = clf.score(raw.X_test, raw.y_test);
auto scoret = clf.score(raw.X_test, raw.y_test);
REQUIRE(score == Catch::Approx(1.0f).epsilon(raw.epsilon));
REQUIRE(scoret == Catch::Approx(1.0f).epsilon(raw.epsilon));
REQUIRE(score == Catch::Approx(0.991666675f).epsilon(raw.epsilon));
REQUIRE(scoret == Catch::Approx(0.991666675f).epsilon(raw.epsilon));
}
TEST_CASE("Bisection Best vs Last", "[BoostAODE]")
{
@@ -172,13 +171,13 @@ TEST_CASE("Bisection Best vs Last", "[BoostAODE]")
clf.setHyperparameters(hyperparameters);
clf.fit(raw.X_train, raw.y_train, raw.features, raw.className, raw.states);
auto score_best = clf.score(raw.X_test, raw.y_test);
REQUIRE(score_best == Catch::Approx(0.993355f).epsilon(raw.epsilon));
REQUIRE(score_best == Catch::Approx(0.980000019f).epsilon(raw.epsilon));
// Now we will set the hyperparameter to use the last accuracy
hyperparameters["convergence_best"] = false;
clf.setHyperparameters(hyperparameters);
clf.fit(raw.X_train, raw.y_train, raw.features, raw.className, raw.states);
auto score_last = clf.score(raw.X_test, raw.y_test);
REQUIRE(score_last == Catch::Approx(0.996678f).epsilon(raw.epsilon));
REQUIRE(score_last == Catch::Approx(0.976666689f).epsilon(raw.epsilon));
}
TEST_CASE("Block Update", "[BoostAODE]")

41
tests/Timer.h Normal file
View File

@@ -0,0 +1,41 @@
#pragma once
#include <chrono>
#include <string>
#include <sstream>
namespace platform {
class Timer {
private:
std::chrono::high_resolution_clock::time_point begin;
std::chrono::high_resolution_clock::time_point end;
public:
Timer() = default;
~Timer() = default;
void start() { begin = std::chrono::high_resolution_clock::now(); }
void stop() { end = std::chrono::high_resolution_clock::now(); }
double getDuration()
{
stop();
std::chrono::duration<double> time_span = std::chrono::duration_cast<std::chrono::duration<double >> (end - begin);
return time_span.count();
}
double getLapse()
{
std::chrono::duration<double> time_span = std::chrono::duration_cast<std::chrono::duration<double >> (std::chrono::high_resolution_clock::now() - begin);
return time_span.count();
}
std::string getDurationString(bool lapse = false)
{
double duration = lapse ? getLapse() : getDuration();
return translate2String(duration);
}
std::string translate2String(double duration)
{
double durationShow = duration > 3600 ? duration / 3600 : duration > 60 ? duration / 60 : duration;
std::string durationUnit = duration > 3600 ? "h" : duration > 60 ? "m" : "s";
std::stringstream ss;
ss << std::setprecision(2) << std::fixed << durationShow << " " << durationUnit;
return ss.str();
}
};
} /* namespace platform */