Complete output interface of gridsearch

This commit is contained in:
Ricardo Montañana Gómez 2023-11-22 16:30:04 +01:00
parent fb347ed5b9
commit c2eb727fc7
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE

View File

@ -11,7 +11,14 @@ namespace platform {
{
this->config.output_file = config.path + "grid_" + config.model + "_output.json";
}
void showProgress(int fold, const std::string& color, const std::string& phase)
void showProgressComb(const int num, const int total, const std::string& color)
{
int spaces = int(log(total) / log(10)) + 1;
int magic = 37 + 2 * spaces;
std::string prefix = num == 1 ? "" : string(magic, '\b') + string(magic + 1, ' ') + string(magic + 1, '\b');
std::cout << prefix << color << "(" << setw(spaces) << num << "/" << setw(spaces) << total << ") " << Colors::RESET() << flush;
}
void showProgressFold(int fold, const std::string& color, const std::string& phase)
{
std::string prefix = phase == "a" ? "" : "\b\b\b\b";
std::cout << prefix << color << fold << Colors::RESET() << "(" << color << phase << Colors::RESET() << ")" << flush;
@ -37,7 +44,6 @@ namespace platform {
auto features = datasets.getFeatures(fileName);
auto samples = datasets.getNSamples(fileName);
auto className = datasets.getClassName(fileName);
std::cout << " (" << setw(5) << samples << "," << setw(3) << features.size() << ") " << flush;
for (const auto& seed : config.seeds) {
std::cout << "(" << seed << ") doing Fold: " << flush;
Fold* fold;
@ -54,10 +60,13 @@ namespace platform {
// auto y_train = y.index({ train_t });
// auto X_test = X.index({ "...", test_t });
// auto y_test = y.index({ test_t });
showProgress(nfold + 1, getColor(clf->getStatus()), "a");
showProgressFold(nfold + 1, getColor(clf->getStatus()), "a");
// Train model
// clf->fit(X_train, y_train, features, className, states);
showProgress(nfold + 1, getColor(clf->getStatus()), "b");
showProgressFold(nfold + 1, getColor(clf->getStatus()), "b");
showProgressFold(nfold + 1, getColor(clf->getStatus()), "c");
sleep(1);
std::cout << "\b\b\b, " << flush;
}
delete fold;
}
@ -68,12 +77,15 @@ namespace platform {
auto datasets = Datasets(config.discretize, Paths::datasets());
// Create model
std::cout << "***************** Starting Gridsearch *****************" << std::endl;
std::cout << "* Doing " << grid.getNumCombinations(config.model) << " combinations for each dataset/seed/fold" << std::endl;
auto totalComb = grid.getNumCombinations(config.model);
std::cout << "* Doing " << totalComb << " combinations for each dataset/seed/fold" << std::endl;
// Generate hyperparameters grid & run gridsearch
// Check each combination of hyperparameters for each dataset and each seed
// Check each combination of hyperparameters for each dataset and each seed
for (const auto& dataset : datasets.getNames()) {
std::cout << "- " << setw(20) << left << dataset << " " << right << flush;
int num = 0;
for (const auto& hyperparam_line : grid.getGrid(config.model)) {
showProgressComb(++num, totalComb, Colors::CYAN());
auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line);
processFile(dataset, datasets, hyperparameters);
}