From 90555489ffd43db2207868a2bac3cadbffa65911 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Sun, 9 Jun 2024 11:35:50 +0200 Subject: [PATCH] Add discretiz_algo to b_main as hyperparameter --- src/commands/b_main.cpp | 9 +++++---- src/common/DotEnv.h | 6 +++++- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/commands/b_main.cpp b/src/commands/b_main.cpp index 89021ea..6288ebc 100644 --- a/src/commands/b_main.cpp +++ b/src/commands/b_main.cpp @@ -47,6 +47,7 @@ void manageArguments(argparse::ArgumentParser& program) ); program.add_argument("--title").default_value("").help("Experiment title"); program.add_argument("--discretize").help("Discretize input dataset").default_value((bool)stoi(env.get("discretize"))).implicit_value(true); + program.add_argument("--discretize-algo").help("Discretize input dataset").default_value(env.get("discretize_algo")); program.add_argument("--generate-fold-files").help("generate fold information in datasets_experiment folder").default_value(false).implicit_value(true); program.add_argument("--no-train-score").help("Don't compute train score").default_value(false).implicit_value(true); program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true); @@ -74,7 +75,7 @@ int main(int argc, char** argv) { argparse::ArgumentParser program("b_main", { platform_project_version.begin(), platform_project_version.end() }); manageArguments(program); - std::string file_name, model_name, title, hyperparameters_file, datasets_file; + std::string file_name, model_name, title, hyperparameters_file, datasets_file, discretize_algo; json hyperparameters_json; bool discretize_dataset, stratified, saveResults, quiet, no_train_score, generate_fold_files; std::vector seeds; @@ -88,6 +89,7 @@ int main(int argc, char** argv) datasets_file = program.get("datasets-file"); model_name = program.get("model"); discretize_dataset = program.get("discretize"); + discretize_algo = program.get("discretize-algo"); stratified = program.get("stratified"); quiet = program.get("quiet"); n_folds = program.get("folds"); @@ -177,9 +179,8 @@ int main(int argc, char** argv) */ auto env = platform::DotEnv(); auto experiment = platform::Experiment(); - std::string discretiz_algo = env.get("discretiz_algo"); - experiment.setTitle(title).setLanguage("c++").setLanguageVersion("13.2.1"); - experiment.setDiscretizationAlgorithm(discretiz_algo); + experiment.setTitle(title).setLanguage("c++").setLanguageVersion("gcc 14.1.1"); + experiment.setDiscretizationAlgorithm(discretize_algo); experiment.setDiscretized(discretize_dataset).setModel(model_name).setPlatform(env.get("platform")); experiment.setStratified(stratified).setNFolds(n_folds).setScoreName("accuracy"); experiment.setHyperparameters(test_hyperparams); diff --git a/src/common/DotEnv.h b/src/common/DotEnv.h index d246f36..a50ad5e 100644 --- a/src/common/DotEnv.h +++ b/src/common/DotEnv.h @@ -29,7 +29,7 @@ namespace platform { {"framework", {"bulma", "bootstrap"}}, {"margin", {"0.1", "0.2", "0.3"}}, {"n_folds", {"5", "10"}}, - {"discretiz_algo", {"mdlp", "bin3u", "bin3q", "bin4u", "bin4q"}}, + {"discretize_algo", {"mdlp", "bin3u", "bin3q", "bin4u", "bin4q"}}, {"platform", {"any"}}, {"model", {"any"}}, {"seeds", {"any"}}, @@ -96,6 +96,10 @@ namespace platform { } std::string get(const std::string& key) { + if (env.find(key) == env.end()) { + std::cerr << "Key not found in .env: " << key << std::endl; + exit(1); + } return env.at(key); } std::vector getSeeds()