Move check valid hyperparameters to Classifier

This commit is contained in:
Ricardo Montañana Gómez 2023-08-22 22:12:20 +02:00
parent 1c1385b768
commit 97ca8ac084
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
3 changed files with 11 additions and 6 deletions

View File

@ -11,12 +11,8 @@ namespace bayesnet {
void BoostAODE::setHyperparameters(nlohmann::json& hyperparameters) void BoostAODE::setHyperparameters(nlohmann::json& hyperparameters)
{ {
// Check if hyperparameters are valid // Check if hyperparameters are valid
auto validKeys = { "repeatSparent", "maxModels", "ascending" }; const vector<string> validKeys = { "repeatSparent", "maxModels", "ascending" };
for (const auto& item : hyperparameters.items()) { checkHyperparameters(validKeys, hyperparameters);
if (find(validKeys.begin(), validKeys.end(), item.key()) == validKeys.end()) {
throw invalid_argument("Hyperparameter " + item.key() + " is not valid");
}
}
if (hyperparameters.contains("repeatSparent")) { if (hyperparameters.contains("repeatSparent")) {
repeatSparent = hyperparameters["repeatSparent"]; repeatSparent = hyperparameters["repeatSparent"];
} }

View File

@ -152,4 +152,12 @@ namespace bayesnet {
{ {
model.dump_cpt(); model.dump_cpt();
} }
void Classifier::checkHyperparameters(const vector<string>& validKeys, nlohmann::json& hyperparameters)
{
for (const auto& item : hyperparameters.items()) {
if (find(validKeys.begin(), validKeys.end(), item.key()) == validKeys.end()) {
throw invalid_argument("Hyperparameter " + item.key() + " is not valid");
}
}
}
} }

View File

@ -24,6 +24,7 @@ namespace bayesnet {
void checkFitParameters(); void checkFitParameters();
virtual void buildModel(const torch::Tensor& weights) = 0; virtual void buildModel(const torch::Tensor& weights) = 0;
void trainModel(const torch::Tensor& weights) override; void trainModel(const torch::Tensor& weights) override;
void checkHyperparameters(const vector<string>& validKeys, nlohmann::json& hyperparameters);
public: public:
Classifier(Network model); Classifier(Network model);
virtual ~Classifier() = default; virtual ~Classifier() = default;