diff --git a/CMakeLists.txt b/CMakeLists.txt index 4d1bc2a..e33b67c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,12 +25,18 @@ set(CMAKE_CXX_EXTENSIONS OFF) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread") - # Options # ------- option(ENABLE_CLANG_TIDY "Enable to add clang tidy." OFF) option(ENABLE_TESTING "Unit testing build" OFF) option(CODE_COVERAGE "Collect coverage from test library" OFF) +option(MPI_ENABLED "Enable MPI options" ON) + +if (MPI_ENABLED) + find_package(MPI REQUIRED) + message("MPI_CXX_LIBRARIES=${MPI_CXX_LIBRARIES}") + message("MPI_CXX_INCLUDE_DIRS=${MPI_CXX_INCLUDE_DIRS}") +endif (MPI_ENABLED) # Boost Library set(Boost_USE_STATIC_LIBS OFF) diff --git a/README.md b/README.md index 2acf581..ad0dd4a 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,16 @@ Bayesian Network Classifier with libtorch from scratch Before compiling BayesNet. +### MPI + +In Linux just install openmpi & openmpi-devel packages. + +In Mac OS X, install mpich with brew and if cmake doesn't find it, edit mpicxx wrapper to remove the ",-commons,use_dylibs" from final_ldflags + +```bash +vi /opt/homebrew/bin/mpicx +``` + ### boost library [Getting Started]() diff --git a/src/Platform/b_grid.cc b/src/Platform/b_grid.cc index 947a305..069f8a2 100644 --- a/src/Platform/b_grid.cc +++ b/src/Platform/b_grid.cc @@ -141,18 +141,18 @@ void list_results(json& results, std::string& model) void initialize_mpi(struct platform::ConfigMPI& config) { - int provided; + // int provided; // MPI_Init_thread(nullptr, nullptr, MPI_THREAD_MULTIPLE, &provided); // if (provided != MPI_THREAD_MULTIPLE) { // std::cerr << "MPI_Init_thread returned " << provided << " instead of " << MPI_THREAD_MULTIPLE << std::endl; // exit(1); // } - MPI_Init(nullptr, nullptr); - int rank, size; - MPI_Comm_rank(MPI_COMM_WORLD, &rank); - MPI_Comm_size(MPI_COMM_WORLD, &size); - config.mpi_rank = rank; - config.mpi_size = size; + // MPI_Init(nullptr, nullptr); + // int rank, size; + // MPI_Comm_rank(MPI_COMM_WORLD, &rank); + // MPI_Comm_size(MPI_COMM_WORLD, &size); + // config.mpi_rank = rank; + // config.mpi_size = size; } @@ -213,8 +213,11 @@ int main(int argc, char** argv) } else { if (compute) { if (program.get("mpi")) { - initialize_mpi(mpi_config); - grid_search.setMPIConfig(mpi_config); + MPI_Init(nullptr, nullptr); + MPI_Comm_rank(MPI_COMM_WORLD, &config.rank); + MPI_Comm_size(MPI_COMM_WORLD, &config.size); + grid_search.go_mpi(); + MPI_Finzalize(); } else { grid_search.go(); std::cout << "Process took " << timer.getDurationString() << std::endl;