From 7aaf6d1bf8c1fccb5594779dbaa31e6a7520f5a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Sat, 18 Jan 2025 13:09:45 +0100 Subject: [PATCH] Add conditional saveResults to GridExperiment --- src/commands/b_grid.cpp | 5 ++++- src/commands/b_main.cpp | 2 +- src/grid/GridExperiment.cpp | 4 ++++ src/grid/GridExperiment.h | 2 ++ 4 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/commands/b_grid.cpp b/src/commands/b_grid.cpp index 94cdcee..7653319 100644 --- a/src/commands/b_grid.cpp +++ b/src/commands/b_grid.cpp @@ -53,6 +53,7 @@ void add_experiment_args(argparse::ArgumentParser& program) group.add_argument("--datasets").nargs(1, 50).help("Datasets file names 1..50 separated by spaces").default_value(std::vector()); group.add_argument("--datasets-file").default_value("").help("Datasets file name. Mutually exclusive with dataset. This file should contain a list of datasets to test."); program.add_argument("--hyperparameters").default_value("{}").help("Hyperparameters passed to the model in Experiment"); + program.add_argument("--save").help("Save result (always save even if a dataset is supplied)").default_value(false).implicit_value(true); program.add_argument("--hyper-file").default_value("").help("Hyperparameters file name." \ "Mutually exclusive with hyperparameters. This file should contain hyperparameters for each dataset in json format."); program.add_argument("--hyper-best").default_value(false).help("Use best results of the model as source of hyperparameters").implicit_value(true); @@ -322,7 +323,9 @@ void experiment(argparse::ArgumentParser& program) std::cout << "* Report of the computed hyperparameters" << std::endl; auto duration = timer.getDuration(); experiment.setDuration(duration); - experiment.saveResult(); + if (grid_experiment.haveToSaveResults()) { + experiment.saveResult(); + } experiment.report(grid_experiment.numFiles() == 1); std::cout << "Process took " << duration << std::endl; } diff --git a/src/commands/b_main.cpp b/src/commands/b_main.cpp index e4a02c4..4fc38f2 100644 --- a/src/commands/b_main.cpp +++ b/src/commands/b_main.cpp @@ -68,7 +68,7 @@ void manageArguments(argparse::ArgumentParser& program) program.add_argument("--graph").help("generate graphviz dot files with the model").default_value(false).implicit_value(true); program.add_argument("--no-train-score").help("Don't compute train score").default_value(false).implicit_value(true); program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true); - program.add_argument("--save").help("Save result (always save if no dataset is supplied)").default_value(false).implicit_value(true); + program.add_argument("--save").help("Save result (always save even if a dataset is supplied)").default_value(false).implicit_value(true); program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true); program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const std::string& value) { try { diff --git a/src/grid/GridExperiment.cpp b/src/grid/GridExperiment.cpp index 67db3b7..67ce91b 100644 --- a/src/grid/GridExperiment.cpp +++ b/src/grid/GridExperiment.cpp @@ -21,6 +21,7 @@ namespace platform { datasets_file = program.get("datasets-file"); model_name = program.get("model"); discretize_dataset = program.get("discretize"); + saveResults = program.get("save"); discretize_algo = program.get("discretize-algo"); smooth_strat = program.get("smooth-strat"); stratified = program.get("stratified"); @@ -61,6 +62,7 @@ namespace platform { filesToTest.push_back(line); } catalog.close(); + saveResults = true; if (title == "") { title = "Test " + to_string(filesToTest.size()) + " datasets (" + datasets_file + ") "\ + model_name + " " + to_string(n_folds) + " folds"; @@ -77,6 +79,7 @@ namespace platform { } } filesToTest = file_names; + saveResults = true; if (title == "") { title = "Test " + to_string(file_names.size()) + " datasets " + model_name + " " + to_string(n_folds) + " folds"; } @@ -92,6 +95,7 @@ namespace platform { filesToTest.push_back(file_name); } else { filesToTest = datasets.getNames(); + saveResults = true; } } } diff --git a/src/grid/GridExperiment.h b/src/grid/GridExperiment.h index f03da41..dddb8e9 100644 --- a/src/grid/GridExperiment.h +++ b/src/grid/GridExperiment.h @@ -23,10 +23,12 @@ namespace platform { json getResults(); Experiment& getExperiment() { return experiment; } size_t numFiles() const { return filesToTest.size(); } + bool haveToSaveResults() const { return saveResults; } private: argparse::ArgumentParser& arguments; Experiment experiment; json computed_results; + bool saveResults; std::vector filesToTest; void save(json& results); json initializeResults();