From cccaa6e0af26571e6205c35bf9068c7089bb899d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Thu, 16 May 2024 13:46:38 +0200 Subject: [PATCH] Complete selectKPairs method & test --- bayesnet/utils/BayesMetrics.cc | 65 ++++++++++++++++++---------------- bayesnet/utils/BayesMetrics.h | 7 ++-- tests/TestBayesMetrics.cc | 54 ++++++++++++++++++++++++++++ tests/TestBayesModels.cc | 4 +-- tests/TestBoostAODE.cc | 1 + 5 files changed, 96 insertions(+), 35 deletions(-) diff --git a/bayesnet/utils/BayesMetrics.cc b/bayesnet/utils/BayesMetrics.cc index 43b0dec..6083662 100644 --- a/bayesnet/utils/BayesMetrics.cc +++ b/bayesnet/utils/BayesMetrics.cc @@ -34,38 +34,40 @@ namespace bayesnet { { // Return the K Best features auto n = features.size(); - if (k == 0) { - k = n; - } // compute scores scoresKPairs.clear(); pairsKBest.clear(); - auto label = samples.index({ -1, "..." }); - // for (int i = 0; i < n; ++i) { - // for (int j = i + 1; j < n; ++j) { - // scoresKBest.push_back(mutualInformation(samples.index({ i, "..." }), samples.index({ j, "..." }), weights)); - // featuresKBest.push_back(i); - // featuresKBest.push_back(j); - // } - // } - // // sort & reduce scores and features - // if (ascending) { - // sort(featuresKBest.begin(), featuresKBest.end(), [&](int i, int j) - // { return scoresKBest[i] < scoresKBest[j]; }); - // sort(scoresKBest.begin(), scoresKBest.end(), std::less()); - // if (k < n) { - // for (int i = 0; i < n - k; ++i) { - // featuresKBest.erase(featuresKBest.begin()); - // scoresKBest.erase(scoresKBest.begin()); - // } - // } - // } else { - // sort(featuresKBest.begin(), featuresKBest.end(), [&](int i, int j) - // { return scoresKBest[i] > scoresKBest[j]; }); - // sort(scoresKBest.begin(), scoresKBest.end(), std::greater()); - // featuresKBest.resize(k); - // scoresKBest.resize(k); - // } + auto labels = samples.index({ -1, "..." }); + for (int i = 0; i < n - 1; ++i) { + for (int j = i + 1; j < n; ++j) { + auto key = std::make_pair(i, j); + auto value = conditionalMutualInformation(samples.index({ i, "..." }), samples.index({ j, "..." }), labels, weights); + scoresKPairs.push_back({ key, value }); + } + } + // sort scores + if (ascending) { + sort(scoresKPairs.begin(), scoresKPairs.end(), [](auto& a, auto& b) + { return a.second < b.second; }); + + } else { + sort(scoresKPairs.begin(), scoresKPairs.end(), [](auto& a, auto& b) + { return a.second > b.second; }); + } + for (auto& [pairs, score] : scoresKPairs) { + pairsKBest.push_back(pairs); + } + if (k != 0) { + if (ascending) { + for (int i = 0; i < n - k; ++i) { + pairsKBest.erase(pairsKBest.begin()); + scoresKPairs.erase(scoresKPairs.begin()); + } + } else { + pairsKBest.resize(k); + scoresKPairs.resize(k); + } + } return pairsKBest; } std::vector Metrics::SelectKBestWeighted(const torch::Tensor& weights, bool ascending, unsigned k) @@ -107,7 +109,10 @@ namespace bayesnet { { return scoresKBest; } - + std::vector, double>> Metrics::getScoresKPairs() const + { + return scoresKPairs; + } torch::Tensor Metrics::conditionalEdge(const torch::Tensor& weights) { auto result = std::vector(); diff --git a/bayesnet/utils/BayesMetrics.h b/bayesnet/utils/BayesMetrics.h index f24a496..0e58b82 100644 --- a/bayesnet/utils/BayesMetrics.h +++ b/bayesnet/utils/BayesMetrics.h @@ -18,6 +18,7 @@ namespace bayesnet { std::vector SelectKBestWeighted(const torch::Tensor& weights, bool ascending = false, unsigned k = 0); std::vector> SelectKPairs(const torch::Tensor& weights, bool ascending = false, unsigned k = 0); std::vector getScoresKBest() const; + std::vector, double>> getScoresKPairs() const; double mutualInformation(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& weights); double conditionalMutualInformation(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& labels, const torch::Tensor& weights); torch::Tensor conditionalEdge(const torch::Tensor& weights); @@ -34,7 +35,7 @@ namespace bayesnet { std::vector> doCombinations(const std::vector& source) { std::vector> result; - for (int i = 0; i < source.size(); ++i) { + for (int i = 0; i < source.size() - 1; ++i) { T temp = source[i]; for (int j = i + 1; j < source.size(); ++j) { result.push_back({ temp, source[j] }); @@ -42,7 +43,7 @@ namespace bayesnet { } return result; } - template + template T pop_first(std::vector& v) { T temp = v[0]; @@ -54,7 +55,7 @@ namespace bayesnet { std::vector scoresKBest; std::vector featuresKBest; // sorted indices of the features std::vector> pairsKBest; // sorted indices of the pairs - std::map, double> scoresKPairs; + std::vector, double>> scoresKPairs; double conditionalEntropy(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& weights); }; } diff --git a/tests/TestBayesMetrics.cc b/tests/TestBayesMetrics.cc index 96fddd9..b40e3f2 100644 --- a/tests/TestBayesMetrics.cc +++ b/tests/TestBayesMetrics.cc @@ -136,4 +136,58 @@ TEST_CASE("Conditional Mutual Information", "[Metrics]") REQUIRE(result == Catch::Approx(expected.at({ i, j })).epsilon(raw.epsilon)); } } +} +TEST_CASE("Select K Pairs descending", "[Metrics]") +{ + auto raw = RawDatasets("iris", true); + bayesnet::Metrics metrics(raw.dataset, raw.features, raw.className, raw.classNumStates); + auto results = metrics.SelectKPairs(raw.weights, false); + auto expected = std::vector, double>>{ + { { 1, 3 }, 1.31852 }, + { { 1, 2 }, 1.17112 }, + { { 0, 3 }, 0.403749 }, + { { 0, 2 }, 0.287696 }, + { { 2, 3 }, 0.210068 }, + { { 0, 1 }, 0.0 }, + }; + auto scores = metrics.getScoresKPairs(); + for (int i = 0; i < results.size(); ++i) { + auto result = results[i]; + auto expect = expected[i]; + auto score = scores[i]; + REQUIRE(result.first == expect.first.first); + REQUIRE(result.second == expect.first.second); + REQUIRE(score.first.first == expect.first.first); + REQUIRE(score.first.second == expect.first.second); + REQUIRE(score.second == Catch::Approx(expect.second).epsilon(raw.epsilon)); + } + REQUIRE(results.size() == 6); + REQUIRE(scores.size() == 6); +} +TEST_CASE("Select K Pairs ascending", "[Metrics]") +{ + auto raw = RawDatasets("iris", true); + bayesnet::Metrics metrics(raw.dataset, raw.features, raw.className, raw.classNumStates); + auto results = metrics.SelectKPairs(raw.weights, true); + auto expected = std::vector, double>>{ + { { 0, 1 }, 0.0 }, + { { 2, 3 }, 0.210068 }, + { { 0, 2 }, 0.287696 }, + { { 0, 3 }, 0.403749 }, + { { 1, 2 }, 1.17112 }, + { { 1, 3 }, 1.31852 }, + }; + auto scores = metrics.getScoresKPairs(); + for (int i = 0; i < results.size(); ++i) { + auto result = results[i]; + auto expect = expected[i]; + auto score = scores[i]; + REQUIRE(result.first == expect.first.first); + REQUIRE(result.second == expect.first.second); + REQUIRE(score.first.first == expect.first.first); + REQUIRE(score.first.second == expect.first.second); + REQUIRE(score.second == Catch::Approx(expect.second).epsilon(raw.epsilon)); + } + REQUIRE(results.size() == 6); + REQUIRE(scores.size() == 6); } \ No newline at end of file diff --git a/tests/TestBayesModels.cc b/tests/TestBayesModels.cc index 5c97f60..b5ee426 100644 --- a/tests/TestBayesModels.cc +++ b/tests/TestBayesModels.cc @@ -56,14 +56,14 @@ TEST_CASE("Test Bayesian Classifiers score & version", "[Models]") auto raw = RawDatasets(file_name, discretize); clf->fit(raw.Xt, raw.yt, raw.features, raw.className, raw.states); auto score = clf->score(raw.Xt, raw.yt); - INFO("Classifier: " + name + " File: " + file_name); + INFO("Classifier: " << name << " File: " << file_name); REQUIRE(score == Catch::Approx(scores[{file_name, name}]).epsilon(raw.epsilon)); REQUIRE(clf->getStatus() == bayesnet::NORMAL); } } SECTION("Library check version") { - INFO("Checking version of " + name + " classifier"); + INFO("Checking version of " << name << " classifier"); REQUIRE(clf->getVersion() == ACTUAL_VERSION); } delete clf; diff --git a/tests/TestBoostAODE.cc b/tests/TestBoostAODE.cc index b434055..66fa7fb 100644 --- a/tests/TestBoostAODE.cc +++ b/tests/TestBoostAODE.cc @@ -8,6 +8,7 @@ #include #include #include +#include #include "bayesnet/ensembles/BoostAODE.h" #include "TestUtils.h"