Compare commits

..

2 Commits

Author SHA1 Message Date
b90e558238
Hyperparameter *maxTolerance* in the BoostAODE class is now in [1, 6] range (it was in [1, 4] range before) 2025-01-23 00:56:18 +01:00
64970cf7f7 Merge pull request 'alphablock' (#32) from alphablock into main
Reviewed-on: #32
Added

- Add a new hyperparameter to the BoostAODE class, alphablock, to control the way α is computed, with the last model or with the ensmble built so far. Default value is false.
- Add a new hyperparameter to the SPODE class, parent, to set the root node of the model. If no value is set the root parameter of the constructor is used.
- Add a new hyperparameter to the TAN class, parent, to set the root node of the model. If not set the first feature is used as root.
2025-01-22 11:48:09 +00:00
4 changed files with 8 additions and 4 deletions

View File

@ -13,6 +13,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add a new hyperparameter to the SPODE class, *parent*, to set the root node of the model. If no value is set the root parameter of the constructor is used. - Add a new hyperparameter to the SPODE class, *parent*, to set the root node of the model. If no value is set the root parameter of the constructor is used.
- Add a new hyperparameter to the TAN class, *parent*, to set the root node of the model. If not set the first feature is used as root. - Add a new hyperparameter to the TAN class, *parent*, to set the root node of the model. If not set the first feature is used as root.
### Changed
- Hyperparameter *maxTolerance* in the BoostAODE class is now in [1, 6] range (it was in [1, 4] range before).
## [1.0.6] 2024-11-23 ## [1.0.6] 2024-11-23
### Fixed ### Fixed

View File

@ -48,8 +48,8 @@ namespace bayesnet {
} }
if (hyperparameters.contains("maxTolerance")) { if (hyperparameters.contains("maxTolerance")) {
maxTolerance = hyperparameters["maxTolerance"]; maxTolerance = hyperparameters["maxTolerance"];
if (maxTolerance < 1 || maxTolerance > 4) if (maxTolerance < 1 || maxTolerance > 6)
throw std::invalid_argument("Invalid maxTolerance value, must be greater in [1, 4]"); throw std::invalid_argument("Invalid maxTolerance value, must be greater in [1, 6]");
hyperparameters.erase("maxTolerance"); hyperparameters.erase("maxTolerance");
} }
if (hyperparameters.contains("predict_voting")) { if (hyperparameters.contains("predict_voting")) {

View File

@ -123,7 +123,7 @@ TEST_CASE("Oddities2", "[BoostA2DE]")
{ { "order", "duck" } }, { { "order", "duck" } },
{ { "select_features", "duck" } }, { { "select_features", "duck" } },
{ { "maxTolerance", 0 } }, { { "maxTolerance", 0 } },
{ { "maxTolerance", 5 } }, { { "maxTolerance", 7 } },
}; };
for (const auto& hyper : bad_hyper.items()) { for (const auto& hyper : bad_hyper.items()) {
INFO("BoostA2DE hyper: " + hyper.value().dump()); INFO("BoostA2DE hyper: " + hyper.value().dump());

View File

@ -118,7 +118,7 @@ TEST_CASE("Oddities", "[BoostAODE]")
{ { "order", "duck" } }, { { "order", "duck" } },
{ { "select_features", "duck" } }, { { "select_features", "duck" } },
{ { "maxTolerance", 0 } }, { { "maxTolerance", 0 } },
{ { "maxTolerance", 5 } }, { { "maxTolerance", 7 } },
}; };
for (const auto& hyper : bad_hyper.items()) { for (const auto& hyper : bad_hyper.items()) {
INFO("BoostAODE hyper: " << hyper.value().dump()); INFO("BoostAODE hyper: " << hyper.value().dump());