From 460d20a4025ee0a9bfa4cfb4decea2fb096a1ab5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Wed, 29 Nov 2023 00:26:48 +0100 Subject: [PATCH] Add reports to gridsearch --- .gitmodules | 10 +++++ lib/argparse | 2 +- src/Platform/GridSearch.cc | 12 ++++++ src/Platform/GridSearch.h | 2 + src/Platform/b_best.cc | 7 ++- src/Platform/b_grid.cc | 88 +++++++++++++++++++++++++++++++++++--- src/Platform/b_main.cc | 7 ++- src/Platform/b_manage.cc | 7 ++- 8 files changed, 116 insertions(+), 19 deletions(-) diff --git a/.gitmodules b/.gitmodules index dbb94fc..6be5a87 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,15 +1,25 @@ [submodule "lib/mdlp"] path = lib/mdlp url = https://github.com/rmontanana/mdlp + main = main + update = merge [submodule "lib/catch2"] path = lib/catch2 + main = v2.x + update = merge url = https://github.com/catchorg/Catch2.git [submodule "lib/argparse"] path = lib/argparse url = https://github.com/p-ranav/argparse + master = master + update = merge [submodule "lib/json"] path = lib/json url = https://github.com/nlohmann/json.git + master = master + update = merge [submodule "lib/libxlsxwriter"] path = lib/libxlsxwriter url = https://github.com/jmcnamara/libxlsxwriter.git + main = main + update = merge diff --git a/lib/argparse b/lib/argparse index b0930ab..69dabd8 160000 --- a/lib/argparse +++ b/lib/argparse @@ -1 +1 @@ -Subproject commit b0930ab0288185815d6dc67af59de7014a6272f7 +Subproject commit 69dabd88a8e6680b1a1a18397eb3e165e4019ce6 diff --git a/src/Platform/GridSearch.cc b/src/Platform/GridSearch.cc index 2074514..4379178 100644 --- a/src/Platform/GridSearch.cc +++ b/src/Platform/GridSearch.cc @@ -32,6 +32,18 @@ namespace platform { this->config.output_file = config.path + "grid_" + config.model + "_output.json"; this->config.input_file = config.path + "grid_" + config.model + "_input.json"; } + std::vector GridSearch::dump() + { + return GridData(config.input_file).getGrid(); + } + json GridSearch::getResults() + { + std::ifstream file(config.output_file); + if (file.is_open()) { + return json::parse(file); + } + return json(); + } void showProgressComb(const int num, const int total, const std::string& color) { int spaces = int(log(total) / log(10)) + 1; diff --git a/src/Platform/GridSearch.h b/src/Platform/GridSearch.h index b52c1e9..61bc242 100644 --- a/src/Platform/GridSearch.h +++ b/src/Platform/GridSearch.h @@ -28,6 +28,8 @@ namespace platform { explicit GridSearch(struct ConfigGrid& config); void go(); ~GridSearch() = default; + std::vector dump(); + json getResults(); private: void save(json& results) const; double processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters); diff --git a/src/Platform/b_best.cc b/src/Platform/b_best.cc index 1ed73c7..4112f4b 100644 --- a/src/Platform/b_best.cc +++ b/src/Platform/b_best.cc @@ -5,9 +5,8 @@ #include "Colors.h" -argparse::ArgumentParser manageArguments(int argc, char** argv) +void manageArguments(argparse::ArgumentParser& program, int argc, char** argv) { - argparse::ArgumentParser program("b_sbest"); program.add_argument("-m", "--model").default_value("").help("Filter results of the selected model) (any for all models)"); program.add_argument("-s", "--score").default_value("").help("Filter results of the score name supplied"); program.add_argument("--build").help("build best score results file").default_value(false).implicit_value(true); @@ -28,12 +27,12 @@ argparse::ArgumentParser manageArguments(int argc, char** argv) catch (...) { throw std::runtime_error("Number of folds must be an decimal number"); }}); - return program; } int main(int argc, char** argv) { - auto program = manageArguments(argc, argv); + argparse::ArgumentParser program("b_sbest"); + manageArguments(program, argc, argv); std::string model, score; bool build, report, friedman, excel; double level; diff --git a/src/Platform/b_grid.cc b/src/Platform/b_grid.cc index 6a63332..19f1e2c 100644 --- a/src/Platform/b_grid.cc +++ b/src/Platform/b_grid.cc @@ -6,12 +6,13 @@ #include "GridSearch.h" #include "Paths.h" #include "Timer.h" +#include "Colors.h" -argparse::ArgumentParser manageArguments(std::string program_name) +void manageArguments(argparse::ArgumentParser& program) { auto env = platform::DotEnv(); - argparse::ArgumentParser program(program_name); + auto& group = program.add_mutually_exclusive_group(true); program.add_argument("-m", "--model") .help("Model to use " + platform::Models::instance()->tostring()) .action([](const std::string& value) { @@ -22,6 +23,9 @@ argparse::ArgumentParser manageArguments(std::string program_name) throw std::runtime_error("Model must be one of " + platform::Models::instance()->tostring()); } ); + group.add_argument("--dump").help("Show the grid combinations").default_value(false).implicit_value(true); + group.add_argument("--list").help("List the computed hyperparameters").default_value(false).implicit_value(true); + group.add_argument("--compute").help("Perform computation of the grid output hyperparameters").default_value(false).implicit_value(true); program.add_argument("--discretize").help("Discretize input datasets").default_value((bool)stoi(env.get("discretize"))).implicit_value(true); program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true); program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true); @@ -44,13 +48,14 @@ argparse::ArgumentParser manageArguments(std::string program_name) }}); auto seed_values = env.getSeeds(); program.add_argument("-s", "--seeds").nargs(1, 10).help("Random seeds. Set to -1 to have pseudo random").scan<'i', int>().default_value(seed_values); - return program; } int main(int argc, char** argv) { - auto program = manageArguments("b_grid"); + argparse::ArgumentParser program("b_grid"); + manageArguments(program); struct platform::ConfigGrid config; + bool dump, compute, list; try { program.parse_args(argc, argv); config.model = program.get("model"); @@ -65,6 +70,12 @@ int main(int argc, char** argv) if (config.continue_from == "No" && config.only) { throw std::runtime_error("Cannot use --only without --continue"); } + dump = program.get("dump"); + compute = program.get("compute"); + list = program.get("list"); + if (dump && (config.continue_from != "No" || config.only)) { + throw std::runtime_error("Cannot use --dump with --continue or --only"); + } } catch (const exception& err) { cerr << err.what() << std::endl; @@ -80,8 +91,73 @@ int main(int argc, char** argv) auto grid_search = platform::GridSearch(config); platform::Timer timer; timer.start(); - grid_search.go(); - std::cout << "Process took " << timer.getDurationString() << std::endl; + if (dump) { + auto combinations = grid_search.dump(); + auto total = combinations.size(); + int spaces = int(log(total) / log(10)) + 1; + std::cout << Colors::MAGENTA() << "There are " << total << " combinations" << std::endl << std::endl; + int index = 0; + int max = 0; + for (auto const& item : combinations) { + if (item.dump().size() > spaces) { + max = item.dump().size(); + } + } + std::cout << Colors::GREEN() << left << setw(spaces) << "#" << left << " " << setw(spaces) + << "Hyperparameters" << std::endl; + std::cout << string(spaces, '=') << " " << string(max, '=') << std::endl; + bool odd = true; + for (auto const& item : combinations) { + auto color = odd ? Colors::CYAN() : Colors::BLUE(); + std::cout << color; + std::cout << setw(spaces) << fixed << right << ++index << left << " " << item.dump() << std::endl; + odd = !odd; + } + std::cout << Colors::RESET() << std::endl; + } else { + if (compute) { + grid_search.go(); + std::cout << "Process took " << timer.getDurationString() << std::endl; + } else { + std::cout << Colors::MAGENTA() << "Listing computed hyperparameters for model " + << config.model << std::endl << std::endl; + auto results = grid_search.getResults(); + if (results.empty()) { + std::cout << "No results found" << std::endl; + } else { + int spaces = 0; + int hyperparameters_spaces = 0; + for (const auto& item : results.items()) { + auto key = item.key(); + auto value = item.value(); + if (key.size() > spaces) { + spaces = key.size(); + } + if (value["hyperparameters"].dump().size() > hyperparameters_spaces) { + hyperparameters_spaces = value["hyperparameters"].dump().size(); + } + } + std::cout << Colors::GREEN() << " # " << left << setw(spaces) << "Dataset" << " " << setw(19) << "Date" << " " + << setw(8) << "Score" << " " << "Hyperparameters" << std::endl; + std::cout << "=== " << string(spaces, '=') << " " << string(19, '=') << " " << string(8, '=') << " " + << string(hyperparameters_spaces, '=') << std::endl; + bool odd = true; + int index = 0; + for (const auto& item : results.items()) { + auto color = odd ? Colors::CYAN() : Colors::BLUE(); + auto key = item.key(); + auto value = item.value(); + std::cout << color; + std::cout << std::setw(3) << std::right << index++ << " "; + std::cout << left << setw(spaces) << key << " " << value["date"].get() + << " " << setw(8) << setprecision(6) << fixed << right + << value["score"].get() << " " << value["hyperparameters"].dump() << std::endl; + odd = !odd; + } + std::cout << Colors::RESET() << std::endl; + } + } + } std::cout << "Done!" << std::endl; return 0; } diff --git a/src/Platform/b_main.cc b/src/Platform/b_main.cc index c09f071..0dee793 100644 --- a/src/Platform/b_main.cc +++ b/src/Platform/b_main.cc @@ -11,10 +11,9 @@ using json = nlohmann::json; -argparse::ArgumentParser manageArguments(std::string program_name) +void manageArguments(argparse::ArgumentParser& program) { auto env = platform::DotEnv(); - argparse::ArgumentParser program(program_name); program.add_argument("-d", "--dataset").default_value("").help("Dataset file name"); program.add_argument("--hyperparameters").default_value("{}").help("Hyperparameters passed to the model in Experiment"); program.add_argument("--hyper-file").default_value("").help("Hyperparameters file name." \ @@ -50,18 +49,18 @@ argparse::ArgumentParser manageArguments(std::string program_name) }}); auto seed_values = env.getSeeds(); program.add_argument("-s", "--seeds").nargs(1, 10).help("Random seeds. Set to -1 to have pseudo random").scan<'i', int>().default_value(seed_values); - return program; } int main(int argc, char** argv) { + argparse::ArgumentParser program("b_main"); + manageArguments(program); std::string file_name, model_name, title, hyperparameters_file; json hyperparameters_json; bool discretize_dataset, stratified, saveResults, quiet; std::vector seeds; std::vector filesToTest; int n_folds; - auto program = manageArguments("b_main"); try { program.parse_args(argc, argv); file_name = program.get("dataset"); diff --git a/src/Platform/b_manage.cc b/src/Platform/b_manage.cc index 1067902..a6e7be2 100644 --- a/src/Platform/b_manage.cc +++ b/src/Platform/b_manage.cc @@ -3,9 +3,8 @@ #include "ManageResults.h" -argparse::ArgumentParser manageArguments(int argc, char** argv) +void manageArguments(argparse::ArgumentParser& program, int argc, char** argv) { - argparse::ArgumentParser program("b_manage"); program.add_argument("-n", "--number").default_value(0).help("Number of results to show (0 = all)").scan<'i', int>(); program.add_argument("-m", "--model").default_value("any").help("Filter results of the selected model)"); program.add_argument("-s", "--score").default_value("any").help("Filter results of the score name supplied"); @@ -29,12 +28,12 @@ argparse::ArgumentParser manageArguments(int argc, char** argv) std::cerr << program; exit(1); } - return program; } int main(int argc, char** argv) { - auto program = manageArguments(argc, argv); + auto program = argparse::ArgumentParser("b_manage"); + manageArguments(program, argc, argv); int number = program.get("number"); std::string model = program.get("model"); std::string score = program.get("score");