Add parent hyperparameter to TAN & SPODE

This commit is contained in:
2024-12-17 10:14:14 +01:00
parent 56a2d3ead0
commit e2781ee525
7 changed files with 78 additions and 11 deletions

View File

@@ -267,4 +267,36 @@ TEST_CASE("Predict, predict_proba & score without fitting", "[Models]")
REQUIRE_THROWS_WITH(clf.predict_proba(raw.Xt), message);
REQUIRE_THROWS_WITH(clf.score(raw.Xv, raw.yv), message);
REQUIRE_THROWS_WITH(clf.score(raw.Xt, raw.yt), message);
}
TEST_CASE("TAN & SPODE with hyperparameters", "[Models]")
{
auto raw = RawDatasets("iris", true);
auto clf = bayesnet::TAN();
clf.setHyperparameters({
{"parent", 1},
});
clf.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing);
auto score = clf.score(raw.Xv, raw.yv);
REQUIRE(score == Catch::Approx(0.973333).epsilon(raw.epsilon));
auto clf2 = bayesnet::SPODE(0);
clf2.setHyperparameters({
{"parent", 1},
});
clf2.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing);
auto score2 = clf2.score(raw.Xv, raw.yv);
REQUIRE(score2 == Catch::Approx(0.973333).epsilon(raw.epsilon));
}
TEST_CASE("TAN & SPODE with invalid hyperparameters", "[Models]")
{
auto raw = RawDatasets("iris", true);
auto clf = bayesnet::TAN();
clf.setHyperparameters({
{"parent", 5},
});
REQUIRE_THROWS_AS(clf.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing), std::invalid_argument);
auto clf2 = bayesnet::SPODE(0);
clf2.setHyperparameters({
{"parent", 5},
});
REQUIRE_THROWS_AS(clf2.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing), std::invalid_argument);
}

View File

@@ -7,7 +7,7 @@
#include <type_traits>
#include <catch2/catch_test_macros.hpp>
#include <catch2/catch_approx.hpp>
#include <catch2/generators/catch_generators.hpp>
#include <catch2/generators/catch_generators.hpp>
#include <catch2/matchers/catch_matchers.hpp>
#include "bayesnet/ensembles/BoostAODE.h"
#include "TestUtils.h"