From 234342f2de535ae9dbaa3ad4a9c060ab5c22ad3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Sun, 10 Dec 2023 22:33:17 +0100 Subject: [PATCH] Add mpi parameter to b_grid --- src/Platform/GridSearch.h | 4 ++++ src/Platform/b_grid.cc | 32 ++++++++++++++++++++++++++++++-- 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/src/Platform/GridSearch.h b/src/Platform/GridSearch.h index e325ca5..70bbf47 100644 --- a/src/Platform/GridSearch.h +++ b/src/Platform/GridSearch.h @@ -24,6 +24,10 @@ namespace platform { json excluded; std::vector seeds; }; + struct ConfigMPI { + int rank; + int nprocs; + } class GridSearch { public: explicit GridSearch(struct ConfigGrid& config); diff --git a/src/Platform/b_grid.cc b/src/Platform/b_grid.cc index a5af2a6..947a305 100644 --- a/src/Platform/b_grid.cc +++ b/src/Platform/b_grid.cc @@ -31,6 +31,7 @@ void manageArguments(argparse::ArgumentParser& program) group.add_argument("--report").help("Report the computed hyperparameters").default_value(false).implicit_value(true); group.add_argument("--compute").help("Perform computation of the grid output hyperparameters").default_value(false).implicit_value(true); program.add_argument("--discretize").help("Discretize input datasets").default_value((bool)stoi(env.get("discretize"))).implicit_value(true); + program.add_argument("--mpi").help("Use MPI computing grid").default_value(false).implicit_value(true); program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true); program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true); program.add_argument("--continue").help("Continue computing from that dataset").default_value(platform::GridSearch::NO_CONTINUE()); @@ -138,6 +139,22 @@ void list_results(json& results, std::string& model) std::cout << Colors::RESET() << std::endl; } +void initialize_mpi(struct platform::ConfigMPI& config) +{ + int provided; + // MPI_Init_thread(nullptr, nullptr, MPI_THREAD_MULTIPLE, &provided); + // if (provided != MPI_THREAD_MULTIPLE) { + // std::cerr << "MPI_Init_thread returned " << provided << " instead of " << MPI_THREAD_MULTIPLE << std::endl; + // exit(1); + // } + MPI_Init(nullptr, nullptr); + int rank, size; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_size(MPI_COMM_WORLD, &size); + config.mpi_rank = rank; + config.mpi_size = size; +} + /* * Main @@ -147,6 +164,7 @@ int main(int argc, char** argv) argparse::ArgumentParser program("b_grid"); manageArguments(program); struct platform::ConfigGrid config; + struct platform::ConfigMPI mpi_config; bool dump, compute; try { program.parse_args(argc, argv); @@ -170,6 +188,11 @@ int main(int argc, char** argv) } auto excluded = program.get("exclude"); config.excluded = json::parse(excluded); + if (program.get("mpi")) { + if (!compute) { + throw std::runtime_error("Cannot use --mpi without --compute"); + } + } } catch (const exception& err) { cerr << err.what() << std::endl; @@ -189,8 +212,13 @@ int main(int argc, char** argv) list_dump(config.model); } else { if (compute) { - grid_search.go(); - std::cout << "Process took " << timer.getDurationString() << std::endl; + if (program.get("mpi")) { + initialize_mpi(mpi_config); + grid_search.setMPIConfig(mpi_config); + } else { + grid_search.go(); + std::cout << "Process took " << timer.getDurationString() << std::endl; + } } else { // List results auto results = grid_search.getResults();