diff --git a/src/main/HyperParameters.cpp b/src/main/HyperParameters.cpp index 719a921..1d3e5a6 100644 --- a/src/main/HyperParameters.cpp +++ b/src/main/HyperParameters.cpp @@ -10,14 +10,7 @@ namespace platform { for (const auto& item : datasets) { hyperparameters[item] = hyperparameters_; } - } - // https://www.techiedelight.com/implode-a-vector-of-strings-into-a-comma-separated-string-in-cpp/ - std::string join(std::vector const& strings, std::string delim) - { - std::stringstream ss; - std::copy(strings.begin(), strings.end(), - std::ostream_iterator(ss, delim.c_str())); - return ss.str(); + normalize_nested(datasets); } HyperParameters::HyperParameters(const std::vector& datasets, const std::string& hyperparameters_file, bool best) { @@ -31,7 +24,7 @@ namespace platform { json input_hyperparameters; if (best) { for (const auto& [key, value] : file_hyperparameters.items()) { - input_hyperparameters[key] = value[1]; + input_hyperparameters[key]["hyperparameters"] = value[1]; } } else { input_hyperparameters = file_hyperparameters["results"]; @@ -45,6 +38,24 @@ namespace platform { } hyperparameters[dataset] = input_hyperparameters[dataset]["hyperparameters"].get(); } + normalize_nested(datasets); + } + void HyperParameters::normalize_nested(const std::vector& datasets) + { + for (const auto& dataset : datasets) { + if (hyperparameters[dataset].contains("be_hyperparams")) { + // Odte has base estimator hyperparameters set this way + hyperparameters[dataset]["be_hyperparams"] = hyperparameters[dataset]["be_hyperparams"].dump(); + } + } + } + // https://www.techiedelight.com/implode-a-vector-of-strings-into-a-comma-separated-string-in-cpp/ + std::string join(std::vector const& strings, std::string delim) + { + std::stringstream ss; + std::copy(strings.begin(), strings.end(), + std::ostream_iterator(ss, delim.c_str())); + return ss.str(); } void HyperParameters::check(const std::vector& valid, const std::string& fileName) { diff --git a/src/main/HyperParameters.h b/src/main/HyperParameters.h index ed1b948..02957c7 100644 --- a/src/main/HyperParameters.h +++ b/src/main/HyperParameters.h @@ -19,6 +19,7 @@ namespace platform { void check(const std::vector& valid, const std::string& fileName); json get(const std::string& fileName); private: + void normalize_nested(const std::vector& datasets); std::map hyperparameters; bool best = false; // Used to separate grid/best hyperparameters as the format of those files are different };