Add --exclude parameter to b_grid to exclude datasets
This commit is contained in:
parent
f0d6f0cc38
commit
aa0936abd1
@ -274,11 +274,9 @@ namespace platform {
|
||||
vector<std::string> GridSearch::processDatasets(Datasets& datasets)
|
||||
{
|
||||
// Load datasets
|
||||
|
||||
auto datasets_names = datasets.getNames();
|
||||
if (config.continue_from != NO_CONTINUE()) {
|
||||
// Continue previous execution:
|
||||
// remove datasets already processed
|
||||
if (std::find(datasets_names.begin(), datasets_names.end(), config.continue_from) == datasets_names.end()) {
|
||||
throw std::invalid_argument("Dataset " + config.continue_from + " not found");
|
||||
}
|
||||
@ -295,6 +293,15 @@ namespace platform {
|
||||
}
|
||||
}
|
||||
}
|
||||
// Exclude datasets
|
||||
for (const auto& name : config.excluded) {
|
||||
auto dataset = name.get<std::string>();
|
||||
auto it = std::find(datasets_names.begin(), datasets_names.end(), dataset);
|
||||
if (it == datasets_names.end()) {
|
||||
throw std::invalid_argument("Dataset " + dataset + " already excluded or doesn't exist!");
|
||||
}
|
||||
datasets_names.erase(it);
|
||||
}
|
||||
return datasets_names;
|
||||
}
|
||||
json GridSearch::initializeResults()
|
||||
|
@ -21,6 +21,7 @@ namespace platform {
|
||||
bool stratified;
|
||||
int nested;
|
||||
int n_folds;
|
||||
json excluded;
|
||||
std::vector<int> seeds;
|
||||
};
|
||||
class GridSearch {
|
||||
|
@ -35,6 +35,7 @@ void manageArguments(argparse::ArgumentParser& program)
|
||||
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
|
||||
program.add_argument("--continue").help("Continue computing from that dataset").default_value(platform::GridSearch::NO_CONTINUE());
|
||||
program.add_argument("--only").help("Used with continue to compute that dataset only").default_value(false).implicit_value(true);
|
||||
program.add_argument("--exclude").default_value("[]").help("Datasets to exclude in json format, e.g. [\"dataset1\", \"dataset2\"]");
|
||||
program.add_argument("--nested").help("Do a double/nested cross validation with n folds").default_value(0).scan<'i', int>();
|
||||
program.add_argument("--score").help("Score used in gridsearch").default_value("accuracy");
|
||||
program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const std::string& value) {
|
||||
@ -167,6 +168,8 @@ int main(int argc, char** argv)
|
||||
if (dump && (config.continue_from != platform::GridSearch::NO_CONTINUE() || config.only)) {
|
||||
throw std::runtime_error("Cannot use --dump with --continue or --only");
|
||||
}
|
||||
auto excluded = program.get<std::string>("exclude");
|
||||
config.excluded = json::parse(excluded);
|
||||
}
|
||||
catch (const exception& err) {
|
||||
cerr << err.what() << std::endl;
|
||||
|
Loading…
Reference in New Issue
Block a user