Add reports to gridsearch

This commit is contained in:
Ricardo Montañana Gómez 2023-11-29 00:26:48 +01:00
parent 8dbbb65a2f
commit 460d20a402
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
8 changed files with 116 additions and 19 deletions

10
.gitmodules vendored
View File

@ -1,15 +1,25 @@
[submodule "lib/mdlp"]
path = lib/mdlp
url = https://github.com/rmontanana/mdlp
main = main
update = merge
[submodule "lib/catch2"]
path = lib/catch2
main = v2.x
update = merge
url = https://github.com/catchorg/Catch2.git
[submodule "lib/argparse"]
path = lib/argparse
url = https://github.com/p-ranav/argparse
master = master
update = merge
[submodule "lib/json"]
path = lib/json
url = https://github.com/nlohmann/json.git
master = master
update = merge
[submodule "lib/libxlsxwriter"]
path = lib/libxlsxwriter
url = https://github.com/jmcnamara/libxlsxwriter.git
main = main
update = merge

@ -1 +1 @@
Subproject commit b0930ab0288185815d6dc67af59de7014a6272f7
Subproject commit 69dabd88a8e6680b1a1a18397eb3e165e4019ce6

View File

@ -32,6 +32,18 @@ namespace platform {
this->config.output_file = config.path + "grid_" + config.model + "_output.json";
this->config.input_file = config.path + "grid_" + config.model + "_input.json";
}
std::vector<json> GridSearch::dump()
{
return GridData(config.input_file).getGrid();
}
json GridSearch::getResults()
{
std::ifstream file(config.output_file);
if (file.is_open()) {
return json::parse(file);
}
return json();
}
void showProgressComb(const int num, const int total, const std::string& color)
{
int spaces = int(log(total) / log(10)) + 1;

View File

@ -28,6 +28,8 @@ namespace platform {
explicit GridSearch(struct ConfigGrid& config);
void go();
~GridSearch() = default;
std::vector<json> dump();
json getResults();
private:
void save(json& results) const;
double processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters);

View File

@ -5,9 +5,8 @@
#include "Colors.h"
argparse::ArgumentParser manageArguments(int argc, char** argv)
void manageArguments(argparse::ArgumentParser& program, int argc, char** argv)
{
argparse::ArgumentParser program("b_sbest");
program.add_argument("-m", "--model").default_value("").help("Filter results of the selected model) (any for all models)");
program.add_argument("-s", "--score").default_value("").help("Filter results of the score name supplied");
program.add_argument("--build").help("build best score results file").default_value(false).implicit_value(true);
@ -28,12 +27,12 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
catch (...) {
throw std::runtime_error("Number of folds must be an decimal number");
}});
return program;
}
int main(int argc, char** argv)
{
auto program = manageArguments(argc, argv);
argparse::ArgumentParser program("b_sbest");
manageArguments(program, argc, argv);
std::string model, score;
bool build, report, friedman, excel;
double level;

View File

@ -6,12 +6,13 @@
#include "GridSearch.h"
#include "Paths.h"
#include "Timer.h"
#include "Colors.h"
argparse::ArgumentParser manageArguments(std::string program_name)
void manageArguments(argparse::ArgumentParser& program)
{
auto env = platform::DotEnv();
argparse::ArgumentParser program(program_name);
auto& group = program.add_mutually_exclusive_group(true);
program.add_argument("-m", "--model")
.help("Model to use " + platform::Models::instance()->tostring())
.action([](const std::string& value) {
@ -22,6 +23,9 @@ argparse::ArgumentParser manageArguments(std::string program_name)
throw std::runtime_error("Model must be one of " + platform::Models::instance()->tostring());
}
);
group.add_argument("--dump").help("Show the grid combinations").default_value(false).implicit_value(true);
group.add_argument("--list").help("List the computed hyperparameters").default_value(false).implicit_value(true);
group.add_argument("--compute").help("Perform computation of the grid output hyperparameters").default_value(false).implicit_value(true);
program.add_argument("--discretize").help("Discretize input datasets").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
@ -44,13 +48,14 @@ argparse::ArgumentParser manageArguments(std::string program_name)
}});
auto seed_values = env.getSeeds();
program.add_argument("-s", "--seeds").nargs(1, 10).help("Random seeds. Set to -1 to have pseudo random").scan<'i', int>().default_value(seed_values);
return program;
}
int main(int argc, char** argv)
{
auto program = manageArguments("b_grid");
argparse::ArgumentParser program("b_grid");
manageArguments(program);
struct platform::ConfigGrid config;
bool dump, compute, list;
try {
program.parse_args(argc, argv);
config.model = program.get<std::string>("model");
@ -65,6 +70,12 @@ int main(int argc, char** argv)
if (config.continue_from == "No" && config.only) {
throw std::runtime_error("Cannot use --only without --continue");
}
dump = program.get<bool>("dump");
compute = program.get<bool>("compute");
list = program.get<bool>("list");
if (dump && (config.continue_from != "No" || config.only)) {
throw std::runtime_error("Cannot use --dump with --continue or --only");
}
}
catch (const exception& err) {
cerr << err.what() << std::endl;
@ -80,8 +91,73 @@ int main(int argc, char** argv)
auto grid_search = platform::GridSearch(config);
platform::Timer timer;
timer.start();
grid_search.go();
std::cout << "Process took " << timer.getDurationString() << std::endl;
if (dump) {
auto combinations = grid_search.dump();
auto total = combinations.size();
int spaces = int(log(total) / log(10)) + 1;
std::cout << Colors::MAGENTA() << "There are " << total << " combinations" << std::endl << std::endl;
int index = 0;
int max = 0;
for (auto const& item : combinations) {
if (item.dump().size() > spaces) {
max = item.dump().size();
}
}
std::cout << Colors::GREEN() << left << setw(spaces) << "#" << left << " " << setw(spaces)
<< "Hyperparameters" << std::endl;
std::cout << string(spaces, '=') << " " << string(max, '=') << std::endl;
bool odd = true;
for (auto const& item : combinations) {
auto color = odd ? Colors::CYAN() : Colors::BLUE();
std::cout << color;
std::cout << setw(spaces) << fixed << right << ++index << left << " " << item.dump() << std::endl;
odd = !odd;
}
std::cout << Colors::RESET() << std::endl;
} else {
if (compute) {
grid_search.go();
std::cout << "Process took " << timer.getDurationString() << std::endl;
} else {
std::cout << Colors::MAGENTA() << "Listing computed hyperparameters for model "
<< config.model << std::endl << std::endl;
auto results = grid_search.getResults();
if (results.empty()) {
std::cout << "No results found" << std::endl;
} else {
int spaces = 0;
int hyperparameters_spaces = 0;
for (const auto& item : results.items()) {
auto key = item.key();
auto value = item.value();
if (key.size() > spaces) {
spaces = key.size();
}
if (value["hyperparameters"].dump().size() > hyperparameters_spaces) {
hyperparameters_spaces = value["hyperparameters"].dump().size();
}
}
std::cout << Colors::GREEN() << " # " << left << setw(spaces) << "Dataset" << " " << setw(19) << "Date" << " "
<< setw(8) << "Score" << " " << "Hyperparameters" << std::endl;
std::cout << "=== " << string(spaces, '=') << " " << string(19, '=') << " " << string(8, '=') << " "
<< string(hyperparameters_spaces, '=') << std::endl;
bool odd = true;
int index = 0;
for (const auto& item : results.items()) {
auto color = odd ? Colors::CYAN() : Colors::BLUE();
auto key = item.key();
auto value = item.value();
std::cout << color;
std::cout << std::setw(3) << std::right << index++ << " ";
std::cout << left << setw(spaces) << key << " " << value["date"].get<string>()
<< " " << setw(8) << setprecision(6) << fixed << right
<< value["score"].get<double>() << " " << value["hyperparameters"].dump() << std::endl;
odd = !odd;
}
std::cout << Colors::RESET() << std::endl;
}
}
}
std::cout << "Done!" << std::endl;
return 0;
}

View File

@ -11,10 +11,9 @@
using json = nlohmann::json;
argparse::ArgumentParser manageArguments(std::string program_name)
void manageArguments(argparse::ArgumentParser& program)
{
auto env = platform::DotEnv();
argparse::ArgumentParser program(program_name);
program.add_argument("-d", "--dataset").default_value("").help("Dataset file name");
program.add_argument("--hyperparameters").default_value("{}").help("Hyperparameters passed to the model in Experiment");
program.add_argument("--hyper-file").default_value("").help("Hyperparameters file name." \
@ -50,18 +49,18 @@ argparse::ArgumentParser manageArguments(std::string program_name)
}});
auto seed_values = env.getSeeds();
program.add_argument("-s", "--seeds").nargs(1, 10).help("Random seeds. Set to -1 to have pseudo random").scan<'i', int>().default_value(seed_values);
return program;
}
int main(int argc, char** argv)
{
argparse::ArgumentParser program("b_main");
manageArguments(program);
std::string file_name, model_name, title, hyperparameters_file;
json hyperparameters_json;
bool discretize_dataset, stratified, saveResults, quiet;
std::vector<int> seeds;
std::vector<std::string> filesToTest;
int n_folds;
auto program = manageArguments("b_main");
try {
program.parse_args(argc, argv);
file_name = program.get<std::string>("dataset");

View File

@ -3,9 +3,8 @@
#include "ManageResults.h"
argparse::ArgumentParser manageArguments(int argc, char** argv)
void manageArguments(argparse::ArgumentParser& program, int argc, char** argv)
{
argparse::ArgumentParser program("b_manage");
program.add_argument("-n", "--number").default_value(0).help("Number of results to show (0 = all)").scan<'i', int>();
program.add_argument("-m", "--model").default_value("any").help("Filter results of the selected model)");
program.add_argument("-s", "--score").default_value("any").help("Filter results of the score name supplied");
@ -29,12 +28,12 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
std::cerr << program;
exit(1);
}
return program;
}
int main(int argc, char** argv)
{
auto program = manageArguments(argc, argv);
auto program = argparse::ArgumentParser("b_manage");
manageArguments(program, argc, argv);
int number = program.get<int>("number");
std::string model = program.get<std::string>("model");
std::string score = program.get<std::string>("score");