Move check valid hyperparameters to Classifier
This commit is contained in:
parent
1c1385b768
commit
97ca8ac084
@ -11,12 +11,8 @@ namespace bayesnet {
|
|||||||
void BoostAODE::setHyperparameters(nlohmann::json& hyperparameters)
|
void BoostAODE::setHyperparameters(nlohmann::json& hyperparameters)
|
||||||
{
|
{
|
||||||
// Check if hyperparameters are valid
|
// Check if hyperparameters are valid
|
||||||
auto validKeys = { "repeatSparent", "maxModels", "ascending" };
|
const vector<string> validKeys = { "repeatSparent", "maxModels", "ascending" };
|
||||||
for (const auto& item : hyperparameters.items()) {
|
checkHyperparameters(validKeys, hyperparameters);
|
||||||
if (find(validKeys.begin(), validKeys.end(), item.key()) == validKeys.end()) {
|
|
||||||
throw invalid_argument("Hyperparameter " + item.key() + " is not valid");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (hyperparameters.contains("repeatSparent")) {
|
if (hyperparameters.contains("repeatSparent")) {
|
||||||
repeatSparent = hyperparameters["repeatSparent"];
|
repeatSparent = hyperparameters["repeatSparent"];
|
||||||
}
|
}
|
||||||
|
@ -152,4 +152,12 @@ namespace bayesnet {
|
|||||||
{
|
{
|
||||||
model.dump_cpt();
|
model.dump_cpt();
|
||||||
}
|
}
|
||||||
|
void Classifier::checkHyperparameters(const vector<string>& validKeys, nlohmann::json& hyperparameters)
|
||||||
|
{
|
||||||
|
for (const auto& item : hyperparameters.items()) {
|
||||||
|
if (find(validKeys.begin(), validKeys.end(), item.key()) == validKeys.end()) {
|
||||||
|
throw invalid_argument("Hyperparameter " + item.key() + " is not valid");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
@ -24,6 +24,7 @@ namespace bayesnet {
|
|||||||
void checkFitParameters();
|
void checkFitParameters();
|
||||||
virtual void buildModel(const torch::Tensor& weights) = 0;
|
virtual void buildModel(const torch::Tensor& weights) = 0;
|
||||||
void trainModel(const torch::Tensor& weights) override;
|
void trainModel(const torch::Tensor& weights) override;
|
||||||
|
void checkHyperparameters(const vector<string>& validKeys, nlohmann::json& hyperparameters);
|
||||||
public:
|
public:
|
||||||
Classifier(Network model);
|
Classifier(Network model);
|
||||||
virtual ~Classifier() = default;
|
virtual ~Classifier() = default;
|
||||||
|
Loading…
Reference in New Issue
Block a user