From 97ca8ac0849cd781497673e2ed66b0ea5936b224 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Tue, 22 Aug 2023 22:12:20 +0200 Subject: [PATCH] Move check valid hyperparameters to Classifier --- src/BayesNet/BoostAODE.cc | 8 ++------ src/BayesNet/Classifier.cc | 8 ++++++++ src/BayesNet/Classifier.h | 1 + 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/BayesNet/BoostAODE.cc b/src/BayesNet/BoostAODE.cc index fa6dabb..a7d5e5c 100644 --- a/src/BayesNet/BoostAODE.cc +++ b/src/BayesNet/BoostAODE.cc @@ -11,12 +11,8 @@ namespace bayesnet { void BoostAODE::setHyperparameters(nlohmann::json& hyperparameters) { // Check if hyperparameters are valid - auto validKeys = { "repeatSparent", "maxModels", "ascending" }; - for (const auto& item : hyperparameters.items()) { - if (find(validKeys.begin(), validKeys.end(), item.key()) == validKeys.end()) { - throw invalid_argument("Hyperparameter " + item.key() + " is not valid"); - } - } + const vector validKeys = { "repeatSparent", "maxModels", "ascending" }; + checkHyperparameters(validKeys, hyperparameters); if (hyperparameters.contains("repeatSparent")) { repeatSparent = hyperparameters["repeatSparent"]; } diff --git a/src/BayesNet/Classifier.cc b/src/BayesNet/Classifier.cc index ff25657..db4a63f 100644 --- a/src/BayesNet/Classifier.cc +++ b/src/BayesNet/Classifier.cc @@ -152,4 +152,12 @@ namespace bayesnet { { model.dump_cpt(); } + void Classifier::checkHyperparameters(const vector& validKeys, nlohmann::json& hyperparameters) + { + for (const auto& item : hyperparameters.items()) { + if (find(validKeys.begin(), validKeys.end(), item.key()) == validKeys.end()) { + throw invalid_argument("Hyperparameter " + item.key() + " is not valid"); + } + } + } } \ No newline at end of file diff --git a/src/BayesNet/Classifier.h b/src/BayesNet/Classifier.h index 0c2940b..d27e486 100644 --- a/src/BayesNet/Classifier.h +++ b/src/BayesNet/Classifier.h @@ -24,6 +24,7 @@ namespace bayesnet { void checkFitParameters(); virtual void buildModel(const torch::Tensor& weights) = 0; void trainModel(const torch::Tensor& weights) override; + void checkHyperparameters(const vector& validKeys, nlohmann::json& hyperparameters); public: Classifier(Network model); virtual ~Classifier() = default;