From bbe5302ab1aceb74f063eb8a9b4c8c3310d3f674 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Wed, 22 Nov 2023 16:38:50 +0100 Subject: [PATCH] Add info to output --- src/Platform/GridSearch.cc | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/Platform/GridSearch.cc b/src/Platform/GridSearch.cc index 73c9ba3..1241ea8 100644 --- a/src/Platform/GridSearch.cc +++ b/src/Platform/GridSearch.cc @@ -54,19 +54,18 @@ namespace platform { for (int nfold = 0; nfold < config.n_folds; nfold++) { auto clf = Models::instance()->create(config.model); auto [train, test] = fold->getFold(nfold); - // auto train_t = torch::tensor(train); - // auto test_t = torch::tensor(test); - // auto X_train = X.index({ "...", train_t }); - // auto y_train = y.index({ train_t }); - // auto X_test = X.index({ "...", test_t }); - // auto y_test = y.index({ test_t }); + auto train_t = torch::tensor(train); + auto test_t = torch::tensor(test); + auto X_train = X.index({ "...", train_t }); + auto y_train = y.index({ train_t }); + auto X_test = X.index({ "...", test_t }); + auto y_test = y.index({ test_t }); showProgressFold(nfold + 1, getColor(clf->getStatus()), "a"); // Train model // clf->fit(X_train, y_train, features, className, states); showProgressFold(nfold + 1, getColor(clf->getStatus()), "b"); showProgressFold(nfold + 1, getColor(clf->getStatus()), "c"); - sleep(1); - std::cout << "\b\b\b, " << flush; + std::cout << "\b\b\b, \b" << flush; } delete fold; } @@ -89,7 +88,7 @@ namespace platform { auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line); processFile(dataset, datasets, hyperparameters); } - std::cout << std::endl; + std::cout << "end." << std::endl; } // Save results save();