First compiling version
This commit is contained in:
parent
db9e80a70e
commit
dbf2f35502
@ -38,6 +38,39 @@ namespace platform {
|
||||
}
|
||||
return json();
|
||||
}
|
||||
vector<std::string> GridSearch::processDatasets(Datasets& datasets)
|
||||
{
|
||||
// Load datasets
|
||||
auto datasets_names = datasets.getNames();
|
||||
if (config.continue_from != NO_CONTINUE()) {
|
||||
// Continue previous execution:
|
||||
if (std::find(datasets_names.begin(), datasets_names.end(), config.continue_from) == datasets_names.end()) {
|
||||
throw std::invalid_argument("Dataset " + config.continue_from + " not found");
|
||||
}
|
||||
// Remove datasets already processed
|
||||
vector< string >::iterator it = datasets_names.begin();
|
||||
while (it != datasets_names.end()) {
|
||||
if (*it != config.continue_from) {
|
||||
it = datasets_names.erase(it);
|
||||
} else {
|
||||
if (config.only)
|
||||
++it;
|
||||
else
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Exclude datasets
|
||||
for (const auto& name : config.excluded) {
|
||||
auto dataset = name.get<std::string>();
|
||||
auto it = std::find(datasets_names.begin(), datasets_names.end(), dataset);
|
||||
if (it == datasets_names.end()) {
|
||||
throw std::invalid_argument("Dataset " + dataset + " already excluded or doesn't exist!");
|
||||
}
|
||||
datasets_names.erase(it);
|
||||
}
|
||||
return datasets_names;
|
||||
}
|
||||
void showProgressComb(const int num, const int n_folds, const int total, const std::string& color)
|
||||
{
|
||||
int spaces = int(log(total) / log(10)) + 1;
|
||||
@ -67,7 +100,7 @@ namespace platform {
|
||||
{
|
||||
auto result = json::array();
|
||||
auto datasets = Datasets(false, Paths::datasets());
|
||||
auto datasets_names = datasets.getNames();
|
||||
auto datasets_names = processDatasets(datasets);
|
||||
auto grid = GridData(Paths::grid_input(config.model));
|
||||
for (const auto& dataset : datasets_names) {
|
||||
for (const auto& seed : config.seeds) {
|
||||
@ -103,6 +136,10 @@ namespace platform {
|
||||
}
|
||||
return { start, end };
|
||||
}
|
||||
void status(struct ConfigMPI& config_mpi, std::string status)
|
||||
{
|
||||
std::cout << "* (" << config_mpi.rank << "): " << status << std::endl;
|
||||
}
|
||||
void GridSearch::go_MPI(struct ConfigMPI& config_mpi)
|
||||
{
|
||||
/*
|
||||
@ -148,12 +185,13 @@ namespace platform {
|
||||
int num_tasks = tasks.size();
|
||||
auto [start, end] = partRange(num_tasks, config_mpi.n_procs, config_mpi.rank);
|
||||
// 2.2 Each worker will process the combinations and return the best score obtained
|
||||
auto datasets = Datasets(config.discretize, Paths::datasets());
|
||||
for (int i = start; i < end; ++i) {
|
||||
auto task = tasks[i];
|
||||
auto dataset = task["dataset"].get<std::string>();
|
||||
auto seed = task["seed"].get<int>();
|
||||
auto hyperparam_line = task["hyperparameters"];
|
||||
auto datasets = Datasets(config.discretize, Paths::datasets());
|
||||
status(config_mpi, "Processing dataset " + dataset + " with seed " + std::to_string(seed) + " and hyperparameters " + hyperparam_line.dump());
|
||||
auto [X, y] = datasets.getTensors(dataset);
|
||||
auto states = datasets.getStates(dataset);
|
||||
auto features = datasets.getFeatures(dataset);
|
||||
@ -167,20 +205,53 @@ namespace platform {
|
||||
else
|
||||
fold = new KFold(config.n_folds, y.size(0), seed);
|
||||
for (int nfold = 0; nfold < config.n_folds; nfold++) {
|
||||
|
||||
auto clf = Models::instance()->create(config.model);
|
||||
auto valid = clf->getValidHyperparameters();
|
||||
hyperparameters.check(valid, dataset);
|
||||
clf->setHyperparameters(hyperparameters.get(dataset));
|
||||
status(config_mpi, "Processing fold " + std::to_string(nfold + 1));
|
||||
auto [train, test] = fold->getFold(nfold);
|
||||
auto train_t = torch::tensor(train);
|
||||
auto test_t = torch::tensor(test);
|
||||
auto X_train = X.index({ "...", train_t });
|
||||
auto y_train = y.index({ train_t });
|
||||
auto X_test = X.index({ "...", test
|
||||
}
|
||||
|
||||
auto X_test = X.index({ "...", test_t });
|
||||
auto y_test = y.index({ test_t });
|
||||
auto num = 0;
|
||||
json result_fold;
|
||||
double hypScore = 0.0;
|
||||
double bestHypScore = 0.0;
|
||||
json bestHypHyperparameters;
|
||||
Fold* nested_fold;
|
||||
if (config.stratified)
|
||||
nested_fold = new StratifiedKFold(config.nested, y_train, seed);
|
||||
else
|
||||
nested_fold = new KFold(config.nested, y_train.size(0), seed);
|
||||
for (int n_nested_fold = 0; n_nested_fold < config.nested; n_nested_fold++) {
|
||||
// Nested level fold
|
||||
status(config_mpi, "Processing nested fold " + std::to_string(n_nested_fold + 1));
|
||||
auto [train_nested, test_nested] = nested_fold->getFold(n_nested_fold);
|
||||
auto train_nested_t = torch::tensor(train_nested);
|
||||
auto test_nested_t = torch::tensor(test_nested);
|
||||
auto X_nexted_train = X_train.index({ "...", train_nested_t });
|
||||
auto y_nested_train = y_train.index({ train_nested_t });
|
||||
auto X_nested_test = X_train.index({ "...", test_nested_t });
|
||||
auto y_nested_test = y_train.index({ test_nested_t });
|
||||
// Build Classifier with selected hyperparameters
|
||||
auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line);
|
||||
auto clf = Models::instance()->create(config.model);
|
||||
auto valid = clf->getValidHyperparameters();
|
||||
hyperparameters.check(valid, dataset);
|
||||
clf->setHyperparameters(hyperparameters.get(dataset));
|
||||
// Train model
|
||||
clf->fit(X_nexted_train, y_nested_train, features, className, states);
|
||||
// Test model
|
||||
hypScore += clf->score(X_nested_test, y_nested_test);
|
||||
}
|
||||
delete nested_fold;
|
||||
hypScore /= config.nested;
|
||||
if (hypScore > bestHypScore) {
|
||||
bestHypScore = hypScore;
|
||||
bestHypHyperparameters = hyperparam_line;
|
||||
}
|
||||
}
|
||||
delete fold;
|
||||
}
|
||||
}
|
||||
void GridSearch::go()
|
||||
@ -391,39 +462,6 @@ namespace platform {
|
||||
}
|
||||
return { goatScore, goatHyperparameters };
|
||||
}
|
||||
vector<std::string> GridSearch::processDatasets(Datasets& datasets)
|
||||
{
|
||||
// Load datasets
|
||||
auto datasets_names = datasets.getNames();
|
||||
if (config.continue_from != NO_CONTINUE()) {
|
||||
// Continue previous execution:
|
||||
if (std::find(datasets_names.begin(), datasets_names.end(), config.continue_from) == datasets_names.end()) {
|
||||
throw std::invalid_argument("Dataset " + config.continue_from + " not found");
|
||||
}
|
||||
// Remove datasets already processed
|
||||
vector< string >::iterator it = datasets_names.begin();
|
||||
while (it != datasets_names.end()) {
|
||||
if (*it != config.continue_from) {
|
||||
it = datasets_names.erase(it);
|
||||
} else {
|
||||
if (config.only)
|
||||
++it;
|
||||
else
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Exclude datasets
|
||||
for (const auto& name : config.excluded) {
|
||||
auto dataset = name.get<std::string>();
|
||||
auto it = std::find(datasets_names.begin(), datasets_names.end(), dataset);
|
||||
if (it == datasets_names.end()) {
|
||||
throw std::invalid_argument("Dataset " + dataset + " already excluded or doesn't exist!");
|
||||
}
|
||||
datasets_names.erase(it);
|
||||
}
|
||||
return datasets_names;
|
||||
}
|
||||
json GridSearch::initializeResults()
|
||||
{
|
||||
// Load previous results
|
||||
|
@ -2,6 +2,7 @@
|
||||
#define GRIDSEARCH_H
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <mpi.h>
|
||||
#include <nlohmann/json.hpp>
|
||||
#include "Datasets.h"
|
||||
#include "HyperParameters.h"
|
||||
|
@ -189,8 +189,8 @@ int main(int argc, char** argv)
|
||||
auto excluded = program.get<std::string>("exclude");
|
||||
config.excluded = json::parse(excluded);
|
||||
if (program.get<bool>("mpi")) {
|
||||
if (!compute) {
|
||||
throw std::runtime_error("Cannot use --mpi without --compute");
|
||||
if (!compute || config.nested == 0) {
|
||||
throw std::runtime_error("Cannot use --mpi without --compute or without --nested");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user