Add quiet parameter

This commit is contained in:
Ricardo Montañana Gómez 2023-11-24 21:16:20 +01:00
parent 2121ba9b98
commit f94e2d6a27
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
3 changed files with 23 additions and 11 deletions

View File

@ -48,7 +48,8 @@ namespace platform {
double totalScore = 0.0; double totalScore = 0.0;
int numItems = 0; int numItems = 0;
for (const auto& seed : config.seeds) { for (const auto& seed : config.seeds) {
std::cout << "(" << seed << ") doing Fold: " << flush; if (!config.quiet)
std::cout << "(" << seed << ") doing Fold: " << flush;
Fold* fold; Fold* fold;
if (config.stratified) if (config.stratified)
fold = new StratifiedKFold(config.n_folds, y, seed); fold = new StratifiedKFold(config.n_folds, y, seed);
@ -66,13 +67,16 @@ namespace platform {
auto X_test = X.index({ "...", test_t }); auto X_test = X.index({ "...", test_t });
auto y_test = y.index({ test_t }); auto y_test = y.index({ test_t });
// Train model // Train model
if (!config.quiet)
showProgressFold(nfold + 1, getColor(clf->getStatus()), "a");
clf->fit(X_train, y_train, features, className, states); clf->fit(X_train, y_train, features, className, states);
showProgressFold(nfold + 1, getColor(clf->getStatus()), "a"); // Test model
showProgressFold(nfold + 1, getColor(clf->getStatus()), "b"); if (!config.quiet)
showProgressFold(nfold + 1, getColor(clf->getStatus()), "b");
totalScore += clf->score(X_test, y_test); totalScore += clf->score(X_test, y_test);
numItems++; numItems++;
showProgressFold(nfold + 1, getColor(clf->getStatus()), "c"); if (!config.quiet)
std::cout << "\b\b\b, \b" << flush; std::cout << "\b\b\b, \b" << flush;
} }
delete fold; delete fold;
} }
@ -91,12 +95,14 @@ namespace platform {
// Generate hyperparameters grid & run gridsearch // Generate hyperparameters grid & run gridsearch
// Check each combination of hyperparameters for each dataset and each seed // Check each combination of hyperparameters for each dataset and each seed
for (const auto& dataset : datasets.getNames()) { for (const auto& dataset : datasets.getNames()) {
std::cout << "- " << setw(20) << left << dataset << " " << right << flush; if (!config.quiet)
std::cout << "- " << setw(20) << left << dataset << " " << right << flush;
int num = 0; int num = 0;
double bestScore = 0.0; double bestScore = 0.0;
json bestHyperparameters; json bestHyperparameters;
for (const auto& hyperparam_line : grid.getGrid()) { for (const auto& hyperparam_line : grid.getGrid()) {
showProgressComb(++num, totalComb, Colors::CYAN()); if (!config.quiet)
showProgressComb(++num, totalComb, Colors::CYAN());
auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line); auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line);
double score = processFile(dataset, datasets, hyperparameters); double score = processFile(dataset, datasets, hyperparameters);
if (score > bestScore) { if (score > bestScore) {
@ -104,15 +110,18 @@ namespace platform {
bestHyperparameters = hyperparam_line; bestHyperparameters = hyperparam_line;
} }
} }
std::cout << "end." << " Score: " << setw(9) << setprecision(7) << fixed if (!config.quiet) {
<< bestScore << " [" << bestHyperparameters.dump() << "]" << std::endl; std::cout << "end." << " Score: " << setw(9) << setprecision(7) << fixed
<< bestScore << " [" << bestHyperparameters.dump() << "]" << std::endl;
}
results[dataset]["score"] = bestScore; results[dataset]["score"] = bestScore;
results[dataset]["hyperparameters"] = bestHyperparameters; results[dataset]["hyperparameters"] = bestHyperparameters;
} }
// Save results // Save results
save(); save();
std::cout << "***************** Ending Gridsearch *******************" << std::endl;
} }
void GridSearch::save() void GridSearch::save() const
{ {
std::ofstream file(config.output_file); std::ofstream file(config.output_file);
file << results.dump(4); file << results.dump(4);

View File

@ -15,6 +15,7 @@ namespace platform {
std::string path; std::string path;
std::string input_file; std::string input_file;
std::string output_file; std::string output_file;
bool quiet;
bool discretize; bool discretize;
bool stratified; bool stratified;
int n_folds; int n_folds;
@ -24,7 +25,7 @@ namespace platform {
public: public:
explicit GridSearch(struct ConfigGrid& config); explicit GridSearch(struct ConfigGrid& config);
void go(); void go();
void save(); void save() const;
~GridSearch() = default; ~GridSearch() = default;
private: private:
double processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters); double processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters);

View File

@ -23,6 +23,7 @@ argparse::ArgumentParser manageArguments(std::string program_name)
} }
); );
program.add_argument("--discretize").help("Discretize input datasets").default_value((bool)stoi(env.get("discretize"))).implicit_value(true); program.add_argument("--discretize").help("Discretize input datasets").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true); program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
program.add_argument("--score").help("Score used in gridsearch").default_value("accuracy"); program.add_argument("--score").help("Score used in gridsearch").default_value("accuracy");
program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const std::string& value) { program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const std::string& value) {
@ -55,6 +56,7 @@ int main(int argc, char** argv)
config.discretize = program.get<bool>("discretize"); config.discretize = program.get<bool>("discretize");
config.stratified = program.get<bool>("stratified"); config.stratified = program.get<bool>("stratified");
config.n_folds = program.get<int>("folds"); config.n_folds = program.get<int>("folds");
config.quiet = program.get<bool>("quiet");
config.seeds = program.get<std::vector<int>>("seeds"); config.seeds = program.get<std::vector<int>>("seeds");
} }
catch (const exception& err) { catch (const exception& err) {