First compiling version
This commit is contained in:
parent
db9e80a70e
commit
dbf2f35502
@ -38,6 +38,39 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
return json();
|
return json();
|
||||||
}
|
}
|
||||||
|
vector<std::string> GridSearch::processDatasets(Datasets& datasets)
|
||||||
|
{
|
||||||
|
// Load datasets
|
||||||
|
auto datasets_names = datasets.getNames();
|
||||||
|
if (config.continue_from != NO_CONTINUE()) {
|
||||||
|
// Continue previous execution:
|
||||||
|
if (std::find(datasets_names.begin(), datasets_names.end(), config.continue_from) == datasets_names.end()) {
|
||||||
|
throw std::invalid_argument("Dataset " + config.continue_from + " not found");
|
||||||
|
}
|
||||||
|
// Remove datasets already processed
|
||||||
|
vector< string >::iterator it = datasets_names.begin();
|
||||||
|
while (it != datasets_names.end()) {
|
||||||
|
if (*it != config.continue_from) {
|
||||||
|
it = datasets_names.erase(it);
|
||||||
|
} else {
|
||||||
|
if (config.only)
|
||||||
|
++it;
|
||||||
|
else
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Exclude datasets
|
||||||
|
for (const auto& name : config.excluded) {
|
||||||
|
auto dataset = name.get<std::string>();
|
||||||
|
auto it = std::find(datasets_names.begin(), datasets_names.end(), dataset);
|
||||||
|
if (it == datasets_names.end()) {
|
||||||
|
throw std::invalid_argument("Dataset " + dataset + " already excluded or doesn't exist!");
|
||||||
|
}
|
||||||
|
datasets_names.erase(it);
|
||||||
|
}
|
||||||
|
return datasets_names;
|
||||||
|
}
|
||||||
void showProgressComb(const int num, const int n_folds, const int total, const std::string& color)
|
void showProgressComb(const int num, const int n_folds, const int total, const std::string& color)
|
||||||
{
|
{
|
||||||
int spaces = int(log(total) / log(10)) + 1;
|
int spaces = int(log(total) / log(10)) + 1;
|
||||||
@ -67,7 +100,7 @@ namespace platform {
|
|||||||
{
|
{
|
||||||
auto result = json::array();
|
auto result = json::array();
|
||||||
auto datasets = Datasets(false, Paths::datasets());
|
auto datasets = Datasets(false, Paths::datasets());
|
||||||
auto datasets_names = datasets.getNames();
|
auto datasets_names = processDatasets(datasets);
|
||||||
auto grid = GridData(Paths::grid_input(config.model));
|
auto grid = GridData(Paths::grid_input(config.model));
|
||||||
for (const auto& dataset : datasets_names) {
|
for (const auto& dataset : datasets_names) {
|
||||||
for (const auto& seed : config.seeds) {
|
for (const auto& seed : config.seeds) {
|
||||||
@ -103,6 +136,10 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
return { start, end };
|
return { start, end };
|
||||||
}
|
}
|
||||||
|
void status(struct ConfigMPI& config_mpi, std::string status)
|
||||||
|
{
|
||||||
|
std::cout << "* (" << config_mpi.rank << "): " << status << std::endl;
|
||||||
|
}
|
||||||
void GridSearch::go_MPI(struct ConfigMPI& config_mpi)
|
void GridSearch::go_MPI(struct ConfigMPI& config_mpi)
|
||||||
{
|
{
|
||||||
/*
|
/*
|
||||||
@ -148,12 +185,13 @@ namespace platform {
|
|||||||
int num_tasks = tasks.size();
|
int num_tasks = tasks.size();
|
||||||
auto [start, end] = partRange(num_tasks, config_mpi.n_procs, config_mpi.rank);
|
auto [start, end] = partRange(num_tasks, config_mpi.n_procs, config_mpi.rank);
|
||||||
// 2.2 Each worker will process the combinations and return the best score obtained
|
// 2.2 Each worker will process the combinations and return the best score obtained
|
||||||
|
auto datasets = Datasets(config.discretize, Paths::datasets());
|
||||||
for (int i = start; i < end; ++i) {
|
for (int i = start; i < end; ++i) {
|
||||||
auto task = tasks[i];
|
auto task = tasks[i];
|
||||||
auto dataset = task["dataset"].get<std::string>();
|
auto dataset = task["dataset"].get<std::string>();
|
||||||
auto seed = task["seed"].get<int>();
|
auto seed = task["seed"].get<int>();
|
||||||
auto hyperparam_line = task["hyperparameters"];
|
auto hyperparam_line = task["hyperparameters"];
|
||||||
auto datasets = Datasets(config.discretize, Paths::datasets());
|
status(config_mpi, "Processing dataset " + dataset + " with seed " + std::to_string(seed) + " and hyperparameters " + hyperparam_line.dump());
|
||||||
auto [X, y] = datasets.getTensors(dataset);
|
auto [X, y] = datasets.getTensors(dataset);
|
||||||
auto states = datasets.getStates(dataset);
|
auto states = datasets.getStates(dataset);
|
||||||
auto features = datasets.getFeatures(dataset);
|
auto features = datasets.getFeatures(dataset);
|
||||||
@ -167,20 +205,53 @@ namespace platform {
|
|||||||
else
|
else
|
||||||
fold = new KFold(config.n_folds, y.size(0), seed);
|
fold = new KFold(config.n_folds, y.size(0), seed);
|
||||||
for (int nfold = 0; nfold < config.n_folds; nfold++) {
|
for (int nfold = 0; nfold < config.n_folds; nfold++) {
|
||||||
|
status(config_mpi, "Processing fold " + std::to_string(nfold + 1));
|
||||||
auto clf = Models::instance()->create(config.model);
|
|
||||||
auto valid = clf->getValidHyperparameters();
|
|
||||||
hyperparameters.check(valid, dataset);
|
|
||||||
clf->setHyperparameters(hyperparameters.get(dataset));
|
|
||||||
auto [train, test] = fold->getFold(nfold);
|
auto [train, test] = fold->getFold(nfold);
|
||||||
auto train_t = torch::tensor(train);
|
auto train_t = torch::tensor(train);
|
||||||
auto test_t = torch::tensor(test);
|
auto test_t = torch::tensor(test);
|
||||||
auto X_train = X.index({ "...", train_t });
|
auto X_train = X.index({ "...", train_t });
|
||||||
auto y_train = y.index({ train_t });
|
auto y_train = y.index({ train_t });
|
||||||
auto X_test = X.index({ "...", test
|
auto X_test = X.index({ "...", test_t });
|
||||||
}
|
auto y_test = y.index({ test_t });
|
||||||
|
auto num = 0;
|
||||||
|
json result_fold;
|
||||||
|
double hypScore = 0.0;
|
||||||
|
double bestHypScore = 0.0;
|
||||||
|
json bestHypHyperparameters;
|
||||||
|
Fold* nested_fold;
|
||||||
|
if (config.stratified)
|
||||||
|
nested_fold = new StratifiedKFold(config.nested, y_train, seed);
|
||||||
|
else
|
||||||
|
nested_fold = new KFold(config.nested, y_train.size(0), seed);
|
||||||
|
for (int n_nested_fold = 0; n_nested_fold < config.nested; n_nested_fold++) {
|
||||||
|
// Nested level fold
|
||||||
|
status(config_mpi, "Processing nested fold " + std::to_string(n_nested_fold + 1));
|
||||||
|
auto [train_nested, test_nested] = nested_fold->getFold(n_nested_fold);
|
||||||
|
auto train_nested_t = torch::tensor(train_nested);
|
||||||
|
auto test_nested_t = torch::tensor(test_nested);
|
||||||
|
auto X_nexted_train = X_train.index({ "...", train_nested_t });
|
||||||
|
auto y_nested_train = y_train.index({ train_nested_t });
|
||||||
|
auto X_nested_test = X_train.index({ "...", test_nested_t });
|
||||||
|
auto y_nested_test = y_train.index({ test_nested_t });
|
||||||
|
// Build Classifier with selected hyperparameters
|
||||||
|
auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line);
|
||||||
|
auto clf = Models::instance()->create(config.model);
|
||||||
|
auto valid = clf->getValidHyperparameters();
|
||||||
|
hyperparameters.check(valid, dataset);
|
||||||
|
clf->setHyperparameters(hyperparameters.get(dataset));
|
||||||
|
// Train model
|
||||||
|
clf->fit(X_nexted_train, y_nested_train, features, className, states);
|
||||||
|
// Test model
|
||||||
|
hypScore += clf->score(X_nested_test, y_nested_test);
|
||||||
|
}
|
||||||
|
delete nested_fold;
|
||||||
|
hypScore /= config.nested;
|
||||||
|
if (hypScore > bestHypScore) {
|
||||||
|
bestHypScore = hypScore;
|
||||||
|
bestHypHyperparameters = hyperparam_line;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
delete fold;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
void GridSearch::go()
|
void GridSearch::go()
|
||||||
@ -391,39 +462,6 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
return { goatScore, goatHyperparameters };
|
return { goatScore, goatHyperparameters };
|
||||||
}
|
}
|
||||||
vector<std::string> GridSearch::processDatasets(Datasets& datasets)
|
|
||||||
{
|
|
||||||
// Load datasets
|
|
||||||
auto datasets_names = datasets.getNames();
|
|
||||||
if (config.continue_from != NO_CONTINUE()) {
|
|
||||||
// Continue previous execution:
|
|
||||||
if (std::find(datasets_names.begin(), datasets_names.end(), config.continue_from) == datasets_names.end()) {
|
|
||||||
throw std::invalid_argument("Dataset " + config.continue_from + " not found");
|
|
||||||
}
|
|
||||||
// Remove datasets already processed
|
|
||||||
vector< string >::iterator it = datasets_names.begin();
|
|
||||||
while (it != datasets_names.end()) {
|
|
||||||
if (*it != config.continue_from) {
|
|
||||||
it = datasets_names.erase(it);
|
|
||||||
} else {
|
|
||||||
if (config.only)
|
|
||||||
++it;
|
|
||||||
else
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Exclude datasets
|
|
||||||
for (const auto& name : config.excluded) {
|
|
||||||
auto dataset = name.get<std::string>();
|
|
||||||
auto it = std::find(datasets_names.begin(), datasets_names.end(), dataset);
|
|
||||||
if (it == datasets_names.end()) {
|
|
||||||
throw std::invalid_argument("Dataset " + dataset + " already excluded or doesn't exist!");
|
|
||||||
}
|
|
||||||
datasets_names.erase(it);
|
|
||||||
}
|
|
||||||
return datasets_names;
|
|
||||||
}
|
|
||||||
json GridSearch::initializeResults()
|
json GridSearch::initializeResults()
|
||||||
{
|
{
|
||||||
// Load previous results
|
// Load previous results
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
#define GRIDSEARCH_H
|
#define GRIDSEARCH_H
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <mpi.h>
|
||||||
#include <nlohmann/json.hpp>
|
#include <nlohmann/json.hpp>
|
||||||
#include "Datasets.h"
|
#include "Datasets.h"
|
||||||
#include "HyperParameters.h"
|
#include "HyperParameters.h"
|
||||||
|
@ -189,8 +189,8 @@ int main(int argc, char** argv)
|
|||||||
auto excluded = program.get<std::string>("exclude");
|
auto excluded = program.get<std::string>("exclude");
|
||||||
config.excluded = json::parse(excluded);
|
config.excluded = json::parse(excluded);
|
||||||
if (program.get<bool>("mpi")) {
|
if (program.get<bool>("mpi")) {
|
||||||
if (!compute) {
|
if (!compute || config.nested == 0) {
|
||||||
throw std::runtime_error("Cannot use --mpi without --compute");
|
throw std::runtime_error("Cannot use --mpi without --compute or without --nested");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user