Add traintest split in gridsearch
This commit is contained in:
@@ -118,17 +118,18 @@ namespace platform {
|
||||
json task = tasks[n_task];
|
||||
auto model = config.model;
|
||||
auto grid = GridData(Paths::grid_input(model));
|
||||
auto dataset = task["dataset"].get<std::string>();
|
||||
auto dataset_name = task["dataset"].get<std::string>();
|
||||
auto idx_dataset = task["idx_dataset"].get<int>();
|
||||
auto seed = task["seed"].get<int>();
|
||||
auto n_fold = task["fold"].get<int>();
|
||||
bool stratified = config.stratified;
|
||||
// Generate the hyperparamters combinations
|
||||
auto combinations = grid.getGrid(dataset);
|
||||
auto [X, y] = datasets.getTensors(dataset);
|
||||
auto states = datasets.getStates(dataset);
|
||||
auto features = datasets.getFeatures(dataset);
|
||||
auto className = datasets.getClassName(dataset);
|
||||
auto& dataset = datasets.getDataset(dataset_name);
|
||||
auto combinations = grid.getGrid(dataset_name);
|
||||
auto [X, y] = dataset.getTensors();
|
||||
auto states = dataset.getStates();
|
||||
auto features = dataset.getFeatures();
|
||||
auto className = dataset.getClassName();
|
||||
//
|
||||
// Start working on task
|
||||
//
|
||||
@@ -138,12 +139,7 @@ namespace platform {
|
||||
else
|
||||
fold = new folding::KFold(config.n_folds, y.size(0), seed);
|
||||
auto [train, test] = fold->getFold(n_fold);
|
||||
auto train_t = torch::tensor(train);
|
||||
auto test_t = torch::tensor(test);
|
||||
auto X_train = X.index({ "...", train_t });
|
||||
auto y_train = y.index({ train_t });
|
||||
auto X_test = X.index({ "...", test_t });
|
||||
auto y_test = y.index({ test_t });
|
||||
auto [X_train, X_test, y_train, y_test] = dataset.getTrainTestTensors(train, test);
|
||||
double best_fold_score = 0.0;
|
||||
int best_idx_combination = -1;
|
||||
json best_fold_hyper;
|
||||
@@ -168,8 +164,8 @@ namespace platform {
|
||||
// Build Classifier with selected hyperparameters
|
||||
auto clf = Models::instance()->create(config.model);
|
||||
auto valid = clf->getValidHyperparameters();
|
||||
hyperparameters.check(valid, dataset);
|
||||
clf->setHyperparameters(hyperparameters.get(dataset));
|
||||
hyperparameters.check(valid, dataset_name);
|
||||
clf->setHyperparameters(hyperparameters.get(dataset_name));
|
||||
// Train model
|
||||
clf->fit(X_nested_train, y_nested_train, features, className, states);
|
||||
// Test model
|
||||
@@ -188,7 +184,7 @@ namespace platform {
|
||||
auto hyperparameters = platform::HyperParameters(datasets.getNames(), best_fold_hyper);
|
||||
auto clf = Models::instance()->create(config.model);
|
||||
auto valid = clf->getValidHyperparameters();
|
||||
hyperparameters.check(valid, dataset);
|
||||
hyperparameters.check(valid, dataset_name);
|
||||
clf->setHyperparameters(best_fold_hyper);
|
||||
clf->fit(X_train, y_train, features, className, states);
|
||||
best_fold_score = clf->score(X_test, y_test);
|
||||
|
Reference in New Issue
Block a user