diff --git a/src/Platform/GridSearch.cc b/src/Platform/GridSearch.cc index 73c9ba3..1241ea8 100644 --- a/src/Platform/GridSearch.cc +++ b/src/Platform/GridSearch.cc @@ -54,19 +54,18 @@ namespace platform { for (int nfold = 0; nfold < config.n_folds; nfold++) { auto clf = Models::instance()->create(config.model); auto [train, test] = fold->getFold(nfold); - // auto train_t = torch::tensor(train); - // auto test_t = torch::tensor(test); - // auto X_train = X.index({ "...", train_t }); - // auto y_train = y.index({ train_t }); - // auto X_test = X.index({ "...", test_t }); - // auto y_test = y.index({ test_t }); + auto train_t = torch::tensor(train); + auto test_t = torch::tensor(test); + auto X_train = X.index({ "...", train_t }); + auto y_train = y.index({ train_t }); + auto X_test = X.index({ "...", test_t }); + auto y_test = y.index({ test_t }); showProgressFold(nfold + 1, getColor(clf->getStatus()), "a"); // Train model // clf->fit(X_train, y_train, features, className, states); showProgressFold(nfold + 1, getColor(clf->getStatus()), "b"); showProgressFold(nfold + 1, getColor(clf->getStatus()), "c"); - sleep(1); - std::cout << "\b\b\b, " << flush; + std::cout << "\b\b\b, \b" << flush; } delete fold; } @@ -89,7 +88,7 @@ namespace platform { auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line); processFile(dataset, datasets, hyperparameters); } - std::cout << std::endl; + std::cout << "end." << std::endl; } // Save results save();