First compiling version

This commit is contained in:
Ricardo Montañana Gómez 2023-12-12 18:57:57 +01:00
parent db9e80a70e
commit dbf2f35502
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
3 changed files with 84 additions and 45 deletions

View File

@ -38,6 +38,39 @@ namespace platform {
}
return json();
}
vector<std::string> GridSearch::processDatasets(Datasets& datasets)
{
// Load datasets
auto datasets_names = datasets.getNames();
if (config.continue_from != NO_CONTINUE()) {
// Continue previous execution:
if (std::find(datasets_names.begin(), datasets_names.end(), config.continue_from) == datasets_names.end()) {
throw std::invalid_argument("Dataset " + config.continue_from + " not found");
}
// Remove datasets already processed
vector< string >::iterator it = datasets_names.begin();
while (it != datasets_names.end()) {
if (*it != config.continue_from) {
it = datasets_names.erase(it);
} else {
if (config.only)
++it;
else
break;
}
}
}
// Exclude datasets
for (const auto& name : config.excluded) {
auto dataset = name.get<std::string>();
auto it = std::find(datasets_names.begin(), datasets_names.end(), dataset);
if (it == datasets_names.end()) {
throw std::invalid_argument("Dataset " + dataset + " already excluded or doesn't exist!");
}
datasets_names.erase(it);
}
return datasets_names;
}
void showProgressComb(const int num, const int n_folds, const int total, const std::string& color)
{
int spaces = int(log(total) / log(10)) + 1;
@ -67,7 +100,7 @@ namespace platform {
{
auto result = json::array();
auto datasets = Datasets(false, Paths::datasets());
auto datasets_names = datasets.getNames();
auto datasets_names = processDatasets(datasets);
auto grid = GridData(Paths::grid_input(config.model));
for (const auto& dataset : datasets_names) {
for (const auto& seed : config.seeds) {
@ -103,6 +136,10 @@ namespace platform {
}
return { start, end };
}
void status(struct ConfigMPI& config_mpi, std::string status)
{
std::cout << "* (" << config_mpi.rank << "): " << status << std::endl;
}
void GridSearch::go_MPI(struct ConfigMPI& config_mpi)
{
/*
@ -148,12 +185,13 @@ namespace platform {
int num_tasks = tasks.size();
auto [start, end] = partRange(num_tasks, config_mpi.n_procs, config_mpi.rank);
// 2.2 Each worker will process the combinations and return the best score obtained
auto datasets = Datasets(config.discretize, Paths::datasets());
for (int i = start; i < end; ++i) {
auto task = tasks[i];
auto dataset = task["dataset"].get<std::string>();
auto seed = task["seed"].get<int>();
auto hyperparam_line = task["hyperparameters"];
auto datasets = Datasets(config.discretize, Paths::datasets());
status(config_mpi, "Processing dataset " + dataset + " with seed " + std::to_string(seed) + " and hyperparameters " + hyperparam_line.dump());
auto [X, y] = datasets.getTensors(dataset);
auto states = datasets.getStates(dataset);
auto features = datasets.getFeatures(dataset);
@ -167,20 +205,53 @@ namespace platform {
else
fold = new KFold(config.n_folds, y.size(0), seed);
for (int nfold = 0; nfold < config.n_folds; nfold++) {
auto clf = Models::instance()->create(config.model);
auto valid = clf->getValidHyperparameters();
hyperparameters.check(valid, dataset);
clf->setHyperparameters(hyperparameters.get(dataset));
status(config_mpi, "Processing fold " + std::to_string(nfold + 1));
auto [train, test] = fold->getFold(nfold);
auto train_t = torch::tensor(train);
auto test_t = torch::tensor(test);
auto X_train = X.index({ "...", train_t });
auto y_train = y.index({ train_t });
auto X_test = X.index({ "...", test
}
auto X_test = X.index({ "...", test_t });
auto y_test = y.index({ test_t });
auto num = 0;
json result_fold;
double hypScore = 0.0;
double bestHypScore = 0.0;
json bestHypHyperparameters;
Fold* nested_fold;
if (config.stratified)
nested_fold = new StratifiedKFold(config.nested, y_train, seed);
else
nested_fold = new KFold(config.nested, y_train.size(0), seed);
for (int n_nested_fold = 0; n_nested_fold < config.nested; n_nested_fold++) {
// Nested level fold
status(config_mpi, "Processing nested fold " + std::to_string(n_nested_fold + 1));
auto [train_nested, test_nested] = nested_fold->getFold(n_nested_fold);
auto train_nested_t = torch::tensor(train_nested);
auto test_nested_t = torch::tensor(test_nested);
auto X_nexted_train = X_train.index({ "...", train_nested_t });
auto y_nested_train = y_train.index({ train_nested_t });
auto X_nested_test = X_train.index({ "...", test_nested_t });
auto y_nested_test = y_train.index({ test_nested_t });
// Build Classifier with selected hyperparameters
auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line);
auto clf = Models::instance()->create(config.model);
auto valid = clf->getValidHyperparameters();
hyperparameters.check(valid, dataset);
clf->setHyperparameters(hyperparameters.get(dataset));
// Train model
clf->fit(X_nexted_train, y_nested_train, features, className, states);
// Test model
hypScore += clf->score(X_nested_test, y_nested_test);
}
delete nested_fold;
hypScore /= config.nested;
if (hypScore > bestHypScore) {
bestHypScore = hypScore;
bestHypHyperparameters = hyperparam_line;
}
}
delete fold;
}
}
void GridSearch::go()
@ -391,39 +462,6 @@ namespace platform {
}
return { goatScore, goatHyperparameters };
}
vector<std::string> GridSearch::processDatasets(Datasets& datasets)
{
// Load datasets
auto datasets_names = datasets.getNames();
if (config.continue_from != NO_CONTINUE()) {
// Continue previous execution:
if (std::find(datasets_names.begin(), datasets_names.end(), config.continue_from) == datasets_names.end()) {
throw std::invalid_argument("Dataset " + config.continue_from + " not found");
}
// Remove datasets already processed
vector< string >::iterator it = datasets_names.begin();
while (it != datasets_names.end()) {
if (*it != config.continue_from) {
it = datasets_names.erase(it);
} else {
if (config.only)
++it;
else
break;
}
}
}
// Exclude datasets
for (const auto& name : config.excluded) {
auto dataset = name.get<std::string>();
auto it = std::find(datasets_names.begin(), datasets_names.end(), dataset);
if (it == datasets_names.end()) {
throw std::invalid_argument("Dataset " + dataset + " already excluded or doesn't exist!");
}
datasets_names.erase(it);
}
return datasets_names;
}
json GridSearch::initializeResults()
{
// Load previous results

View File

@ -2,6 +2,7 @@
#define GRIDSEARCH_H
#include <string>
#include <map>
#include <mpi.h>
#include <nlohmann/json.hpp>
#include "Datasets.h"
#include "HyperParameters.h"

View File

@ -189,8 +189,8 @@ int main(int argc, char** argv)
auto excluded = program.get<std::string>("exclude");
config.excluded = json::parse(excluded);
if (program.get<bool>("mpi")) {
if (!compute) {
throw std::runtime_error("Cannot use --mpi without --compute");
if (!compute || config.nested == 0) {
throw std::runtime_error("Cannot use --mpi without --compute or without --nested");
}
}
}