diff --git a/src/Platform/GridSearch.cc b/src/Platform/GridSearch.cc index 35b171e..9f91c6a 100644 --- a/src/Platform/GridSearch.cc +++ b/src/Platform/GridSearch.cc @@ -274,11 +274,9 @@ namespace platform { vector GridSearch::processDatasets(Datasets& datasets) { // Load datasets - auto datasets_names = datasets.getNames(); if (config.continue_from != NO_CONTINUE()) { // Continue previous execution: - // remove datasets already processed if (std::find(datasets_names.begin(), datasets_names.end(), config.continue_from) == datasets_names.end()) { throw std::invalid_argument("Dataset " + config.continue_from + " not found"); } @@ -295,6 +293,15 @@ namespace platform { } } } + // Exclude datasets + for (const auto& name : config.excluded) { + auto dataset = name.get(); + auto it = std::find(datasets_names.begin(), datasets_names.end(), dataset); + if (it == datasets_names.end()) { + throw std::invalid_argument("Dataset " + dataset + " already excluded or doesn't exist!"); + } + datasets_names.erase(it); + } return datasets_names; } json GridSearch::initializeResults() diff --git a/src/Platform/GridSearch.h b/src/Platform/GridSearch.h index d24beb6..e325ca5 100644 --- a/src/Platform/GridSearch.h +++ b/src/Platform/GridSearch.h @@ -21,6 +21,7 @@ namespace platform { bool stratified; int nested; int n_folds; + json excluded; std::vector seeds; }; class GridSearch { diff --git a/src/Platform/b_grid.cc b/src/Platform/b_grid.cc index 4d8a4e9..a5af2a6 100644 --- a/src/Platform/b_grid.cc +++ b/src/Platform/b_grid.cc @@ -35,6 +35,7 @@ void manageArguments(argparse::ArgumentParser& program) program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true); program.add_argument("--continue").help("Continue computing from that dataset").default_value(platform::GridSearch::NO_CONTINUE()); program.add_argument("--only").help("Used with continue to compute that dataset only").default_value(false).implicit_value(true); + program.add_argument("--exclude").default_value("[]").help("Datasets to exclude in json format, e.g. [\"dataset1\", \"dataset2\"]"); program.add_argument("--nested").help("Do a double/nested cross validation with n folds").default_value(0).scan<'i', int>(); program.add_argument("--score").help("Score used in gridsearch").default_value("accuracy"); program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const std::string& value) { @@ -167,6 +168,8 @@ int main(int argc, char** argv) if (dump && (config.continue_from != platform::GridSearch::NO_CONTINUE() || config.only)) { throw std::runtime_error("Cannot use --dump with --continue or --only"); } + auto excluded = program.get("exclude"); + config.excluded = json::parse(excluded); } catch (const exception& err) { cerr << err.what() << std::endl;