From 21c4c6df512de60e3dd80c17ed36ec288b7bd574 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Mon, 25 Dec 2023 19:33:52 +0100 Subject: [PATCH] Fix first mistakes in structure --- kk | 0 src/Platform/GridSearch.cc | 105 ++++++++++++++++++++----------------- src/Platform/GridSearch.h | 9 ++-- 3 files changed, 63 insertions(+), 51 deletions(-) create mode 100644 kk diff --git a/kk b/kk new file mode 100644 index 0000000..e69de29 diff --git a/src/Platform/GridSearch.cc b/src/Platform/GridSearch.cc index f8fd2ee..c2ef440 100644 --- a/src/Platform/GridSearch.cc +++ b/src/Platform/GridSearch.cc @@ -1,4 +1,5 @@ #include +#include #include #include "GridSearch.h" #include "Models.h" @@ -101,13 +102,20 @@ namespace platform { auto tasks = json::array(); auto grid = GridData(Paths::grid_input(config.model)); auto datasets = Datasets(false, Paths::datasets()); + auto all_datasets = datasets.getNames(); auto datasets_names = processDatasets(datasets); for (const auto& dataset : datasets_names) { for (const auto& seed : config.seeds) { auto combinations = grid.getGrid(dataset); for (int n_fold = 0; n_fold < config.n_folds; n_fold++) { + auto it = find(all_datasets.begin(), all_datasets.end(), dataset); + if (it == all_datasets.end()) { + throw std::invalid_argument("Dataset " + dataset + " not found"); + } + auto idx_dataset = std::distance(all_datasets.begin(), it); json task = { { "dataset", dataset }, + { "idx_dataset", idx_dataset}, { "seed", seed }, { "fold", n_fold} }; @@ -126,6 +134,9 @@ namespace platform { std::cout << "|" << std::endl << "|" << std::flush; return tasks; } + void process_task_mpi(struct ConfigMPI& config_mpi, int task, Task_Result* result) + { + } std::pair GridSearch::part_range_mpi(int n_tasks, int nprocs, int rank) { int assigned = 0; @@ -149,7 +160,48 @@ namespace platform { auto colors = { Colors::RED(), Colors::GREEN(), Colors::BLUE(), Colors::MAGENTA(), Colors::CYAN() }; return *(colors.begin() + rank % colors.size()); } + void producer(json& tasks, struct ConfigMPI& config_mpi, MPI_Datatype& MPI_Result) + { + Task_Result result; + int num_tasks = tasks.size(); + for (int i = 0; i < num_tasks; ++i) { + MPI_Status status; + MPI_Recv(&result, 1, MPI_Result, MPI_ANY_SOURCE, MPI_ANY_TAG, MPI_COMM_WORLD, &status); + if (status.MPI_TAG == TAG_RESULT) { + //Store result + } + MPI_Send(&i, 1, MPI_INT, status.MPI_SOURCE, TAG_TASK, MPI_COMM_WORLD); + } + // Send end message to all workers + for (int i = 0; i < config_mpi.n_procs; ++i) { + MPI_Status status; + MPI_Recv(&result, 1, MPI_Result, MPI_ANY_SOURCE, MPI_ANY_TAG, MPI_COMM_WORLD, &status); + if (status.MPI_TAG == TAG_RESULT) { + //Store result + + } + MPI_Send(&i, 1, MPI_INT, status.MPI_SOURCE, TAG_END, MPI_COMM_WORLD); + } + } + void consumer(json& tasks, struct ConfigMPI& config_mpi, MPI_Datatype& MPI_Result) + { + Task_Result result; + // Anounce to the producer + MPI_Send(&result, 1, MPI_Result, config_mpi.manager, TAG_QUERY, MPI_COMM_WORLD); + int task; + while (true) { + MPI_Status status; + MPI_Recv(&task, 1, MPI_INT, config_mpi.manager, MPI_ANY_TAG, MPI_COMM_WORLD, &status); + if (status.MPI_TAG == TAG_END) { + break; + } + // Process task + process_task_mpi(config_mpi, task, &result); + // Send result to producer + MPI_Send(&result, 1, MPI_Result, config_mpi.manager, TAG_RESULT, MPI_COMM_WORLD); + } + } void GridSearch::go_producer_consumer(struct ConfigMPI& config_mpi) { /* @@ -182,13 +234,14 @@ namespace platform { // 0.1 Create the MPI result type // Task_Result result; + int tasks_size; MPI_Datatype MPI_Result; MPI_Datatype type[3] = { MPI_UNSIGNED, MPI_UNSIGNED, MPI_DOUBLE }; int blocklen[3] = { 1, 1, 1 }; MPI_Aint disp[3]; - disp[0] = offsetof(struct MPI_Result, idx_dataset); - disp[1] = offsetof(struct MPI_Result, idx_combination); - disp[2] = offsetof(struct MPI_Result, score); + disp[0] = offsetof(Task_Result, idx_dataset); + disp[1] = offsetof(Task_Result, idx_combination); + disp[2] = offsetof(Task_Result, score); MPI_Type_create_struct(3, blocklen, disp, type, &MPI_Result); MPI_Type_commit(&MPI_Result); // @@ -217,51 +270,9 @@ namespace platform { // 2. All Workers will receive the tasks and start the process // if (config_mpi.rank == config_mpi.manager) { - producer(tasks, &MPI_Result); + producer(tasks, config_mpi, MPI_Result); } else { - consumer(tasks, &MPI_Result); - } - } - void producer(json& tasks, MPI_Datatpe& MPI_Result) - { - Task_Result result; - int num_tasks = tasks.size(); - for (int i = 0; i < num_tasks; ++i) { - MPI_Status status; - MPI_recv(&result, 1, MPI_Result, MPI_ANY_SOURCE, MPI_ANY_TAG, MPI_COMM_WORLD, &status); - if (status.MPI_TAG == TAG_RESULT) { - //Store result - - } - MPI_Send(&i, 1, MPI_INT, status.MPI_SOURCE, TAG_TASK, MPI_COMM_WORLD); - } - // Send end message to all workers - for (int i = 0; i < config_mpi.n_procs; ++i) { - MPI_Status status; - MPI_recv(&result, 1, MPI_Result, MPI_ANY_SOURCE, MPI_ANY_TAG, MPI_COMM_WORLD, &status); - if (status.MPI_TAG == TAG_RESULT) { - //Store result - - } - MPI_Send(&i, 1, MPI_INT, status.MPI_SOURCE, TAG_END, MPI_COMM_WORLD); - } - } - void consumer(json& tasks, MPI_Datatpe& MPI_Result) - { - Task_Result result; - // Anounce to the producer - MPI_Send(&result, 1, MPI_Result, config_mpi.manager, TAG_QUERY, MPI_COMM_WORLD); - int task; - while (true) { - MPI_Status status; - MPI_recv(&task, 1, MPI_INT, config_mpi.manager, MPI_ANY_TAG, MPI_COMM_WORLD, &status); - if (status.MPI_TAG == TAG_END) { - break; - } - // Process task - process_task_mpi(config_mpi, task, datasets, results); - // Send result to producer - MPI_Send(&result, 1, MPI_Result, config_mpi.manager, TAG_RESULT, MPI_COMM_WORLD); + consumer(tasks, config_mpi, MPI_Result); } } void GridSearch::go_mpi(struct ConfigMPI& config_mpi) diff --git a/src/Platform/GridSearch.h b/src/Platform/GridSearch.h index 1a868a8..a9e2f6e 100644 --- a/src/Platform/GridSearch.h +++ b/src/Platform/GridSearch.h @@ -35,15 +35,16 @@ namespace platform { uint idx_combination; double score; } Task_Result; - const TAG_QUERY = 1; - const TAG_RESULT = 2; - const TAG_TASK = 3; - const TAG_END = 4; + const int TAG_QUERY = 1; + const int TAG_RESULT = 2; + const int TAG_TASK = 3; + const int TAG_END = 4; class GridSearch { public: explicit GridSearch(struct ConfigGrid& config); void go(); void go_mpi(struct ConfigMPI& config_mpi); + void go_producer_consumer(struct ConfigMPI& config_mpi); ~GridSearch() = default; json getResults(); static inline std::string NO_CONTINUE() { return "NO_CONTINUE"; }