Accept nested hyperparameters in b_main

This commit is contained in:
2024-08-04 17:19:31 +02:00
parent 0ea967dd9d
commit 800246acd2
2 changed files with 21 additions and 9 deletions

View File

@@ -10,14 +10,7 @@ namespace platform {
for (const auto& item : datasets) {
hyperparameters[item] = hyperparameters_;
}
}
// https://www.techiedelight.com/implode-a-vector-of-strings-into-a-comma-separated-string-in-cpp/
std::string join(std::vector<std::string> const& strings, std::string delim)
{
std::stringstream ss;
std::copy(strings.begin(), strings.end(),
std::ostream_iterator<std::string>(ss, delim.c_str()));
return ss.str();
normalize_nested(datasets);
}
HyperParameters::HyperParameters(const std::vector<std::string>& datasets, const std::string& hyperparameters_file, bool best)
{
@@ -31,7 +24,7 @@ namespace platform {
json input_hyperparameters;
if (best) {
for (const auto& [key, value] : file_hyperparameters.items()) {
input_hyperparameters[key] = value[1];
input_hyperparameters[key]["hyperparameters"] = value[1];
}
} else {
input_hyperparameters = file_hyperparameters["results"];
@@ -45,6 +38,24 @@ namespace platform {
}
hyperparameters[dataset] = input_hyperparameters[dataset]["hyperparameters"].get<json>();
}
normalize_nested(datasets);
}
void HyperParameters::normalize_nested(const std::vector<std::string>& datasets)
{
for (const auto& dataset : datasets) {
if (hyperparameters[dataset].contains("be_hyperparams")) {
// Odte has base estimator hyperparameters set this way
hyperparameters[dataset]["be_hyperparams"] = hyperparameters[dataset]["be_hyperparams"].dump();
}
}
}
// https://www.techiedelight.com/implode-a-vector-of-strings-into-a-comma-separated-string-in-cpp/
std::string join(std::vector<std::string> const& strings, std::string delim)
{
std::stringstream ss;
std::copy(strings.begin(), strings.end(),
std::ostream_iterator<std::string>(ss, delim.c_str()));
return ss.str();
}
void HyperParameters::check(const std::vector<std::string>& valid, const std::string& fileName)
{

View File

@@ -19,6 +19,7 @@ namespace platform {
void check(const std::vector<std::string>& valid, const std::string& fileName);
json get(const std::string& fileName);
private:
void normalize_nested(const std::vector<std::string>& datasets);
std::map<std::string, json> hyperparameters;
bool best = false; // Used to separate grid/best hyperparameters as the format of those files are different
};