From eb97a5a14b38b057c3a3982722730c018379f6dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Tue, 19 Mar 2024 09:42:03 +0100 Subject: [PATCH] Remove repeatSparent hyperparameter --- bayesnet/ensembles/BoostAODE.cc | 29 +++++++++-------------------- 1 file changed, 9 insertions(+), 20 deletions(-) diff --git a/bayesnet/ensembles/BoostAODE.cc b/bayesnet/ensembles/BoostAODE.cc index fba10f3..f9c2205 100644 --- a/bayesnet/ensembles/BoostAODE.cc +++ b/bayesnet/ensembles/BoostAODE.cc @@ -22,7 +22,7 @@ namespace bayesnet { BoostAODE::BoostAODE(bool predict_voting) : Ensemble(predict_voting) { validHyperparameters = { - "repeatSparent", "maxModels", "order", "convergence", "threshold", + "maxModels", "order", "convergence", "threshold", "select_features", "tolerance", "predict_voting", "predict_single" }; @@ -63,10 +63,6 @@ namespace bayesnet { void BoostAODE::setHyperparameters(const nlohmann::json& hyperparameters_) { auto hyperparameters = hyperparameters_; - if (hyperparameters.contains("repeatSparent")) { - repeatSparent = hyperparameters["repeatSparent"]; - hyperparameters.erase("repeatSparent"); - } if (hyperparameters.contains("maxModels")) { maxModels = hyperparameters["maxModels"]; hyperparameters.erase("maxModels"); @@ -230,22 +226,15 @@ namespace bayesnet { if (order_algorithm == Orders.RAND) { std::shuffle(featureSelection.begin(), featureSelection.end(), g); } - auto feature = featureSelection[0]; - if (!repeatSparent || featuresUsed.size() < featureSelection.size()) { - bool used = true; - for (const auto& feat : featureSelection) { - if (std::find(featuresUsed.begin(), featuresUsed.end(), feat) != featuresUsed.end()) { - continue; - } - used = false; - feature = feat; - break; - } - if (used) { - exitCondition = true; - continue; - } + // Remove used features + featureSelection.erase(remove_if(begin(featureSelection), end(featureSelection), [&](auto x) + { return find(begin(featuresUsed), end(featuresUsed), x) != end(featuresUsed);}), + end(featureSelection) + ); + if (featureSelection.empty()) { + break; } + auto feature = featureSelection[0]; std::unique_ptr model; model = std::make_unique(feature); model->fit(dataset, features, className, states, weights_);