BayesNet/tests/TestBayesMetrics.cc

62 lines
2.5 KiB
C++
Raw Normal View History

#include <catch2/catch_test_macros.hpp>
#include <catch2/catch_approx.hpp>
#include <catch2/generators/catch_generators.hpp>
#include "BayesMetrics.h"
2023-10-04 23:14:16 +00:00
#include "TestUtils.h"
2023-10-05 13:45:36 +00:00
TEST_CASE("Metrics Test", "[BayesNet]")
{
2023-11-08 17:45:35 +00:00
std::string file_name = GENERATE("glass", "iris", "ecoli", "diabetes");
map<std::string, pair<int, std::vector<int>>> resultsKBest = {
2023-10-05 13:45:36 +00:00
{"glass", {7, { 0, 1, 7, 6, 3, 5, 2 }}},
{"iris", {3, { 0, 3, 2 }} },
{"ecoli", {6, { 2, 4, 1, 0, 6, 5 }}},
{"diabetes", {2, { 7, 1 }}}
};
2023-11-08 17:45:35 +00:00
map<std::string, double> resultsMI = {
2023-10-05 13:45:36 +00:00
{"glass", 0.12805398},
{"iris", 0.3158139948},
{"ecoli", 0.0089431099},
{"diabetes", 0.0345470614}
};
2023-11-08 17:45:35 +00:00
map<pair<std::string, int>, std::vector<pair<int, int>>> resultsMST = {
2023-10-07 17:08:13 +00:00
{ {"glass", 0}, { {0, 6}, {0, 5}, {0, 3}, {5, 1}, {5, 8}, {5, 4}, {6, 2}, {6, 7} } },
{ {"glass", 1}, { {1, 5}, {5, 0}, {5, 8}, {5, 4}, {0, 6}, {0, 3}, {6, 2}, {6, 7} } },
{ {"iris", 0}, { {0, 1}, {0, 2}, {1, 3} } },
{ {"iris", 1}, { {1, 0}, {1, 3}, {0, 2} } },
{ {"ecoli", 0}, { {0, 1}, {0, 2}, {1, 5}, {1, 3}, {5, 6}, {5, 4} } },
{ {"ecoli", 1}, { {1, 0}, {1, 5}, {1, 3}, {5, 6}, {5, 4}, {0, 2} } },
{ {"diabetes", 0}, { {0, 7}, {0, 2}, {0, 6}, {2, 3}, {3, 4}, {3, 5}, {4, 1} } },
{ {"diabetes", 1}, { {1, 4}, {4, 3}, {3, 2}, {3, 5}, {2, 0}, {0, 7}, {0, 6} } }
2023-10-04 23:14:16 +00:00
};
2023-10-06 15:08:54 +00:00
auto raw = RawDatasets(file_name, true);
bayesnet::Metrics metrics(raw.dataset, raw.featurest, raw.classNamet, raw.classNumStates);
2023-10-04 23:14:16 +00:00
SECTION("Test Constructor")
{
2023-10-04 23:14:16 +00:00
REQUIRE(metrics.getScoresKBest().size() == 0);
}
SECTION("Test SelectKBestWeighted")
{
2023-11-08 17:45:35 +00:00
std::vector<int> kBest = metrics.SelectKBestWeighted(raw.weights, true, resultsKBest.at(file_name).first);
2023-10-05 09:45:00 +00:00
REQUIRE(kBest.size() == resultsKBest.at(file_name).first);
REQUIRE(kBest == resultsKBest.at(file_name).second);
}
2023-10-05 13:45:36 +00:00
SECTION("Test Mutual Information")
{
2023-10-06 15:08:54 +00:00
auto result = metrics.mutualInformation(raw.dataset.index({ 1, "..." }), raw.dataset.index({ 2, "..." }), raw.weights);
REQUIRE(result == Catch::Approx(resultsMI.at(file_name)).epsilon(raw.epsilon));
2023-10-05 13:45:36 +00:00
}
SECTION("Test Maximum Spanning Tree")
{
2023-10-06 15:08:54 +00:00
auto weights_matrix = metrics.conditionalEdge(raw.weights);
2023-10-07 17:08:13 +00:00
for (int i = 0; i < 2; ++i) {
auto result = metrics.maximumSpanningTree(raw.featurest, weights_matrix, i);
REQUIRE(result == resultsMST.at({ file_name, i }));
2023-10-06 23:43:26 +00:00
}
}
}