Fix GridBase to eliminate uneeded GridData

This commit is contained in:
2025-03-20 15:54:13 +01:00
parent c9ab88e475
commit facf6f6ddd
5 changed files with 7 additions and 18 deletions

View File

@@ -1,11 +1,9 @@
#include <iostream> #include <iostream>
#include <argparse/argparse.hpp> #include <argparse/argparse.hpp>
#include <map> #include <map>
#include <tuple>
#include <nlohmann/json.hpp> #include <nlohmann/json.hpp>
#include <mpi.h> #include <mpi.h>
#include "main/Models.h" #include "main/Models.h"
#include "main/modelRegister.h"
#include "main/ArgumentsExperiment.h" #include "main/ArgumentsExperiment.h"
#include "common/Paths.h" #include "common/Paths.h"
#include "common/Timer.hpp" #include "common/Timer.hpp"

View File

@@ -2,9 +2,10 @@
#include <cstddef> #include <cstddef>
#include "common/DotEnv.h" #include "common/DotEnv.h"
#include "common/Paths.h" #include "common/Paths.h"
#include "common/DotEnv.h" #include "common/Colors.h"
#include "GridBase.h" #include "GridBase.h"
namespace platform { namespace platform {
GridBase::GridBase(struct ConfigGrid& config) GridBase::GridBase(struct ConfigGrid& config)
@@ -63,13 +64,11 @@ namespace platform {
* This way a task consists in process all combinations of hyperparameters for a dataset, seed and fold * This way a task consists in process all combinations of hyperparameters for a dataset, seed and fold
*/ */
auto tasks = json::array(); auto tasks = json::array();
auto grid = GridData(Paths::grid_input(config.model));
auto all_datasets = datasets.getNames(); auto all_datasets = datasets.getNames();
auto datasets_names = filterDatasets(datasets); auto datasets_names = filterDatasets(datasets);
for (int idx_dataset = 0; idx_dataset < datasets_names.size(); ++idx_dataset) { for (int idx_dataset = 0; idx_dataset < datasets_names.size(); ++idx_dataset) {
auto dataset = datasets_names[idx_dataset]; auto dataset = datasets_names[idx_dataset];
for (const auto& seed : config.seeds) { for (const auto& seed : config.seeds) {
auto combinations = grid.getGrid(dataset);
for (int n_fold = 0; n_fold < config.n_folds; n_fold++) { for (int n_fold = 0; n_fold < config.n_folds; n_fold++) {
json task = { json task = {
{ "dataset", dataset }, { "dataset", dataset },
@@ -312,4 +311,4 @@ namespace platform {
} }
} }
} }

View File

@@ -1,16 +1,12 @@
#ifndef GRIDBASE_H #ifndef GRIDBASE_H
#define GRIDBASE_H #define GRIDBASE_H
#include <string> #include <string>
#include <map>
#include <mpi.h> #include <mpi.h>
#include <nlohmann/json.hpp> #include <nlohmann/json.hpp>
#include "common/Datasets.h" #include "common/Datasets.h"
#include "common/Timer.hpp" #include "common/Timer.hpp"
#include "common/Colors.h"
#include "main/HyperParameters.h" #include "main/HyperParameters.h"
#include "GridData.h"
#include "GridConfig.h" #include "GridConfig.h"
#include "bayesnet/network/Network.h"
namespace platform { namespace platform {
@@ -40,4 +36,4 @@ namespace platform {
bayesnet::Smoothing_t smooth_type{ bayesnet::Smoothing_t::NONE }; bayesnet::Smoothing_t smooth_type{ bayesnet::Smoothing_t::NONE };
}; };
} /* namespace platform */ } /* namespace platform */
#endif #endif

View File

@@ -1,18 +1,14 @@
#ifndef GRIDEXPERIMENT_H #ifndef GRIDEXPERIMENT_H
#define GRIDEXPERIMENT_H #define GRIDEXPERIMENT_H
#include <string> #include <string>
#include <map>
#include <mpi.h> #include <mpi.h>
#include <argparse/argparse.hpp> #include <argparse/argparse.hpp>
#include <nlohmann/json.hpp> #include <nlohmann/json.hpp>
#include "common/Datasets.h" #include "common/Datasets.h"
#include "common/DotEnv.h"
#include "main/Experiment.h" #include "main/Experiment.h"
#include "main/HyperParameters.h" #include "main/HyperParameters.h"
#include "main/ArgumentsExperiment.h" #include "main/ArgumentsExperiment.h"
#include "GridData.h"
#include "GridBase.h" #include "GridBase.h"
#include "bayesnet/network/Network.h"
namespace platform { namespace platform {
@@ -39,4 +35,4 @@ namespace platform {
void consumer_go(struct ConfigGrid& config, struct ConfigMPI& config_mpi, json& tasks, int n_task, Datasets& datasets, Task_Result* result); void consumer_go(struct ConfigGrid& config, struct ConfigMPI& config_mpi, json& tasks, int n_task, Datasets& datasets, Task_Result* result);
}; };
} /* namespace platform */ } /* namespace platform */
#endif #endif

View File

@@ -1,10 +1,10 @@
#include <iostream> #include <iostream>
#include <cstddef>
#include <torch/torch.h> #include <torch/torch.h>
#include <folding.hpp> #include <folding.hpp>
#include "main/Models.h" #include "main/Models.h"
#include "common/Paths.h" #include "common/Paths.h"
#include "common/Utils.h" #include "common/Utils.h"
#include "common/Colors.h"
#include "GridSearch.h" #include "GridSearch.h"
namespace platform { namespace platform {
@@ -256,4 +256,4 @@ namespace platform {
// //
std::cout << get_color_rank(config_mpi.rank) << std::flush; std::cout << get_color_rank(config_mpi.rank) << std::flush;
} }
} /* namespace platform */ } /* namespace platform */