Fix first mistakes in structure
This commit is contained in:
parent
702f086706
commit
21c4c6df51
@ -1,4 +1,5 @@
|
|||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
#include <cstddef>
|
||||||
#include <torch/torch.h>
|
#include <torch/torch.h>
|
||||||
#include "GridSearch.h"
|
#include "GridSearch.h"
|
||||||
#include "Models.h"
|
#include "Models.h"
|
||||||
@ -101,13 +102,20 @@ namespace platform {
|
|||||||
auto tasks = json::array();
|
auto tasks = json::array();
|
||||||
auto grid = GridData(Paths::grid_input(config.model));
|
auto grid = GridData(Paths::grid_input(config.model));
|
||||||
auto datasets = Datasets(false, Paths::datasets());
|
auto datasets = Datasets(false, Paths::datasets());
|
||||||
|
auto all_datasets = datasets.getNames();
|
||||||
auto datasets_names = processDatasets(datasets);
|
auto datasets_names = processDatasets(datasets);
|
||||||
for (const auto& dataset : datasets_names) {
|
for (const auto& dataset : datasets_names) {
|
||||||
for (const auto& seed : config.seeds) {
|
for (const auto& seed : config.seeds) {
|
||||||
auto combinations = grid.getGrid(dataset);
|
auto combinations = grid.getGrid(dataset);
|
||||||
for (int n_fold = 0; n_fold < config.n_folds; n_fold++) {
|
for (int n_fold = 0; n_fold < config.n_folds; n_fold++) {
|
||||||
|
auto it = find(all_datasets.begin(), all_datasets.end(), dataset);
|
||||||
|
if (it == all_datasets.end()) {
|
||||||
|
throw std::invalid_argument("Dataset " + dataset + " not found");
|
||||||
|
}
|
||||||
|
auto idx_dataset = std::distance(all_datasets.begin(), it);
|
||||||
json task = {
|
json task = {
|
||||||
{ "dataset", dataset },
|
{ "dataset", dataset },
|
||||||
|
{ "idx_dataset", idx_dataset},
|
||||||
{ "seed", seed },
|
{ "seed", seed },
|
||||||
{ "fold", n_fold}
|
{ "fold", n_fold}
|
||||||
};
|
};
|
||||||
@ -126,6 +134,9 @@ namespace platform {
|
|||||||
std::cout << "|" << std::endl << "|" << std::flush;
|
std::cout << "|" << std::endl << "|" << std::flush;
|
||||||
return tasks;
|
return tasks;
|
||||||
}
|
}
|
||||||
|
void process_task_mpi(struct ConfigMPI& config_mpi, int task, Task_Result* result)
|
||||||
|
{
|
||||||
|
}
|
||||||
std::pair<int, int> GridSearch::part_range_mpi(int n_tasks, int nprocs, int rank)
|
std::pair<int, int> GridSearch::part_range_mpi(int n_tasks, int nprocs, int rank)
|
||||||
{
|
{
|
||||||
int assigned = 0;
|
int assigned = 0;
|
||||||
@ -149,7 +160,48 @@ namespace platform {
|
|||||||
auto colors = { Colors::RED(), Colors::GREEN(), Colors::BLUE(), Colors::MAGENTA(), Colors::CYAN() };
|
auto colors = { Colors::RED(), Colors::GREEN(), Colors::BLUE(), Colors::MAGENTA(), Colors::CYAN() };
|
||||||
return *(colors.begin() + rank % colors.size());
|
return *(colors.begin() + rank % colors.size());
|
||||||
}
|
}
|
||||||
|
void producer(json& tasks, struct ConfigMPI& config_mpi, MPI_Datatype& MPI_Result)
|
||||||
|
{
|
||||||
|
Task_Result result;
|
||||||
|
int num_tasks = tasks.size();
|
||||||
|
for (int i = 0; i < num_tasks; ++i) {
|
||||||
|
MPI_Status status;
|
||||||
|
MPI_Recv(&result, 1, MPI_Result, MPI_ANY_SOURCE, MPI_ANY_TAG, MPI_COMM_WORLD, &status);
|
||||||
|
if (status.MPI_TAG == TAG_RESULT) {
|
||||||
|
//Store result
|
||||||
|
|
||||||
|
}
|
||||||
|
MPI_Send(&i, 1, MPI_INT, status.MPI_SOURCE, TAG_TASK, MPI_COMM_WORLD);
|
||||||
|
}
|
||||||
|
// Send end message to all workers
|
||||||
|
for (int i = 0; i < config_mpi.n_procs; ++i) {
|
||||||
|
MPI_Status status;
|
||||||
|
MPI_Recv(&result, 1, MPI_Result, MPI_ANY_SOURCE, MPI_ANY_TAG, MPI_COMM_WORLD, &status);
|
||||||
|
if (status.MPI_TAG == TAG_RESULT) {
|
||||||
|
//Store result
|
||||||
|
|
||||||
|
}
|
||||||
|
MPI_Send(&i, 1, MPI_INT, status.MPI_SOURCE, TAG_END, MPI_COMM_WORLD);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void consumer(json& tasks, struct ConfigMPI& config_mpi, MPI_Datatype& MPI_Result)
|
||||||
|
{
|
||||||
|
Task_Result result;
|
||||||
|
// Anounce to the producer
|
||||||
|
MPI_Send(&result, 1, MPI_Result, config_mpi.manager, TAG_QUERY, MPI_COMM_WORLD);
|
||||||
|
int task;
|
||||||
|
while (true) {
|
||||||
|
MPI_Status status;
|
||||||
|
MPI_Recv(&task, 1, MPI_INT, config_mpi.manager, MPI_ANY_TAG, MPI_COMM_WORLD, &status);
|
||||||
|
if (status.MPI_TAG == TAG_END) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
// Process task
|
||||||
|
process_task_mpi(config_mpi, task, &result);
|
||||||
|
// Send result to producer
|
||||||
|
MPI_Send(&result, 1, MPI_Result, config_mpi.manager, TAG_RESULT, MPI_COMM_WORLD);
|
||||||
|
}
|
||||||
|
}
|
||||||
void GridSearch::go_producer_consumer(struct ConfigMPI& config_mpi)
|
void GridSearch::go_producer_consumer(struct ConfigMPI& config_mpi)
|
||||||
{
|
{
|
||||||
/*
|
/*
|
||||||
@ -182,13 +234,14 @@ namespace platform {
|
|||||||
// 0.1 Create the MPI result type
|
// 0.1 Create the MPI result type
|
||||||
//
|
//
|
||||||
Task_Result result;
|
Task_Result result;
|
||||||
|
int tasks_size;
|
||||||
MPI_Datatype MPI_Result;
|
MPI_Datatype MPI_Result;
|
||||||
MPI_Datatype type[3] = { MPI_UNSIGNED, MPI_UNSIGNED, MPI_DOUBLE };
|
MPI_Datatype type[3] = { MPI_UNSIGNED, MPI_UNSIGNED, MPI_DOUBLE };
|
||||||
int blocklen[3] = { 1, 1, 1 };
|
int blocklen[3] = { 1, 1, 1 };
|
||||||
MPI_Aint disp[3];
|
MPI_Aint disp[3];
|
||||||
disp[0] = offsetof(struct MPI_Result, idx_dataset);
|
disp[0] = offsetof(Task_Result, idx_dataset);
|
||||||
disp[1] = offsetof(struct MPI_Result, idx_combination);
|
disp[1] = offsetof(Task_Result, idx_combination);
|
||||||
disp[2] = offsetof(struct MPI_Result, score);
|
disp[2] = offsetof(Task_Result, score);
|
||||||
MPI_Type_create_struct(3, blocklen, disp, type, &MPI_Result);
|
MPI_Type_create_struct(3, blocklen, disp, type, &MPI_Result);
|
||||||
MPI_Type_commit(&MPI_Result);
|
MPI_Type_commit(&MPI_Result);
|
||||||
//
|
//
|
||||||
@ -217,51 +270,9 @@ namespace platform {
|
|||||||
// 2. All Workers will receive the tasks and start the process
|
// 2. All Workers will receive the tasks and start the process
|
||||||
//
|
//
|
||||||
if (config_mpi.rank == config_mpi.manager) {
|
if (config_mpi.rank == config_mpi.manager) {
|
||||||
producer(tasks, &MPI_Result);
|
producer(tasks, config_mpi, MPI_Result);
|
||||||
} else {
|
} else {
|
||||||
consumer(tasks, &MPI_Result);
|
consumer(tasks, config_mpi, MPI_Result);
|
||||||
}
|
|
||||||
}
|
|
||||||
void producer(json& tasks, MPI_Datatpe& MPI_Result)
|
|
||||||
{
|
|
||||||
Task_Result result;
|
|
||||||
int num_tasks = tasks.size();
|
|
||||||
for (int i = 0; i < num_tasks; ++i) {
|
|
||||||
MPI_Status status;
|
|
||||||
MPI_recv(&result, 1, MPI_Result, MPI_ANY_SOURCE, MPI_ANY_TAG, MPI_COMM_WORLD, &status);
|
|
||||||
if (status.MPI_TAG == TAG_RESULT) {
|
|
||||||
//Store result
|
|
||||||
|
|
||||||
}
|
|
||||||
MPI_Send(&i, 1, MPI_INT, status.MPI_SOURCE, TAG_TASK, MPI_COMM_WORLD);
|
|
||||||
}
|
|
||||||
// Send end message to all workers
|
|
||||||
for (int i = 0; i < config_mpi.n_procs; ++i) {
|
|
||||||
MPI_Status status;
|
|
||||||
MPI_recv(&result, 1, MPI_Result, MPI_ANY_SOURCE, MPI_ANY_TAG, MPI_COMM_WORLD, &status);
|
|
||||||
if (status.MPI_TAG == TAG_RESULT) {
|
|
||||||
//Store result
|
|
||||||
|
|
||||||
}
|
|
||||||
MPI_Send(&i, 1, MPI_INT, status.MPI_SOURCE, TAG_END, MPI_COMM_WORLD);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
void consumer(json& tasks, MPI_Datatpe& MPI_Result)
|
|
||||||
{
|
|
||||||
Task_Result result;
|
|
||||||
// Anounce to the producer
|
|
||||||
MPI_Send(&result, 1, MPI_Result, config_mpi.manager, TAG_QUERY, MPI_COMM_WORLD);
|
|
||||||
int task;
|
|
||||||
while (true) {
|
|
||||||
MPI_Status status;
|
|
||||||
MPI_recv(&task, 1, MPI_INT, config_mpi.manager, MPI_ANY_TAG, MPI_COMM_WORLD, &status);
|
|
||||||
if (status.MPI_TAG == TAG_END) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
// Process task
|
|
||||||
process_task_mpi(config_mpi, task, datasets, results);
|
|
||||||
// Send result to producer
|
|
||||||
MPI_Send(&result, 1, MPI_Result, config_mpi.manager, TAG_RESULT, MPI_COMM_WORLD);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
void GridSearch::go_mpi(struct ConfigMPI& config_mpi)
|
void GridSearch::go_mpi(struct ConfigMPI& config_mpi)
|
||||||
|
@ -35,15 +35,16 @@ namespace platform {
|
|||||||
uint idx_combination;
|
uint idx_combination;
|
||||||
double score;
|
double score;
|
||||||
} Task_Result;
|
} Task_Result;
|
||||||
const TAG_QUERY = 1;
|
const int TAG_QUERY = 1;
|
||||||
const TAG_RESULT = 2;
|
const int TAG_RESULT = 2;
|
||||||
const TAG_TASK = 3;
|
const int TAG_TASK = 3;
|
||||||
const TAG_END = 4;
|
const int TAG_END = 4;
|
||||||
class GridSearch {
|
class GridSearch {
|
||||||
public:
|
public:
|
||||||
explicit GridSearch(struct ConfigGrid& config);
|
explicit GridSearch(struct ConfigGrid& config);
|
||||||
void go();
|
void go();
|
||||||
void go_mpi(struct ConfigMPI& config_mpi);
|
void go_mpi(struct ConfigMPI& config_mpi);
|
||||||
|
void go_producer_consumer(struct ConfigMPI& config_mpi);
|
||||||
~GridSearch() = default;
|
~GridSearch() = default;
|
||||||
json getResults();
|
json getResults();
|
||||||
static inline std::string NO_CONTINUE() { return "NO_CONTINUE"; }
|
static inline std::string NO_CONTINUE() { return "NO_CONTINUE"; }
|
||||||
|
Loading…
Reference in New Issue
Block a user