Add --exclude parameter to b_grid to exclude datasets
This commit is contained in:
parent
f0d6f0cc38
commit
aa0936abd1
@ -274,11 +274,9 @@ namespace platform {
|
|||||||
vector<std::string> GridSearch::processDatasets(Datasets& datasets)
|
vector<std::string> GridSearch::processDatasets(Datasets& datasets)
|
||||||
{
|
{
|
||||||
// Load datasets
|
// Load datasets
|
||||||
|
|
||||||
auto datasets_names = datasets.getNames();
|
auto datasets_names = datasets.getNames();
|
||||||
if (config.continue_from != NO_CONTINUE()) {
|
if (config.continue_from != NO_CONTINUE()) {
|
||||||
// Continue previous execution:
|
// Continue previous execution:
|
||||||
// remove datasets already processed
|
|
||||||
if (std::find(datasets_names.begin(), datasets_names.end(), config.continue_from) == datasets_names.end()) {
|
if (std::find(datasets_names.begin(), datasets_names.end(), config.continue_from) == datasets_names.end()) {
|
||||||
throw std::invalid_argument("Dataset " + config.continue_from + " not found");
|
throw std::invalid_argument("Dataset " + config.continue_from + " not found");
|
||||||
}
|
}
|
||||||
@ -295,6 +293,15 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Exclude datasets
|
||||||
|
for (const auto& name : config.excluded) {
|
||||||
|
auto dataset = name.get<std::string>();
|
||||||
|
auto it = std::find(datasets_names.begin(), datasets_names.end(), dataset);
|
||||||
|
if (it == datasets_names.end()) {
|
||||||
|
throw std::invalid_argument("Dataset " + dataset + " already excluded or doesn't exist!");
|
||||||
|
}
|
||||||
|
datasets_names.erase(it);
|
||||||
|
}
|
||||||
return datasets_names;
|
return datasets_names;
|
||||||
}
|
}
|
||||||
json GridSearch::initializeResults()
|
json GridSearch::initializeResults()
|
||||||
|
@ -21,6 +21,7 @@ namespace platform {
|
|||||||
bool stratified;
|
bool stratified;
|
||||||
int nested;
|
int nested;
|
||||||
int n_folds;
|
int n_folds;
|
||||||
|
json excluded;
|
||||||
std::vector<int> seeds;
|
std::vector<int> seeds;
|
||||||
};
|
};
|
||||||
class GridSearch {
|
class GridSearch {
|
||||||
|
@ -35,6 +35,7 @@ void manageArguments(argparse::ArgumentParser& program)
|
|||||||
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
|
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
|
||||||
program.add_argument("--continue").help("Continue computing from that dataset").default_value(platform::GridSearch::NO_CONTINUE());
|
program.add_argument("--continue").help("Continue computing from that dataset").default_value(platform::GridSearch::NO_CONTINUE());
|
||||||
program.add_argument("--only").help("Used with continue to compute that dataset only").default_value(false).implicit_value(true);
|
program.add_argument("--only").help("Used with continue to compute that dataset only").default_value(false).implicit_value(true);
|
||||||
|
program.add_argument("--exclude").default_value("[]").help("Datasets to exclude in json format, e.g. [\"dataset1\", \"dataset2\"]");
|
||||||
program.add_argument("--nested").help("Do a double/nested cross validation with n folds").default_value(0).scan<'i', int>();
|
program.add_argument("--nested").help("Do a double/nested cross validation with n folds").default_value(0).scan<'i', int>();
|
||||||
program.add_argument("--score").help("Score used in gridsearch").default_value("accuracy");
|
program.add_argument("--score").help("Score used in gridsearch").default_value("accuracy");
|
||||||
program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const std::string& value) {
|
program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const std::string& value) {
|
||||||
@ -167,6 +168,8 @@ int main(int argc, char** argv)
|
|||||||
if (dump && (config.continue_from != platform::GridSearch::NO_CONTINUE() || config.only)) {
|
if (dump && (config.continue_from != platform::GridSearch::NO_CONTINUE() || config.only)) {
|
||||||
throw std::runtime_error("Cannot use --dump with --continue or --only");
|
throw std::runtime_error("Cannot use --dump with --continue or --only");
|
||||||
}
|
}
|
||||||
|
auto excluded = program.get<std::string>("exclude");
|
||||||
|
config.excluded = json::parse(excluded);
|
||||||
}
|
}
|
||||||
catch (const exception& err) {
|
catch (const exception& err) {
|
||||||
cerr << err.what() << std::endl;
|
cerr << err.what() << std::endl;
|
||||||
|
Loading…
Reference in New Issue
Block a user