Add reports to gridsearch

This commit is contained in:
2023-11-29 00:26:48 +01:00
parent 8dbbb65a2f
commit 460d20a402
8 changed files with 116 additions and 19 deletions

View File

@@ -6,12 +6,13 @@
#include "GridSearch.h"
#include "Paths.h"
#include "Timer.h"
#include "Colors.h"
argparse::ArgumentParser manageArguments(std::string program_name)
void manageArguments(argparse::ArgumentParser& program)
{
auto env = platform::DotEnv();
argparse::ArgumentParser program(program_name);
auto& group = program.add_mutually_exclusive_group(true);
program.add_argument("-m", "--model")
.help("Model to use " + platform::Models::instance()->tostring())
.action([](const std::string& value) {
@@ -22,6 +23,9 @@ argparse::ArgumentParser manageArguments(std::string program_name)
throw std::runtime_error("Model must be one of " + platform::Models::instance()->tostring());
}
);
group.add_argument("--dump").help("Show the grid combinations").default_value(false).implicit_value(true);
group.add_argument("--list").help("List the computed hyperparameters").default_value(false).implicit_value(true);
group.add_argument("--compute").help("Perform computation of the grid output hyperparameters").default_value(false).implicit_value(true);
program.add_argument("--discretize").help("Discretize input datasets").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
@@ -44,13 +48,14 @@ argparse::ArgumentParser manageArguments(std::string program_name)
}});
auto seed_values = env.getSeeds();
program.add_argument("-s", "--seeds").nargs(1, 10).help("Random seeds. Set to -1 to have pseudo random").scan<'i', int>().default_value(seed_values);
return program;
}
int main(int argc, char** argv)
{
auto program = manageArguments("b_grid");
argparse::ArgumentParser program("b_grid");
manageArguments(program);
struct platform::ConfigGrid config;
bool dump, compute, list;
try {
program.parse_args(argc, argv);
config.model = program.get<std::string>("model");
@@ -65,6 +70,12 @@ int main(int argc, char** argv)
if (config.continue_from == "No" && config.only) {
throw std::runtime_error("Cannot use --only without --continue");
}
dump = program.get<bool>("dump");
compute = program.get<bool>("compute");
list = program.get<bool>("list");
if (dump && (config.continue_from != "No" || config.only)) {
throw std::runtime_error("Cannot use --dump with --continue or --only");
}
}
catch (const exception& err) {
cerr << err.what() << std::endl;
@@ -80,8 +91,73 @@ int main(int argc, char** argv)
auto grid_search = platform::GridSearch(config);
platform::Timer timer;
timer.start();
grid_search.go();
std::cout << "Process took " << timer.getDurationString() << std::endl;
if (dump) {
auto combinations = grid_search.dump();
auto total = combinations.size();
int spaces = int(log(total) / log(10)) + 1;
std::cout << Colors::MAGENTA() << "There are " << total << " combinations" << std::endl << std::endl;
int index = 0;
int max = 0;
for (auto const& item : combinations) {
if (item.dump().size() > spaces) {
max = item.dump().size();
}
}
std::cout << Colors::GREEN() << left << setw(spaces) << "#" << left << " " << setw(spaces)
<< "Hyperparameters" << std::endl;
std::cout << string(spaces, '=') << " " << string(max, '=') << std::endl;
bool odd = true;
for (auto const& item : combinations) {
auto color = odd ? Colors::CYAN() : Colors::BLUE();
std::cout << color;
std::cout << setw(spaces) << fixed << right << ++index << left << " " << item.dump() << std::endl;
odd = !odd;
}
std::cout << Colors::RESET() << std::endl;
} else {
if (compute) {
grid_search.go();
std::cout << "Process took " << timer.getDurationString() << std::endl;
} else {
std::cout << Colors::MAGENTA() << "Listing computed hyperparameters for model "
<< config.model << std::endl << std::endl;
auto results = grid_search.getResults();
if (results.empty()) {
std::cout << "No results found" << std::endl;
} else {
int spaces = 0;
int hyperparameters_spaces = 0;
for (const auto& item : results.items()) {
auto key = item.key();
auto value = item.value();
if (key.size() > spaces) {
spaces = key.size();
}
if (value["hyperparameters"].dump().size() > hyperparameters_spaces) {
hyperparameters_spaces = value["hyperparameters"].dump().size();
}
}
std::cout << Colors::GREEN() << " # " << left << setw(spaces) << "Dataset" << " " << setw(19) << "Date" << " "
<< setw(8) << "Score" << " " << "Hyperparameters" << std::endl;
std::cout << "=== " << string(spaces, '=') << " " << string(19, '=') << " " << string(8, '=') << " "
<< string(hyperparameters_spaces, '=') << std::endl;
bool odd = true;
int index = 0;
for (const auto& item : results.items()) {
auto color = odd ? Colors::CYAN() : Colors::BLUE();
auto key = item.key();
auto value = item.value();
std::cout << color;
std::cout << std::setw(3) << std::right << index++ << " ";
std::cout << left << setw(spaces) << key << " " << value["date"].get<string>()
<< " " << setw(8) << setprecision(6) << fixed << right
<< value["score"].get<double>() << " " << value["hyperparameters"].dump() << std::endl;
odd = !odd;
}
std::cout << Colors::RESET() << std::endl;
}
}
}
std::cout << "Done!" << std::endl;
return 0;
}