Add generate-fold-files to b_main

This commit is contained in:
2024-05-28 10:52:08 +02:00
parent b34af13eea
commit f5d5c35002
7 changed files with 58 additions and 10 deletions

View File

@@ -47,6 +47,7 @@ void manageArguments(argparse::ArgumentParser& program)
);
program.add_argument("--title").default_value("").help("Experiment title");
program.add_argument("--discretize").help("Discretize input dataset").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
program.add_argument("--generate-fold-files").help("generate fold information in datasets_experiment folder").default_value(false).implicit_value(true);
program.add_argument("--no-train-score").help("Don't compute train score").default_value(false).implicit_value(true);
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
program.add_argument("--save").help("Save result (always save if no dataset is supplied)").default_value(false).implicit_value(true);
@@ -75,7 +76,7 @@ int main(int argc, char** argv)
manageArguments(program);
std::string file_name, model_name, title, hyperparameters_file, datasets_file;
json hyperparameters_json;
bool discretize_dataset, stratified, saveResults, quiet, no_train_score;
bool discretize_dataset, stratified, saveResults, quiet, no_train_score, generate_fold_files;
std::vector<int> seeds;
std::vector<std::string> file_names;
std::vector<std::string> filesToTest;
@@ -95,6 +96,7 @@ int main(int argc, char** argv)
hyperparameters_json = json::parse(hyperparameters);
hyperparameters_file = program.get<std::string>("hyper-file");
no_train_score = program.get<bool>("no-train-score");
generate_fold_files = program.get<bool>("generate-fold-files");
if (hyperparameters_file != "" && hyperparameters != "{}") {
throw runtime_error("hyperparameters and hyper_file are mutually exclusive");
}
@@ -184,7 +186,7 @@ int main(int argc, char** argv)
}
platform::Timer timer;
timer.start();
experiment.go(filesToTest, quiet, no_train_score);
experiment.go(filesToTest, quiet, no_train_score, generate_fold_files);
experiment.setDuration(timer.getDuration());
if (saveResults) {
experiment.saveResult();