From d6603dd638c953bbd242220b175177c1bc48a395 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Wed, 14 May 2025 11:46:15 +0200 Subject: [PATCH] Add folder parameter to best, grid and main --- src/commands/b_best.cpp | 3 +++ src/commands/b_grid.cpp | 3 ++- src/commands/b_main.cpp | 3 ++- src/main/ArgumentsExperiment.cpp | 8 +++++++- src/main/ArgumentsExperiment.h | 4 +++- src/main/Experiment.cpp | 6 +++--- src/main/Experiment.h | 2 +- src/results/Result.cpp | 4 ++-- src/results/Result.h | 2 +- 9 files changed, 24 insertions(+), 11 deletions(-) diff --git a/src/commands/b_best.cpp b/src/commands/b_best.cpp index 970ff1f..fb4b5cd 100644 --- a/src/commands/b_best.cpp +++ b/src/commands/b_best.cpp @@ -44,6 +44,9 @@ int main(int argc, char** argv) program.parse_args(argc, argv); model = program.get("model"); folder = program.get("folder"); + if (folder.back() != '/') { + folder += '/'; + } dataset = program.get("dataset"); score = program.get("score"); friedman = program.get("friedman"); diff --git a/src/commands/b_grid.cpp b/src/commands/b_grid.cpp index b6efd56..7e246a5 100644 --- a/src/commands/b_grid.cpp +++ b/src/commands/b_grid.cpp @@ -231,6 +231,7 @@ void experiment(argparse::ArgumentParser& program) { struct platform::ConfigGrid config; auto arguments = platform::ArgumentsExperiment(program, platform::experiment_t::GRID); + auto path_results = arguments.getPathResults(); arguments.parse(); auto grid_experiment = platform::GridExperiment(arguments, config); platform::Timer timer; @@ -250,7 +251,7 @@ void experiment(argparse::ArgumentParser& program) auto duration = timer.getDuration(); experiment.setDuration(duration); if (grid_experiment.haveToSaveResults()) { - experiment.saveResult(); + experiment.saveResult(path_results); } experiment.report(); std::cout << "Process took " << duration << std::endl; diff --git a/src/commands/b_main.cpp b/src/commands/b_main.cpp index f04a79f..03002d5 100644 --- a/src/commands/b_main.cpp +++ b/src/commands/b_main.cpp @@ -18,6 +18,7 @@ int main(int argc, char** argv) */ // Initialize the experiment class with the command line arguments auto experiment = arguments.initializedExperiment(); + auto path_results = arguments.getPathResults(); platform::Timer timer; timer.start(); experiment.go(); @@ -27,7 +28,7 @@ int main(int argc, char** argv) experiment.report(); } if (arguments.haveToSaveResults()) { - experiment.saveResult(); + experiment.saveResult(path_results); } if (arguments.doGraph()) { experiment.saveGraph(); diff --git a/src/main/ArgumentsExperiment.cpp b/src/main/ArgumentsExperiment.cpp index aa8199e..8d778ba 100644 --- a/src/main/ArgumentsExperiment.cpp +++ b/src/main/ArgumentsExperiment.cpp @@ -13,6 +13,7 @@ namespace platform { auto env = platform::DotEnv(); auto datasets = platform::Datasets(false, platform::Paths::datasets()); auto& group = arguments.add_mutually_exclusive_group(true); + group.add_argument("-d", "--dataset") .help("Dataset file name: " + datasets.toString()) .default_value("all") @@ -43,6 +44,7 @@ namespace platform { } ); arguments.add_argument("--title").default_value("").help("Experiment title"); + arguments.add_argument("--folder").help("Results folder to use").default_value(platform::Paths::results()); arguments.add_argument("--discretize").help("Discretize input dataset").default_value((bool)stoi(env.get("discretize"))).implicit_value(true); auto valid_choices = env.valid_tokens("discretize_algo"); auto& disc_arg = arguments.add_argument("--discretize-algo").help("Algorithm to use in discretization. Valid values: " + env.valid_values("discretize_algo")).default_value(env.get("discretize_algo")); @@ -103,6 +105,10 @@ namespace platform { file_name = arguments.get("dataset"); file_names = arguments.get>("datasets"); datasets_file = arguments.get("datasets-file"); + path_results = arguments.get("folder"); + if (path_results.back() != '/') { + path_results += '/'; + } model_name = arguments.get("model"); discretize_dataset = arguments.get("discretize"); discretize_algo = arguments.get("discretize-algo"); @@ -119,7 +125,7 @@ namespace platform { hyper_best = arguments.get("hyper-best"); if (hyper_best) { // Build the best results file_name - hyperparameters_file = platform::Paths::results() + platform::Paths::bestResultsFile(score, model_name); + hyperparameters_file = path_results + platform::Paths::bestResultsFile(score, model_name); // ignore this parameter hyperparameters = "{}"; } else { diff --git a/src/main/ArgumentsExperiment.h b/src/main/ArgumentsExperiment.h index c4528b9..5f2fbc2 100644 --- a/src/main/ArgumentsExperiment.h +++ b/src/main/ArgumentsExperiment.h @@ -22,11 +22,13 @@ namespace platform { bool isQuiet() const { return quiet; } bool haveToSaveResults() const { return saveResults; } bool doGraph() const { return graph; } + std::string getPathResults() const { return path_results; } private: Experiment experiment; experiment_t type; argparse::ArgumentParser& arguments; - std::string file_name, model_name, title, hyperparameters_file, datasets_file, discretize_algo, smooth_strat, score; + std::string file_name, model_name, title, hyperparameters_file, datasets_file, discretize_algo, smooth_strat; + std::string score, path_results; json hyperparameters_json; bool discretize_dataset, stratified, saveResults, quiet, no_train_score, generate_fold_files, graph, hyper_best; std::vector seeds; diff --git a/src/main/Experiment.cpp b/src/main/Experiment.cpp index 3e33e28..438735d 100644 --- a/src/main/Experiment.cpp +++ b/src/main/Experiment.cpp @@ -7,12 +7,12 @@ namespace platform { using json = nlohmann::ordered_json; - void Experiment::saveResult() + void Experiment::saveResult(const std::string& path) { result.setSchemaVersion("1.0"); result.check(); - result.save(); - std::cout << "Result saved in " << Paths::results() << result.getFilename() << std::endl; + result.save(path); + std::cout << "Result saved in " << path << result.getFilename() << std::endl; } void Experiment::report() { diff --git a/src/main/Experiment.h b/src/main/Experiment.h index bfad61c..52c08fe 100644 --- a/src/main/Experiment.h +++ b/src/main/Experiment.h @@ -45,7 +45,7 @@ namespace platform { std::vector getRandomSeeds() const { return randomSeeds; } void cross_validation(const std::string& fileName); void go(); - void saveResult(); + void saveResult(const std::string& path); void show(); void saveGraph(); void report(); diff --git a/src/results/Result.cpp b/src/results/Result.cpp index c143874..f37ca61 100644 --- a/src/results/Result.cpp +++ b/src/results/Result.cpp @@ -69,9 +69,9 @@ namespace platform { platform::JsonValidator validator(platform::SchemaV1_0::schema); return validator.validate(data); } - void Result::save() + void Result::save(const std::string& path) { - std::ofstream file(Paths::results() + getFilename()); + std::ofstream file(path + getFilename()); file << data; file.close(); } diff --git a/src/results/Result.h b/src/results/Result.h index 3f45c70..a16748d 100644 --- a/src/results/Result.h +++ b/src/results/Result.h @@ -15,7 +15,7 @@ namespace platform { public: Result(); Result& load(const std::string& path, const std::string& filename); - void save(); + void save(const std::string& path); std::vector check(); // Getters json getJson();