Remove repeatSparent hyperparameter

This commit is contained in:
Ricardo Montañana Gómez 2024-03-19 09:42:03 +01:00
parent eb72f13bf0
commit eb97a5a14b
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE

View File

@ -22,7 +22,7 @@ namespace bayesnet {
BoostAODE::BoostAODE(bool predict_voting) : Ensemble(predict_voting)
{
validHyperparameters = {
"repeatSparent", "maxModels", "order", "convergence", "threshold",
"maxModels", "order", "convergence", "threshold",
"select_features", "tolerance", "predict_voting", "predict_single"
};
@ -63,10 +63,6 @@ namespace bayesnet {
void BoostAODE::setHyperparameters(const nlohmann::json& hyperparameters_)
{
auto hyperparameters = hyperparameters_;
if (hyperparameters.contains("repeatSparent")) {
repeatSparent = hyperparameters["repeatSparent"];
hyperparameters.erase("repeatSparent");
}
if (hyperparameters.contains("maxModels")) {
maxModels = hyperparameters["maxModels"];
hyperparameters.erase("maxModels");
@ -230,22 +226,15 @@ namespace bayesnet {
if (order_algorithm == Orders.RAND) {
std::shuffle(featureSelection.begin(), featureSelection.end(), g);
}
auto feature = featureSelection[0];
if (!repeatSparent || featuresUsed.size() < featureSelection.size()) {
bool used = true;
for (const auto& feat : featureSelection) {
if (std::find(featuresUsed.begin(), featuresUsed.end(), feat) != featuresUsed.end()) {
continue;
}
used = false;
feature = feat;
break;
}
if (used) {
exitCondition = true;
continue;
}
// Remove used features
featureSelection.erase(remove_if(begin(featureSelection), end(featureSelection), [&](auto x)
{ return find(begin(featuresUsed), end(featuresUsed), x) != end(featuresUsed);}),
end(featureSelection)
);
if (featureSelection.empty()) {
break;
}
auto feature = featureSelection[0];
std::unique_ptr<Classifier> model;
model = std::make_unique<SPODE>(feature);
model->fit(dataset, features, className, states, weights_);