Remove repeatSparent hyperparameter

This commit is contained in:
Ricardo Montañana Gómez 2024-03-19 09:42:03 +01:00
parent eb72f13bf0
commit eb97a5a14b
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE

View File

@ -22,7 +22,7 @@ namespace bayesnet {
BoostAODE::BoostAODE(bool predict_voting) : Ensemble(predict_voting) BoostAODE::BoostAODE(bool predict_voting) : Ensemble(predict_voting)
{ {
validHyperparameters = { validHyperparameters = {
"repeatSparent", "maxModels", "order", "convergence", "threshold", "maxModels", "order", "convergence", "threshold",
"select_features", "tolerance", "predict_voting", "predict_single" "select_features", "tolerance", "predict_voting", "predict_single"
}; };
@ -63,10 +63,6 @@ namespace bayesnet {
void BoostAODE::setHyperparameters(const nlohmann::json& hyperparameters_) void BoostAODE::setHyperparameters(const nlohmann::json& hyperparameters_)
{ {
auto hyperparameters = hyperparameters_; auto hyperparameters = hyperparameters_;
if (hyperparameters.contains("repeatSparent")) {
repeatSparent = hyperparameters["repeatSparent"];
hyperparameters.erase("repeatSparent");
}
if (hyperparameters.contains("maxModels")) { if (hyperparameters.contains("maxModels")) {
maxModels = hyperparameters["maxModels"]; maxModels = hyperparameters["maxModels"];
hyperparameters.erase("maxModels"); hyperparameters.erase("maxModels");
@ -230,22 +226,15 @@ namespace bayesnet {
if (order_algorithm == Orders.RAND) { if (order_algorithm == Orders.RAND) {
std::shuffle(featureSelection.begin(), featureSelection.end(), g); std::shuffle(featureSelection.begin(), featureSelection.end(), g);
} }
auto feature = featureSelection[0]; // Remove used features
if (!repeatSparent || featuresUsed.size() < featureSelection.size()) { featureSelection.erase(remove_if(begin(featureSelection), end(featureSelection), [&](auto x)
bool used = true; { return find(begin(featuresUsed), end(featuresUsed), x) != end(featuresUsed);}),
for (const auto& feat : featureSelection) { end(featureSelection)
if (std::find(featuresUsed.begin(), featuresUsed.end(), feat) != featuresUsed.end()) { );
continue; if (featureSelection.empty()) {
}
used = false;
feature = feat;
break; break;
} }
if (used) { auto feature = featureSelection[0];
exitCondition = true;
continue;
}
}
std::unique_ptr<Classifier> model; std::unique_ptr<Classifier> model;
model = std::make_unique<SPODE>(feature); model = std::make_unique<SPODE>(feature);
model->fit(dataset, features, className, states, weights_); model->fit(dataset, features, className, states, weights_);