Fix duration of task not set

This commit is contained in:
Ricardo Montañana Gómez 2023-12-16 19:31:45 +01:00
parent 49b26bd04b
commit 9b8db37a4b
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
3 changed files with 12 additions and 19 deletions

View File

@ -152,6 +152,8 @@ namespace platform {
void GridSearch::process_task_mpi(struct ConfigMPI& config_mpi, json& task, Datasets& datasets, json& results)
{
// Process the task and store the result in the results json
Timer timer;
timer.start();
auto grid = GridData(Paths::grid_input(config.model));
auto dataset = task["dataset"].get<std::string>();
auto seed = task["seed"].get<int>();
@ -227,6 +229,7 @@ namespace platform {
results[dataset][std::to_string(n_fold)]["score"] = best_fold_score;
results[dataset][std::to_string(n_fold)]["hyperparameters"] = best_fold_hyper;
results[dataset][std::to_string(n_fold)]["seed"] = seed;
results[dataset][std::to_string(n_fold)]["duration"] = timer.getDuration();
std::cout << get_color_rank(config_mpi.rank) << "*" << std::flush;
}
void GridSearch::go_mpi(struct ConfigMPI& config_mpi)
@ -322,19 +325,22 @@ namespace platform {
auto grid = GridData(Paths::grid_input(config.model));
for (auto& [dataset, folds] : total_results.items()) {
double best_score = 0.0;
double duration = 0.0;
json best_hyper;
for (auto& [fold, result] : folds.items()) {
duration += result["duration"].get<double>();
if (result["score"] > best_score) {
best_score = result["score"];
best_hyper = result["hyperparameters"];
}
}
auto timer = Timer();
json result = {
{ "score", best_score },
{ "hyperparameters", best_hyper },
{ "date", get_date() + " " + get_time() },
{ "grid", grid.getInputGrid(dataset) },
{ "duration", 0 }
{ "duration", timer.translate2String(duration) }
};
best_results[dataset] = result;
}

View File

@ -28,6 +28,10 @@ namespace platform {
std::string getDurationString(bool lapse = false)
{
double duration = lapse ? getLapse() : getDuration();
return translate2String(duration);
}
std::string translate2String(double duration)
{
double durationShow = duration > 3600 ? duration / 3600 : duration > 60 ? duration / 60 : duration;
std::string durationUnit = duration > 3600 ? "h" : duration > 60 ? "m" : "s";
std::stringstream ss;

View File

@ -133,30 +133,13 @@ void list_results(json& results, std::string& model)
std::cout << color;
std::cout << std::setw(3) << std::right << index++ << " ";
std::cout << left << setw(spaces) << key << " " << value["date"].get<string>()
<< " " << setw(8) << value["duration"] << " " << setw(8) << setprecision(6)
<< " " << setw(8) << value["duration"].get<string>() << " " << setw(8) << setprecision(6)
<< fixed << right << value["score"].get<double>() << " " << value["hyperparameters"].dump() << std::endl;
odd = !odd;
}
std::cout << Colors::RESET() << std::endl;
}
void initialize_mpi(struct platform::ConfigMPI& config)
{
// int provided;
// MPI_Init_thread(nullptr, nullptr, MPI_THREAD_MULTIPLE, &provided);
// if (provided != MPI_THREAD_MULTIPLE) {
// std::cerr << "MPI_Init_thread returned " << provided << " instead of " << MPI_THREAD_MULTIPLE << std::endl;
// exit(1);
// }
// MPI_Init(nullptr, nullptr);
// int rank, size;
// MPI_Comm_rank(MPI_COMM_WORLD, &rank);
// MPI_Comm_size(MPI_COMM_WORLD, &size);
// config.mpi_rank = rank;
// config.mpi_size = size;
}
/*
* Main
*/