From ecce7955f8e0d61fa616f218264631a2f6d04db4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Mon, 15 Jan 2024 11:26:39 +0100 Subject: [PATCH] Add export command to b_grid --- src/Platform/b_grid.cc | 221 ++++++++++++++++++----------- src/Platform/modules/GridSearch.cc | 11 ++ src/Platform/modules/GridSearch.h | 1 + src/Platform/modules/Paths.h | 4 + 4 files changed, 158 insertions(+), 79 deletions(-) diff --git a/src/Platform/b_grid.cc b/src/Platform/b_grid.cc index 7a750f8..8c841b1 100644 --- a/src/Platform/b_grid.cc +++ b/src/Platform/b_grid.cc @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include "DotEnv.h" @@ -15,23 +16,24 @@ using json = nlohmann::json; const int MAXL = 133; -void manageArguments(argparse::ArgumentParser& program) +void assignModel(argparse::ArgumentParser& parser) { - auto env = platform::DotEnv(); - auto& group = program.add_mutually_exclusive_group(true); - program.add_argument("-m", "--model") - .help("Model to use " + platform::Models::instance()->tostring()) - .action([](const std::string& value) { - static const std::vector choices = platform::Models::instance()->getNames(); + auto models = platform::Models::instance(); + parser.add_argument("-m", "--model") + .help("Model to use " + models->tostring()) + .required() + .action([models](const std::string& value) { + static const std::vector choices = models->getNames(); if (find(choices.begin(), choices.end(), value) != choices.end()) { return value; } - throw std::runtime_error("Model must be one of " + platform::Models::instance()->tostring()); + throw std::runtime_error("Model must be one of " + models->tostring()); } ); - group.add_argument("--dump").help("Show the grid combinations").default_value(false).implicit_value(true); - group.add_argument("--report").help("Report the computed hyperparameters").default_value(false).implicit_value(true); - group.add_argument("--compute").help("Perform computation of the grid output hyperparameters").default_value(false).implicit_value(true); +} +void add_compute_args(argparse::ArgumentParser& program) +{ + auto env = platform::DotEnv(); program.add_argument("--discretize").help("Discretize input datasets").default_value((bool)stoi(env.get("discretize"))).implicit_value(true); program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true); program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true); @@ -70,11 +72,19 @@ void manageArguments(argparse::ArgumentParser& program) auto seed_values = env.getSeeds(); program.add_argument("-s", "--seeds").nargs(1, 10).help("Random seeds. Set to -1 to have pseudo random").scan<'i', int>().default_value(seed_values); } - +std::string headerLine(const std::string& text, int utf = 0) +{ + int n = MAXL - text.length() - 3; + n = n < 0 ? 0 : n; + return "* " + text + std::string(n + utf, ' ') + "*\n"; +} void list_dump(std::string& model) { auto data = platform::GridData(platform::Paths::grid_input(model)); - std::cout << Colors::MAGENTA() << "Listing configuration input file (Grid)" << std::endl << std::endl; + std::cout << Colors::MAGENTA() << std::string(MAXL, '*') << std::endl; + std::cout << headerLine("Listing configuration input file (Grid)"); + std::cout << headerLine("Model: " + model); + std::cout << Colors::MAGENTA() << std::string(MAXL, '*') << std::endl; int index = 0; int max_hyper = 15; int max_dataset = 7; @@ -96,17 +106,11 @@ void list_dump(std::string& model) std::cout << color; auto num_combinations = data.getNumCombinations(item.first); std::cout << setw(3) << fixed << right << ++index << left << " " << setw(max_dataset) << item.first - << " " << setw(5) << right << num_combinations << " " << setw(max_hyper) << item.second.dump() << std::endl; + << " " << setw(5) << right << num_combinations << " " << setw(max_hyper) << left << item.second.dump() << std::endl; odd = !odd; } std::cout << Colors::RESET() << std::endl; } -std::string headerLine(const std::string& text, int utf = 0) -{ - int n = MAXL - text.length() - 3; - n = n < 0 ? 0 : n; - return "* " + text + std::string(n + utf, ' ') + "*\n"; -} void list_results(json& results, std::string& model) { std::cout << Colors::MAGENTA() << std::string(MAXL, '*') << std::endl; @@ -155,77 +159,136 @@ void list_results(json& results, std::string& model) /* * Main */ -int main(int argc, char** argv) +void dump(argparse::ArgumentParser& program) { - argparse::ArgumentParser program("b_grid", { project_version.begin(), project_version.end() }); - manageArguments(program); + auto model = program.get("model"); + list_dump(model); +} +void report(argparse::ArgumentParser& program) +{ + // List results struct platform::ConfigGrid config; - bool dump, compute; - try { - program.parse_args(argc, argv); - config.model = program.get("model"); - config.score = program.get("score"); - config.discretize = program.get("discretize"); - config.stratified = program.get("stratified"); - config.n_folds = program.get("folds"); - config.quiet = program.get("quiet"); - config.only = program.get("only"); - config.seeds = program.get>("seeds"); - config.nested = program.get("nested"); - config.continue_from = program.get("continue"); - if (config.continue_from == platform::GridSearch::NO_CONTINUE() && config.only) { - throw std::runtime_error("Cannot use --only without --continue"); - } - dump = program.get("dump"); - compute = program.get("compute"); - if (dump && (config.continue_from != platform::GridSearch::NO_CONTINUE() || config.only)) { - throw std::runtime_error("Cannot use --dump with --continue or --only"); - } - auto excluded = program.get("exclude"); - config.excluded = json::parse(excluded); + config.model = program.get("model"); + auto grid_search = platform::GridSearch(config); + auto results = grid_search.loadResults(); + if (results.empty()) { + std::cout << "** No results found" << std::endl; + } else { + list_results(results, config.model); } - catch (const exception& err) { - cerr << err.what() << std::endl; - cerr << program; - exit(1); +} +void exportResults(argparse::ArgumentParser& program) +{ + // Generate a grid_.json file with the results of the grid search + // this file can be used by b_main to run the model with the best hyperparameters + struct platform::ConfigGrid config; + config.model = program.get("model"); + auto grid_search = platform::GridSearch(config); + auto results = grid_search.loadResults(); + auto output = json::array(); + if (results.empty()) { + std::cout << "** No results found" << std::endl; + } else { + grid_search.exportResults(results); + std::cout << "Exported results to " << platform::Paths::grid_export(config.model) << std::endl; } - /* - * Begin Processing - */ +} +void compute(argparse::ArgumentParser& program) +{ + struct platform::ConfigGrid config; + config.model = program.get("model"); + config.score = program.get("score"); + config.discretize = program.get("discretize"); + config.stratified = program.get("stratified"); + config.n_folds = program.get("folds"); + config.quiet = program.get("quiet"); + config.only = program.get("only"); + config.seeds = program.get>("seeds"); + config.nested = program.get("nested"); + config.continue_from = program.get("continue"); + if (config.continue_from == platform::GridSearch::NO_CONTINUE() && config.only) { + throw std::runtime_error("Cannot use --only without --continue"); + } + auto excluded = program.get("exclude"); + config.excluded = json::parse(excluded); + auto env = platform::DotEnv(); config.platform = env.get("platform"); platform::Paths::createPath(platform::Paths::grid()); auto grid_search = platform::GridSearch(config); platform::Timer timer; timer.start(); - if (dump) { - list_dump(config.model); - } else { - if (compute) { - struct platform::ConfigMPI mpi_config; - mpi_config.manager = 0; // which process is the manager - MPI_Init(&argc, &argv); - MPI_Comm_rank(MPI_COMM_WORLD, &mpi_config.rank); - MPI_Comm_size(MPI_COMM_WORLD, &mpi_config.n_procs); - if (mpi_config.n_procs < 2) { - throw std::runtime_error("Cannot use --compute with less than 2 mpi processes, try mpirun -np 2 ..."); - } - grid_search.go(mpi_config); - if (mpi_config.rank == mpi_config.manager) { - auto results = grid_search.loadResults(); - list_results(results, config.model); - std::cout << "Process took " << timer.getDurationString() << std::endl; - } - MPI_Finalize(); - } else { - // List results - auto results = grid_search.loadResults(); - if (results.empty()) { - std::cout << "** No results found" << std::endl; - } else { - list_results(results, config.model); + struct platform::ConfigMPI mpi_config; + mpi_config.manager = 0; // which process is the manager + MPI_Init(nullptr, nullptr); + MPI_Comm_rank(MPI_COMM_WORLD, &mpi_config.rank); + MPI_Comm_size(MPI_COMM_WORLD, &mpi_config.n_procs); + if (mpi_config.n_procs < 2) { + throw std::runtime_error("Cannot use --compute with less than 2 mpi processes, try mpirun -np 2 ..."); + } + grid_search.go(mpi_config); + if (mpi_config.rank == mpi_config.manager) { + auto results = grid_search.loadResults(); + list_results(results, config.model); + std::cout << "Process took " << timer.getDurationString() << std::endl; + } + MPI_Finalize(); +} +int main(int argc, char** argv) +{ + // + // Manage arguments + // + argparse::ArgumentParser program("b_grid", { project_version.begin(), project_version.end() }); + // grid dump subparser + argparse::ArgumentParser dump_command("dump"); + dump_command.add_description("Dump the combinations of hyperparameters of a model."); + assignModel(dump_command); + + // grid report subparser + argparse::ArgumentParser report_command("report"); + assignModel(report_command); + report_command.add_description("Report the computed hyperparameters of a model."); + + // grid compute subparser + argparse::ArgumentParser compute_command("compute"); + compute_command.add_description("Compute using mpi the hyperparameters of a model."); + assignModel(compute_command); + add_compute_args(compute_command); + + // grid export subparser + argparse::ArgumentParser export_command("export"); + assignModel(export_command); + export_command.add_description("Export the computed hyperparameters to a file readable by b_main."); + + program.add_subparser(dump_command); + program.add_subparser(report_command); + program.add_subparser(compute_command); + program.add_subparser(export_command); + + // + // Process options + // + try { + program.parse_args(argc, argv); + bool found = false; + map commands = + { {"dump", &dump}, {"report", &report}, {"export", &exportResults}, {"compute", &compute} }; + for (const auto& command : commands) { + if (program.is_subcommand_used(command.first)) { + std::invoke(command.second, program.at(command.first)); + found = true; + break; } } + if (!found) { + throw std::runtime_error("You must specify one of the following commands: dump, report, compute, export\n"); + } + } + catch (const exception& err) { + cerr << err.what() << std::endl; + cerr << program; + exit(1); } std::cout << "Done!" << std::endl; return 0; diff --git a/src/Platform/modules/GridSearch.cc b/src/Platform/modules/GridSearch.cc index 2ef0ad1..daff8b1 100644 --- a/src/Platform/modules/GridSearch.cc +++ b/src/Platform/modules/GridSearch.cc @@ -438,4 +438,15 @@ namespace platform { }; file << output.dump(4); } + void GridSearch::exportResults(json& results) + { + std::ofstream file(Paths::grid_export(config.model)); + auto output = json(); + for (const auto& item : results["results"].items()) { + auto key = item.key(); + auto value = item.value(); + output[key] = value["hyperparameters"]; + } + file << output.dump(4); + } } /* namespace platform */ \ No newline at end of file diff --git a/src/Platform/modules/GridSearch.h b/src/Platform/modules/GridSearch.h index ec1b3cb..e665519 100644 --- a/src/Platform/modules/GridSearch.h +++ b/src/Platform/modules/GridSearch.h @@ -47,6 +47,7 @@ namespace platform { void go(struct ConfigMPI& config_mpi); ~GridSearch() = default; json loadResults(); + void exportResults(json& results); static inline std::string NO_CONTINUE() { return "NO_CONTINUE"; } private: void save(json& results); diff --git a/src/Platform/modules/Paths.h b/src/Platform/modules/Paths.h index 6fd61cf..8d19040 100644 --- a/src/Platform/modules/Paths.h +++ b/src/Platform/modules/Paths.h @@ -34,6 +34,10 @@ namespace platform { { return grid() + "grid_" + model + "_output.json"; } + static std::string grid_export(const std::string& model) + { + return grid() + "grid_" + model + ".json"; + } }; } #endif \ No newline at end of file