diff --git a/src/BayesNet/BoostAODE.cc b/src/BayesNet/BoostAODE.cc index 80fd20c..fa6dabb 100644 --- a/src/BayesNet/BoostAODE.cc +++ b/src/BayesNet/BoostAODE.cc @@ -10,6 +10,13 @@ namespace bayesnet { } void BoostAODE::setHyperparameters(nlohmann::json& hyperparameters) { + // Check if hyperparameters are valid + auto validKeys = { "repeatSparent", "maxModels", "ascending" }; + for (const auto& item : hyperparameters.items()) { + if (find(validKeys.begin(), validKeys.end(), item.key()) == validKeys.end()) { + throw invalid_argument("Hyperparameter " + item.key() + " is not valid"); + } + } if (hyperparameters.contains("repeatSparent")) { repeatSparent = hyperparameters["repeatSparent"]; } @@ -74,7 +81,7 @@ namespace bayesnet { // Step 3.4: Store classifier and its accuracy to weigh its future vote models.push_back(std::move(model)); significanceModels.push_back(significance); - exitCondition = n_models == maxModels; + exitCondition = n_models == maxModels && repeatSparent; } if (featuresUsed.size() != features.size()) { cout << "Warning: BoostAODE did not use all the features" << endl;