Add header to grid output and report
This commit is contained in:
@@ -11,6 +11,7 @@
|
||||
#include "Colors.h"
|
||||
|
||||
using json = nlohmann::json;
|
||||
const int MAXL = 133;
|
||||
|
||||
void manageArguments(argparse::ArgumentParser& program)
|
||||
{
|
||||
@@ -27,13 +28,14 @@ void manageArguments(argparse::ArgumentParser& program)
|
||||
}
|
||||
);
|
||||
group.add_argument("--dump").help("Show the grid combinations").default_value(false).implicit_value(true);
|
||||
group.add_argument("--list").help("List the computed hyperparameters").default_value(false).implicit_value(true);
|
||||
group.add_argument("--report").help("Report the computed hyperparameters").default_value(false).implicit_value(true);
|
||||
group.add_argument("--compute").help("Perform computation of the grid output hyperparameters").default_value(false).implicit_value(true);
|
||||
program.add_argument("--discretize").help("Discretize input datasets").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
|
||||
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
|
||||
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
|
||||
program.add_argument("--continue").help("Continue computing from that dataset").default_value("No");
|
||||
program.add_argument("--only").help("Used with continue to compute that dataset only").default_value(false).implicit_value(true);
|
||||
program.add_argument("--nested").help("Do a double/nested cross validation with n folds").default_value(0).scan<'i', int>();
|
||||
program.add_argument("--score").help("Score used in gridsearch").default_value("accuracy");
|
||||
program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const std::string& value) {
|
||||
try {
|
||||
@@ -83,13 +85,29 @@ void list_dump(std::string& model)
|
||||
}
|
||||
std::cout << Colors::RESET() << std::endl;
|
||||
}
|
||||
std::string headerLine(const std::string& text, int utf = 0)
|
||||
{
|
||||
int n = MAXL - text.length() - 3;
|
||||
n = n < 0 ? 0 : n;
|
||||
return "* " + text + std::string(n + utf, ' ') + "*\n";
|
||||
}
|
||||
void list_results(json& results, std::string& model)
|
||||
{
|
||||
std::cout << Colors::MAGENTA() << "Listing computed hyperparameters for model "
|
||||
<< model << std::endl << std::endl;
|
||||
std::cout << Colors::MAGENTA() << std::string(MAXL, '*') << std::endl;
|
||||
std::cout << headerLine("Listing computed hyperparameters for model " + model);
|
||||
std::cout << headerLine("Date & time: " + results["date"].get<std::string>());
|
||||
std::cout << headerLine("Score: " + results["score"].get<std::string>());
|
||||
std::cout << headerLine(
|
||||
"Random seeds: " + results["seeds"].dump()
|
||||
+ " Discretized: " + (results["discretize"].get<bool>() ? "True" : "False")
|
||||
+ " Stratified: " + (results["stratified"].get<bool>() ? "True" : "False")
|
||||
+ " #Folds: " + std::to_string(results["n_folds"].get<int>())
|
||||
+ " Nested: " + (results["nested"].get<int>() == 0 ? "False" : to_string(results["nested"].get<int>()))
|
||||
);
|
||||
std::cout << std::string(MAXL, '*') << std::endl;
|
||||
int spaces = 0;
|
||||
int hyperparameters_spaces = 0;
|
||||
for (const auto& item : results.items()) {
|
||||
for (const auto& item : results["results"].items()) {
|
||||
auto key = item.key();
|
||||
auto value = item.value();
|
||||
if (key.size() > spaces) {
|
||||
@@ -105,7 +123,7 @@ void list_results(json& results, std::string& model)
|
||||
<< string(hyperparameters_spaces, '=') << std::endl;
|
||||
bool odd = true;
|
||||
int index = 0;
|
||||
for (const auto& item : results.items()) {
|
||||
for (const auto& item : results["results"].items()) {
|
||||
auto color = odd ? Colors::CYAN() : Colors::BLUE();
|
||||
auto key = item.key();
|
||||
auto value = item.value();
|
||||
@@ -119,12 +137,16 @@ void list_results(json& results, std::string& model)
|
||||
std::cout << Colors::RESET() << std::endl;
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
* Main
|
||||
*/
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
argparse::ArgumentParser program("b_grid");
|
||||
manageArguments(program);
|
||||
struct platform::ConfigGrid config;
|
||||
bool dump, compute, list;
|
||||
bool dump, compute;
|
||||
try {
|
||||
program.parse_args(argc, argv);
|
||||
config.model = program.get<std::string>("model");
|
||||
@@ -135,13 +157,13 @@ int main(int argc, char** argv)
|
||||
config.quiet = program.get<bool>("quiet");
|
||||
config.only = program.get<bool>("only");
|
||||
config.seeds = program.get<std::vector<int>>("seeds");
|
||||
config.nested = program.get<int>("nested");
|
||||
config.continue_from = program.get<std::string>("continue");
|
||||
if (config.continue_from == "No" && config.only) {
|
||||
throw std::runtime_error("Cannot use --only without --continue");
|
||||
}
|
||||
dump = program.get<bool>("dump");
|
||||
compute = program.get<bool>("compute");
|
||||
list = program.get<bool>("list");
|
||||
if (dump && (config.continue_from != "No" || config.only)) {
|
||||
throw std::runtime_error("Cannot use --dump with --continue or --only");
|
||||
}
|
||||
@@ -163,7 +185,10 @@ int main(int argc, char** argv)
|
||||
list_dump(config.model);
|
||||
} else {
|
||||
if (compute) {
|
||||
grid_search.go();
|
||||
if (config.nested == 0)
|
||||
grid_search.goSingle();
|
||||
else
|
||||
grid_search.goNested();
|
||||
std::cout << "Process took " << timer.getDurationString() << std::endl;
|
||||
} else {
|
||||
// List results
|
||||
|
Reference in New Issue
Block a user