Add datasets-file to b_main
This commit is contained in:
@@ -30,7 +30,8 @@ void manageArguments(argparse::ArgumentParser& program)
|
||||
throw std::runtime_error("Dataset must be one of: " + datasets.toString());
|
||||
}
|
||||
);
|
||||
group.add_argument("--datasets").nargs(1, 50).help("Datasets file names").default_value(std::vector<std::string>());
|
||||
group.add_argument("--datasets").nargs(1, 50).help("Datasets file names 1..50 separated by spaces").default_value(std::vector<std::string>());
|
||||
group.add_argument("--datasets-file").default_value("").help("Datasets file name. Mutually exclusive with dataset. This file should contain a list of datasets to test.");
|
||||
program.add_argument("--hyperparameters").default_value("{}").help("Hyperparameters passed to the model in Experiment");
|
||||
program.add_argument("--hyper-file").default_value("").help("Hyperparameters file name." \
|
||||
"Mutually exclusive with hyperparameters. This file should contain hyperparameters for each dataset in json format.");
|
||||
@@ -72,7 +73,7 @@ int main(int argc, char** argv)
|
||||
{
|
||||
argparse::ArgumentParser program("b_main", { platform_project_version.begin(), platform_project_version.end() });
|
||||
manageArguments(program);
|
||||
std::string file_name, model_name, title, hyperparameters_file;
|
||||
std::string file_name, model_name, title, hyperparameters_file, datasets_file;
|
||||
json hyperparameters_json;
|
||||
bool discretize_dataset, stratified, saveResults, quiet, no_train_score;
|
||||
std::vector<int> seeds;
|
||||
@@ -83,6 +84,7 @@ int main(int argc, char** argv)
|
||||
program.parse_args(argc, argv);
|
||||
file_name = program.get<std::string>("dataset");
|
||||
file_names = program.get<std::vector<std::string>>("datasets");
|
||||
datasets_file = program.get<std::string>("datasets-file");
|
||||
model_name = program.get<std::string>("model");
|
||||
discretize_dataset = program.get<bool>("discretize");
|
||||
stratified = program.get<bool>("stratified");
|
||||
@@ -108,27 +110,59 @@ int main(int argc, char** argv)
|
||||
exit(1);
|
||||
}
|
||||
auto datasets = platform::Datasets(discretize_dataset, platform::Paths::datasets());
|
||||
if (file_names.size() > 0) {
|
||||
filesToTest = file_names;
|
||||
saveResults = true;
|
||||
if (title == "") {
|
||||
title = "Test " + to_string(file_names.size()) + " datasets " + model_name + " " + to_string(n_folds) + " folds";
|
||||
if (datasets_file != "") {
|
||||
ifstream catalog(datasets_file);
|
||||
if (catalog.is_open()) {
|
||||
std::string line;
|
||||
while (getline(catalog, line)) {
|
||||
if (line.empty() || line[0] == '#') {
|
||||
continue;
|
||||
}
|
||||
if (!datasets.isDataset(line)) {
|
||||
cerr << "Dataset " << line << " not found" << std::endl;
|
||||
exit(1);
|
||||
}
|
||||
filesToTest.push_back(line);
|
||||
}
|
||||
catalog.close();
|
||||
saveResults = true;
|
||||
if (title == "") {
|
||||
title = "Test " + to_string(filesToTest.size()) + " datasets (" + datasets_file + ") "\
|
||||
+ model_name + " " + to_string(n_folds) + " folds";
|
||||
}
|
||||
} else {
|
||||
throw std::invalid_argument("Unable to open catalog file. [" + datasets_file + "]");
|
||||
}
|
||||
} else {
|
||||
if (file_name != "all") {
|
||||
if (!datasets.isDataset(file_name)) {
|
||||
cerr << "Dataset " << file_name << " not found" << std::endl;
|
||||
exit(1);
|
||||
if (file_names.size() > 0) {
|
||||
for (auto file : file_names) {
|
||||
if (!datasets.isDataset(file)) {
|
||||
cerr << "Dataset " << file << " not found" << std::endl;
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
if (title == "") {
|
||||
title = "Test " + file_name + " " + model_name + " " + to_string(n_folds) + " folds";
|
||||
}
|
||||
filesToTest.push_back(file_name);
|
||||
} else {
|
||||
filesToTest = datasets.getNames();
|
||||
filesToTest = file_names;
|
||||
saveResults = true;
|
||||
if (title == "") {
|
||||
title = "Test " + to_string(file_names.size()) + " datasets " + model_name + " " + to_string(n_folds) + " folds";
|
||||
}
|
||||
} else {
|
||||
if (file_name != "all") {
|
||||
if (!datasets.isDataset(file_name)) {
|
||||
cerr << "Dataset " << file_name << " not found" << std::endl;
|
||||
exit(1);
|
||||
}
|
||||
if (title == "") {
|
||||
title = "Test " + file_name + " " + model_name + " " + to_string(n_folds) + " folds";
|
||||
}
|
||||
filesToTest.push_back(file_name);
|
||||
} else {
|
||||
filesToTest = datasets.getNames();
|
||||
saveResults = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
platform::HyperParameters test_hyperparams;
|
||||
if (hyperparameters_file != "") {
|
||||
test_hyperparams = platform::HyperParameters(datasets.getNames(), hyperparameters_file);
|
||||
|
Reference in New Issue
Block a user