From 2e95e8999d8bfe2dd222fcf7e95644df426add05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Sun, 3 Dec 2023 12:37:25 +0100 Subject: [PATCH] Complete nested gridsearch --- src/Platform/GridSearch.cc | 54 ++++++++++++++++++++++++++++++-------- 1 file changed, 43 insertions(+), 11 deletions(-) diff --git a/src/Platform/GridSearch.cc b/src/Platform/GridSearch.cc index 9eed00c..49c3a71 100644 --- a/src/Platform/GridSearch.cc +++ b/src/Platform/GridSearch.cc @@ -172,8 +172,9 @@ namespace platform { auto states = datasets.getStates(fileName); auto features = datasets.getFeatures(fileName); auto className = datasets.getClassName(fileName); - double bestScore = 0.0; - json bestHyperparameters; + int spcs_combinations = int(log(combinations.size()) / log(10)) + 1; + double goatScore = 0.0; + json goatHyperparameters; // for dataset // for seed // for fold // for hyperparameters // for nested fold for (const auto& seed : config.seeds) { Fold* fold; @@ -182,9 +183,11 @@ namespace platform { else fold = new KFold(config.n_folds, y.size(0), seed); double bestScore = 0.0; + json bestHyperparameters; std::cout << "(" << seed << ") doing Fold: " << flush; for (int nfold = 0; nfold < config.n_folds; nfold++) { - std::cout << nfold + 1 << flush; + if (!config.quiet) + std::cout << Colors::GREEN() << nfold + 1 << " " << flush; // First level fold auto [train, test] = fold->getFold(nfold); auto train_t = torch::tensor(train); @@ -195,9 +198,12 @@ namespace platform { auto y_test = y.index({ test_t }); auto num = 0; json result_fold; - double score = 0.0; + double hypScore = 0.0; + double bestHypScore = 0.0; + json bestHypHyperparameters; for (const auto& hyperparam_line : combinations) { - std::cout << "[" << ++num << "/" << combinations.size() << "] " << std::flush; + std::cout << "[" << setw(spcs_combinations) << ++num << "/" << setw(spcs_combinations) + << combinations.size() << "] " << std::flush; Fold* nested_fold; if (config.stratified) nested_fold = new StratifiedKFold(config.nested, y_train, seed); @@ -221,21 +227,47 @@ namespace platform { // Train model if (!config.quiet) showProgressFold(n_nested_fold + 1, getColor(clf->getStatus()), "a"); - // clf->fit(X_nexted_train, y_nested_train, features, className, states); + clf->fit(X_nexted_train, y_nested_train, features, className, states); // Test model if (!config.quiet) showProgressFold(n_nested_fold + 1, getColor(clf->getStatus()), "b"); - // score += clf->score(X_nested_test, y_nested_test); - score = 0.0; + hypScore += clf->score(X_nested_test, y_nested_test); + if (!config.quiet) + std::cout << "\b\b\b, \b" << flush; } + std::cout << string(3 * config.nested + 2 * spcs_combinations + 4, '\b') << flush; delete nested_fold; - score = score / config.nested; - + hypScore /= config.nested; + if (hypScore > bestHypScore) { + bestHypScore = hypScore; + bestHypHyperparameters = hyperparam_line; + } } + // Build Classifier with selected hyperparameters + auto clf = Models::instance()->create(config.model); + clf->setHyperparameters(bestHypHyperparameters); + // Train model + if (!config.quiet) + showProgressFold(nfold + 1, getColor(clf->getStatus()), "a"); + clf->fit(X_train, y_train, features, className, states); + // Test model + if (!config.quiet) + showProgressFold(nfold + 1, getColor(clf->getStatus()), "b"); + double score = clf->score(X_test, y_test); + if (!config.quiet) + std::cout << "\b\b\b\b\b, \b" << flush; + if (score > bestScore) { + bestScore = score; + bestHyperparameters = bestHypHyperparameters; + } + } + if (bestScore > goatScore) { + goatScore = bestScore; + goatHyperparameters = bestHyperparameters; } delete fold; } - return { bestScore, bestHyperparameters }; + return { goatScore, goatHyperparameters }; } vector GridSearch::processDatasets(Datasets& datasets) {