Refactor arguments management for Experimentation
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
#include <mpi.h>
|
||||
#include "main/Models.h"
|
||||
#include "main/modelRegister.h"
|
||||
#include "main/ArgumentsExperiment.h"
|
||||
#include "common/Paths.h"
|
||||
#include "common/Timer.h"
|
||||
#include "common/Colors.h"
|
||||
@@ -32,76 +33,7 @@ void assignModel(argparse::ArgumentParser& parser)
|
||||
}
|
||||
);
|
||||
}
|
||||
void add_experiment_args(argparse::ArgumentParser& program)
|
||||
{
|
||||
auto env = platform::DotEnv();
|
||||
auto datasets = platform::Datasets(false, platform::Paths::datasets());
|
||||
auto& group = program.add_mutually_exclusive_group(true);
|
||||
group.add_argument("-d", "--dataset")
|
||||
.help("Dataset file name: " + datasets.toString())
|
||||
.default_value("all")
|
||||
.action([](const std::string& value) {
|
||||
auto datasets = platform::Datasets(false, platform::Paths::datasets());
|
||||
static std::vector<std::string> choices_datasets(datasets.getNames());
|
||||
choices_datasets.push_back("all");
|
||||
if (find(choices_datasets.begin(), choices_datasets.end(), value) != choices_datasets.end()) {
|
||||
return value;
|
||||
}
|
||||
throw std::runtime_error("Dataset must be one of: " + datasets.toString());
|
||||
}
|
||||
);
|
||||
group.add_argument("--datasets").nargs(1, 50).help("Datasets file names 1..50 separated by spaces").default_value(std::vector<std::string>());
|
||||
group.add_argument("--datasets-file").default_value("").help("Datasets file name. Mutually exclusive with dataset. This file should contain a list of datasets to test.");
|
||||
program.add_argument("--hyperparameters").default_value("{}").help("Hyperparameters passed to the model in Experiment");
|
||||
program.add_argument("--save").help("Save result (always save even if a dataset is supplied)").default_value(false).implicit_value(true);
|
||||
program.add_argument("--hyper-file").default_value("").help("Hyperparameters file name." \
|
||||
"Mutually exclusive with hyperparameters. This file should contain hyperparameters for each dataset in json format.");
|
||||
program.add_argument("--hyper-best").default_value(false).help("Use best results of the model as source of hyperparameters").implicit_value(true);
|
||||
program.add_argument("-m", "--model")
|
||||
.help("Model to use: " + platform::Models::instance()->toString())
|
||||
.action([](const std::string& value) {
|
||||
static const std::vector<std::string> choices = platform::Models::instance()->getNames();
|
||||
if (find(choices.begin(), choices.end(), value) != choices.end()) {
|
||||
return value;
|
||||
}
|
||||
throw std::runtime_error("Model must be one of " + platform::Models::instance()->toString());
|
||||
}
|
||||
);
|
||||
program.add_argument("--title").default_value("").help("Experiment title");
|
||||
program.add_argument("--discretize").help("Discretize input dataset").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
|
||||
auto valid_choices = env.valid_tokens("discretize_algo");
|
||||
auto& disc_arg = program.add_argument("--discretize-algo").help("Algorithm to use in discretization. Valid values: " + env.valid_values("discretize_algo")).default_value(env.get("discretize_algo"));
|
||||
for (auto choice : valid_choices) {
|
||||
disc_arg.choices(choice);
|
||||
}
|
||||
valid_choices = env.valid_tokens("smooth_strat");
|
||||
auto& smooth_arg = program.add_argument("--smooth-strat").help("Smooth strategy used in Bayes Network node initialization. Valid values: " + env.valid_values("smooth_strat")).default_value(env.get("smooth_strat"));
|
||||
for (auto choice : valid_choices) {
|
||||
smooth_arg.choices(choice);
|
||||
}
|
||||
auto& score_arg = program.add_argument("-s", "--score").help("Score to use. Valid values: " + env.valid_values("score")).default_value(env.get("score"));
|
||||
valid_choices = env.valid_tokens("score");
|
||||
for (auto choice : valid_choices) {
|
||||
score_arg.choices(choice);
|
||||
}
|
||||
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
|
||||
program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const std::string& value) {
|
||||
try {
|
||||
auto k = stoi(value);
|
||||
if (k < 2) {
|
||||
throw std::runtime_error("Number of folds must be greater than 1");
|
||||
}
|
||||
return k;
|
||||
}
|
||||
catch (const runtime_error& err) {
|
||||
throw std::runtime_error(err.what());
|
||||
}
|
||||
catch (...) {
|
||||
throw std::runtime_error("Number of folds must be an integer");
|
||||
}});
|
||||
auto seed_values = env.getSeeds();
|
||||
program.add_argument("--seeds").nargs(1, 10).help("Random seeds. Set to -1 to have pseudo random").scan<'i', int>().default_value(seed_values);
|
||||
}
|
||||
|
||||
void add_search_args(argparse::ArgumentParser& program)
|
||||
{
|
||||
auto env = platform::DotEnv();
|
||||
@@ -276,9 +208,6 @@ void search(argparse::ArgumentParser& program)
|
||||
}
|
||||
auto excluded = program.get<std::string>("exclude");
|
||||
config.excluded = json::parse(excluded);
|
||||
|
||||
auto env = platform::DotEnv();
|
||||
config.platform = env.get("platform");
|
||||
platform::Paths::createPath(platform::Paths::grid());
|
||||
auto grid_search = platform::GridSearch(config);
|
||||
platform::Timer timer;
|
||||
@@ -303,10 +232,9 @@ void search(argparse::ArgumentParser& program)
|
||||
void experiment(argparse::ArgumentParser& program)
|
||||
{
|
||||
struct platform::ConfigGrid config;
|
||||
|
||||
auto env = platform::DotEnv();
|
||||
config.platform = env.get("platform");
|
||||
auto grid_experiment = platform::GridExperiment(program, config);
|
||||
auto arguments = platform::ArgumentsExperiment(program, platform::experiment_t::GRID);
|
||||
arguments.parse();
|
||||
auto grid_experiment = platform::GridExperiment(arguments, config);
|
||||
platform::Timer timer;
|
||||
timer.start();
|
||||
struct platform::ConfigMPI mpi_config;
|
||||
@@ -326,7 +254,7 @@ void experiment(argparse::ArgumentParser& program)
|
||||
if (grid_experiment.haveToSaveResults()) {
|
||||
experiment.saveResult();
|
||||
}
|
||||
experiment.report(grid_experiment.numFiles() == 1);
|
||||
experiment.report();
|
||||
std::cout << "Process took " << duration << std::endl;
|
||||
}
|
||||
MPI_Finalize();
|
||||
@@ -356,9 +284,7 @@ int main(int argc, char** argv)
|
||||
// grid experiment subparser
|
||||
argparse::ArgumentParser experiment_command("experiment");
|
||||
experiment_command.add_description("Experiment like b_main using mpi.");
|
||||
assignModel(experiment_command);
|
||||
add_experiment_args(experiment_command);
|
||||
|
||||
auto arguments = platform::ArgumentsExperiment(experiment_command, platform::experiment_t::GRID);
|
||||
program.add_subparser(dump_command);
|
||||
program.add_subparser(report_command);
|
||||
program.add_subparser(search_command);
|
||||
|
Reference in New Issue
Block a user