diff --git a/CHANGELOG.md b/CHANGELOG.md index 28e0d13..11ad42a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- Add a new hyperparameter to the BoostAODE class, *alphablock*, to control the way α is computed, with the last model or with the ensmble built so far. Default value is *false*. + ## [1.0.6] 2024-11-23 ### Fixed diff --git a/bayesnet/ensembles/Boost.cc b/bayesnet/ensembles/Boost.cc index a582811..50f3c73 100644 --- a/bayesnet/ensembles/Boost.cc +++ b/bayesnet/ensembles/Boost.cc @@ -12,7 +12,7 @@ namespace bayesnet { Boost::Boost(bool predict_voting) : Ensemble(predict_voting) { - validHyperparameters = { "order", "convergence", "convergence_best", "bisection", "threshold", "maxTolerance", + validHyperparameters = { "alpha_block", "order", "convergence", "convergence_best", "bisection", "threshold", "maxTolerance", "predict_voting", "select_features", "block_update" }; } void Boost::setHyperparameters(const nlohmann::json& hyperparameters_) @@ -26,6 +26,10 @@ namespace bayesnet { } hyperparameters.erase("order"); } + if (hyperparameters.contains("alpha_block")) { + alpha_block = hyperparameters["alpha_block"]; + hyperparameters.erase("alpha_block"); + } if (hyperparameters.contains("convergence")) { convergence = hyperparameters["convergence"]; hyperparameters.erase("convergence"); @@ -66,6 +70,12 @@ namespace bayesnet { block_update = hyperparameters["block_update"]; hyperparameters.erase("block_update"); } + if (block_update && alpha_block) { + throw std::invalid_argument("alpha_block and block_update cannot be true at the same time"); + } + if (block_update && !bisection) { + throw std::invalid_argument("block_update needs bisection to be true"); + } Classifier::setHyperparameters(hyperparameters); } void Boost::buildModel(const torch::Tensor& weights) diff --git a/bayesnet/ensembles/Boost.h b/bayesnet/ensembles/Boost.h index 2594bcb..82433e0 100644 --- a/bayesnet/ensembles/Boost.h +++ b/bayesnet/ensembles/Boost.h @@ -45,8 +45,8 @@ namespace bayesnet { std::string select_features_algorithm = Orders.DESC; // Selected feature selection algorithm FeatureSelect* featureSelector = nullptr; double threshold = -1; - bool block_update = false; - + bool block_update = false; // if true, use block update algorithm, only meaningful if bisection is true + bool alpha_block = false; // if true, the alpha is computed with the ensemble built so far and the new model }; } #endif \ No newline at end of file