Add --exclude parameter to b_grid to exclude datasets

This commit is contained in:
Ricardo Montañana Gómez 2023-12-08 12:09:08 +01:00
parent f0d6f0cc38
commit aa0936abd1
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
3 changed files with 13 additions and 2 deletions

View File

@ -274,11 +274,9 @@ namespace platform {
vector<std::string> GridSearch::processDatasets(Datasets& datasets) vector<std::string> GridSearch::processDatasets(Datasets& datasets)
{ {
// Load datasets // Load datasets
auto datasets_names = datasets.getNames(); auto datasets_names = datasets.getNames();
if (config.continue_from != NO_CONTINUE()) { if (config.continue_from != NO_CONTINUE()) {
// Continue previous execution: // Continue previous execution:
// remove datasets already processed
if (std::find(datasets_names.begin(), datasets_names.end(), config.continue_from) == datasets_names.end()) { if (std::find(datasets_names.begin(), datasets_names.end(), config.continue_from) == datasets_names.end()) {
throw std::invalid_argument("Dataset " + config.continue_from + " not found"); throw std::invalid_argument("Dataset " + config.continue_from + " not found");
} }
@ -295,6 +293,15 @@ namespace platform {
} }
} }
} }
// Exclude datasets
for (const auto& name : config.excluded) {
auto dataset = name.get<std::string>();
auto it = std::find(datasets_names.begin(), datasets_names.end(), dataset);
if (it == datasets_names.end()) {
throw std::invalid_argument("Dataset " + dataset + " already excluded or doesn't exist!");
}
datasets_names.erase(it);
}
return datasets_names; return datasets_names;
} }
json GridSearch::initializeResults() json GridSearch::initializeResults()

View File

@ -21,6 +21,7 @@ namespace platform {
bool stratified; bool stratified;
int nested; int nested;
int n_folds; int n_folds;
json excluded;
std::vector<int> seeds; std::vector<int> seeds;
}; };
class GridSearch { class GridSearch {

View File

@ -35,6 +35,7 @@ void manageArguments(argparse::ArgumentParser& program)
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true); program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
program.add_argument("--continue").help("Continue computing from that dataset").default_value(platform::GridSearch::NO_CONTINUE()); program.add_argument("--continue").help("Continue computing from that dataset").default_value(platform::GridSearch::NO_CONTINUE());
program.add_argument("--only").help("Used with continue to compute that dataset only").default_value(false).implicit_value(true); program.add_argument("--only").help("Used with continue to compute that dataset only").default_value(false).implicit_value(true);
program.add_argument("--exclude").default_value("[]").help("Datasets to exclude in json format, e.g. [\"dataset1\", \"dataset2\"]");
program.add_argument("--nested").help("Do a double/nested cross validation with n folds").default_value(0).scan<'i', int>(); program.add_argument("--nested").help("Do a double/nested cross validation with n folds").default_value(0).scan<'i', int>();
program.add_argument("--score").help("Score used in gridsearch").default_value("accuracy"); program.add_argument("--score").help("Score used in gridsearch").default_value("accuracy");
program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const std::string& value) { program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const std::string& value) {
@ -167,6 +168,8 @@ int main(int argc, char** argv)
if (dump && (config.continue_from != platform::GridSearch::NO_CONTINUE() || config.only)) { if (dump && (config.continue_from != platform::GridSearch::NO_CONTINUE() || config.only)) {
throw std::runtime_error("Cannot use --dump with --continue or --only"); throw std::runtime_error("Cannot use --dump with --continue or --only");
} }
auto excluded = program.get<std::string>("exclude");
config.excluded = json::parse(excluded);
} }
catch (const exception& err) { catch (const exception& err) {
cerr << err.what() << std::endl; cerr << err.what() << std::endl;