From 8705adf3ee062a71f15c6eb7b93e20fe96d06f59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Fri, 20 Dec 2024 12:51:33 +0100 Subject: [PATCH] Begin b_grid experiment --- src/commands/b_grid.cpp | 117 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 116 insertions(+), 1 deletion(-) diff --git a/src/commands/b_grid.cpp b/src/commands/b_grid.cpp index e816dfa..86b5669 100644 --- a/src/commands/b_grid.cpp +++ b/src/commands/b_grid.cpp @@ -31,6 +31,80 @@ void assignModel(argparse::ArgumentParser& parser) } ); } +void add_experiment_args(argparse::ArgumentParser& program) +{ + auto env = platform::DotEnv(); + auto datasets = platform::Datasets(false, platform::Paths::datasets()); + auto& group = program.add_mutually_exclusive_group(true); + group.add_argument("-d", "--dataset") + .help("Dataset file name: " + datasets.toString()) + .default_value("all") + .action([](const std::string& value) { + auto datasets = platform::Datasets(false, platform::Paths::datasets()); + static std::vector choices_datasets(datasets.getNames()); + choices_datasets.push_back("all"); + if (find(choices_datasets.begin(), choices_datasets.end(), value) != choices_datasets.end()) { + return value; + } + throw std::runtime_error("Dataset must be one of: " + datasets.toString()); + } + ); + group.add_argument("--datasets").nargs(1, 50).help("Datasets file names 1..50 separated by spaces").default_value(std::vector()); + group.add_argument("--datasets-file").default_value("").help("Datasets file name. Mutually exclusive with dataset. This file should contain a list of datasets to test."); + program.add_argument("--hyperparameters").default_value("{}").help("Hyperparameters passed to the model in Experiment"); + program.add_argument("--hyper-file").default_value("").help("Hyperparameters file name." \ + "Mutually exclusive with hyperparameters. This file should contain hyperparameters for each dataset in json format."); + program.add_argument("--hyper-best").default_value(false).help("Use best results of the model as source of hyperparameters").implicit_value(true); + program.add_argument("-m", "--model") + .help("Model to use: " + platform::Models::instance()->toString()) + .action([](const std::string& value) { + static const std::vector choices = platform::Models::instance()->getNames(); + if (find(choices.begin(), choices.end(), value) != choices.end()) { + return value; + } + throw std::runtime_error("Model must be one of " + platform::Models::instance()->toString()); + } + ); + program.add_argument("--title").default_value("").help("Experiment title"); + program.add_argument("--discretize").help("Discretize input dataset").default_value((bool)stoi(env.get("discretize"))).implicit_value(true); + auto valid_choices = env.valid_tokens("discretize_algo"); + auto& disc_arg = program.add_argument("--discretize-algo").help("Algorithm to use in discretization. Valid values: " + env.valid_values("discretize_algo")).default_value(env.get("discretize_algo")); + for (auto choice : valid_choices) { + disc_arg.choices(choice); + } + valid_choices = env.valid_tokens("smooth_strat"); + auto& smooth_arg = program.add_argument("--smooth-strat").help("Smooth strategy used in Bayes Network node initialization. Valid values: " + env.valid_values("smooth_strat")).default_value(env.get("smooth_strat")); + for (auto choice : valid_choices) { + smooth_arg.choices(choice); + } + auto& score_arg = program.add_argument("-s", "--score").help("Score to use. Valid values: " + env.valid_values("score")).default_value(env.get("score")); + valid_choices = env.valid_tokens("score"); + for (auto choice : valid_choices) { + score_arg.choices(choice); + } + program.add_argument("--generate-fold-files").help("generate fold information in datasets_experiment folder").default_value(false).implicit_value(true); + program.add_argument("--graph").help("generate graphviz dot files with the model").default_value(false).implicit_value(true); + program.add_argument("--no-train-score").help("Don't compute train score").default_value(false).implicit_value(true); + program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true); + program.add_argument("--save").help("Save result (always save if no dataset is supplied)").default_value(false).implicit_value(true); + program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true); + program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const std::string& value) { + try { + auto k = stoi(value); + if (k < 2) { + throw std::runtime_error("Number of folds must be greater than 1"); + } + return k; + } + catch (const runtime_error& err) { + throw std::runtime_error(err.what()); + } + catch (...) { + throw std::runtime_error("Number of folds must be an integer"); + }}); + auto seed_values = env.getSeeds(); + program.add_argument("--seeds").nargs(1, 10).help("Random seeds. Set to -1 to have pseudo random").scan<'i', int>().default_value(seed_values); +} void add_compute_args(argparse::ArgumentParser& program) { auto env = platform::DotEnv(); @@ -228,6 +302,40 @@ void compute(argparse::ArgumentParser& program) } MPI_Finalize(); } +void experiment(argparse::ArgumentParser& program) +{ + struct platform::ConfigGrid config; + config.model = program.get("model"); + config.score = program.get("score"); + config.discretize = program.get("discretize"); + config.stratified = program.get("stratified"); + config.smooth_strategy = program.get("smooth-strat"); + config.n_folds = program.get("folds"); + config.quiet = program.get("quiet"); + config.seeds = program.get>("seeds"); + + auto env = platform::DotEnv(); + config.platform = env.get("platform"); + platform::Paths::createPath(platform::Paths::grid()); + // auto grid_experiment = platform::GridExperiment(config); + platform::Timer timer; + timer.start(); + struct platform::ConfigMPI mpi_config; + mpi_config.manager = 0; // which process is the manager + MPI_Init(nullptr, nullptr); + MPI_Comm_rank(MPI_COMM_WORLD, &mpi_config.rank); + MPI_Comm_size(MPI_COMM_WORLD, &mpi_config.n_procs); + if (mpi_config.n_procs < 2) { + throw std::runtime_error("Cannot use --compute with less than 2 mpi processes, try mpirun -np 2 ..."); + } + // grid_experiment.go(mpi_config); + if (mpi_config.rank == mpi_config.manager) { + // auto results = grid_experiment.loadResults(); + // list_results(results, config.model); + std::cout << "Process took " << timer.getDurationString() << std::endl; + } + MPI_Finalize(); +} int main(int argc, char** argv) { // @@ -250,9 +358,16 @@ int main(int argc, char** argv) assignModel(compute_command); add_compute_args(compute_command); + // grid experiment subparser + argparse::ArgumentParser experiment_command("experiment"); + experiment_command.add_description("Experiment like b_main using mpi."); + assignModel(experiment_command); + add_experiment_args(experiment_command); + program.add_subparser(dump_command); program.add_subparser(report_command); program.add_subparser(compute_command); + program.add_subparser(experiment_command); // // Process options @@ -260,7 +375,7 @@ int main(int argc, char** argv) try { program.parse_args(argc, argv); bool found = false; - map commands = { {"dump", &dump}, {"report", &report}, {"compute", &compute} }; + map commands = { {"dump", &dump}, {"report", &report}, {"compute", &compute}, { "experiment",&experiment } }; for (const auto& command : commands) { if (program.is_subcommand_used(command.first)) { std::invoke(command.second, program.at(command.first));