Complete output interface of gridsearch
This commit is contained in:
parent
fb347ed5b9
commit
c2eb727fc7
@ -11,7 +11,14 @@ namespace platform {
|
||||
{
|
||||
this->config.output_file = config.path + "grid_" + config.model + "_output.json";
|
||||
}
|
||||
void showProgress(int fold, const std::string& color, const std::string& phase)
|
||||
void showProgressComb(const int num, const int total, const std::string& color)
|
||||
{
|
||||
int spaces = int(log(total) / log(10)) + 1;
|
||||
int magic = 37 + 2 * spaces;
|
||||
std::string prefix = num == 1 ? "" : string(magic, '\b') + string(magic + 1, ' ') + string(magic + 1, '\b');
|
||||
std::cout << prefix << color << "(" << setw(spaces) << num << "/" << setw(spaces) << total << ") " << Colors::RESET() << flush;
|
||||
}
|
||||
void showProgressFold(int fold, const std::string& color, const std::string& phase)
|
||||
{
|
||||
std::string prefix = phase == "a" ? "" : "\b\b\b\b";
|
||||
std::cout << prefix << color << fold << Colors::RESET() << "(" << color << phase << Colors::RESET() << ")" << flush;
|
||||
@ -37,7 +44,6 @@ namespace platform {
|
||||
auto features = datasets.getFeatures(fileName);
|
||||
auto samples = datasets.getNSamples(fileName);
|
||||
auto className = datasets.getClassName(fileName);
|
||||
std::cout << " (" << setw(5) << samples << "," << setw(3) << features.size() << ") " << flush;
|
||||
for (const auto& seed : config.seeds) {
|
||||
std::cout << "(" << seed << ") doing Fold: " << flush;
|
||||
Fold* fold;
|
||||
@ -54,10 +60,13 @@ namespace platform {
|
||||
// auto y_train = y.index({ train_t });
|
||||
// auto X_test = X.index({ "...", test_t });
|
||||
// auto y_test = y.index({ test_t });
|
||||
showProgress(nfold + 1, getColor(clf->getStatus()), "a");
|
||||
showProgressFold(nfold + 1, getColor(clf->getStatus()), "a");
|
||||
// Train model
|
||||
// clf->fit(X_train, y_train, features, className, states);
|
||||
showProgress(nfold + 1, getColor(clf->getStatus()), "b");
|
||||
showProgressFold(nfold + 1, getColor(clf->getStatus()), "b");
|
||||
showProgressFold(nfold + 1, getColor(clf->getStatus()), "c");
|
||||
sleep(1);
|
||||
std::cout << "\b\b\b, " << flush;
|
||||
}
|
||||
delete fold;
|
||||
}
|
||||
@ -68,12 +77,15 @@ namespace platform {
|
||||
auto datasets = Datasets(config.discretize, Paths::datasets());
|
||||
// Create model
|
||||
std::cout << "***************** Starting Gridsearch *****************" << std::endl;
|
||||
std::cout << "* Doing " << grid.getNumCombinations(config.model) << " combinations for each dataset/seed/fold" << std::endl;
|
||||
auto totalComb = grid.getNumCombinations(config.model);
|
||||
std::cout << "* Doing " << totalComb << " combinations for each dataset/seed/fold" << std::endl;
|
||||
// Generate hyperparameters grid & run gridsearch
|
||||
// Check each combination of hyperparameters for each dataset and each seed
|
||||
// Check each combination of hyperparameters for each dataset and each seed
|
||||
for (const auto& dataset : datasets.getNames()) {
|
||||
std::cout << "- " << setw(20) << left << dataset << " " << right << flush;
|
||||
int num = 0;
|
||||
for (const auto& hyperparam_line : grid.getGrid(config.model)) {
|
||||
showProgressComb(++num, totalComb, Colors::CYAN());
|
||||
auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line);
|
||||
processFile(dataset, datasets, hyperparameters);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user