From 65a96851efb0148cf3a0a9cf6a91d0769443f896 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Thu, 4 Jan 2024 11:01:59 +0100 Subject: [PATCH] Check min number of nested folds --- src/Platform/b_grid.cc | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/Platform/b_grid.cc b/src/Platform/b_grid.cc index c54a21c..6e9796d 100644 --- a/src/Platform/b_grid.cc +++ b/src/Platform/b_grid.cc @@ -37,7 +37,20 @@ void manageArguments(argparse::ArgumentParser& program) program.add_argument("--continue").help("Continue computing from that dataset").default_value(platform::GridSearch::NO_CONTINUE()); program.add_argument("--only").help("Used with continue to compute that dataset only").default_value(false).implicit_value(true); program.add_argument("--exclude").default_value("[]").help("Datasets to exclude in json format, e.g. [\"dataset1\", \"dataset2\"]"); - program.add_argument("--nested").help("Do a double/nested cross validation with n folds").default_value(0).scan<'i', int>(); + program.add_argument("--nested").help("Set the double/nested cross validation number of folds").default_value(5).scan<'i', int>().action([](const std::string& value) { + try { + auto k = stoi(value); + if (k < 2) { + throw std::runtime_error("Number of nested folds must be greater than 1"); + } + return k; + } + catch (const runtime_error& err) { + throw std::runtime_error(err.what()); + } + catch (...) { + throw std::runtime_error("Number of nested folds must be an integer"); + }}); program.add_argument("--score").help("Score used in gridsearch").default_value("accuracy"); program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const std::string& value) { try {