Add mpi parameter to b_grid
This commit is contained in:
parent
aa0936abd1
commit
234342f2de
@ -24,6 +24,10 @@ namespace platform {
|
|||||||
json excluded;
|
json excluded;
|
||||||
std::vector<int> seeds;
|
std::vector<int> seeds;
|
||||||
};
|
};
|
||||||
|
struct ConfigMPI {
|
||||||
|
int rank;
|
||||||
|
int nprocs;
|
||||||
|
}
|
||||||
class GridSearch {
|
class GridSearch {
|
||||||
public:
|
public:
|
||||||
explicit GridSearch(struct ConfigGrid& config);
|
explicit GridSearch(struct ConfigGrid& config);
|
||||||
|
@ -31,6 +31,7 @@ void manageArguments(argparse::ArgumentParser& program)
|
|||||||
group.add_argument("--report").help("Report the computed hyperparameters").default_value(false).implicit_value(true);
|
group.add_argument("--report").help("Report the computed hyperparameters").default_value(false).implicit_value(true);
|
||||||
group.add_argument("--compute").help("Perform computation of the grid output hyperparameters").default_value(false).implicit_value(true);
|
group.add_argument("--compute").help("Perform computation of the grid output hyperparameters").default_value(false).implicit_value(true);
|
||||||
program.add_argument("--discretize").help("Discretize input datasets").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
|
program.add_argument("--discretize").help("Discretize input datasets").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
|
||||||
|
program.add_argument("--mpi").help("Use MPI computing grid").default_value(false).implicit_value(true);
|
||||||
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
|
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
|
||||||
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
|
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
|
||||||
program.add_argument("--continue").help("Continue computing from that dataset").default_value(platform::GridSearch::NO_CONTINUE());
|
program.add_argument("--continue").help("Continue computing from that dataset").default_value(platform::GridSearch::NO_CONTINUE());
|
||||||
@ -138,6 +139,22 @@ void list_results(json& results, std::string& model)
|
|||||||
std::cout << Colors::RESET() << std::endl;
|
std::cout << Colors::RESET() << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void initialize_mpi(struct platform::ConfigMPI& config)
|
||||||
|
{
|
||||||
|
int provided;
|
||||||
|
// MPI_Init_thread(nullptr, nullptr, MPI_THREAD_MULTIPLE, &provided);
|
||||||
|
// if (provided != MPI_THREAD_MULTIPLE) {
|
||||||
|
// std::cerr << "MPI_Init_thread returned " << provided << " instead of " << MPI_THREAD_MULTIPLE << std::endl;
|
||||||
|
// exit(1);
|
||||||
|
// }
|
||||||
|
MPI_Init(nullptr, nullptr);
|
||||||
|
int rank, size;
|
||||||
|
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
|
||||||
|
MPI_Comm_size(MPI_COMM_WORLD, &size);
|
||||||
|
config.mpi_rank = rank;
|
||||||
|
config.mpi_size = size;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Main
|
* Main
|
||||||
@ -147,6 +164,7 @@ int main(int argc, char** argv)
|
|||||||
argparse::ArgumentParser program("b_grid");
|
argparse::ArgumentParser program("b_grid");
|
||||||
manageArguments(program);
|
manageArguments(program);
|
||||||
struct platform::ConfigGrid config;
|
struct platform::ConfigGrid config;
|
||||||
|
struct platform::ConfigMPI mpi_config;
|
||||||
bool dump, compute;
|
bool dump, compute;
|
||||||
try {
|
try {
|
||||||
program.parse_args(argc, argv);
|
program.parse_args(argc, argv);
|
||||||
@ -170,6 +188,11 @@ int main(int argc, char** argv)
|
|||||||
}
|
}
|
||||||
auto excluded = program.get<std::string>("exclude");
|
auto excluded = program.get<std::string>("exclude");
|
||||||
config.excluded = json::parse(excluded);
|
config.excluded = json::parse(excluded);
|
||||||
|
if (program.get<bool>("mpi")) {
|
||||||
|
if (!compute) {
|
||||||
|
throw std::runtime_error("Cannot use --mpi without --compute");
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
catch (const exception& err) {
|
catch (const exception& err) {
|
||||||
cerr << err.what() << std::endl;
|
cerr << err.what() << std::endl;
|
||||||
@ -189,8 +212,13 @@ int main(int argc, char** argv)
|
|||||||
list_dump(config.model);
|
list_dump(config.model);
|
||||||
} else {
|
} else {
|
||||||
if (compute) {
|
if (compute) {
|
||||||
grid_search.go();
|
if (program.get<bool>("mpi")) {
|
||||||
std::cout << "Process took " << timer.getDurationString() << std::endl;
|
initialize_mpi(mpi_config);
|
||||||
|
grid_search.setMPIConfig(mpi_config);
|
||||||
|
} else {
|
||||||
|
grid_search.go();
|
||||||
|
std::cout << "Process took " << timer.getDurationString() << std::endl;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// List results
|
// List results
|
||||||
auto results = grid_search.getResults();
|
auto results = grid_search.getResults();
|
||||||
|
Loading…
Reference in New Issue
Block a user