Add quiet parameter
This commit is contained in:
parent
2121ba9b98
commit
f94e2d6a27
@ -48,7 +48,8 @@ namespace platform {
|
|||||||
double totalScore = 0.0;
|
double totalScore = 0.0;
|
||||||
int numItems = 0;
|
int numItems = 0;
|
||||||
for (const auto& seed : config.seeds) {
|
for (const auto& seed : config.seeds) {
|
||||||
std::cout << "(" << seed << ") doing Fold: " << flush;
|
if (!config.quiet)
|
||||||
|
std::cout << "(" << seed << ") doing Fold: " << flush;
|
||||||
Fold* fold;
|
Fold* fold;
|
||||||
if (config.stratified)
|
if (config.stratified)
|
||||||
fold = new StratifiedKFold(config.n_folds, y, seed);
|
fold = new StratifiedKFold(config.n_folds, y, seed);
|
||||||
@ -66,13 +67,16 @@ namespace platform {
|
|||||||
auto X_test = X.index({ "...", test_t });
|
auto X_test = X.index({ "...", test_t });
|
||||||
auto y_test = y.index({ test_t });
|
auto y_test = y.index({ test_t });
|
||||||
// Train model
|
// Train model
|
||||||
|
if (!config.quiet)
|
||||||
|
showProgressFold(nfold + 1, getColor(clf->getStatus()), "a");
|
||||||
clf->fit(X_train, y_train, features, className, states);
|
clf->fit(X_train, y_train, features, className, states);
|
||||||
showProgressFold(nfold + 1, getColor(clf->getStatus()), "a");
|
// Test model
|
||||||
showProgressFold(nfold + 1, getColor(clf->getStatus()), "b");
|
if (!config.quiet)
|
||||||
|
showProgressFold(nfold + 1, getColor(clf->getStatus()), "b");
|
||||||
totalScore += clf->score(X_test, y_test);
|
totalScore += clf->score(X_test, y_test);
|
||||||
numItems++;
|
numItems++;
|
||||||
showProgressFold(nfold + 1, getColor(clf->getStatus()), "c");
|
if (!config.quiet)
|
||||||
std::cout << "\b\b\b, \b" << flush;
|
std::cout << "\b\b\b, \b" << flush;
|
||||||
}
|
}
|
||||||
delete fold;
|
delete fold;
|
||||||
}
|
}
|
||||||
@ -91,12 +95,14 @@ namespace platform {
|
|||||||
// Generate hyperparameters grid & run gridsearch
|
// Generate hyperparameters grid & run gridsearch
|
||||||
// Check each combination of hyperparameters for each dataset and each seed
|
// Check each combination of hyperparameters for each dataset and each seed
|
||||||
for (const auto& dataset : datasets.getNames()) {
|
for (const auto& dataset : datasets.getNames()) {
|
||||||
std::cout << "- " << setw(20) << left << dataset << " " << right << flush;
|
if (!config.quiet)
|
||||||
|
std::cout << "- " << setw(20) << left << dataset << " " << right << flush;
|
||||||
int num = 0;
|
int num = 0;
|
||||||
double bestScore = 0.0;
|
double bestScore = 0.0;
|
||||||
json bestHyperparameters;
|
json bestHyperparameters;
|
||||||
for (const auto& hyperparam_line : grid.getGrid()) {
|
for (const auto& hyperparam_line : grid.getGrid()) {
|
||||||
showProgressComb(++num, totalComb, Colors::CYAN());
|
if (!config.quiet)
|
||||||
|
showProgressComb(++num, totalComb, Colors::CYAN());
|
||||||
auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line);
|
auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line);
|
||||||
double score = processFile(dataset, datasets, hyperparameters);
|
double score = processFile(dataset, datasets, hyperparameters);
|
||||||
if (score > bestScore) {
|
if (score > bestScore) {
|
||||||
@ -104,15 +110,18 @@ namespace platform {
|
|||||||
bestHyperparameters = hyperparam_line;
|
bestHyperparameters = hyperparam_line;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
std::cout << "end." << " Score: " << setw(9) << setprecision(7) << fixed
|
if (!config.quiet) {
|
||||||
<< bestScore << " [" << bestHyperparameters.dump() << "]" << std::endl;
|
std::cout << "end." << " Score: " << setw(9) << setprecision(7) << fixed
|
||||||
|
<< bestScore << " [" << bestHyperparameters.dump() << "]" << std::endl;
|
||||||
|
}
|
||||||
results[dataset]["score"] = bestScore;
|
results[dataset]["score"] = bestScore;
|
||||||
results[dataset]["hyperparameters"] = bestHyperparameters;
|
results[dataset]["hyperparameters"] = bestHyperparameters;
|
||||||
}
|
}
|
||||||
// Save results
|
// Save results
|
||||||
save();
|
save();
|
||||||
|
std::cout << "***************** Ending Gridsearch *******************" << std::endl;
|
||||||
}
|
}
|
||||||
void GridSearch::save()
|
void GridSearch::save() const
|
||||||
{
|
{
|
||||||
std::ofstream file(config.output_file);
|
std::ofstream file(config.output_file);
|
||||||
file << results.dump(4);
|
file << results.dump(4);
|
||||||
|
@ -15,6 +15,7 @@ namespace platform {
|
|||||||
std::string path;
|
std::string path;
|
||||||
std::string input_file;
|
std::string input_file;
|
||||||
std::string output_file;
|
std::string output_file;
|
||||||
|
bool quiet;
|
||||||
bool discretize;
|
bool discretize;
|
||||||
bool stratified;
|
bool stratified;
|
||||||
int n_folds;
|
int n_folds;
|
||||||
@ -24,7 +25,7 @@ namespace platform {
|
|||||||
public:
|
public:
|
||||||
explicit GridSearch(struct ConfigGrid& config);
|
explicit GridSearch(struct ConfigGrid& config);
|
||||||
void go();
|
void go();
|
||||||
void save();
|
void save() const;
|
||||||
~GridSearch() = default;
|
~GridSearch() = default;
|
||||||
private:
|
private:
|
||||||
double processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters);
|
double processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters);
|
||||||
|
@ -23,6 +23,7 @@ argparse::ArgumentParser manageArguments(std::string program_name)
|
|||||||
}
|
}
|
||||||
);
|
);
|
||||||
program.add_argument("--discretize").help("Discretize input datasets").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
|
program.add_argument("--discretize").help("Discretize input datasets").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
|
||||||
|
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
|
||||||
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
|
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
|
||||||
program.add_argument("--score").help("Score used in gridsearch").default_value("accuracy");
|
program.add_argument("--score").help("Score used in gridsearch").default_value("accuracy");
|
||||||
program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const std::string& value) {
|
program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const std::string& value) {
|
||||||
@ -55,6 +56,7 @@ int main(int argc, char** argv)
|
|||||||
config.discretize = program.get<bool>("discretize");
|
config.discretize = program.get<bool>("discretize");
|
||||||
config.stratified = program.get<bool>("stratified");
|
config.stratified = program.get<bool>("stratified");
|
||||||
config.n_folds = program.get<int>("folds");
|
config.n_folds = program.get<int>("folds");
|
||||||
|
config.quiet = program.get<bool>("quiet");
|
||||||
config.seeds = program.get<std::vector<int>>("seeds");
|
config.seeds = program.get<std::vector<int>>("seeds");
|
||||||
}
|
}
|
||||||
catch (const exception& err) {
|
catch (const exception& err) {
|
||||||
|
Loading…
Reference in New Issue
Block a user