Add KDB hyperparameters K and theta
This commit is contained in:
parent
067430fd1b
commit
d82148079d
@ -4,6 +4,18 @@ namespace bayesnet {
|
||||
using namespace torch;
|
||||
|
||||
KDB::KDB(int k, float theta) : Classifier(Network()), k(k), theta(theta) {}
|
||||
void KDB::setHyperparameters(nlohmann::json& hyperparameters)
|
||||
{
|
||||
// Check if hyperparameters are valid
|
||||
const vector<string> validKeys = { "k", "theta" };
|
||||
checkHyperparameters(validKeys, hyperparameters);
|
||||
if (hyperparameters.contains("k")) {
|
||||
k = hyperparameters["k"];
|
||||
}
|
||||
if (hyperparameters.contains("theta")) {
|
||||
theta = hyperparameters["theta"];
|
||||
}
|
||||
}
|
||||
void KDB::buildModel(const torch::Tensor& weights)
|
||||
{
|
||||
/*
|
||||
|
@ -16,7 +16,7 @@ namespace bayesnet {
|
||||
public:
|
||||
explicit KDB(int k, float theta = 0.03);
|
||||
virtual ~KDB() {};
|
||||
void setHyperparameters(nlohmann::json& hyperparameters) override {};
|
||||
void setHyperparameters(nlohmann::json& hyperparameters) override;
|
||||
vector<string> graph(const string& name = "KDB") const override;
|
||||
};
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user