Fix GridBase to eliminate uneeded GridData

This commit is contained in:
2025-03-20 15:54:13 +01:00
parent c9ab88e475
commit facf6f6ddd
5 changed files with 7 additions and 18 deletions

View File

@@ -1,11 +1,9 @@
#include <iostream> #include <iostream>
#include <argparse/argparse.hpp> #include <argparse/argparse.hpp>
#include <map> #include <map>
#include <tuple>
#include <nlohmann/json.hpp> #include <nlohmann/json.hpp>
#include <mpi.h> #include <mpi.h>
#include "main/Models.h" #include "main/Models.h"
#include "main/modelRegister.h"
#include "main/ArgumentsExperiment.h" #include "main/ArgumentsExperiment.h"
#include "common/Paths.h" #include "common/Paths.h"
#include "common/Timer.hpp" #include "common/Timer.hpp"

View File

@@ -2,9 +2,10 @@
#include <cstddef> #include <cstddef>
#include "common/DotEnv.h" #include "common/DotEnv.h"
#include "common/Paths.h" #include "common/Paths.h"
#include "common/DotEnv.h" #include "common/Colors.h"
#include "GridBase.h" #include "GridBase.h"
namespace platform { namespace platform {
GridBase::GridBase(struct ConfigGrid& config) GridBase::GridBase(struct ConfigGrid& config)
@@ -63,13 +64,11 @@ namespace platform {
* This way a task consists in process all combinations of hyperparameters for a dataset, seed and fold * This way a task consists in process all combinations of hyperparameters for a dataset, seed and fold
*/ */
auto tasks = json::array(); auto tasks = json::array();
auto grid = GridData(Paths::grid_input(config.model));
auto all_datasets = datasets.getNames(); auto all_datasets = datasets.getNames();
auto datasets_names = filterDatasets(datasets); auto datasets_names = filterDatasets(datasets);
for (int idx_dataset = 0; idx_dataset < datasets_names.size(); ++idx_dataset) { for (int idx_dataset = 0; idx_dataset < datasets_names.size(); ++idx_dataset) {
auto dataset = datasets_names[idx_dataset]; auto dataset = datasets_names[idx_dataset];
for (const auto& seed : config.seeds) { for (const auto& seed : config.seeds) {
auto combinations = grid.getGrid(dataset);
for (int n_fold = 0; n_fold < config.n_folds; n_fold++) { for (int n_fold = 0; n_fold < config.n_folds; n_fold++) {
json task = { json task = {
{ "dataset", dataset }, { "dataset", dataset },

View File

@@ -1,16 +1,12 @@
#ifndef GRIDBASE_H #ifndef GRIDBASE_H
#define GRIDBASE_H #define GRIDBASE_H
#include <string> #include <string>
#include <map>
#include <mpi.h> #include <mpi.h>
#include <nlohmann/json.hpp> #include <nlohmann/json.hpp>
#include "common/Datasets.h" #include "common/Datasets.h"
#include "common/Timer.hpp" #include "common/Timer.hpp"
#include "common/Colors.h"
#include "main/HyperParameters.h" #include "main/HyperParameters.h"
#include "GridData.h"
#include "GridConfig.h" #include "GridConfig.h"
#include "bayesnet/network/Network.h"
namespace platform { namespace platform {

View File

@@ -1,18 +1,14 @@
#ifndef GRIDEXPERIMENT_H #ifndef GRIDEXPERIMENT_H
#define GRIDEXPERIMENT_H #define GRIDEXPERIMENT_H
#include <string> #include <string>
#include <map>
#include <mpi.h> #include <mpi.h>
#include <argparse/argparse.hpp> #include <argparse/argparse.hpp>
#include <nlohmann/json.hpp> #include <nlohmann/json.hpp>
#include "common/Datasets.h" #include "common/Datasets.h"
#include "common/DotEnv.h"
#include "main/Experiment.h" #include "main/Experiment.h"
#include "main/HyperParameters.h" #include "main/HyperParameters.h"
#include "main/ArgumentsExperiment.h" #include "main/ArgumentsExperiment.h"
#include "GridData.h"
#include "GridBase.h" #include "GridBase.h"
#include "bayesnet/network/Network.h"
namespace platform { namespace platform {

View File

@@ -1,10 +1,10 @@
#include <iostream> #include <iostream>
#include <cstddef>
#include <torch/torch.h> #include <torch/torch.h>
#include <folding.hpp> #include <folding.hpp>
#include "main/Models.h" #include "main/Models.h"
#include "common/Paths.h" #include "common/Paths.h"
#include "common/Utils.h" #include "common/Utils.h"
#include "common/Colors.h"
#include "GridSearch.h" #include "GridSearch.h"
namespace platform { namespace platform {