Add grid base class and static class

This commit is contained in:
2024-12-20 18:54:08 +01:00
parent 1a336a094e
commit f88944de36
7 changed files with 524 additions and 323 deletions

View File

@@ -9,18 +9,8 @@
#include "GridSearch.h"
namespace platform {
GridSearch::GridSearch(struct ConfigGrid& config) : config(config)
GridSearch::GridSearch(struct ConfigGrid& config) : GridBase(config)
{
if (config.smooth_strategy == "ORIGINAL")
smooth_type = bayesnet::Smoothing_t::ORIGINAL;
else if (config.smooth_strategy == "LAPLACE")
smooth_type = bayesnet::Smoothing_t::LAPLACE;
else if (config.smooth_strategy == "CESTNIK")
smooth_type = bayesnet::Smoothing_t::CESTNIK;
else {
std::cerr << "GridSearch: Unknown smoothing strategy: " << config.smooth_strategy << std::endl;
exit(1);
}
}
json GridSearch::loadResults()
{
@@ -63,7 +53,7 @@ namespace platform {
}
return datasets_names;
}
json GridSearch::build_tasks_mpi(int rank)
json GridSearch::build_tasks_mpi()
{
auto tasks = json::array();
auto grid = GridData(Paths::grid_input(config.model));
@@ -155,7 +145,7 @@ namespace platform {
json tasks;
if (config_mpi.rank == config_mpi.manager) {
timer.start();
tasks = build_tasks_mpi(config_mpi.rank);
tasks = build_tasks_mpi();
auto tasks_str = tasks.dump();
tasks_size = tasks_str.size();
msg = new char[tasks_size + 1];
@@ -179,13 +169,13 @@ namespace platform {
// 2a. Producer delivers the tasks to the consumers
//
auto datasets_names = filterDatasets(datasets);
json all_results = mpi_search_producer(datasets_names, tasks, config_mpi, MPI_Result);
json all_results = MPI_SEARCH::producer(datasets_names, tasks, config_mpi, MPI_Result);
std::cout << separator << std::endl;
//
// 3. Manager select the bests sccores for each dataset
//
auto results = initializeResults();
select_best_results_folds(results, all_results, config.model);
MPI_SEARCH::select_best_results_folds(results, all_results, config.model);
//
// 3.2 Save the results
//
@@ -194,7 +184,7 @@ namespace platform {
//
// 2b. Consumers process the tasks and send the results to the producer
//
mpi_search_consumer(datasets, tasks, config, config_mpi, MPI_Result);
MPI_SEARCH::consumer(datasets, tasks, config, config_mpi, MPI_Result);
}
}
json GridSearch::initializeResults()