Implement the functionality of the hyperparameter alpha_block with test

This commit is contained in:
Ricardo Montañana Gómez 2024-12-14 14:02:45 +01:00
parent 18844c7da7
commit 3d6b4f0614
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
2 changed files with 36 additions and 3 deletions

View File

@ -92,7 +92,25 @@ namespace bayesnet {
model->fit(dataset, features, className, states, weights_, smoothing);
alpha_t = 0.0;
if (!block_update) {
auto ypred = model->predict(X_train);
torch::Tensor ypred;
if (alpha_block) {
//
// Compute the prediction with the current ensemble + model
//
// Add the model to the ensemble
n_models++;
models.push_back(std::move(model));
significanceModels.push_back(1);
// Compute the prediction
ypred = predict(X_train);
// Remove the model from the ensemble
model = std::move(models.back());
models.pop_back();
significanceModels.pop_back();
n_models--;
} else {
ypred = model->predict(X_train);
}
// Step 3.1: Compute the classifier amout of say
std::tie(weights_, alpha_t, finished) = update_weights(y_train, ypred, weights_);
}

View File

@ -130,6 +130,8 @@ TEST_CASE("Oddities", "[BoostAODE]")
{ { "select_features","IWSS" }, { "threshold", 0.51 } },
{ { "select_features","FCBF" }, { "threshold", 1e-8 } },
{ { "select_features","FCBF" }, { "threshold", 1.01 } },
{ { "alpha_block", true }, { "block_update", true } },
{ { "bisection", false }, { "block_update", true } },
};
for (const auto& hyper : bad_hyper_fit.items()) {
INFO("BoostAODE hyper: " << hyper.value().dump());
@ -137,7 +139,6 @@ TEST_CASE("Oddities", "[BoostAODE]")
REQUIRE_THROWS_AS(clf.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing), std::invalid_argument);
}
}
TEST_CASE("Bisection Best", "[BoostAODE]")
{
auto clf = bayesnet::BoostAODE();
@ -180,7 +181,6 @@ TEST_CASE("Bisection Best vs Last", "[BoostAODE]")
auto score_last = clf.score(raw.X_test, raw.y_test);
REQUIRE(score_last == Catch::Approx(0.976666689f).epsilon(raw.epsilon));
}
TEST_CASE("Block Update", "[BoostAODE]")
{
auto clf = bayesnet::BoostAODE();
@ -210,4 +210,19 @@ TEST_CASE("Block Update", "[BoostAODE]")
// std::cout << note << std::endl;
// }
// std::cout << "Score " << score << std::endl;
}
TEST_CASE("Alphablock", "[BoostAODE]")
{
auto clf_alpha = bayesnet::BoostAODE();
auto clf_no_alpha = bayesnet::BoostAODE();
auto raw = RawDatasets("diabetes", true);
clf_alpha.setHyperparameters({
{"alpha_block", true},
});
clf_alpha.fit(raw.X_train, raw.y_train, raw.features, raw.className, raw.states, raw.smoothing);
clf_no_alpha.fit(raw.X_train, raw.y_train, raw.features, raw.className, raw.states, raw.smoothing);
auto score_alpha = clf_alpha.score(raw.X_test, raw.y_test);
auto score_no_alpha = clf_no_alpha.score(raw.X_test, raw.y_test);
REQUIRE(score_alpha == Catch::Approx(0.720779f).epsilon(raw.epsilon));
REQUIRE(score_no_alpha == Catch::Approx(0.733766f).epsilon(raw.epsilon));
}