diff --git a/src/BayesNet/BoostAODE.cc b/src/BayesNet/BoostAODE.cc index fa6dabb..a7d5e5c 100644 --- a/src/BayesNet/BoostAODE.cc +++ b/src/BayesNet/BoostAODE.cc @@ -11,12 +11,8 @@ namespace bayesnet { void BoostAODE::setHyperparameters(nlohmann::json& hyperparameters) { // Check if hyperparameters are valid - auto validKeys = { "repeatSparent", "maxModels", "ascending" }; - for (const auto& item : hyperparameters.items()) { - if (find(validKeys.begin(), validKeys.end(), item.key()) == validKeys.end()) { - throw invalid_argument("Hyperparameter " + item.key() + " is not valid"); - } - } + const vector validKeys = { "repeatSparent", "maxModels", "ascending" }; + checkHyperparameters(validKeys, hyperparameters); if (hyperparameters.contains("repeatSparent")) { repeatSparent = hyperparameters["repeatSparent"]; } diff --git a/src/BayesNet/Classifier.cc b/src/BayesNet/Classifier.cc index ff25657..db4a63f 100644 --- a/src/BayesNet/Classifier.cc +++ b/src/BayesNet/Classifier.cc @@ -152,4 +152,12 @@ namespace bayesnet { { model.dump_cpt(); } + void Classifier::checkHyperparameters(const vector& validKeys, nlohmann::json& hyperparameters) + { + for (const auto& item : hyperparameters.items()) { + if (find(validKeys.begin(), validKeys.end(), item.key()) == validKeys.end()) { + throw invalid_argument("Hyperparameter " + item.key() + " is not valid"); + } + } + } } \ No newline at end of file diff --git a/src/BayesNet/Classifier.h b/src/BayesNet/Classifier.h index 0c2940b..d27e486 100644 --- a/src/BayesNet/Classifier.h +++ b/src/BayesNet/Classifier.h @@ -24,6 +24,7 @@ namespace bayesnet { void checkFitParameters(); virtual void buildModel(const torch::Tensor& weights) = 0; void trainModel(const torch::Tensor& weights) override; + void checkHyperparameters(const vector& validKeys, nlohmann::json& hyperparameters); public: Classifier(Network model); virtual ~Classifier() = default;