Add --exclude parameter to b_grid to exclude datasets

This commit is contained in:
Ricardo Montañana Gómez 2023-12-08 12:09:08 +01:00
parent f0d6f0cc38
commit aa0936abd1
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
3 changed files with 13 additions and 2 deletions

View File

@ -274,11 +274,9 @@ namespace platform {
vector<std::string> GridSearch::processDatasets(Datasets& datasets)
{
// Load datasets
auto datasets_names = datasets.getNames();
if (config.continue_from != NO_CONTINUE()) {
// Continue previous execution:
// remove datasets already processed
if (std::find(datasets_names.begin(), datasets_names.end(), config.continue_from) == datasets_names.end()) {
throw std::invalid_argument("Dataset " + config.continue_from + " not found");
}
@ -295,6 +293,15 @@ namespace platform {
}
}
}
// Exclude datasets
for (const auto& name : config.excluded) {
auto dataset = name.get<std::string>();
auto it = std::find(datasets_names.begin(), datasets_names.end(), dataset);
if (it == datasets_names.end()) {
throw std::invalid_argument("Dataset " + dataset + " already excluded or doesn't exist!");
}
datasets_names.erase(it);
}
return datasets_names;
}
json GridSearch::initializeResults()

View File

@ -21,6 +21,7 @@ namespace platform {
bool stratified;
int nested;
int n_folds;
json excluded;
std::vector<int> seeds;
};
class GridSearch {

View File

@ -35,6 +35,7 @@ void manageArguments(argparse::ArgumentParser& program)
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
program.add_argument("--continue").help("Continue computing from that dataset").default_value(platform::GridSearch::NO_CONTINUE());
program.add_argument("--only").help("Used with continue to compute that dataset only").default_value(false).implicit_value(true);
program.add_argument("--exclude").default_value("[]").help("Datasets to exclude in json format, e.g. [\"dataset1\", \"dataset2\"]");
program.add_argument("--nested").help("Do a double/nested cross validation with n folds").default_value(0).scan<'i', int>();
program.add_argument("--score").help("Score used in gridsearch").default_value("accuracy");
program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const std::string& value) {
@ -167,6 +168,8 @@ int main(int argc, char** argv)
if (dump && (config.continue_from != platform::GridSearch::NO_CONTINUE() || config.only)) {
throw std::runtime_error("Cannot use --dump with --continue or --only");
}
auto excluded = program.get<std::string>("exclude");
config.excluded = json::parse(excluded);
}
catch (const exception& err) {
cerr << err.what() << std::endl;