Add quiet parameter

This commit is contained in:
Ricardo Montañana Gómez 2023-11-24 21:16:20 +01:00
parent 2121ba9b98
commit f94e2d6a27
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
3 changed files with 23 additions and 11 deletions

View File

@ -48,7 +48,8 @@ namespace platform {
double totalScore = 0.0;
int numItems = 0;
for (const auto& seed : config.seeds) {
std::cout << "(" << seed << ") doing Fold: " << flush;
if (!config.quiet)
std::cout << "(" << seed << ") doing Fold: " << flush;
Fold* fold;
if (config.stratified)
fold = new StratifiedKFold(config.n_folds, y, seed);
@ -66,13 +67,16 @@ namespace platform {
auto X_test = X.index({ "...", test_t });
auto y_test = y.index({ test_t });
// Train model
if (!config.quiet)
showProgressFold(nfold + 1, getColor(clf->getStatus()), "a");
clf->fit(X_train, y_train, features, className, states);
showProgressFold(nfold + 1, getColor(clf->getStatus()), "a");
showProgressFold(nfold + 1, getColor(clf->getStatus()), "b");
// Test model
if (!config.quiet)
showProgressFold(nfold + 1, getColor(clf->getStatus()), "b");
totalScore += clf->score(X_test, y_test);
numItems++;
showProgressFold(nfold + 1, getColor(clf->getStatus()), "c");
std::cout << "\b\b\b, \b" << flush;
if (!config.quiet)
std::cout << "\b\b\b, \b" << flush;
}
delete fold;
}
@ -91,12 +95,14 @@ namespace platform {
// Generate hyperparameters grid & run gridsearch
// Check each combination of hyperparameters for each dataset and each seed
for (const auto& dataset : datasets.getNames()) {
std::cout << "- " << setw(20) << left << dataset << " " << right << flush;
if (!config.quiet)
std::cout << "- " << setw(20) << left << dataset << " " << right << flush;
int num = 0;
double bestScore = 0.0;
json bestHyperparameters;
for (const auto& hyperparam_line : grid.getGrid()) {
showProgressComb(++num, totalComb, Colors::CYAN());
if (!config.quiet)
showProgressComb(++num, totalComb, Colors::CYAN());
auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line);
double score = processFile(dataset, datasets, hyperparameters);
if (score > bestScore) {
@ -104,15 +110,18 @@ namespace platform {
bestHyperparameters = hyperparam_line;
}
}
std::cout << "end." << " Score: " << setw(9) << setprecision(7) << fixed
<< bestScore << " [" << bestHyperparameters.dump() << "]" << std::endl;
if (!config.quiet) {
std::cout << "end." << " Score: " << setw(9) << setprecision(7) << fixed
<< bestScore << " [" << bestHyperparameters.dump() << "]" << std::endl;
}
results[dataset]["score"] = bestScore;
results[dataset]["hyperparameters"] = bestHyperparameters;
}
// Save results
save();
std::cout << "***************** Ending Gridsearch *******************" << std::endl;
}
void GridSearch::save()
void GridSearch::save() const
{
std::ofstream file(config.output_file);
file << results.dump(4);

View File

@ -15,6 +15,7 @@ namespace platform {
std::string path;
std::string input_file;
std::string output_file;
bool quiet;
bool discretize;
bool stratified;
int n_folds;
@ -24,7 +25,7 @@ namespace platform {
public:
explicit GridSearch(struct ConfigGrid& config);
void go();
void save();
void save() const;
~GridSearch() = default;
private:
double processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters);

View File

@ -23,6 +23,7 @@ argparse::ArgumentParser manageArguments(std::string program_name)
}
);
program.add_argument("--discretize").help("Discretize input datasets").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
program.add_argument("--score").help("Score used in gridsearch").default_value("accuracy");
program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const std::string& value) {
@ -55,6 +56,7 @@ int main(int argc, char** argv)
config.discretize = program.get<bool>("discretize");
config.stratified = program.get<bool>("stratified");
config.n_folds = program.get<int>("folds");
config.quiet = program.get<bool>("quiet");
config.seeds = program.get<std::vector<int>>("seeds");
}
catch (const exception& err) {