Add quiet mode to b_main

Reduce output when --quiet is set, not showing fold info
This commit is contained in:
2023-10-17 21:51:53 +02:00
parent 6765552a7c
commit 2c2159f192
3 changed files with 26 additions and 15 deletions

View File

@@ -102,12 +102,12 @@ namespace platform {
cout << data.dump(4) << endl;
}
void Experiment::go(vector<string> filesToProcess)
void Experiment::go(vector<string> filesToProcess, bool quiet)
{
cout << "*** Starting experiment: " << title << " ***" << endl;
for (auto fileName : filesToProcess) {
cout << "- " << setw(20) << left << fileName << " " << right << flush;
cross_validation(fileName);
cross_validation(fileName, quiet);
cout << endl;
}
}
@@ -132,7 +132,7 @@ namespace platform {
cout << prefix << color << fold << Colors::RESET() << "(" << color << phase << Colors::RESET() << ")" << flush;
}
void Experiment::cross_validation(const string& fileName)
void Experiment::cross_validation(const string& fileName, bool quiet)
{
auto datasets = platform::Datasets(discretized, Paths::datasets());
// Get dataset
@@ -141,7 +141,9 @@ namespace platform {
auto features = datasets.getFeatures(fileName);
auto samples = datasets.getNSamples(fileName);
auto className = datasets.getClassName(fileName);
cout << " (" << setw(5) << samples << "," << setw(3) << features.size() << ") " << flush;
if (!quiet) {
cout << " (" << setw(5) << samples << "," << setw(3) << features.size() << ") " << flush;
}
// Prepare Result
auto result = Result();
auto [values, counts] = at::_unique(y);
@@ -159,7 +161,8 @@ namespace platform {
Timer train_timer, test_timer;
int item = 0;
for (auto seed : randomSeeds) {
cout << "(" << seed << ") doing Fold: " << flush;
if (!quiet)
cout << "(" << seed << ") doing Fold: " << flush;
Fold* fold;
if (stratified)
fold = new StratifiedKFold(nfolds, y, seed);
@@ -180,10 +183,12 @@ namespace platform {
auto y_train = y.index({ train_t });
auto X_test = X.index({ "...", test_t });
auto y_test = y.index({ test_t });
showProgress(nfold + 1, getColor(clf->getStatus()), "a");
if (!quiet)
showProgress(nfold + 1, getColor(clf->getStatus()), "a");
// Train model
clf->fit(X_train, y_train, features, className, states);
showProgress(nfold + 1, getColor(clf->getStatus()), "b");
if (!quiet)
showProgress(nfold + 1, getColor(clf->getStatus()), "b");
nodes[item] = clf->getNumberOfNodes();
edges[item] = clf->getNumberOfEdges();
num_states[item] = clf->getNumberOfStates();
@@ -191,13 +196,15 @@ namespace platform {
// Score train
auto accuracy_train_value = clf->score(X_train, y_train);
// Test model
showProgress(nfold + 1, getColor(clf->getStatus()), "c");
if (!quiet)
showProgress(nfold + 1, getColor(clf->getStatus()), "c");
test_timer.start();
auto accuracy_test_value = clf->score(X_test, y_test);
test_time[item] = test_timer.getDuration();
accuracy_train[item] = accuracy_train_value;
accuracy_test[item] = accuracy_test_value;
cout << "\b\b\b, " << flush;
if (!quiet)
cout << "\b\b\b, " << flush;
// Store results and times in vector
result.addScoreTrain(accuracy_train_value);
result.addScoreTest(accuracy_test_value);
@@ -206,7 +213,8 @@ namespace platform {
item++;
clf.reset();
}
cout << "end. " << flush;
if (!quiet)
cout << "end. " << flush;
delete fold;
}
result.setScoreTest(torch::mean(accuracy_test).item<double>()).setScoreTrain(torch::mean(accuracy_train).item<double>());