2024-04-11 16:02:49 +00:00
|
|
|
// ***************************************************************
|
|
|
|
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
|
|
|
|
// SPDX-FileType: SOURCE
|
|
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
// ***************************************************************
|
|
|
|
|
2023-10-04 21:19:23 +00:00
|
|
|
#include <catch2/catch_test_macros.hpp>
|
|
|
|
#include <catch2/catch_approx.hpp>
|
|
|
|
#include <catch2/generators/catch_generators.hpp>
|
2024-03-08 21:20:54 +00:00
|
|
|
#include "bayesnet/utils/BayesMetrics.h"
|
2023-10-04 23:14:16 +00:00
|
|
|
#include "TestUtils.h"
|
2024-05-14 23:24:27 +00:00
|
|
|
#include "Timer.h"
|
2023-10-04 21:19:23 +00:00
|
|
|
|
2024-04-02 15:38:48 +00:00
|
|
|
TEST_CASE("Metrics Test", "[Metrics]")
|
2023-10-04 21:19:23 +00:00
|
|
|
{
|
2023-11-08 17:45:35 +00:00
|
|
|
std::string file_name = GENERATE("glass", "iris", "ecoli", "diabetes");
|
|
|
|
map<std::string, pair<int, std::vector<int>>> resultsKBest = {
|
2023-10-05 13:45:36 +00:00
|
|
|
{"glass", {7, { 0, 1, 7, 6, 3, 5, 2 }}},
|
|
|
|
{"iris", {3, { 0, 3, 2 }} },
|
|
|
|
{"ecoli", {6, { 2, 4, 1, 0, 6, 5 }}},
|
|
|
|
{"diabetes", {2, { 7, 1 }}}
|
|
|
|
};
|
2023-11-08 17:45:35 +00:00
|
|
|
map<std::string, double> resultsMI = {
|
2023-10-05 13:45:36 +00:00
|
|
|
{"glass", 0.12805398},
|
|
|
|
{"iris", 0.3158139948},
|
|
|
|
{"ecoli", 0.0089431099},
|
|
|
|
{"diabetes", 0.0345470614}
|
|
|
|
};
|
2023-11-08 17:45:35 +00:00
|
|
|
map<pair<std::string, int>, std::vector<pair<int, int>>> resultsMST = {
|
2024-05-17 09:15:45 +00:00
|
|
|
{ {"glass", 0}, { {0, 6}, {0, 5}, {0, 3}, {3, 4}, {5, 1}, {5, 8}, {6, 2}, {6, 7} } },
|
|
|
|
{ {"glass", 1}, { {1, 5}, {5, 0}, {5, 8}, {0, 6}, {0, 3}, {3, 4}, {6, 2}, {6, 7} } },
|
2023-10-07 17:08:13 +00:00
|
|
|
{ {"iris", 0}, { {0, 1}, {0, 2}, {1, 3} } },
|
|
|
|
{ {"iris", 1}, { {1, 0}, {1, 3}, {0, 2} } },
|
|
|
|
{ {"ecoli", 0}, { {0, 1}, {0, 2}, {1, 5}, {1, 3}, {5, 6}, {5, 4} } },
|
|
|
|
{ {"ecoli", 1}, { {1, 0}, {1, 5}, {1, 3}, {5, 6}, {5, 4}, {0, 2} } },
|
|
|
|
{ {"diabetes", 0}, { {0, 7}, {0, 2}, {0, 6}, {2, 3}, {3, 4}, {3, 5}, {4, 1} } },
|
|
|
|
{ {"diabetes", 1}, { {1, 4}, {4, 3}, {3, 2}, {3, 5}, {2, 0}, {0, 7}, {0, 6} } }
|
2023-10-04 23:14:16 +00:00
|
|
|
};
|
2023-10-06 15:08:54 +00:00
|
|
|
auto raw = RawDatasets(file_name, true);
|
2024-04-29 22:52:09 +00:00
|
|
|
bayesnet::Metrics metrics(raw.dataset, raw.features, raw.className, raw.classNumStates);
|
|
|
|
bayesnet::Metrics metricsv(raw.Xv, raw.yv, raw.features, raw.className, raw.classNumStates);
|
2023-10-04 23:14:16 +00:00
|
|
|
|
2023-10-04 21:19:23 +00:00
|
|
|
SECTION("Test Constructor")
|
|
|
|
{
|
2023-10-04 23:14:16 +00:00
|
|
|
REQUIRE(metrics.getScoresKBest().size() == 0);
|
2024-04-07 22:13:59 +00:00
|
|
|
REQUIRE(metricsv.getScoresKBest().size() == 0);
|
2023-10-04 21:19:23 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
SECTION("Test SelectKBestWeighted")
|
|
|
|
{
|
2023-11-08 17:45:35 +00:00
|
|
|
std::vector<int> kBest = metrics.SelectKBestWeighted(raw.weights, true, resultsKBest.at(file_name).first);
|
2024-04-07 22:13:59 +00:00
|
|
|
std::vector<int> kBestv = metricsv.SelectKBestWeighted(raw.weights, true, resultsKBest.at(file_name).first);
|
2023-10-05 09:45:00 +00:00
|
|
|
REQUIRE(kBest.size() == resultsKBest.at(file_name).first);
|
2024-04-07 22:13:59 +00:00
|
|
|
REQUIRE(kBestv.size() == resultsKBest.at(file_name).first);
|
2023-10-05 09:45:00 +00:00
|
|
|
REQUIRE(kBest == resultsKBest.at(file_name).second);
|
2024-04-07 22:13:59 +00:00
|
|
|
REQUIRE(kBestv == resultsKBest.at(file_name).second);
|
2023-10-04 21:19:23 +00:00
|
|
|
}
|
|
|
|
|
2023-10-05 13:45:36 +00:00
|
|
|
SECTION("Test Mutual Information")
|
2023-10-04 21:19:23 +00:00
|
|
|
{
|
2023-10-06 15:08:54 +00:00
|
|
|
auto result = metrics.mutualInformation(raw.dataset.index({ 1, "..." }), raw.dataset.index({ 2, "..." }), raw.weights);
|
2024-04-07 22:13:59 +00:00
|
|
|
auto resultv = metricsv.mutualInformation(raw.dataset.index({ 1, "..." }), raw.dataset.index({ 2, "..." }), raw.weights);
|
2023-10-06 15:08:54 +00:00
|
|
|
REQUIRE(result == Catch::Approx(resultsMI.at(file_name)).epsilon(raw.epsilon));
|
2024-04-07 22:13:59 +00:00
|
|
|
REQUIRE(resultv == Catch::Approx(resultsMI.at(file_name)).epsilon(raw.epsilon));
|
2023-10-05 13:45:36 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
SECTION("Test Maximum Spanning Tree")
|
|
|
|
{
|
2023-10-06 15:08:54 +00:00
|
|
|
auto weights_matrix = metrics.conditionalEdge(raw.weights);
|
2024-04-07 22:13:59 +00:00
|
|
|
auto weights_matrixv = metricsv.conditionalEdge(raw.weights);
|
2023-10-07 17:08:13 +00:00
|
|
|
for (int i = 0; i < 2; ++i) {
|
2024-04-29 22:52:09 +00:00
|
|
|
auto result = metrics.maximumSpanningTree(raw.features, weights_matrix, i);
|
|
|
|
auto resultv = metricsv.maximumSpanningTree(raw.features, weights_matrixv, i);
|
2023-10-07 17:08:13 +00:00
|
|
|
REQUIRE(result == resultsMST.at({ file_name, i }));
|
2024-04-07 22:13:59 +00:00
|
|
|
REQUIRE(resultv == resultsMST.at({ file_name, i }));
|
2023-10-06 23:43:26 +00:00
|
|
|
}
|
2023-10-04 21:19:23 +00:00
|
|
|
}
|
2024-05-06 15:56:00 +00:00
|
|
|
}
|
|
|
|
TEST_CASE("Select all features ordered by Mutual Information", "[Metrics]")
|
|
|
|
{
|
|
|
|
auto raw = RawDatasets("iris", true);
|
|
|
|
bayesnet::Metrics metrics(raw.dataset, raw.features, raw.className, raw.classNumStates);
|
|
|
|
auto kBest = metrics.SelectKBestWeighted(raw.weights, true, 0);
|
|
|
|
REQUIRE(kBest.size() == raw.features.size());
|
|
|
|
REQUIRE(kBest == std::vector<int>({ 1, 0, 3, 2 }));
|
2024-05-14 22:48:02 +00:00
|
|
|
}
|
|
|
|
TEST_CASE("Entropy Test", "[Metrics]")
|
|
|
|
{
|
|
|
|
auto raw = RawDatasets("iris", true);
|
|
|
|
bayesnet::Metrics metrics(raw.dataset, raw.features, raw.className, raw.classNumStates);
|
|
|
|
auto result = metrics.entropy(raw.dataset.index({ 0, "..." }), raw.weights);
|
|
|
|
REQUIRE(result == Catch::Approx(0.9848175048828125).epsilon(raw.epsilon));
|
|
|
|
auto data = torch::tensor({ 0, 0, 0, 0, 0, 0, 0, 1, 1, 1 }, torch::kInt32);
|
|
|
|
auto weights = torch::tensor({ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 }, torch::kFloat32);
|
|
|
|
result = metrics.entropy(data, weights);
|
|
|
|
REQUIRE(result == Catch::Approx(0.61086434125900269).epsilon(raw.epsilon));
|
|
|
|
data = torch::tensor({ 0, 0, 0, 0, 0, 1, 1, 1, 1, 1 }, torch::kInt32);
|
|
|
|
result = metrics.entropy(data, weights);
|
|
|
|
REQUIRE(result == Catch::Approx(0.693147180559945).epsilon(raw.epsilon));
|
|
|
|
}
|
|
|
|
TEST_CASE("Conditional Entropy", "[Metrics]")
|
2024-05-15 09:28:09 +00:00
|
|
|
{
|
|
|
|
auto raw = RawDatasets("iris", true);
|
|
|
|
bayesnet::Metrics metrics(raw.dataset, raw.features, raw.className, raw.classNumStates);
|
|
|
|
auto expected = std::map<std::pair<int, int>, double>{
|
|
|
|
{ { 0, 1 }, 1.32674 },
|
|
|
|
{ { 0, 2 }, 0.236253 },
|
|
|
|
{ { 0, 3 }, 0.1202 },
|
|
|
|
{ { 1, 2 }, 0.252551 },
|
|
|
|
{ { 1, 3 }, 0.10515 },
|
|
|
|
{ { 2, 3 }, 0.108323 },
|
|
|
|
};
|
|
|
|
for (int i = 0; i < raw.features.size() - 1; ++i) {
|
|
|
|
for (int j = i + 1; j < raw.features.size(); ++j) {
|
|
|
|
double result = metrics.conditionalEntropy(raw.dataset.index({ i, "..." }), raw.dataset.index({ j, "..." }), raw.yt, raw.weights);
|
|
|
|
REQUIRE(result == Catch::Approx(expected.at({ i, j })).epsilon(raw.epsilon));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
TEST_CASE("Conditional Mutual Information", "[Metrics]")
|
2024-05-14 22:48:02 +00:00
|
|
|
{
|
2024-05-15 09:09:23 +00:00
|
|
|
auto raw = RawDatasets("iris", true);
|
2024-05-14 22:48:02 +00:00
|
|
|
bayesnet::Metrics metrics(raw.dataset, raw.features, raw.className, raw.classNumStates);
|
2024-05-15 09:09:23 +00:00
|
|
|
auto expected = std::map<std::pair<int, int>, double>{
|
|
|
|
{ { 0, 1 }, 0.0 },
|
|
|
|
{ { 0, 2 }, 0.287696 },
|
|
|
|
{ { 0, 3 }, 0.403749 },
|
|
|
|
{ { 1, 2 }, 1.17112 },
|
|
|
|
{ { 1, 3 }, 1.31852 },
|
|
|
|
{ { 2, 3 }, 0.210068 },
|
|
|
|
};
|
2024-05-14 23:24:27 +00:00
|
|
|
for (int i = 0; i < raw.features.size() - 1; ++i) {
|
|
|
|
for (int j = i + 1; j < raw.features.size(); ++j) {
|
2024-05-15 09:09:23 +00:00
|
|
|
double result = metrics.conditionalMutualInformation(raw.dataset.index({ i, "..." }), raw.dataset.index({ j, "..." }), raw.yt, raw.weights);
|
|
|
|
REQUIRE(result == Catch::Approx(expected.at({ i, j })).epsilon(raw.epsilon));
|
2024-05-14 23:24:27 +00:00
|
|
|
}
|
|
|
|
}
|
2024-05-16 11:46:38 +00:00
|
|
|
}
|
|
|
|
TEST_CASE("Select K Pairs descending", "[Metrics]")
|
|
|
|
{
|
|
|
|
auto raw = RawDatasets("iris", true);
|
|
|
|
bayesnet::Metrics metrics(raw.dataset, raw.features, raw.className, raw.classNumStates);
|
2024-05-16 12:18:45 +00:00
|
|
|
std::vector<int> empty;
|
|
|
|
auto results = metrics.SelectKPairs(raw.weights, empty, false);
|
2024-05-16 11:46:38 +00:00
|
|
|
auto expected = std::vector<std::pair<std::pair<int, int>, double>>{
|
|
|
|
{ { 1, 3 }, 1.31852 },
|
|
|
|
{ { 1, 2 }, 1.17112 },
|
|
|
|
{ { 0, 3 }, 0.403749 },
|
|
|
|
{ { 0, 2 }, 0.287696 },
|
|
|
|
{ { 2, 3 }, 0.210068 },
|
|
|
|
{ { 0, 1 }, 0.0 },
|
|
|
|
};
|
|
|
|
auto scores = metrics.getScoresKPairs();
|
|
|
|
for (int i = 0; i < results.size(); ++i) {
|
|
|
|
auto result = results[i];
|
|
|
|
auto expect = expected[i];
|
|
|
|
auto score = scores[i];
|
|
|
|
REQUIRE(result.first == expect.first.first);
|
|
|
|
REQUIRE(result.second == expect.first.second);
|
|
|
|
REQUIRE(score.first.first == expect.first.first);
|
|
|
|
REQUIRE(score.first.second == expect.first.second);
|
|
|
|
REQUIRE(score.second == Catch::Approx(expect.second).epsilon(raw.epsilon));
|
|
|
|
}
|
|
|
|
REQUIRE(results.size() == 6);
|
|
|
|
REQUIRE(scores.size() == 6);
|
|
|
|
}
|
|
|
|
TEST_CASE("Select K Pairs ascending", "[Metrics]")
|
|
|
|
{
|
|
|
|
auto raw = RawDatasets("iris", true);
|
|
|
|
bayesnet::Metrics metrics(raw.dataset, raw.features, raw.className, raw.classNumStates);
|
2024-05-16 12:18:45 +00:00
|
|
|
std::vector<int> empty;
|
|
|
|
auto results = metrics.SelectKPairs(raw.weights, empty, true);
|
2024-05-16 11:46:38 +00:00
|
|
|
auto expected = std::vector<std::pair<std::pair<int, int>, double>>{
|
|
|
|
{ { 0, 1 }, 0.0 },
|
|
|
|
{ { 2, 3 }, 0.210068 },
|
|
|
|
{ { 0, 2 }, 0.287696 },
|
|
|
|
{ { 0, 3 }, 0.403749 },
|
|
|
|
{ { 1, 2 }, 1.17112 },
|
|
|
|
{ { 1, 3 }, 1.31852 },
|
|
|
|
};
|
|
|
|
auto scores = metrics.getScoresKPairs();
|
|
|
|
for (int i = 0; i < results.size(); ++i) {
|
|
|
|
auto result = results[i];
|
|
|
|
auto expect = expected[i];
|
|
|
|
auto score = scores[i];
|
|
|
|
REQUIRE(result.first == expect.first.first);
|
|
|
|
REQUIRE(result.second == expect.first.second);
|
|
|
|
REQUIRE(score.first.first == expect.first.first);
|
|
|
|
REQUIRE(score.first.second == expect.first.second);
|
|
|
|
REQUIRE(score.second == Catch::Approx(expect.second).epsilon(raw.epsilon));
|
|
|
|
}
|
|
|
|
REQUIRE(results.size() == 6);
|
|
|
|
REQUIRE(scores.size() == 6);
|
2024-05-16 12:18:45 +00:00
|
|
|
}
|
|
|
|
TEST_CASE("Select K Pairs with features excluded", "[Metrics]")
|
|
|
|
{
|
|
|
|
auto raw = RawDatasets("iris", true);
|
|
|
|
bayesnet::Metrics metrics(raw.dataset, raw.features, raw.className, raw.classNumStates);
|
|
|
|
std::vector<int> excluded = { 0, 3 };
|
|
|
|
auto results = metrics.SelectKPairs(raw.weights, excluded, true);
|
|
|
|
auto expected = std::vector<std::pair<std::pair<int, int>, double>>{
|
|
|
|
{ { 1, 2 }, 1.17112 },
|
|
|
|
};
|
|
|
|
auto scores = metrics.getScoresKPairs();
|
|
|
|
for (int i = 0; i < results.size(); ++i) {
|
|
|
|
auto result = results[i];
|
|
|
|
auto expect = expected[i];
|
|
|
|
auto score = scores[i];
|
|
|
|
REQUIRE(result.first == expect.first.first);
|
|
|
|
REQUIRE(result.second == expect.first.second);
|
|
|
|
REQUIRE(score.first.first == expect.first.first);
|
|
|
|
REQUIRE(score.first.second == expect.first.second);
|
|
|
|
REQUIRE(score.second == Catch::Approx(expect.second).epsilon(raw.epsilon));
|
|
|
|
}
|
|
|
|
REQUIRE(results.size() == 1);
|
|
|
|
REQUIRE(scores.size() == 1);
|
|
|
|
}
|
|
|
|
TEST_CASE("Select K Pairs with number of pairs descending", "[Metrics]")
|
|
|
|
{
|
|
|
|
auto raw = RawDatasets("iris", true);
|
|
|
|
bayesnet::Metrics metrics(raw.dataset, raw.features, raw.className, raw.classNumStates);
|
|
|
|
std::vector<int> empty;
|
|
|
|
auto results = metrics.SelectKPairs(raw.weights, empty, false, 3);
|
|
|
|
auto expected = std::vector<std::pair<std::pair<int, int>, double>>{
|
|
|
|
{ { 1, 3 }, 1.31852 },
|
|
|
|
{ { 1, 2 }, 1.17112 },
|
|
|
|
{ { 0, 3 }, 0.403749 }
|
|
|
|
};
|
|
|
|
auto scores = metrics.getScoresKPairs();
|
|
|
|
REQUIRE(results.size() == 3);
|
|
|
|
REQUIRE(scores.size() == 3);
|
|
|
|
for (int i = 0; i < results.size(); ++i) {
|
|
|
|
auto result = results[i];
|
|
|
|
auto expect = expected[i];
|
|
|
|
auto score = scores[i];
|
|
|
|
REQUIRE(result.first == expect.first.first);
|
|
|
|
REQUIRE(result.second == expect.first.second);
|
|
|
|
REQUIRE(score.first.first == expect.first.first);
|
|
|
|
REQUIRE(score.first.second == expect.first.second);
|
|
|
|
REQUIRE(score.second == Catch::Approx(expect.second).epsilon(raw.epsilon));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
TEST_CASE("Select K Pairs with number of pairs ascending", "[Metrics]")
|
|
|
|
{
|
|
|
|
auto raw = RawDatasets("iris", true);
|
|
|
|
bayesnet::Metrics metrics(raw.dataset, raw.features, raw.className, raw.classNumStates);
|
|
|
|
std::vector<int> empty;
|
|
|
|
auto results = metrics.SelectKPairs(raw.weights, empty, true, 3);
|
|
|
|
auto expected = std::vector<std::pair<std::pair<int, int>, double>>{
|
|
|
|
{ { 0, 3 }, 0.403749 },
|
|
|
|
{ { 1, 2 }, 1.17112 },
|
|
|
|
{ { 1, 3 }, 1.31852 }
|
|
|
|
};
|
|
|
|
auto scores = metrics.getScoresKPairs();
|
|
|
|
REQUIRE(results.size() == 3);
|
|
|
|
REQUIRE(scores.size() == 3);
|
|
|
|
for (int i = 0; i < results.size(); ++i) {
|
|
|
|
auto result = results[i];
|
|
|
|
auto expect = expected[i];
|
|
|
|
auto score = scores[i];
|
|
|
|
REQUIRE(result.first == expect.first.first);
|
|
|
|
REQUIRE(result.second == expect.first.second);
|
|
|
|
REQUIRE(score.first.first == expect.first.first);
|
|
|
|
REQUIRE(score.first.second == expect.first.second);
|
|
|
|
REQUIRE(score.second == Catch::Approx(expect.second).epsilon(raw.epsilon));
|
|
|
|
}
|
2023-10-04 21:19:23 +00:00
|
|
|
}
|