From 42d61c6fc44df213bbe2fe4da93399f02a404dfc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Mon, 15 Apr 2024 18:14:21 +0200 Subject: [PATCH] Add datasets-file to b_main --- src/commands/b_main.cpp | 68 ++++++++++++++++++++++++++++++----------- 1 file changed, 51 insertions(+), 17 deletions(-) diff --git a/src/commands/b_main.cpp b/src/commands/b_main.cpp index 83dfb56..1aaa792 100644 --- a/src/commands/b_main.cpp +++ b/src/commands/b_main.cpp @@ -30,7 +30,8 @@ void manageArguments(argparse::ArgumentParser& program) throw std::runtime_error("Dataset must be one of: " + datasets.toString()); } ); - group.add_argument("--datasets").nargs(1, 50).help("Datasets file names").default_value(std::vector()); + group.add_argument("--datasets").nargs(1, 50).help("Datasets file names 1..50 separated by spaces").default_value(std::vector()); + group.add_argument("--datasets-file").default_value("").help("Datasets file name. Mutually exclusive with dataset. This file should contain a list of datasets to test."); program.add_argument("--hyperparameters").default_value("{}").help("Hyperparameters passed to the model in Experiment"); program.add_argument("--hyper-file").default_value("").help("Hyperparameters file name." \ "Mutually exclusive with hyperparameters. This file should contain hyperparameters for each dataset in json format."); @@ -72,7 +73,7 @@ int main(int argc, char** argv) { argparse::ArgumentParser program("b_main", { platform_project_version.begin(), platform_project_version.end() }); manageArguments(program); - std::string file_name, model_name, title, hyperparameters_file; + std::string file_name, model_name, title, hyperparameters_file, datasets_file; json hyperparameters_json; bool discretize_dataset, stratified, saveResults, quiet, no_train_score; std::vector seeds; @@ -83,6 +84,7 @@ int main(int argc, char** argv) program.parse_args(argc, argv); file_name = program.get("dataset"); file_names = program.get>("datasets"); + datasets_file = program.get("datasets-file"); model_name = program.get("model"); discretize_dataset = program.get("discretize"); stratified = program.get("stratified"); @@ -108,27 +110,59 @@ int main(int argc, char** argv) exit(1); } auto datasets = platform::Datasets(discretize_dataset, platform::Paths::datasets()); - if (file_names.size() > 0) { - filesToTest = file_names; - saveResults = true; - if (title == "") { - title = "Test " + to_string(file_names.size()) + " datasets " + model_name + " " + to_string(n_folds) + " folds"; + if (datasets_file != "") { + ifstream catalog(datasets_file); + if (catalog.is_open()) { + std::string line; + while (getline(catalog, line)) { + if (line.empty() || line[0] == '#') { + continue; + } + if (!datasets.isDataset(line)) { + cerr << "Dataset " << line << " not found" << std::endl; + exit(1); + } + filesToTest.push_back(line); + } + catalog.close(); + saveResults = true; + if (title == "") { + title = "Test " + to_string(filesToTest.size()) + " datasets (" + datasets_file + ") "\ + + model_name + " " + to_string(n_folds) + " folds"; + } + } else { + throw std::invalid_argument("Unable to open catalog file. [" + datasets_file + "]"); } } else { - if (file_name != "all") { - if (!datasets.isDataset(file_name)) { - cerr << "Dataset " << file_name << " not found" << std::endl; - exit(1); + if (file_names.size() > 0) { + for (auto file : file_names) { + if (!datasets.isDataset(file)) { + cerr << "Dataset " << file << " not found" << std::endl; + exit(1); + } } - if (title == "") { - title = "Test " + file_name + " " + model_name + " " + to_string(n_folds) + " folds"; - } - filesToTest.push_back(file_name); - } else { - filesToTest = datasets.getNames(); + filesToTest = file_names; saveResults = true; + if (title == "") { + title = "Test " + to_string(file_names.size()) + " datasets " + model_name + " " + to_string(n_folds) + " folds"; + } + } else { + if (file_name != "all") { + if (!datasets.isDataset(file_name)) { + cerr << "Dataset " << file_name << " not found" << std::endl; + exit(1); + } + if (title == "") { + title = "Test " + file_name + " " + model_name + " " + to_string(n_folds) + " folds"; + } + filesToTest.push_back(file_name); + } else { + filesToTest = datasets.getNames(); + saveResults = true; + } } } + platform::HyperParameters test_hyperparams; if (hyperparameters_file != "") { test_hyperparams = platform::HyperParameters(datasets.getNames(), hyperparameters_file);