Move check valid hyperparameters to Classifier

This commit is contained in:
Ricardo Montañana Gómez 2023-08-22 22:12:20 +02:00
parent 1c1385b768
commit 97ca8ac084
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
3 changed files with 11 additions and 6 deletions

View File

@ -11,12 +11,8 @@ namespace bayesnet {
void BoostAODE::setHyperparameters(nlohmann::json& hyperparameters)
{
// Check if hyperparameters are valid
auto validKeys = { "repeatSparent", "maxModels", "ascending" };
for (const auto& item : hyperparameters.items()) {
if (find(validKeys.begin(), validKeys.end(), item.key()) == validKeys.end()) {
throw invalid_argument("Hyperparameter " + item.key() + " is not valid");
}
}
const vector<string> validKeys = { "repeatSparent", "maxModels", "ascending" };
checkHyperparameters(validKeys, hyperparameters);
if (hyperparameters.contains("repeatSparent")) {
repeatSparent = hyperparameters["repeatSparent"];
}

View File

@ -152,4 +152,12 @@ namespace bayesnet {
{
model.dump_cpt();
}
void Classifier::checkHyperparameters(const vector<string>& validKeys, nlohmann::json& hyperparameters)
{
for (const auto& item : hyperparameters.items()) {
if (find(validKeys.begin(), validKeys.end(), item.key()) == validKeys.end()) {
throw invalid_argument("Hyperparameter " + item.key() + " is not valid");
}
}
}
}

View File

@ -24,6 +24,7 @@ namespace bayesnet {
void checkFitParameters();
virtual void buildModel(const torch::Tensor& weights) = 0;
void trainModel(const torch::Tensor& weights) override;
void checkHyperparameters(const vector<string>& validKeys, nlohmann::json& hyperparameters);
public:
Classifier(Network model);
virtual ~Classifier() = default;