From c2eb727fc7cda664d2c183cfb61f04a5f3f631e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Wed, 22 Nov 2023 16:30:04 +0100 Subject: [PATCH] Complete output interface of gridsearch --- src/Platform/GridSearch.cc | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/src/Platform/GridSearch.cc b/src/Platform/GridSearch.cc index e5f072a..73c9ba3 100644 --- a/src/Platform/GridSearch.cc +++ b/src/Platform/GridSearch.cc @@ -11,7 +11,14 @@ namespace platform { { this->config.output_file = config.path + "grid_" + config.model + "_output.json"; } - void showProgress(int fold, const std::string& color, const std::string& phase) + void showProgressComb(const int num, const int total, const std::string& color) + { + int spaces = int(log(total) / log(10)) + 1; + int magic = 37 + 2 * spaces; + std::string prefix = num == 1 ? "" : string(magic, '\b') + string(magic + 1, ' ') + string(magic + 1, '\b'); + std::cout << prefix << color << "(" << setw(spaces) << num << "/" << setw(spaces) << total << ") " << Colors::RESET() << flush; + } + void showProgressFold(int fold, const std::string& color, const std::string& phase) { std::string prefix = phase == "a" ? "" : "\b\b\b\b"; std::cout << prefix << color << fold << Colors::RESET() << "(" << color << phase << Colors::RESET() << ")" << flush; @@ -37,7 +44,6 @@ namespace platform { auto features = datasets.getFeatures(fileName); auto samples = datasets.getNSamples(fileName); auto className = datasets.getClassName(fileName); - std::cout << " (" << setw(5) << samples << "," << setw(3) << features.size() << ") " << flush; for (const auto& seed : config.seeds) { std::cout << "(" << seed << ") doing Fold: " << flush; Fold* fold; @@ -54,10 +60,13 @@ namespace platform { // auto y_train = y.index({ train_t }); // auto X_test = X.index({ "...", test_t }); // auto y_test = y.index({ test_t }); - showProgress(nfold + 1, getColor(clf->getStatus()), "a"); + showProgressFold(nfold + 1, getColor(clf->getStatus()), "a"); // Train model // clf->fit(X_train, y_train, features, className, states); - showProgress(nfold + 1, getColor(clf->getStatus()), "b"); + showProgressFold(nfold + 1, getColor(clf->getStatus()), "b"); + showProgressFold(nfold + 1, getColor(clf->getStatus()), "c"); + sleep(1); + std::cout << "\b\b\b, " << flush; } delete fold; } @@ -68,12 +77,15 @@ namespace platform { auto datasets = Datasets(config.discretize, Paths::datasets()); // Create model std::cout << "***************** Starting Gridsearch *****************" << std::endl; - std::cout << "* Doing " << grid.getNumCombinations(config.model) << " combinations for each dataset/seed/fold" << std::endl; + auto totalComb = grid.getNumCombinations(config.model); + std::cout << "* Doing " << totalComb << " combinations for each dataset/seed/fold" << std::endl; // Generate hyperparameters grid & run gridsearch - // Check each combination of hyperparameters for each dataset and each seed + // Check each combination of hyperparameters for each dataset and each seed for (const auto& dataset : datasets.getNames()) { std::cout << "- " << setw(20) << left << dataset << " " << right << flush; + int num = 0; for (const auto& hyperparam_line : grid.getGrid(config.model)) { + showProgressComb(++num, totalComb, Colors::CYAN()); auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line); processFile(dataset, datasets, hyperparameters); }