Add quiet parameter
This commit is contained in:
@@ -48,7 +48,8 @@ namespace platform {
|
||||
double totalScore = 0.0;
|
||||
int numItems = 0;
|
||||
for (const auto& seed : config.seeds) {
|
||||
std::cout << "(" << seed << ") doing Fold: " << flush;
|
||||
if (!config.quiet)
|
||||
std::cout << "(" << seed << ") doing Fold: " << flush;
|
||||
Fold* fold;
|
||||
if (config.stratified)
|
||||
fold = new StratifiedKFold(config.n_folds, y, seed);
|
||||
@@ -66,13 +67,16 @@ namespace platform {
|
||||
auto X_test = X.index({ "...", test_t });
|
||||
auto y_test = y.index({ test_t });
|
||||
// Train model
|
||||
if (!config.quiet)
|
||||
showProgressFold(nfold + 1, getColor(clf->getStatus()), "a");
|
||||
clf->fit(X_train, y_train, features, className, states);
|
||||
showProgressFold(nfold + 1, getColor(clf->getStatus()), "a");
|
||||
showProgressFold(nfold + 1, getColor(clf->getStatus()), "b");
|
||||
// Test model
|
||||
if (!config.quiet)
|
||||
showProgressFold(nfold + 1, getColor(clf->getStatus()), "b");
|
||||
totalScore += clf->score(X_test, y_test);
|
||||
numItems++;
|
||||
showProgressFold(nfold + 1, getColor(clf->getStatus()), "c");
|
||||
std::cout << "\b\b\b, \b" << flush;
|
||||
if (!config.quiet)
|
||||
std::cout << "\b\b\b, \b" << flush;
|
||||
}
|
||||
delete fold;
|
||||
}
|
||||
@@ -91,12 +95,14 @@ namespace platform {
|
||||
// Generate hyperparameters grid & run gridsearch
|
||||
// Check each combination of hyperparameters for each dataset and each seed
|
||||
for (const auto& dataset : datasets.getNames()) {
|
||||
std::cout << "- " << setw(20) << left << dataset << " " << right << flush;
|
||||
if (!config.quiet)
|
||||
std::cout << "- " << setw(20) << left << dataset << " " << right << flush;
|
||||
int num = 0;
|
||||
double bestScore = 0.0;
|
||||
json bestHyperparameters;
|
||||
for (const auto& hyperparam_line : grid.getGrid()) {
|
||||
showProgressComb(++num, totalComb, Colors::CYAN());
|
||||
if (!config.quiet)
|
||||
showProgressComb(++num, totalComb, Colors::CYAN());
|
||||
auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line);
|
||||
double score = processFile(dataset, datasets, hyperparameters);
|
||||
if (score > bestScore) {
|
||||
@@ -104,15 +110,18 @@ namespace platform {
|
||||
bestHyperparameters = hyperparam_line;
|
||||
}
|
||||
}
|
||||
std::cout << "end." << " Score: " << setw(9) << setprecision(7) << fixed
|
||||
<< bestScore << " [" << bestHyperparameters.dump() << "]" << std::endl;
|
||||
if (!config.quiet) {
|
||||
std::cout << "end." << " Score: " << setw(9) << setprecision(7) << fixed
|
||||
<< bestScore << " [" << bestHyperparameters.dump() << "]" << std::endl;
|
||||
}
|
||||
results[dataset]["score"] = bestScore;
|
||||
results[dataset]["hyperparameters"] = bestHyperparameters;
|
||||
}
|
||||
// Save results
|
||||
save();
|
||||
std::cout << "***************** Ending Gridsearch *******************" << std::endl;
|
||||
}
|
||||
void GridSearch::save()
|
||||
void GridSearch::save() const
|
||||
{
|
||||
std::ofstream file(config.output_file);
|
||||
file << results.dump(4);
|
||||
|
Reference in New Issue
Block a user