Add datasets hyperparameter to b_main
This commit is contained in:
@@ -16,7 +16,8 @@ void manageArguments(argparse::ArgumentParser& program)
|
|||||||
{
|
{
|
||||||
auto env = platform::DotEnv();
|
auto env = platform::DotEnv();
|
||||||
auto datasets = platform::Datasets(false, platform::Paths::datasets());
|
auto datasets = platform::Datasets(false, platform::Paths::datasets());
|
||||||
program.add_argument("-d", "--dataset")
|
auto& group = program.add_mutually_exclusive_group(true);
|
||||||
|
group.add_argument("-d", "--dataset")
|
||||||
.help("Dataset file name: " + datasets.toString())
|
.help("Dataset file name: " + datasets.toString())
|
||||||
.default_value("all")
|
.default_value("all")
|
||||||
.action([](const std::string& value) {
|
.action([](const std::string& value) {
|
||||||
@@ -29,6 +30,7 @@ void manageArguments(argparse::ArgumentParser& program)
|
|||||||
throw std::runtime_error("Dataset must be one of: " + datasets.toString());
|
throw std::runtime_error("Dataset must be one of: " + datasets.toString());
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
group.add_argument("--datasets").nargs(1, 50).help("Datasets file names").default_value(std::vector<std::string>());
|
||||||
program.add_argument("--hyperparameters").default_value("{}").help("Hyperparameters passed to the model in Experiment");
|
program.add_argument("--hyperparameters").default_value("{}").help("Hyperparameters passed to the model in Experiment");
|
||||||
program.add_argument("--hyper-file").default_value("").help("Hyperparameters file name." \
|
program.add_argument("--hyper-file").default_value("").help("Hyperparameters file name." \
|
||||||
"Mutually exclusive with hyperparameters. This file should contain hyperparameters for each dataset in json format.");
|
"Mutually exclusive with hyperparameters. This file should contain hyperparameters for each dataset in json format.");
|
||||||
@@ -74,11 +76,13 @@ int main(int argc, char** argv)
|
|||||||
json hyperparameters_json;
|
json hyperparameters_json;
|
||||||
bool discretize_dataset, stratified, saveResults, quiet, no_train_score;
|
bool discretize_dataset, stratified, saveResults, quiet, no_train_score;
|
||||||
std::vector<int> seeds;
|
std::vector<int> seeds;
|
||||||
|
std::vector<std::string> file_names;
|
||||||
std::vector<std::string> filesToTest;
|
std::vector<std::string> filesToTest;
|
||||||
int n_folds;
|
int n_folds;
|
||||||
try {
|
try {
|
||||||
program.parse_args(argc, argv);
|
program.parse_args(argc, argv);
|
||||||
file_name = program.get<std::string>("dataset");
|
file_name = program.get<std::string>("dataset");
|
||||||
|
file_names = program.get<std::vector<std::string>>("datasets");
|
||||||
model_name = program.get<std::string>("model");
|
model_name = program.get<std::string>("model");
|
||||||
discretize_dataset = program.get<bool>("discretize");
|
discretize_dataset = program.get<bool>("discretize");
|
||||||
stratified = program.get<bool>("stratified");
|
stratified = program.get<bool>("stratified");
|
||||||
@@ -104,18 +108,26 @@ int main(int argc, char** argv)
|
|||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
auto datasets = platform::Datasets(discretize_dataset, platform::Paths::datasets());
|
auto datasets = platform::Datasets(discretize_dataset, platform::Paths::datasets());
|
||||||
if (file_name != "all") {
|
if (file_names.size() > 0) {
|
||||||
if (!datasets.isDataset(file_name)) {
|
filesToTest = file_names;
|
||||||
cerr << "Dataset " << file_name << " not found" << std::endl;
|
|
||||||
exit(1);
|
|
||||||
}
|
|
||||||
if (title == "") {
|
|
||||||
title = "Test " + file_name + " " + model_name + " " + to_string(n_folds) + " folds";
|
|
||||||
}
|
|
||||||
filesToTest.push_back(file_name);
|
|
||||||
} else {
|
|
||||||
filesToTest = datasets.getNames();
|
|
||||||
saveResults = true;
|
saveResults = true;
|
||||||
|
if (title == "") {
|
||||||
|
title = "Test " + to_string(file_names.size()) + " datasets " + model_name + " " + to_string(n_folds) + " folds";
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (file_name != "all") {
|
||||||
|
if (!datasets.isDataset(file_name)) {
|
||||||
|
cerr << "Dataset " << file_name << " not found" << std::endl;
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
if (title == "") {
|
||||||
|
title = "Test " + file_name + " " + model_name + " " + to_string(n_folds) + " folds";
|
||||||
|
}
|
||||||
|
filesToTest.push_back(file_name);
|
||||||
|
} else {
|
||||||
|
filesToTest = datasets.getNames();
|
||||||
|
saveResults = true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
platform::HyperParameters test_hyperparams;
|
platform::HyperParameters test_hyperparams;
|
||||||
if (hyperparameters_file != "") {
|
if (hyperparameters_file != "") {
|
||||||
|
Reference in New Issue
Block a user