Accept nested hyperparameters in b_main

This commit is contained in:
2024-08-04 17:19:31 +02:00
parent 0ea967dd9d
commit 800246acd2
2 changed files with 21 additions and 9 deletions

View File

@@ -10,14 +10,7 @@ namespace platform {
for (const auto& item : datasets) { for (const auto& item : datasets) {
hyperparameters[item] = hyperparameters_; hyperparameters[item] = hyperparameters_;
} }
} normalize_nested(datasets);
// https://www.techiedelight.com/implode-a-vector-of-strings-into-a-comma-separated-string-in-cpp/
std::string join(std::vector<std::string> const& strings, std::string delim)
{
std::stringstream ss;
std::copy(strings.begin(), strings.end(),
std::ostream_iterator<std::string>(ss, delim.c_str()));
return ss.str();
} }
HyperParameters::HyperParameters(const std::vector<std::string>& datasets, const std::string& hyperparameters_file, bool best) HyperParameters::HyperParameters(const std::vector<std::string>& datasets, const std::string& hyperparameters_file, bool best)
{ {
@@ -31,7 +24,7 @@ namespace platform {
json input_hyperparameters; json input_hyperparameters;
if (best) { if (best) {
for (const auto& [key, value] : file_hyperparameters.items()) { for (const auto& [key, value] : file_hyperparameters.items()) {
input_hyperparameters[key] = value[1]; input_hyperparameters[key]["hyperparameters"] = value[1];
} }
} else { } else {
input_hyperparameters = file_hyperparameters["results"]; input_hyperparameters = file_hyperparameters["results"];
@@ -45,6 +38,24 @@ namespace platform {
} }
hyperparameters[dataset] = input_hyperparameters[dataset]["hyperparameters"].get<json>(); hyperparameters[dataset] = input_hyperparameters[dataset]["hyperparameters"].get<json>();
} }
normalize_nested(datasets);
}
void HyperParameters::normalize_nested(const std::vector<std::string>& datasets)
{
for (const auto& dataset : datasets) {
if (hyperparameters[dataset].contains("be_hyperparams")) {
// Odte has base estimator hyperparameters set this way
hyperparameters[dataset]["be_hyperparams"] = hyperparameters[dataset]["be_hyperparams"].dump();
}
}
}
// https://www.techiedelight.com/implode-a-vector-of-strings-into-a-comma-separated-string-in-cpp/
std::string join(std::vector<std::string> const& strings, std::string delim)
{
std::stringstream ss;
std::copy(strings.begin(), strings.end(),
std::ostream_iterator<std::string>(ss, delim.c_str()));
return ss.str();
} }
void HyperParameters::check(const std::vector<std::string>& valid, const std::string& fileName) void HyperParameters::check(const std::vector<std::string>& valid, const std::string& fileName)
{ {

View File

@@ -19,6 +19,7 @@ namespace platform {
void check(const std::vector<std::string>& valid, const std::string& fileName); void check(const std::vector<std::string>& valid, const std::string& fileName);
json get(const std::string& fileName); json get(const std::string& fileName);
private: private:
void normalize_nested(const std::vector<std::string>& datasets);
std::map<std::string, json> hyperparameters; std::map<std::string, json> hyperparameters;
bool best = false; // Used to separate grid/best hyperparameters as the format of those files are different bool best = false; // Used to separate grid/best hyperparameters as the format of those files are different
}; };