diff --git a/src/Platform/GridSearch.cc b/src/Platform/GridSearch.cc index d0b84ed..fd2cbbe 100644 --- a/src/Platform/GridSearch.cc +++ b/src/Platform/GridSearch.cc @@ -48,7 +48,8 @@ namespace platform { double totalScore = 0.0; int numItems = 0; for (const auto& seed : config.seeds) { - std::cout << "(" << seed << ") doing Fold: " << flush; + if (!config.quiet) + std::cout << "(" << seed << ") doing Fold: " << flush; Fold* fold; if (config.stratified) fold = new StratifiedKFold(config.n_folds, y, seed); @@ -66,13 +67,16 @@ namespace platform { auto X_test = X.index({ "...", test_t }); auto y_test = y.index({ test_t }); // Train model + if (!config.quiet) + showProgressFold(nfold + 1, getColor(clf->getStatus()), "a"); clf->fit(X_train, y_train, features, className, states); - showProgressFold(nfold + 1, getColor(clf->getStatus()), "a"); - showProgressFold(nfold + 1, getColor(clf->getStatus()), "b"); + // Test model + if (!config.quiet) + showProgressFold(nfold + 1, getColor(clf->getStatus()), "b"); totalScore += clf->score(X_test, y_test); numItems++; - showProgressFold(nfold + 1, getColor(clf->getStatus()), "c"); - std::cout << "\b\b\b, \b" << flush; + if (!config.quiet) + std::cout << "\b\b\b, \b" << flush; } delete fold; } @@ -91,12 +95,14 @@ namespace platform { // Generate hyperparameters grid & run gridsearch // Check each combination of hyperparameters for each dataset and each seed for (const auto& dataset : datasets.getNames()) { - std::cout << "- " << setw(20) << left << dataset << " " << right << flush; + if (!config.quiet) + std::cout << "- " << setw(20) << left << dataset << " " << right << flush; int num = 0; double bestScore = 0.0; json bestHyperparameters; for (const auto& hyperparam_line : grid.getGrid()) { - showProgressComb(++num, totalComb, Colors::CYAN()); + if (!config.quiet) + showProgressComb(++num, totalComb, Colors::CYAN()); auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line); double score = processFile(dataset, datasets, hyperparameters); if (score > bestScore) { @@ -104,15 +110,18 @@ namespace platform { bestHyperparameters = hyperparam_line; } } - std::cout << "end." << " Score: " << setw(9) << setprecision(7) << fixed - << bestScore << " [" << bestHyperparameters.dump() << "]" << std::endl; + if (!config.quiet) { + std::cout << "end." << " Score: " << setw(9) << setprecision(7) << fixed + << bestScore << " [" << bestHyperparameters.dump() << "]" << std::endl; + } results[dataset]["score"] = bestScore; results[dataset]["hyperparameters"] = bestHyperparameters; } // Save results save(); + std::cout << "***************** Ending Gridsearch *******************" << std::endl; } - void GridSearch::save() + void GridSearch::save() const { std::ofstream file(config.output_file); file << results.dump(4); diff --git a/src/Platform/GridSearch.h b/src/Platform/GridSearch.h index 81f06b5..c5528d0 100644 --- a/src/Platform/GridSearch.h +++ b/src/Platform/GridSearch.h @@ -15,6 +15,7 @@ namespace platform { std::string path; std::string input_file; std::string output_file; + bool quiet; bool discretize; bool stratified; int n_folds; @@ -24,7 +25,7 @@ namespace platform { public: explicit GridSearch(struct ConfigGrid& config); void go(); - void save(); + void save() const; ~GridSearch() = default; private: double processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters); diff --git a/src/Platform/b_grid.cc b/src/Platform/b_grid.cc index b66050d..8887892 100644 --- a/src/Platform/b_grid.cc +++ b/src/Platform/b_grid.cc @@ -23,6 +23,7 @@ argparse::ArgumentParser manageArguments(std::string program_name) } ); program.add_argument("--discretize").help("Discretize input datasets").default_value((bool)stoi(env.get("discretize"))).implicit_value(true); + program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true); program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true); program.add_argument("--score").help("Score used in gridsearch").default_value("accuracy"); program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const std::string& value) { @@ -55,6 +56,7 @@ int main(int argc, char** argv) config.discretize = program.get("discretize"); config.stratified = program.get("stratified"); config.n_folds = program.get("folds"); + config.quiet = program.get("quiet"); config.seeds = program.get>("seeds"); } catch (const exception& err) {