Add quiet mode to b_main

Reduce output when --quiet is set, not showing fold info
This commit is contained in:
Ricardo Montañana Gómez 2023-10-17 21:51:53 +02:00
parent 6765552a7c
commit 2c2159f192
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
3 changed files with 26 additions and 15 deletions

View File

@ -102,12 +102,12 @@ namespace platform {
cout << data.dump(4) << endl; cout << data.dump(4) << endl;
} }
void Experiment::go(vector<string> filesToProcess) void Experiment::go(vector<string> filesToProcess, bool quiet)
{ {
cout << "*** Starting experiment: " << title << " ***" << endl; cout << "*** Starting experiment: " << title << " ***" << endl;
for (auto fileName : filesToProcess) { for (auto fileName : filesToProcess) {
cout << "- " << setw(20) << left << fileName << " " << right << flush; cout << "- " << setw(20) << left << fileName << " " << right << flush;
cross_validation(fileName); cross_validation(fileName, quiet);
cout << endl; cout << endl;
} }
} }
@ -132,7 +132,7 @@ namespace platform {
cout << prefix << color << fold << Colors::RESET() << "(" << color << phase << Colors::RESET() << ")" << flush; cout << prefix << color << fold << Colors::RESET() << "(" << color << phase << Colors::RESET() << ")" << flush;
} }
void Experiment::cross_validation(const string& fileName) void Experiment::cross_validation(const string& fileName, bool quiet)
{ {
auto datasets = platform::Datasets(discretized, Paths::datasets()); auto datasets = platform::Datasets(discretized, Paths::datasets());
// Get dataset // Get dataset
@ -141,7 +141,9 @@ namespace platform {
auto features = datasets.getFeatures(fileName); auto features = datasets.getFeatures(fileName);
auto samples = datasets.getNSamples(fileName); auto samples = datasets.getNSamples(fileName);
auto className = datasets.getClassName(fileName); auto className = datasets.getClassName(fileName);
if (!quiet) {
cout << " (" << setw(5) << samples << "," << setw(3) << features.size() << ") " << flush; cout << " (" << setw(5) << samples << "," << setw(3) << features.size() << ") " << flush;
}
// Prepare Result // Prepare Result
auto result = Result(); auto result = Result();
auto [values, counts] = at::_unique(y); auto [values, counts] = at::_unique(y);
@ -159,6 +161,7 @@ namespace platform {
Timer train_timer, test_timer; Timer train_timer, test_timer;
int item = 0; int item = 0;
for (auto seed : randomSeeds) { for (auto seed : randomSeeds) {
if (!quiet)
cout << "(" << seed << ") doing Fold: " << flush; cout << "(" << seed << ") doing Fold: " << flush;
Fold* fold; Fold* fold;
if (stratified) if (stratified)
@ -180,9 +183,11 @@ namespace platform {
auto y_train = y.index({ train_t }); auto y_train = y.index({ train_t });
auto X_test = X.index({ "...", test_t }); auto X_test = X.index({ "...", test_t });
auto y_test = y.index({ test_t }); auto y_test = y.index({ test_t });
if (!quiet)
showProgress(nfold + 1, getColor(clf->getStatus()), "a"); showProgress(nfold + 1, getColor(clf->getStatus()), "a");
// Train model // Train model
clf->fit(X_train, y_train, features, className, states); clf->fit(X_train, y_train, features, className, states);
if (!quiet)
showProgress(nfold + 1, getColor(clf->getStatus()), "b"); showProgress(nfold + 1, getColor(clf->getStatus()), "b");
nodes[item] = clf->getNumberOfNodes(); nodes[item] = clf->getNumberOfNodes();
edges[item] = clf->getNumberOfEdges(); edges[item] = clf->getNumberOfEdges();
@ -191,12 +196,14 @@ namespace platform {
// Score train // Score train
auto accuracy_train_value = clf->score(X_train, y_train); auto accuracy_train_value = clf->score(X_train, y_train);
// Test model // Test model
if (!quiet)
showProgress(nfold + 1, getColor(clf->getStatus()), "c"); showProgress(nfold + 1, getColor(clf->getStatus()), "c");
test_timer.start(); test_timer.start();
auto accuracy_test_value = clf->score(X_test, y_test); auto accuracy_test_value = clf->score(X_test, y_test);
test_time[item] = test_timer.getDuration(); test_time[item] = test_timer.getDuration();
accuracy_train[item] = accuracy_train_value; accuracy_train[item] = accuracy_train_value;
accuracy_test[item] = accuracy_test_value; accuracy_test[item] = accuracy_test_value;
if (!quiet)
cout << "\b\b\b, " << flush; cout << "\b\b\b, " << flush;
// Store results and times in vector // Store results and times in vector
result.addScoreTrain(accuracy_train_value); result.addScoreTrain(accuracy_train_value);
@ -206,6 +213,7 @@ namespace platform {
item++; item++;
clf.reset(); clf.reset();
} }
if (!quiet)
cout << "end. " << flush; cout << "end. " << flush;
delete fold; delete fold;
} }

View File

@ -108,8 +108,8 @@ namespace platform {
Experiment& setHyperparameters(const json& hyperparameters) { this->hyperparameters = hyperparameters; return *this; } Experiment& setHyperparameters(const json& hyperparameters) { this->hyperparameters = hyperparameters; return *this; }
string get_file_name(); string get_file_name();
void save(const string& path); void save(const string& path);
void cross_validation(const string& fileName); void cross_validation(const string& fileName, bool quiet);
void go(vector<string> filesToProcess); void go(vector<string> filesToProcess, bool quiet);
void show(); void show();
void report(); void report();
}; };

View File

@ -30,6 +30,7 @@ argparse::ArgumentParser manageArguments()
); );
program.add_argument("--title").default_value("").help("Experiment title"); program.add_argument("--title").default_value("").help("Experiment title");
program.add_argument("--discretize").help("Discretize input dataset").default_value((bool)stoi(env.get("discretize"))).implicit_value(true); program.add_argument("--discretize").help("Discretize input dataset").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
program.add_argument("--save").help("Save result (always save if no dataset is supplied)").default_value(false).implicit_value(true); program.add_argument("--save").help("Save result (always save if no dataset is supplied)").default_value(false).implicit_value(true);
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true); program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const string& value) { program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const string& value) {
@ -55,7 +56,7 @@ int main(int argc, char** argv)
{ {
string file_name, model_name, title; string file_name, model_name, title;
json hyperparameters_json; json hyperparameters_json;
bool discretize_dataset, stratified, saveResults; bool discretize_dataset, stratified, saveResults, quiet;
vector<int> seeds; vector<int> seeds;
vector<string> filesToTest; vector<string> filesToTest;
int n_folds; int n_folds;
@ -66,6 +67,7 @@ int main(int argc, char** argv)
model_name = program.get<string>("model"); model_name = program.get<string>("model");
discretize_dataset = program.get<bool>("discretize"); discretize_dataset = program.get<bool>("discretize");
stratified = program.get<bool>("stratified"); stratified = program.get<bool>("stratified");
quiet = program.get<bool>("quiet");
n_folds = program.get<int>("folds"); n_folds = program.get<int>("folds");
seeds = program.get<vector<int>>("seeds"); seeds = program.get<vector<int>>("seeds");
auto hyperparameters = program.get<string>("hyperparameters"); auto hyperparameters = program.get<string>("hyperparameters");
@ -109,11 +111,12 @@ int main(int argc, char** argv)
} }
platform::Timer timer; platform::Timer timer;
timer.start(); timer.start();
experiment.go(filesToTest); experiment.go(filesToTest, quiet);
experiment.setDuration(timer.getDuration()); experiment.setDuration(timer.getDuration());
if (saveResults) { if (saveResults) {
experiment.save(platform::Paths::results()); experiment.save(platform::Paths::results());
} }
if (!quiet)
experiment.report(); experiment.report();
cout << "Done!" << endl; cout << "Done!" << endl;
return 0; return 0;