From 07e3cc959913e99513898773c273c4679e516940 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Sun, 19 Jan 2025 13:51:51 +0100 Subject: [PATCH] Fix errors in grid Experiment --- src/commands/b_grid.cpp | 1 + src/commands/b_main.cpp | 1 + src/grid/GridExperiment.cpp | 2 ++ src/grid/GridExperiment.h | 2 +- src/main/ArgumentsExperiment.cpp | 5 +++-- src/main/ArgumentsExperiment.h | 1 + 6 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/commands/b_grid.cpp b/src/commands/b_grid.cpp index d932ea5..cce453d 100644 --- a/src/commands/b_grid.cpp +++ b/src/commands/b_grid.cpp @@ -284,6 +284,7 @@ int main(int argc, char** argv) argparse::ArgumentParser experiment_command("experiment"); experiment_command.add_description("Experiment like b_main using mpi."); auto arguments = platform::ArgumentsExperiment(experiment_command, platform::experiment_t::GRID); + arguments.add_arguments(); program.add_subparser(dump_command); program.add_subparser(report_command); program.add_subparser(search_command); diff --git a/src/commands/b_main.cpp b/src/commands/b_main.cpp index 98dca58..f04a79f 100644 --- a/src/commands/b_main.cpp +++ b/src/commands/b_main.cpp @@ -11,6 +11,7 @@ int main(int argc, char** argv) { argparse::ArgumentParser program("b_main", { platform_project_version.begin(), platform_project_version.end() }); auto arguments = platform::ArgumentsExperiment(program, platform::experiment_t::NORMAL); + arguments.add_arguments(); arguments.parse_args(argc, argv); /* * Begin Processing diff --git a/src/grid/GridExperiment.cpp b/src/grid/GridExperiment.cpp index 62e9699..63d4083 100644 --- a/src/grid/GridExperiment.cpp +++ b/src/grid/GridExperiment.cpp @@ -12,6 +12,8 @@ namespace platform { GridExperiment::GridExperiment(ArgumentsExperiment& program, struct ConfigGrid& config) : arguments(program), GridBase(config) { experiment = arguments.initializedExperiment(); + filesToTest = arguments.getFilesToTest(); + saveResults = arguments.haveToSaveResults(); this->config.model = experiment.getModel(); this->config.score = experiment.getScore(); this->config.discretize = experiment.isDiscretized(); diff --git a/src/grid/GridExperiment.h b/src/grid/GridExperiment.h index 8bcb27c..81503e3 100644 --- a/src/grid/GridExperiment.h +++ b/src/grid/GridExperiment.h @@ -29,7 +29,7 @@ namespace platform { ArgumentsExperiment& arguments; Experiment experiment; json computed_results; - bool saveResults; + bool saveResults = false; std::vector filesToTest; void save(json& results); json initializeResults(); diff --git a/src/main/ArgumentsExperiment.cpp b/src/main/ArgumentsExperiment.cpp index f77b23a..d66a5d0 100644 --- a/src/main/ArgumentsExperiment.cpp +++ b/src/main/ArgumentsExperiment.cpp @@ -6,6 +6,9 @@ #include "ArgumentsExperiment.h" namespace platform { ArgumentsExperiment::ArgumentsExperiment(argparse::ArgumentParser& program, experiment_t type) : arguments{ program }, type{ type } + { + } + void ArgumentsExperiment::add_arguments() { auto env = platform::DotEnv(); auto datasets = platform::Datasets(false, platform::Paths::datasets()); @@ -106,7 +109,6 @@ namespace platform { smooth_strat = arguments.get("smooth-strat"); stratified = arguments.get("stratified"); quiet = arguments.get("quiet"); - n_folds = arguments.get("folds"); score = arguments.get("score"); seeds = arguments.get>("seeds"); @@ -196,7 +198,6 @@ namespace platform { } } } - if (hyperparameters_file != "") { test_hyperparams = platform::HyperParameters(datasets.getNames(), hyperparameters_file, hyper_best); } else { diff --git a/src/main/ArgumentsExperiment.h b/src/main/ArgumentsExperiment.h index 06cd08e..c4528b9 100644 --- a/src/main/ArgumentsExperiment.h +++ b/src/main/ArgumentsExperiment.h @@ -15,6 +15,7 @@ namespace platform { ArgumentsExperiment(argparse::ArgumentParser& program, experiment_t type); ~ArgumentsExperiment() = default; std::vector getFilesToTest() const { return filesToTest; } + void add_arguments(); void parse_args(int argc, char** argv); void parse(); Experiment& initializedExperiment();