From 3d6b4f0614f7a701a05f174cb6a4ed4d63abb3e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Sat, 14 Dec 2024 14:02:45 +0100 Subject: [PATCH] Implement the functionality of the hyperparameter alpha_block with test --- bayesnet/ensembles/BoostAODE.cc | 20 +++++++++++++++++++- tests/TestBoostAODE.cc | 19 +++++++++++++++++-- 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/bayesnet/ensembles/BoostAODE.cc b/bayesnet/ensembles/BoostAODE.cc index 0638d78..b2ba9b6 100644 --- a/bayesnet/ensembles/BoostAODE.cc +++ b/bayesnet/ensembles/BoostAODE.cc @@ -92,7 +92,25 @@ namespace bayesnet { model->fit(dataset, features, className, states, weights_, smoothing); alpha_t = 0.0; if (!block_update) { - auto ypred = model->predict(X_train); + torch::Tensor ypred; + if (alpha_block) { + // + // Compute the prediction with the current ensemble + model + // + // Add the model to the ensemble + n_models++; + models.push_back(std::move(model)); + significanceModels.push_back(1); + // Compute the prediction + ypred = predict(X_train); + // Remove the model from the ensemble + model = std::move(models.back()); + models.pop_back(); + significanceModels.pop_back(); + n_models--; + } else { + ypred = model->predict(X_train); + } // Step 3.1: Compute the classifier amout of say std::tie(weights_, alpha_t, finished) = update_weights(y_train, ypred, weights_); } diff --git a/tests/TestBoostAODE.cc b/tests/TestBoostAODE.cc index 1a8a0f0..73d9048 100644 --- a/tests/TestBoostAODE.cc +++ b/tests/TestBoostAODE.cc @@ -130,6 +130,8 @@ TEST_CASE("Oddities", "[BoostAODE]") { { "select_features","IWSS" }, { "threshold", 0.51 } }, { { "select_features","FCBF" }, { "threshold", 1e-8 } }, { { "select_features","FCBF" }, { "threshold", 1.01 } }, + { { "alpha_block", true }, { "block_update", true } }, + { { "bisection", false }, { "block_update", true } }, }; for (const auto& hyper : bad_hyper_fit.items()) { INFO("BoostAODE hyper: " << hyper.value().dump()); @@ -137,7 +139,6 @@ TEST_CASE("Oddities", "[BoostAODE]") REQUIRE_THROWS_AS(clf.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing), std::invalid_argument); } } - TEST_CASE("Bisection Best", "[BoostAODE]") { auto clf = bayesnet::BoostAODE(); @@ -180,7 +181,6 @@ TEST_CASE("Bisection Best vs Last", "[BoostAODE]") auto score_last = clf.score(raw.X_test, raw.y_test); REQUIRE(score_last == Catch::Approx(0.976666689f).epsilon(raw.epsilon)); } - TEST_CASE("Block Update", "[BoostAODE]") { auto clf = bayesnet::BoostAODE(); @@ -210,4 +210,19 @@ TEST_CASE("Block Update", "[BoostAODE]") // std::cout << note << std::endl; // } // std::cout << "Score " << score << std::endl; +} +TEST_CASE("Alphablock", "[BoostAODE]") +{ + auto clf_alpha = bayesnet::BoostAODE(); + auto clf_no_alpha = bayesnet::BoostAODE(); + auto raw = RawDatasets("diabetes", true); + clf_alpha.setHyperparameters({ + {"alpha_block", true}, + }); + clf_alpha.fit(raw.X_train, raw.y_train, raw.features, raw.className, raw.states, raw.smoothing); + clf_no_alpha.fit(raw.X_train, raw.y_train, raw.features, raw.className, raw.states, raw.smoothing); + auto score_alpha = clf_alpha.score(raw.X_test, raw.y_test); + auto score_no_alpha = clf_no_alpha.score(raw.X_test, raw.y_test); + REQUIRE(score_alpha == Catch::Approx(0.720779f).epsilon(raw.epsilon)); + REQUIRE(score_no_alpha == Catch::Approx(0.733766f).epsilon(raw.epsilon)); } \ No newline at end of file