From f8007721495a0192fe500f17bc10e09fde5f8f5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Mon, 10 Jun 2024 10:16:07 +0200 Subject: [PATCH] Add new hyperparameters validation in b_main --- src/commands/b_main.cpp | 13 +++++++++++-- src/common/DotEnv.h | 26 ++++++++++++++++++++------ src/main/Experiment.h | 4 ++++ 3 files changed, 35 insertions(+), 8 deletions(-) diff --git a/src/commands/b_main.cpp b/src/commands/b_main.cpp index 087854f..5cc585c 100644 --- a/src/commands/b_main.cpp +++ b/src/commands/b_main.cpp @@ -12,6 +12,7 @@ using json = nlohmann::ordered_json; + void manageArguments(argparse::ArgumentParser& program) { auto env = platform::DotEnv(); @@ -47,8 +48,16 @@ void manageArguments(argparse::ArgumentParser& program) ); program.add_argument("--title").default_value("").help("Experiment title"); program.add_argument("--discretize").help("Discretize input dataset").default_value((bool)stoi(env.get("discretize"))).implicit_value(true); - program.add_argument("--discretize-algo").help("Discretize input dataset").default_value(env.get("discretize_algo")); - program.add_argument("--smooth-strat").help("Discretize input dataset").default_value(env.get("smooth_strat")); + auto valid_choices = env.valid_tokens("discretize_algo"); + auto& disc_arg = program.add_argument("--discretize-algo").help("Algorithm to use in discretization. Valid values: " + env.valid_values("discretize_algo")).default_value(env.get("discretize_algo")); + for (auto choice : valid_choices) { + disc_arg.choices(choice); + } + valid_choices = env.valid_tokens("smooth_strat"); + auto& smooth_arg = program.add_argument("--smooth-strat").help("Smooth strategy used in Bayes Network node initialization. Valid values: " + env.valid_values("smooth_strat")).default_value(env.get("smooth_strat")); + for (auto choice : valid_choices) { + smooth_arg.choices(choice); + } program.add_argument("--generate-fold-files").help("generate fold information in datasets_experiment folder").default_value(false).implicit_value(true); program.add_argument("--no-train-score").help("Don't compute train score").default_value(false).implicit_value(true); program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true); diff --git a/src/common/DotEnv.h b/src/common/DotEnv.h index 4fea7bf..78b24f9 100644 --- a/src/common/DotEnv.h +++ b/src/common/DotEnv.h @@ -81,16 +81,30 @@ namespace platform { exit(1); } } + std::vector valid_tokens(const std::string& key) + { + if (valid.find(key) == valid.end()) { + return {}; + } + return valid.at(key); + } + std::string valid_values(const std::string& key) + { + std::string valid_values = "{", sep = ""; + if (valid.find(key) == valid.end()) { + return "{}"; + } + for (const auto& value : valid.at(key)) { + valid_values += sep + value; + sep = ", "; + } + return valid_values + "}"; + } void parseEnv() { for (auto& [key, values] : valid) { if (env.find(key) == env.end()) { - std::string valid_values = "", sep = ""; - for (const auto& value : values) { - valid_values += sep + value; - sep = ", "; - } - std::cerr << "Key not found in .env: " << key << ", valid values: " << valid_values << std::endl; + std::cerr << "Key not found in .env: " << key << ", valid values: " << valid_values(key) << std::endl; exit(1); } } diff --git a/src/main/Experiment.h b/src/main/Experiment.h index 29d208b..2546603 100644 --- a/src/main/Experiment.h +++ b/src/main/Experiment.h @@ -34,6 +34,10 @@ namespace platform { smooth_type = bayesnet::Smoothing_t::LAPLACE; else if (smooth_strategy == "CESTNIK") smooth_type = bayesnet::Smoothing_t::CESTNIK; + else { + std::cerr << "Experiment: Unknown smoothing strategy: " << smooth_strategy << std::endl; + exit(1); + } return *this; } Experiment& setLanguageVersion(const std::string& language_version) { this->result.setLanguageVersion(language_version); return *this; }