Add grid base class and static class
This commit is contained in:
@@ -9,18 +9,8 @@
|
||||
#include "GridSearch.h"
|
||||
|
||||
namespace platform {
|
||||
GridSearch::GridSearch(struct ConfigGrid& config) : config(config)
|
||||
GridSearch::GridSearch(struct ConfigGrid& config) : GridBase(config)
|
||||
{
|
||||
if (config.smooth_strategy == "ORIGINAL")
|
||||
smooth_type = bayesnet::Smoothing_t::ORIGINAL;
|
||||
else if (config.smooth_strategy == "LAPLACE")
|
||||
smooth_type = bayesnet::Smoothing_t::LAPLACE;
|
||||
else if (config.smooth_strategy == "CESTNIK")
|
||||
smooth_type = bayesnet::Smoothing_t::CESTNIK;
|
||||
else {
|
||||
std::cerr << "GridSearch: Unknown smoothing strategy: " << config.smooth_strategy << std::endl;
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
json GridSearch::loadResults()
|
||||
{
|
||||
@@ -63,7 +53,7 @@ namespace platform {
|
||||
}
|
||||
return datasets_names;
|
||||
}
|
||||
json GridSearch::build_tasks_mpi(int rank)
|
||||
json GridSearch::build_tasks_mpi()
|
||||
{
|
||||
auto tasks = json::array();
|
||||
auto grid = GridData(Paths::grid_input(config.model));
|
||||
@@ -155,7 +145,7 @@ namespace platform {
|
||||
json tasks;
|
||||
if (config_mpi.rank == config_mpi.manager) {
|
||||
timer.start();
|
||||
tasks = build_tasks_mpi(config_mpi.rank);
|
||||
tasks = build_tasks_mpi();
|
||||
auto tasks_str = tasks.dump();
|
||||
tasks_size = tasks_str.size();
|
||||
msg = new char[tasks_size + 1];
|
||||
@@ -179,13 +169,13 @@ namespace platform {
|
||||
// 2a. Producer delivers the tasks to the consumers
|
||||
//
|
||||
auto datasets_names = filterDatasets(datasets);
|
||||
json all_results = mpi_search_producer(datasets_names, tasks, config_mpi, MPI_Result);
|
||||
json all_results = MPI_SEARCH::producer(datasets_names, tasks, config_mpi, MPI_Result);
|
||||
std::cout << separator << std::endl;
|
||||
//
|
||||
// 3. Manager select the bests sccores for each dataset
|
||||
//
|
||||
auto results = initializeResults();
|
||||
select_best_results_folds(results, all_results, config.model);
|
||||
MPI_SEARCH::select_best_results_folds(results, all_results, config.model);
|
||||
//
|
||||
// 3.2 Save the results
|
||||
//
|
||||
@@ -194,7 +184,7 @@ namespace platform {
|
||||
//
|
||||
// 2b. Consumers process the tasks and send the results to the producer
|
||||
//
|
||||
mpi_search_consumer(datasets, tasks, config, config_mpi, MPI_Result);
|
||||
MPI_SEARCH::consumer(datasets, tasks, config, config_mpi, MPI_Result);
|
||||
}
|
||||
}
|
||||
json GridSearch::initializeResults()
|
||||
|
Reference in New Issue
Block a user