Accept nested hyperparameters in b_main
This commit is contained in:
@@ -10,14 +10,7 @@ namespace platform {
|
||||
for (const auto& item : datasets) {
|
||||
hyperparameters[item] = hyperparameters_;
|
||||
}
|
||||
}
|
||||
// https://www.techiedelight.com/implode-a-vector-of-strings-into-a-comma-separated-string-in-cpp/
|
||||
std::string join(std::vector<std::string> const& strings, std::string delim)
|
||||
{
|
||||
std::stringstream ss;
|
||||
std::copy(strings.begin(), strings.end(),
|
||||
std::ostream_iterator<std::string>(ss, delim.c_str()));
|
||||
return ss.str();
|
||||
normalize_nested(datasets);
|
||||
}
|
||||
HyperParameters::HyperParameters(const std::vector<std::string>& datasets, const std::string& hyperparameters_file, bool best)
|
||||
{
|
||||
@@ -31,7 +24,7 @@ namespace platform {
|
||||
json input_hyperparameters;
|
||||
if (best) {
|
||||
for (const auto& [key, value] : file_hyperparameters.items()) {
|
||||
input_hyperparameters[key] = value[1];
|
||||
input_hyperparameters[key]["hyperparameters"] = value[1];
|
||||
}
|
||||
} else {
|
||||
input_hyperparameters = file_hyperparameters["results"];
|
||||
@@ -45,6 +38,24 @@ namespace platform {
|
||||
}
|
||||
hyperparameters[dataset] = input_hyperparameters[dataset]["hyperparameters"].get<json>();
|
||||
}
|
||||
normalize_nested(datasets);
|
||||
}
|
||||
void HyperParameters::normalize_nested(const std::vector<std::string>& datasets)
|
||||
{
|
||||
for (const auto& dataset : datasets) {
|
||||
if (hyperparameters[dataset].contains("be_hyperparams")) {
|
||||
// Odte has base estimator hyperparameters set this way
|
||||
hyperparameters[dataset]["be_hyperparams"] = hyperparameters[dataset]["be_hyperparams"].dump();
|
||||
}
|
||||
}
|
||||
}
|
||||
// https://www.techiedelight.com/implode-a-vector-of-strings-into-a-comma-separated-string-in-cpp/
|
||||
std::string join(std::vector<std::string> const& strings, std::string delim)
|
||||
{
|
||||
std::stringstream ss;
|
||||
std::copy(strings.begin(), strings.end(),
|
||||
std::ostream_iterator<std::string>(ss, delim.c_str()));
|
||||
return ss.str();
|
||||
}
|
||||
void HyperParameters::check(const std::vector<std::string>& valid, const std::string& fileName)
|
||||
{
|
||||
|
@@ -19,6 +19,7 @@ namespace platform {
|
||||
void check(const std::vector<std::string>& valid, const std::string& fileName);
|
||||
json get(const std::string& fileName);
|
||||
private:
|
||||
void normalize_nested(const std::vector<std::string>& datasets);
|
||||
std::map<std::string, json> hyperparameters;
|
||||
bool best = false; // Used to separate grid/best hyperparameters as the format of those files are different
|
||||
};
|
||||
|
Reference in New Issue
Block a user