diff --git a/src/Platform/b_grid.cc b/src/Platform/b_grid.cc index c54a21c..6e9796d 100644 --- a/src/Platform/b_grid.cc +++ b/src/Platform/b_grid.cc @@ -37,7 +37,20 @@ void manageArguments(argparse::ArgumentParser& program) program.add_argument("--continue").help("Continue computing from that dataset").default_value(platform::GridSearch::NO_CONTINUE()); program.add_argument("--only").help("Used with continue to compute that dataset only").default_value(false).implicit_value(true); program.add_argument("--exclude").default_value("[]").help("Datasets to exclude in json format, e.g. [\"dataset1\", \"dataset2\"]"); - program.add_argument("--nested").help("Do a double/nested cross validation with n folds").default_value(0).scan<'i', int>(); + program.add_argument("--nested").help("Set the double/nested cross validation number of folds").default_value(5).scan<'i', int>().action([](const std::string& value) { + try { + auto k = stoi(value); + if (k < 2) { + throw std::runtime_error("Number of nested folds must be greater than 1"); + } + return k; + } + catch (const runtime_error& err) { + throw std::runtime_error(err.what()); + } + catch (...) { + throw std::runtime_error("Number of nested folds must be an integer"); + }}); program.add_argument("--score").help("Score used in gridsearch").default_value("accuracy"); program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const std::string& value) { try {