From d82148079d778b887615a371cd9e73aa258e54d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Wed, 23 Aug 2023 00:44:10 +0200 Subject: [PATCH] Add KDB hyperparameters K and theta --- src/BayesNet/KDB.cc | 12 ++++++++++++ src/BayesNet/KDB.h | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/BayesNet/KDB.cc b/src/BayesNet/KDB.cc index cfbbca1..d511354 100644 --- a/src/BayesNet/KDB.cc +++ b/src/BayesNet/KDB.cc @@ -4,6 +4,18 @@ namespace bayesnet { using namespace torch; KDB::KDB(int k, float theta) : Classifier(Network()), k(k), theta(theta) {} + void KDB::setHyperparameters(nlohmann::json& hyperparameters) + { + // Check if hyperparameters are valid + const vector validKeys = { "k", "theta" }; + checkHyperparameters(validKeys, hyperparameters); + if (hyperparameters.contains("k")) { + k = hyperparameters["k"]; + } + if (hyperparameters.contains("theta")) { + theta = hyperparameters["theta"]; + } + } void KDB::buildModel(const torch::Tensor& weights) { /* diff --git a/src/BayesNet/KDB.h b/src/BayesNet/KDB.h index 713b415..992d061 100644 --- a/src/BayesNet/KDB.h +++ b/src/BayesNet/KDB.h @@ -16,7 +16,7 @@ namespace bayesnet { public: explicit KDB(int k, float theta = 0.03); virtual ~KDB() {}; - void setHyperparameters(nlohmann::json& hyperparameters) override {}; + void setHyperparameters(nlohmann::json& hyperparameters) override; vector graph(const string& name = "KDB") const override; }; }