Complete selectKPairs method & test
This commit is contained in:
@@ -136,4 +136,58 @@ TEST_CASE("Conditional Mutual Information", "[Metrics]")
|
||||
REQUIRE(result == Catch::Approx(expected.at({ i, j })).epsilon(raw.epsilon));
|
||||
}
|
||||
}
|
||||
}
|
||||
TEST_CASE("Select K Pairs descending", "[Metrics]")
|
||||
{
|
||||
auto raw = RawDatasets("iris", true);
|
||||
bayesnet::Metrics metrics(raw.dataset, raw.features, raw.className, raw.classNumStates);
|
||||
auto results = metrics.SelectKPairs(raw.weights, false);
|
||||
auto expected = std::vector<std::pair<std::pair<int, int>, double>>{
|
||||
{ { 1, 3 }, 1.31852 },
|
||||
{ { 1, 2 }, 1.17112 },
|
||||
{ { 0, 3 }, 0.403749 },
|
||||
{ { 0, 2 }, 0.287696 },
|
||||
{ { 2, 3 }, 0.210068 },
|
||||
{ { 0, 1 }, 0.0 },
|
||||
};
|
||||
auto scores = metrics.getScoresKPairs();
|
||||
for (int i = 0; i < results.size(); ++i) {
|
||||
auto result = results[i];
|
||||
auto expect = expected[i];
|
||||
auto score = scores[i];
|
||||
REQUIRE(result.first == expect.first.first);
|
||||
REQUIRE(result.second == expect.first.second);
|
||||
REQUIRE(score.first.first == expect.first.first);
|
||||
REQUIRE(score.first.second == expect.first.second);
|
||||
REQUIRE(score.second == Catch::Approx(expect.second).epsilon(raw.epsilon));
|
||||
}
|
||||
REQUIRE(results.size() == 6);
|
||||
REQUIRE(scores.size() == 6);
|
||||
}
|
||||
TEST_CASE("Select K Pairs ascending", "[Metrics]")
|
||||
{
|
||||
auto raw = RawDatasets("iris", true);
|
||||
bayesnet::Metrics metrics(raw.dataset, raw.features, raw.className, raw.classNumStates);
|
||||
auto results = metrics.SelectKPairs(raw.weights, true);
|
||||
auto expected = std::vector<std::pair<std::pair<int, int>, double>>{
|
||||
{ { 0, 1 }, 0.0 },
|
||||
{ { 2, 3 }, 0.210068 },
|
||||
{ { 0, 2 }, 0.287696 },
|
||||
{ { 0, 3 }, 0.403749 },
|
||||
{ { 1, 2 }, 1.17112 },
|
||||
{ { 1, 3 }, 1.31852 },
|
||||
};
|
||||
auto scores = metrics.getScoresKPairs();
|
||||
for (int i = 0; i < results.size(); ++i) {
|
||||
auto result = results[i];
|
||||
auto expect = expected[i];
|
||||
auto score = scores[i];
|
||||
REQUIRE(result.first == expect.first.first);
|
||||
REQUIRE(result.second == expect.first.second);
|
||||
REQUIRE(score.first.first == expect.first.first);
|
||||
REQUIRE(score.first.second == expect.first.second);
|
||||
REQUIRE(score.second == Catch::Approx(expect.second).epsilon(raw.epsilon));
|
||||
}
|
||||
REQUIRE(results.size() == 6);
|
||||
REQUIRE(scores.size() == 6);
|
||||
}
|
@@ -56,14 +56,14 @@ TEST_CASE("Test Bayesian Classifiers score & version", "[Models]")
|
||||
auto raw = RawDatasets(file_name, discretize);
|
||||
clf->fit(raw.Xt, raw.yt, raw.features, raw.className, raw.states);
|
||||
auto score = clf->score(raw.Xt, raw.yt);
|
||||
INFO("Classifier: " + name + " File: " + file_name);
|
||||
INFO("Classifier: " << name << " File: " << file_name);
|
||||
REQUIRE(score == Catch::Approx(scores[{file_name, name}]).epsilon(raw.epsilon));
|
||||
REQUIRE(clf->getStatus() == bayesnet::NORMAL);
|
||||
}
|
||||
}
|
||||
SECTION("Library check version")
|
||||
{
|
||||
INFO("Checking version of " + name + " classifier");
|
||||
INFO("Checking version of " << name << " classifier");
|
||||
REQUIRE(clf->getVersion() == ACTUAL_VERSION);
|
||||
}
|
||||
delete clf;
|
||||
|
@@ -8,6 +8,7 @@
|
||||
#include <catch2/catch_test_macros.hpp>
|
||||
#include <catch2/catch_approx.hpp>
|
||||
#include <catch2/generators/catch_generators.hpp>
|
||||
#include <catch2/matchers/catch_matchers.hpp>
|
||||
#include "bayesnet/ensembles/BoostAODE.h"
|
||||
#include "TestUtils.h"
|
||||
|
||||
|
Reference in New Issue
Block a user