From dbf2f355029b395eb961a265df145f27a37a7fe5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Tue, 12 Dec 2023 18:57:57 +0100 Subject: [PATCH] First compiling version --- src/Platform/GridSearch.cc | 124 ++++++++++++++++++++++++------------- src/Platform/GridSearch.h | 1 + src/Platform/b_grid.cc | 4 +- 3 files changed, 84 insertions(+), 45 deletions(-) diff --git a/src/Platform/GridSearch.cc b/src/Platform/GridSearch.cc index 7b629ba..4f1f6e7 100644 --- a/src/Platform/GridSearch.cc +++ b/src/Platform/GridSearch.cc @@ -38,6 +38,39 @@ namespace platform { } return json(); } + vector GridSearch::processDatasets(Datasets& datasets) + { + // Load datasets + auto datasets_names = datasets.getNames(); + if (config.continue_from != NO_CONTINUE()) { + // Continue previous execution: + if (std::find(datasets_names.begin(), datasets_names.end(), config.continue_from) == datasets_names.end()) { + throw std::invalid_argument("Dataset " + config.continue_from + " not found"); + } + // Remove datasets already processed + vector< string >::iterator it = datasets_names.begin(); + while (it != datasets_names.end()) { + if (*it != config.continue_from) { + it = datasets_names.erase(it); + } else { + if (config.only) + ++it; + else + break; + } + } + } + // Exclude datasets + for (const auto& name : config.excluded) { + auto dataset = name.get(); + auto it = std::find(datasets_names.begin(), datasets_names.end(), dataset); + if (it == datasets_names.end()) { + throw std::invalid_argument("Dataset " + dataset + " already excluded or doesn't exist!"); + } + datasets_names.erase(it); + } + return datasets_names; + } void showProgressComb(const int num, const int n_folds, const int total, const std::string& color) { int spaces = int(log(total) / log(10)) + 1; @@ -67,7 +100,7 @@ namespace platform { { auto result = json::array(); auto datasets = Datasets(false, Paths::datasets()); - auto datasets_names = datasets.getNames(); + auto datasets_names = processDatasets(datasets); auto grid = GridData(Paths::grid_input(config.model)); for (const auto& dataset : datasets_names) { for (const auto& seed : config.seeds) { @@ -103,6 +136,10 @@ namespace platform { } return { start, end }; } + void status(struct ConfigMPI& config_mpi, std::string status) + { + std::cout << "* (" << config_mpi.rank << "): " << status << std::endl; + } void GridSearch::go_MPI(struct ConfigMPI& config_mpi) { /* @@ -148,12 +185,13 @@ namespace platform { int num_tasks = tasks.size(); auto [start, end] = partRange(num_tasks, config_mpi.n_procs, config_mpi.rank); // 2.2 Each worker will process the combinations and return the best score obtained + auto datasets = Datasets(config.discretize, Paths::datasets()); for (int i = start; i < end; ++i) { auto task = tasks[i]; auto dataset = task["dataset"].get(); auto seed = task["seed"].get(); auto hyperparam_line = task["hyperparameters"]; - auto datasets = Datasets(config.discretize, Paths::datasets()); + status(config_mpi, "Processing dataset " + dataset + " with seed " + std::to_string(seed) + " and hyperparameters " + hyperparam_line.dump()); auto [X, y] = datasets.getTensors(dataset); auto states = datasets.getStates(dataset); auto features = datasets.getFeatures(dataset); @@ -167,20 +205,53 @@ namespace platform { else fold = new KFold(config.n_folds, y.size(0), seed); for (int nfold = 0; nfold < config.n_folds; nfold++) { - - auto clf = Models::instance()->create(config.model); - auto valid = clf->getValidHyperparameters(); - hyperparameters.check(valid, dataset); - clf->setHyperparameters(hyperparameters.get(dataset)); + status(config_mpi, "Processing fold " + std::to_string(nfold + 1)); auto [train, test] = fold->getFold(nfold); auto train_t = torch::tensor(train); auto test_t = torch::tensor(test); auto X_train = X.index({ "...", train_t }); auto y_train = y.index({ train_t }); - auto X_test = X.index({ "...", test - } - + auto X_test = X.index({ "...", test_t }); + auto y_test = y.index({ test_t }); + auto num = 0; + json result_fold; + double hypScore = 0.0; + double bestHypScore = 0.0; + json bestHypHyperparameters; + Fold* nested_fold; + if (config.stratified) + nested_fold = new StratifiedKFold(config.nested, y_train, seed); + else + nested_fold = new KFold(config.nested, y_train.size(0), seed); + for (int n_nested_fold = 0; n_nested_fold < config.nested; n_nested_fold++) { + // Nested level fold + status(config_mpi, "Processing nested fold " + std::to_string(n_nested_fold + 1)); + auto [train_nested, test_nested] = nested_fold->getFold(n_nested_fold); + auto train_nested_t = torch::tensor(train_nested); + auto test_nested_t = torch::tensor(test_nested); + auto X_nexted_train = X_train.index({ "...", train_nested_t }); + auto y_nested_train = y_train.index({ train_nested_t }); + auto X_nested_test = X_train.index({ "...", test_nested_t }); + auto y_nested_test = y_train.index({ test_nested_t }); + // Build Classifier with selected hyperparameters + auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line); + auto clf = Models::instance()->create(config.model); + auto valid = clf->getValidHyperparameters(); + hyperparameters.check(valid, dataset); + clf->setHyperparameters(hyperparameters.get(dataset)); + // Train model + clf->fit(X_nexted_train, y_nested_train, features, className, states); + // Test model + hypScore += clf->score(X_nested_test, y_nested_test); + } + delete nested_fold; + hypScore /= config.nested; + if (hypScore > bestHypScore) { + bestHypScore = hypScore; + bestHypHyperparameters = hyperparam_line; + } } + delete fold; } } void GridSearch::go() @@ -391,39 +462,6 @@ namespace platform { } return { goatScore, goatHyperparameters }; } - vector GridSearch::processDatasets(Datasets& datasets) - { - // Load datasets - auto datasets_names = datasets.getNames(); - if (config.continue_from != NO_CONTINUE()) { - // Continue previous execution: - if (std::find(datasets_names.begin(), datasets_names.end(), config.continue_from) == datasets_names.end()) { - throw std::invalid_argument("Dataset " + config.continue_from + " not found"); - } - // Remove datasets already processed - vector< string >::iterator it = datasets_names.begin(); - while (it != datasets_names.end()) { - if (*it != config.continue_from) { - it = datasets_names.erase(it); - } else { - if (config.only) - ++it; - else - break; - } - } - } - // Exclude datasets - for (const auto& name : config.excluded) { - auto dataset = name.get(); - auto it = std::find(datasets_names.begin(), datasets_names.end(), dataset); - if (it == datasets_names.end()) { - throw std::invalid_argument("Dataset " + dataset + " already excluded or doesn't exist!"); - } - datasets_names.erase(it); - } - return datasets_names; - } json GridSearch::initializeResults() { // Load previous results diff --git a/src/Platform/GridSearch.h b/src/Platform/GridSearch.h index 330696d..4c757fa 100644 --- a/src/Platform/GridSearch.h +++ b/src/Platform/GridSearch.h @@ -2,6 +2,7 @@ #define GRIDSEARCH_H #include #include +#include #include #include "Datasets.h" #include "HyperParameters.h" diff --git a/src/Platform/b_grid.cc b/src/Platform/b_grid.cc index dde5d14..e5bcd03 100644 --- a/src/Platform/b_grid.cc +++ b/src/Platform/b_grid.cc @@ -189,8 +189,8 @@ int main(int argc, char** argv) auto excluded = program.get("exclude"); config.excluded = json::parse(excluded); if (program.get("mpi")) { - if (!compute) { - throw std::runtime_error("Cannot use --mpi without --compute"); + if (!compute || config.nested == 0) { + throw std::runtime_error("Cannot use --mpi without --compute or without --nested"); } } }