From 2c2159f192be414ad259db3df15b731238f08493 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Tue, 17 Oct 2023 21:51:53 +0200 Subject: [PATCH] Add quiet mode to b_main Reduce output when --quiet is set, not showing fold info --- src/Platform/Experiment.cc | 28 ++++++++++++++++++---------- src/Platform/Experiment.h | 4 ++-- src/Platform/main.cc | 9 ++++++--- 3 files changed, 26 insertions(+), 15 deletions(-) diff --git a/src/Platform/Experiment.cc b/src/Platform/Experiment.cc index 311dbc7..219295c 100644 --- a/src/Platform/Experiment.cc +++ b/src/Platform/Experiment.cc @@ -102,12 +102,12 @@ namespace platform { cout << data.dump(4) << endl; } - void Experiment::go(vector filesToProcess) + void Experiment::go(vector filesToProcess, bool quiet) { cout << "*** Starting experiment: " << title << " ***" << endl; for (auto fileName : filesToProcess) { cout << "- " << setw(20) << left << fileName << " " << right << flush; - cross_validation(fileName); + cross_validation(fileName, quiet); cout << endl; } } @@ -132,7 +132,7 @@ namespace platform { cout << prefix << color << fold << Colors::RESET() << "(" << color << phase << Colors::RESET() << ")" << flush; } - void Experiment::cross_validation(const string& fileName) + void Experiment::cross_validation(const string& fileName, bool quiet) { auto datasets = platform::Datasets(discretized, Paths::datasets()); // Get dataset @@ -141,7 +141,9 @@ namespace platform { auto features = datasets.getFeatures(fileName); auto samples = datasets.getNSamples(fileName); auto className = datasets.getClassName(fileName); - cout << " (" << setw(5) << samples << "," << setw(3) << features.size() << ") " << flush; + if (!quiet) { + cout << " (" << setw(5) << samples << "," << setw(3) << features.size() << ") " << flush; + } // Prepare Result auto result = Result(); auto [values, counts] = at::_unique(y); @@ -159,7 +161,8 @@ namespace platform { Timer train_timer, test_timer; int item = 0; for (auto seed : randomSeeds) { - cout << "(" << seed << ") doing Fold: " << flush; + if (!quiet) + cout << "(" << seed << ") doing Fold: " << flush; Fold* fold; if (stratified) fold = new StratifiedKFold(nfolds, y, seed); @@ -180,10 +183,12 @@ namespace platform { auto y_train = y.index({ train_t }); auto X_test = X.index({ "...", test_t }); auto y_test = y.index({ test_t }); - showProgress(nfold + 1, getColor(clf->getStatus()), "a"); + if (!quiet) + showProgress(nfold + 1, getColor(clf->getStatus()), "a"); // Train model clf->fit(X_train, y_train, features, className, states); - showProgress(nfold + 1, getColor(clf->getStatus()), "b"); + if (!quiet) + showProgress(nfold + 1, getColor(clf->getStatus()), "b"); nodes[item] = clf->getNumberOfNodes(); edges[item] = clf->getNumberOfEdges(); num_states[item] = clf->getNumberOfStates(); @@ -191,13 +196,15 @@ namespace platform { // Score train auto accuracy_train_value = clf->score(X_train, y_train); // Test model - showProgress(nfold + 1, getColor(clf->getStatus()), "c"); + if (!quiet) + showProgress(nfold + 1, getColor(clf->getStatus()), "c"); test_timer.start(); auto accuracy_test_value = clf->score(X_test, y_test); test_time[item] = test_timer.getDuration(); accuracy_train[item] = accuracy_train_value; accuracy_test[item] = accuracy_test_value; - cout << "\b\b\b, " << flush; + if (!quiet) + cout << "\b\b\b, " << flush; // Store results and times in vector result.addScoreTrain(accuracy_train_value); result.addScoreTest(accuracy_test_value); @@ -206,7 +213,8 @@ namespace platform { item++; clf.reset(); } - cout << "end. " << flush; + if (!quiet) + cout << "end. " << flush; delete fold; } result.setScoreTest(torch::mean(accuracy_test).item()).setScoreTrain(torch::mean(accuracy_train).item()); diff --git a/src/Platform/Experiment.h b/src/Platform/Experiment.h index 5653e93..1af372e 100644 --- a/src/Platform/Experiment.h +++ b/src/Platform/Experiment.h @@ -108,8 +108,8 @@ namespace platform { Experiment& setHyperparameters(const json& hyperparameters) { this->hyperparameters = hyperparameters; return *this; } string get_file_name(); void save(const string& path); - void cross_validation(const string& fileName); - void go(vector filesToProcess); + void cross_validation(const string& fileName, bool quiet); + void go(vector filesToProcess, bool quiet); void show(); void report(); }; diff --git a/src/Platform/main.cc b/src/Platform/main.cc index ecdf258..033c8a1 100644 --- a/src/Platform/main.cc +++ b/src/Platform/main.cc @@ -30,6 +30,7 @@ argparse::ArgumentParser manageArguments() ); program.add_argument("--title").default_value("").help("Experiment title"); program.add_argument("--discretize").help("Discretize input dataset").default_value((bool)stoi(env.get("discretize"))).implicit_value(true); + program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true); program.add_argument("--save").help("Save result (always save if no dataset is supplied)").default_value(false).implicit_value(true); program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true); program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const string& value) { @@ -55,7 +56,7 @@ int main(int argc, char** argv) { string file_name, model_name, title; json hyperparameters_json; - bool discretize_dataset, stratified, saveResults; + bool discretize_dataset, stratified, saveResults, quiet; vector seeds; vector filesToTest; int n_folds; @@ -66,6 +67,7 @@ int main(int argc, char** argv) model_name = program.get("model"); discretize_dataset = program.get("discretize"); stratified = program.get("stratified"); + quiet = program.get("quiet"); n_folds = program.get("folds"); seeds = program.get>("seeds"); auto hyperparameters = program.get("hyperparameters"); @@ -109,12 +111,13 @@ int main(int argc, char** argv) } platform::Timer timer; timer.start(); - experiment.go(filesToTest); + experiment.go(filesToTest, quiet); experiment.setDuration(timer.getDuration()); if (saveResults) { experiment.save(platform::Paths::results()); } - experiment.report(); + if (!quiet) + experiment.report(); cout << "Done!" << endl; return 0; }