Fix first mistakes in structure
This commit is contained in:
parent
702f086706
commit
21c4c6df51
@ -1,4 +1,5 @@
|
||||
#include <iostream>
|
||||
#include <cstddef>
|
||||
#include <torch/torch.h>
|
||||
#include "GridSearch.h"
|
||||
#include "Models.h"
|
||||
@ -101,13 +102,20 @@ namespace platform {
|
||||
auto tasks = json::array();
|
||||
auto grid = GridData(Paths::grid_input(config.model));
|
||||
auto datasets = Datasets(false, Paths::datasets());
|
||||
auto all_datasets = datasets.getNames();
|
||||
auto datasets_names = processDatasets(datasets);
|
||||
for (const auto& dataset : datasets_names) {
|
||||
for (const auto& seed : config.seeds) {
|
||||
auto combinations = grid.getGrid(dataset);
|
||||
for (int n_fold = 0; n_fold < config.n_folds; n_fold++) {
|
||||
auto it = find(all_datasets.begin(), all_datasets.end(), dataset);
|
||||
if (it == all_datasets.end()) {
|
||||
throw std::invalid_argument("Dataset " + dataset + " not found");
|
||||
}
|
||||
auto idx_dataset = std::distance(all_datasets.begin(), it);
|
||||
json task = {
|
||||
{ "dataset", dataset },
|
||||
{ "idx_dataset", idx_dataset},
|
||||
{ "seed", seed },
|
||||
{ "fold", n_fold}
|
||||
};
|
||||
@ -126,6 +134,9 @@ namespace platform {
|
||||
std::cout << "|" << std::endl << "|" << std::flush;
|
||||
return tasks;
|
||||
}
|
||||
void process_task_mpi(struct ConfigMPI& config_mpi, int task, Task_Result* result)
|
||||
{
|
||||
}
|
||||
std::pair<int, int> GridSearch::part_range_mpi(int n_tasks, int nprocs, int rank)
|
||||
{
|
||||
int assigned = 0;
|
||||
@ -149,7 +160,48 @@ namespace platform {
|
||||
auto colors = { Colors::RED(), Colors::GREEN(), Colors::BLUE(), Colors::MAGENTA(), Colors::CYAN() };
|
||||
return *(colors.begin() + rank % colors.size());
|
||||
}
|
||||
void producer(json& tasks, struct ConfigMPI& config_mpi, MPI_Datatype& MPI_Result)
|
||||
{
|
||||
Task_Result result;
|
||||
int num_tasks = tasks.size();
|
||||
for (int i = 0; i < num_tasks; ++i) {
|
||||
MPI_Status status;
|
||||
MPI_Recv(&result, 1, MPI_Result, MPI_ANY_SOURCE, MPI_ANY_TAG, MPI_COMM_WORLD, &status);
|
||||
if (status.MPI_TAG == TAG_RESULT) {
|
||||
//Store result
|
||||
|
||||
}
|
||||
MPI_Send(&i, 1, MPI_INT, status.MPI_SOURCE, TAG_TASK, MPI_COMM_WORLD);
|
||||
}
|
||||
// Send end message to all workers
|
||||
for (int i = 0; i < config_mpi.n_procs; ++i) {
|
||||
MPI_Status status;
|
||||
MPI_Recv(&result, 1, MPI_Result, MPI_ANY_SOURCE, MPI_ANY_TAG, MPI_COMM_WORLD, &status);
|
||||
if (status.MPI_TAG == TAG_RESULT) {
|
||||
//Store result
|
||||
|
||||
}
|
||||
MPI_Send(&i, 1, MPI_INT, status.MPI_SOURCE, TAG_END, MPI_COMM_WORLD);
|
||||
}
|
||||
}
|
||||
void consumer(json& tasks, struct ConfigMPI& config_mpi, MPI_Datatype& MPI_Result)
|
||||
{
|
||||
Task_Result result;
|
||||
// Anounce to the producer
|
||||
MPI_Send(&result, 1, MPI_Result, config_mpi.manager, TAG_QUERY, MPI_COMM_WORLD);
|
||||
int task;
|
||||
while (true) {
|
||||
MPI_Status status;
|
||||
MPI_Recv(&task, 1, MPI_INT, config_mpi.manager, MPI_ANY_TAG, MPI_COMM_WORLD, &status);
|
||||
if (status.MPI_TAG == TAG_END) {
|
||||
break;
|
||||
}
|
||||
// Process task
|
||||
process_task_mpi(config_mpi, task, &result);
|
||||
// Send result to producer
|
||||
MPI_Send(&result, 1, MPI_Result, config_mpi.manager, TAG_RESULT, MPI_COMM_WORLD);
|
||||
}
|
||||
}
|
||||
void GridSearch::go_producer_consumer(struct ConfigMPI& config_mpi)
|
||||
{
|
||||
/*
|
||||
@ -182,13 +234,14 @@ namespace platform {
|
||||
// 0.1 Create the MPI result type
|
||||
//
|
||||
Task_Result result;
|
||||
int tasks_size;
|
||||
MPI_Datatype MPI_Result;
|
||||
MPI_Datatype type[3] = { MPI_UNSIGNED, MPI_UNSIGNED, MPI_DOUBLE };
|
||||
int blocklen[3] = { 1, 1, 1 };
|
||||
MPI_Aint disp[3];
|
||||
disp[0] = offsetof(struct MPI_Result, idx_dataset);
|
||||
disp[1] = offsetof(struct MPI_Result, idx_combination);
|
||||
disp[2] = offsetof(struct MPI_Result, score);
|
||||
disp[0] = offsetof(Task_Result, idx_dataset);
|
||||
disp[1] = offsetof(Task_Result, idx_combination);
|
||||
disp[2] = offsetof(Task_Result, score);
|
||||
MPI_Type_create_struct(3, blocklen, disp, type, &MPI_Result);
|
||||
MPI_Type_commit(&MPI_Result);
|
||||
//
|
||||
@ -217,51 +270,9 @@ namespace platform {
|
||||
// 2. All Workers will receive the tasks and start the process
|
||||
//
|
||||
if (config_mpi.rank == config_mpi.manager) {
|
||||
producer(tasks, &MPI_Result);
|
||||
producer(tasks, config_mpi, MPI_Result);
|
||||
} else {
|
||||
consumer(tasks, &MPI_Result);
|
||||
}
|
||||
}
|
||||
void producer(json& tasks, MPI_Datatpe& MPI_Result)
|
||||
{
|
||||
Task_Result result;
|
||||
int num_tasks = tasks.size();
|
||||
for (int i = 0; i < num_tasks; ++i) {
|
||||
MPI_Status status;
|
||||
MPI_recv(&result, 1, MPI_Result, MPI_ANY_SOURCE, MPI_ANY_TAG, MPI_COMM_WORLD, &status);
|
||||
if (status.MPI_TAG == TAG_RESULT) {
|
||||
//Store result
|
||||
|
||||
}
|
||||
MPI_Send(&i, 1, MPI_INT, status.MPI_SOURCE, TAG_TASK, MPI_COMM_WORLD);
|
||||
}
|
||||
// Send end message to all workers
|
||||
for (int i = 0; i < config_mpi.n_procs; ++i) {
|
||||
MPI_Status status;
|
||||
MPI_recv(&result, 1, MPI_Result, MPI_ANY_SOURCE, MPI_ANY_TAG, MPI_COMM_WORLD, &status);
|
||||
if (status.MPI_TAG == TAG_RESULT) {
|
||||
//Store result
|
||||
|
||||
}
|
||||
MPI_Send(&i, 1, MPI_INT, status.MPI_SOURCE, TAG_END, MPI_COMM_WORLD);
|
||||
}
|
||||
}
|
||||
void consumer(json& tasks, MPI_Datatpe& MPI_Result)
|
||||
{
|
||||
Task_Result result;
|
||||
// Anounce to the producer
|
||||
MPI_Send(&result, 1, MPI_Result, config_mpi.manager, TAG_QUERY, MPI_COMM_WORLD);
|
||||
int task;
|
||||
while (true) {
|
||||
MPI_Status status;
|
||||
MPI_recv(&task, 1, MPI_INT, config_mpi.manager, MPI_ANY_TAG, MPI_COMM_WORLD, &status);
|
||||
if (status.MPI_TAG == TAG_END) {
|
||||
break;
|
||||
}
|
||||
// Process task
|
||||
process_task_mpi(config_mpi, task, datasets, results);
|
||||
// Send result to producer
|
||||
MPI_Send(&result, 1, MPI_Result, config_mpi.manager, TAG_RESULT, MPI_COMM_WORLD);
|
||||
consumer(tasks, config_mpi, MPI_Result);
|
||||
}
|
||||
}
|
||||
void GridSearch::go_mpi(struct ConfigMPI& config_mpi)
|
||||
|
@ -35,15 +35,16 @@ namespace platform {
|
||||
uint idx_combination;
|
||||
double score;
|
||||
} Task_Result;
|
||||
const TAG_QUERY = 1;
|
||||
const TAG_RESULT = 2;
|
||||
const TAG_TASK = 3;
|
||||
const TAG_END = 4;
|
||||
const int TAG_QUERY = 1;
|
||||
const int TAG_RESULT = 2;
|
||||
const int TAG_TASK = 3;
|
||||
const int TAG_END = 4;
|
||||
class GridSearch {
|
||||
public:
|
||||
explicit GridSearch(struct ConfigGrid& config);
|
||||
void go();
|
||||
void go_mpi(struct ConfigMPI& config_mpi);
|
||||
void go_producer_consumer(struct ConfigMPI& config_mpi);
|
||||
~GridSearch() = default;
|
||||
json getResults();
|
||||
static inline std::string NO_CONTINUE() { return "NO_CONTINUE"; }
|
||||
|
Loading…
Reference in New Issue
Block a user