Refactor Hyperparameters management
This commit is contained in:
@@ -170,9 +170,9 @@ namespace platform {
|
||||
for (int nfold = 0; nfold < nfolds; nfold++) {
|
||||
auto clf = Models::instance()->create(model);
|
||||
setModelVersion(clf->getVersion());
|
||||
if (hyperparameters.notEmpty(fileName)) {
|
||||
clf->setHyperparameters(hyperparameters.get(fileName));
|
||||
}
|
||||
auto valid = clf->getValidHyperparameters();
|
||||
hyperparameters.check(valid, fileName);
|
||||
clf->setHyperparameters(hyperparameters.get(fileName));
|
||||
// Split train - test dataset
|
||||
train_timer.start();
|
||||
auto [train, test] = fold->getFold(nfold);
|
||||
|
@@ -1,5 +1,6 @@
|
||||
#include "HyperParameters.h"
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
|
||||
namespace platform {
|
||||
HyperParameters::HyperParameters(const std::vector<std::string>& datasets, const json& hyperparameters_)
|
||||
@@ -21,13 +22,24 @@ namespace platform {
|
||||
// Check if hyperparameters are valid
|
||||
for (const auto& dataset : datasets) {
|
||||
if (!input_hyperparameters.contains(dataset)) {
|
||||
throw std::runtime_error("Dataset " + dataset + " not found in hyperparameters file");
|
||||
std::cerr << "*Warning: Dataset " << dataset << " not found in hyperparameters file" << " assuming default hyperparameters" << std::endl;
|
||||
hyperparameters[dataset] = json({});
|
||||
continue;
|
||||
}
|
||||
hyperparameters[dataset] = input_hyperparameters[dataset];
|
||||
hyperparameters[dataset] = input_hyperparameters[dataset].get<json>();
|
||||
}
|
||||
}
|
||||
json HyperParameters::get(const std::string& key)
|
||||
void HyperParameters::check(const std::vector<std::string>& valid, const std::string& fileName)
|
||||
{
|
||||
return hyperparameters.at(key);
|
||||
json result = hyperparameters.at(fileName);
|
||||
for (const auto& item : result.items()) {
|
||||
if (find(valid.begin(), valid.end(), item.key()) == valid.end()) {
|
||||
throw std::invalid_argument("Hyperparameter " + item.key() + " is not valid. Passed Hyperparameters are: " + result.dump(4));
|
||||
}
|
||||
}
|
||||
}
|
||||
json HyperParameters::get(const std::string& fileName)
|
||||
{
|
||||
return hyperparameters.at(fileName);
|
||||
}
|
||||
} /* namespace platform */
|
@@ -13,8 +13,9 @@ namespace platform {
|
||||
explicit HyperParameters(const std::vector<std::string>& datasets, const json& hyperparameters_);
|
||||
explicit HyperParameters(const std::vector<std::string>& datasets, const std::string& hyperparameters_file);
|
||||
~HyperParameters() = default;
|
||||
bool notEmpty(const std::string& key) const { return hyperparameters.at(key) != json(); }
|
||||
json get(const std::string& key);
|
||||
bool notEmpty(const std::string& key) const { return !hyperparameters.at(key).empty(); }
|
||||
void check(const std::vector<std::string>& valid, const std::string& fileName);
|
||||
json get(const std::string& fileName);
|
||||
private:
|
||||
std::map<std::string, json> hyperparameters;
|
||||
};
|
||||
|
Reference in New Issue
Block a user