Add conditional saveResults to GridExperiment

This commit is contained in:
2025-01-18 13:09:45 +01:00
parent eb430a84c4
commit 7aaf6d1bf8
4 changed files with 11 additions and 2 deletions

View File

@@ -53,6 +53,7 @@ void add_experiment_args(argparse::ArgumentParser& program)
group.add_argument("--datasets").nargs(1, 50).help("Datasets file names 1..50 separated by spaces").default_value(std::vector<std::string>()); group.add_argument("--datasets").nargs(1, 50).help("Datasets file names 1..50 separated by spaces").default_value(std::vector<std::string>());
group.add_argument("--datasets-file").default_value("").help("Datasets file name. Mutually exclusive with dataset. This file should contain a list of datasets to test."); group.add_argument("--datasets-file").default_value("").help("Datasets file name. Mutually exclusive with dataset. This file should contain a list of datasets to test.");
program.add_argument("--hyperparameters").default_value("{}").help("Hyperparameters passed to the model in Experiment"); program.add_argument("--hyperparameters").default_value("{}").help("Hyperparameters passed to the model in Experiment");
program.add_argument("--save").help("Save result (always save even if a dataset is supplied)").default_value(false).implicit_value(true);
program.add_argument("--hyper-file").default_value("").help("Hyperparameters file name." \ program.add_argument("--hyper-file").default_value("").help("Hyperparameters file name." \
"Mutually exclusive with hyperparameters. This file should contain hyperparameters for each dataset in json format."); "Mutually exclusive with hyperparameters. This file should contain hyperparameters for each dataset in json format.");
program.add_argument("--hyper-best").default_value(false).help("Use best results of the model as source of hyperparameters").implicit_value(true); program.add_argument("--hyper-best").default_value(false).help("Use best results of the model as source of hyperparameters").implicit_value(true);
@@ -322,7 +323,9 @@ void experiment(argparse::ArgumentParser& program)
std::cout << "* Report of the computed hyperparameters" << std::endl; std::cout << "* Report of the computed hyperparameters" << std::endl;
auto duration = timer.getDuration(); auto duration = timer.getDuration();
experiment.setDuration(duration); experiment.setDuration(duration);
experiment.saveResult(); if (grid_experiment.haveToSaveResults()) {
experiment.saveResult();
}
experiment.report(grid_experiment.numFiles() == 1); experiment.report(grid_experiment.numFiles() == 1);
std::cout << "Process took " << duration << std::endl; std::cout << "Process took " << duration << std::endl;
} }

View File

@@ -68,7 +68,7 @@ void manageArguments(argparse::ArgumentParser& program)
program.add_argument("--graph").help("generate graphviz dot files with the model").default_value(false).implicit_value(true); program.add_argument("--graph").help("generate graphviz dot files with the model").default_value(false).implicit_value(true);
program.add_argument("--no-train-score").help("Don't compute train score").default_value(false).implicit_value(true); program.add_argument("--no-train-score").help("Don't compute train score").default_value(false).implicit_value(true);
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true); program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
program.add_argument("--save").help("Save result (always save if no dataset is supplied)").default_value(false).implicit_value(true); program.add_argument("--save").help("Save result (always save even if a dataset is supplied)").default_value(false).implicit_value(true);
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true); program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const std::string& value) { program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const std::string& value) {
try { try {

View File

@@ -21,6 +21,7 @@ namespace platform {
datasets_file = program.get<std::string>("datasets-file"); datasets_file = program.get<std::string>("datasets-file");
model_name = program.get<std::string>("model"); model_name = program.get<std::string>("model");
discretize_dataset = program.get<bool>("discretize"); discretize_dataset = program.get<bool>("discretize");
saveResults = program.get<bool>("save");
discretize_algo = program.get<std::string>("discretize-algo"); discretize_algo = program.get<std::string>("discretize-algo");
smooth_strat = program.get<std::string>("smooth-strat"); smooth_strat = program.get<std::string>("smooth-strat");
stratified = program.get<bool>("stratified"); stratified = program.get<bool>("stratified");
@@ -61,6 +62,7 @@ namespace platform {
filesToTest.push_back(line); filesToTest.push_back(line);
} }
catalog.close(); catalog.close();
saveResults = true;
if (title == "") { if (title == "") {
title = "Test " + to_string(filesToTest.size()) + " datasets (" + datasets_file + ") "\ title = "Test " + to_string(filesToTest.size()) + " datasets (" + datasets_file + ") "\
+ model_name + " " + to_string(n_folds) + " folds"; + model_name + " " + to_string(n_folds) + " folds";
@@ -77,6 +79,7 @@ namespace platform {
} }
} }
filesToTest = file_names; filesToTest = file_names;
saveResults = true;
if (title == "") { if (title == "") {
title = "Test " + to_string(file_names.size()) + " datasets " + model_name + " " + to_string(n_folds) + " folds"; title = "Test " + to_string(file_names.size()) + " datasets " + model_name + " " + to_string(n_folds) + " folds";
} }
@@ -92,6 +95,7 @@ namespace platform {
filesToTest.push_back(file_name); filesToTest.push_back(file_name);
} else { } else {
filesToTest = datasets.getNames(); filesToTest = datasets.getNames();
saveResults = true;
} }
} }
} }

View File

@@ -23,10 +23,12 @@ namespace platform {
json getResults(); json getResults();
Experiment& getExperiment() { return experiment; } Experiment& getExperiment() { return experiment; }
size_t numFiles() const { return filesToTest.size(); } size_t numFiles() const { return filesToTest.size(); }
bool haveToSaveResults() const { return saveResults; }
private: private:
argparse::ArgumentParser& arguments; argparse::ArgumentParser& arguments;
Experiment experiment; Experiment experiment;
json computed_results; json computed_results;
bool saveResults;
std::vector<std::string> filesToTest; std::vector<std::string> filesToTest;
void save(json& results); void save(json& results);
json initializeResults(); json initializeResults();