Complete selectKPairs method & test

This commit is contained in:
2024-05-16 13:46:38 +02:00
parent 2e3e0e0fc2
commit cccaa6e0af
5 changed files with 96 additions and 35 deletions

View File

@@ -136,4 +136,58 @@ TEST_CASE("Conditional Mutual Information", "[Metrics]")
REQUIRE(result == Catch::Approx(expected.at({ i, j })).epsilon(raw.epsilon));
}
}
}
TEST_CASE("Select K Pairs descending", "[Metrics]")
{
auto raw = RawDatasets("iris", true);
bayesnet::Metrics metrics(raw.dataset, raw.features, raw.className, raw.classNumStates);
auto results = metrics.SelectKPairs(raw.weights, false);
auto expected = std::vector<std::pair<std::pair<int, int>, double>>{
{ { 1, 3 }, 1.31852 },
{ { 1, 2 }, 1.17112 },
{ { 0, 3 }, 0.403749 },
{ { 0, 2 }, 0.287696 },
{ { 2, 3 }, 0.210068 },
{ { 0, 1 }, 0.0 },
};
auto scores = metrics.getScoresKPairs();
for (int i = 0; i < results.size(); ++i) {
auto result = results[i];
auto expect = expected[i];
auto score = scores[i];
REQUIRE(result.first == expect.first.first);
REQUIRE(result.second == expect.first.second);
REQUIRE(score.first.first == expect.first.first);
REQUIRE(score.first.second == expect.first.second);
REQUIRE(score.second == Catch::Approx(expect.second).epsilon(raw.epsilon));
}
REQUIRE(results.size() == 6);
REQUIRE(scores.size() == 6);
}
TEST_CASE("Select K Pairs ascending", "[Metrics]")
{
auto raw = RawDatasets("iris", true);
bayesnet::Metrics metrics(raw.dataset, raw.features, raw.className, raw.classNumStates);
auto results = metrics.SelectKPairs(raw.weights, true);
auto expected = std::vector<std::pair<std::pair<int, int>, double>>{
{ { 0, 1 }, 0.0 },
{ { 2, 3 }, 0.210068 },
{ { 0, 2 }, 0.287696 },
{ { 0, 3 }, 0.403749 },
{ { 1, 2 }, 1.17112 },
{ { 1, 3 }, 1.31852 },
};
auto scores = metrics.getScoresKPairs();
for (int i = 0; i < results.size(); ++i) {
auto result = results[i];
auto expect = expected[i];
auto score = scores[i];
REQUIRE(result.first == expect.first.first);
REQUIRE(result.second == expect.first.second);
REQUIRE(score.first.first == expect.first.first);
REQUIRE(score.first.second == expect.first.second);
REQUIRE(score.second == Catch::Approx(expect.second).epsilon(raw.epsilon));
}
REQUIRE(results.size() == 6);
REQUIRE(scores.size() == 6);
}

View File

@@ -56,14 +56,14 @@ TEST_CASE("Test Bayesian Classifiers score & version", "[Models]")
auto raw = RawDatasets(file_name, discretize);
clf->fit(raw.Xt, raw.yt, raw.features, raw.className, raw.states);
auto score = clf->score(raw.Xt, raw.yt);
INFO("Classifier: " + name + " File: " + file_name);
INFO("Classifier: " << name << " File: " << file_name);
REQUIRE(score == Catch::Approx(scores[{file_name, name}]).epsilon(raw.epsilon));
REQUIRE(clf->getStatus() == bayesnet::NORMAL);
}
}
SECTION("Library check version")
{
INFO("Checking version of " + name + " classifier");
INFO("Checking version of " << name << " classifier");
REQUIRE(clf->getVersion() == ACTUAL_VERSION);
}
delete clf;

View File

@@ -8,6 +8,7 @@
#include <catch2/catch_test_macros.hpp>
#include <catch2/catch_approx.hpp>
#include <catch2/generators/catch_generators.hpp>
#include <catch2/matchers/catch_matchers.hpp>
#include "bayesnet/ensembles/BoostAODE.h"
#include "TestUtils.h"