diff --git a/src/BayesNet/BoostAODE.cc b/src/BayesNet/BoostAODE.cc index de6ebb8..c67424c 100644 --- a/src/BayesNet/BoostAODE.cc +++ b/src/BayesNet/BoostAODE.cc @@ -12,7 +12,7 @@ namespace bayesnet { BoostAODE::BoostAODE() : Ensemble() { - validHyperparameters = { "repeatSparent", "maxModels", "ascending", "convergence", "threshold", "select_features" }; + validHyperparameters = { "repeatSparent", "maxModels", "ascending", "convergence", "threshold", "select_features", "tolerance" }; } void BoostAODE::buildModel(const torch::Tensor& weights) @@ -47,22 +47,32 @@ namespace bayesnet { y_train = y_; } } - void BoostAODE::setHyperparameters(const nlohmann::json& hyperparameters) + void BoostAODE::setHyperparameters(const nlohmann::json& hyperparameters_) { + auto hyperparameters = hyperparameters_; if (hyperparameters.contains("repeatSparent")) { repeatSparent = hyperparameters["repeatSparent"]; + hyperparameters.erase("repeatSparent"); } if (hyperparameters.contains("maxModels")) { maxModels = hyperparameters["maxModels"]; + hyperparameters.erase("maxModels"); } if (hyperparameters.contains("ascending")) { ascending = hyperparameters["ascending"]; + hyperparameters.erase("ascending"); } if (hyperparameters.contains("convergence")) { convergence = hyperparameters["convergence"]; + hyperparameters.erase("convergence"); } if (hyperparameters.contains("threshold")) { threshold = hyperparameters["threshold"]; + hyperparameters.erase("threshold"); + } + if (hyperparameters.contains("tolerance")) { + tolerance = hyperparameters["tolerance"]; + hyperparameters.erase("tolerance"); } if (hyperparameters.contains("select_features")) { auto selectedAlgorithm = hyperparameters["select_features"]; @@ -72,6 +82,10 @@ namespace bayesnet { if (std::find(algos.begin(), algos.end(), selectedAlgorithm) == algos.end()) { throw std::invalid_argument("Invalid selectFeatures value [IWSS, FCBF, CFS]"); } + hyperparameters.erase("select_features"); + } + if (!hyperparameters.empty()) { + throw std::invalid_argument("Invalid hyperparameters" + hyperparameters.dump()); } } std::unordered_set BoostAODE::initializeModels() @@ -109,10 +123,8 @@ namespace bayesnet { void BoostAODE::trainModel(const torch::Tensor& weights) { std::unordered_set featuresUsed; - int tolerance = 5; // number of times the accuracy can be lower than the threshold if (selectFeatures) { featuresUsed = initializeModels(); - tolerance = 0; // Remove tolerance if features are selected } if (maxModels == 0) maxModels = .1 * n > 10 ? .1 * n : n; diff --git a/src/BayesNet/BoostAODE.h b/src/BayesNet/BoostAODE.h index 670c696..4b0b063 100644 --- a/src/BayesNet/BoostAODE.h +++ b/src/BayesNet/BoostAODE.h @@ -21,6 +21,7 @@ namespace bayesnet { // Hyperparameters bool repeatSparent = false; // if true, a feature can be selected more than once int maxModels = 0; + int tolerance = 0; bool ascending = false; //Process KBest features ascending or descending order bool convergence = false; //if true, stop when the model does not improve bool selectFeatures = false; // if true, use feature selection diff --git a/src/Platform/GridSearch.cc b/src/Platform/GridSearch.cc index 4379178..c056272 100644 --- a/src/Platform/GridSearch.cc +++ b/src/Platform/GridSearch.cc @@ -89,6 +89,8 @@ namespace platform { double bestScore = 0.0; for (int nfold = 0; nfold < config.n_folds; nfold++) { auto clf = Models::instance()->create(config.model); + auto valid = clf->getValidHyperparameters(); + hyperparameters.check(valid, fileName); clf->setHyperparameters(hyperparameters.get(fileName)); auto [train, test] = fold->getFold(nfold); auto train_t = torch::tensor(train);