Add KDB hyperparameters K and theta

This commit is contained in:
Ricardo Montañana Gómez 2023-08-23 00:44:10 +02:00
parent 067430fd1b
commit d82148079d
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
2 changed files with 13 additions and 1 deletions

View File

@ -4,6 +4,18 @@ namespace bayesnet {
using namespace torch;
KDB::KDB(int k, float theta) : Classifier(Network()), k(k), theta(theta) {}
void KDB::setHyperparameters(nlohmann::json& hyperparameters)
{
// Check if hyperparameters are valid
const vector<string> validKeys = { "k", "theta" };
checkHyperparameters(validKeys, hyperparameters);
if (hyperparameters.contains("k")) {
k = hyperparameters["k"];
}
if (hyperparameters.contains("theta")) {
theta = hyperparameters["theta"];
}
}
void KDB::buildModel(const torch::Tensor& weights)
{
/*

View File

@ -16,7 +16,7 @@ namespace bayesnet {
public:
explicit KDB(int k, float theta = 0.03);
virtual ~KDB() {};
void setHyperparameters(nlohmann::json& hyperparameters) override {};
void setHyperparameters(nlohmann::json& hyperparameters) override;
vector<string> graph(const string& name = "KDB") const override;
};
}