From dee9c674da2aee8bb4cb7ed10c2d0e99a6d5416c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Wed, 29 Nov 2023 18:24:34 +0100 Subject: [PATCH] Refactor grid input hyperparameter file --- src/Platform/GridData.cc | 30 +++++++-- src/Platform/GridData.h | 11 ++-- src/Platform/GridSearch.cc | 23 +++---- src/Platform/GridSearch.h | 6 +- src/Platform/Paths.h | 8 +++ src/Platform/b_grid.cc | 129 +++++++++++++++++++++---------------- 6 files changed, 120 insertions(+), 87 deletions(-) diff --git a/src/Platform/GridData.cc b/src/Platform/GridData.cc index b2a8149..e93ee17 100644 --- a/src/Platform/GridData.cc +++ b/src/Platform/GridData.cc @@ -4,12 +4,19 @@ namespace platform { GridData::GridData(const std::string& fileName) { + json grid_file; std::ifstream resultData(fileName); if (resultData.is_open()) { - grid = json::parse(resultData); + grid_file = json::parse(resultData); } else { throw std::invalid_argument("Unable to open input file. [" + fileName + "]"); } + for (const auto& item : grid_file.items()) { + auto key = item.key(); + auto value = item.value(); + grid[key] = value; + } + } int GridData::computeNumCombinations(const json& line) { @@ -19,10 +26,11 @@ namespace platform { } return numCombinations; } - int GridData::getNumCombinations() + int GridData::getNumCombinations(const std::string& dataset) { int numCombinations = 0; - for (const auto& line : grid) { + auto selected = decide_dataset(dataset); + for (const auto& line : grid.at(selected)) { numCombinations += computeNumCombinations(line); } return numCombinations; @@ -44,16 +52,24 @@ namespace platform { } return currentCombination; } - std::vector GridData::getGrid() + std::vector GridData::getGrid(const std::string& dataset) { + auto selected = decide_dataset(dataset); auto result = std::vector(); - for (json line : grid) { + for (json line : grid.at(selected)) { generateCombinations(line.begin(), line.end(), result, json({})); } return result; } - json& GridData::getInputGrid() + json& GridData::getInputGrid(const std::string& dataset) { - return grid; + auto selected = decide_dataset(dataset); + return grid.at(selected); + } + std::string GridData::decide_dataset(const std::string& dataset) + { + if (grid.find(dataset) != grid.end()) + return dataset; + return ALL_DATASETS; } } /* namespace platform */ \ No newline at end of file diff --git a/src/Platform/GridData.h b/src/Platform/GridData.h index 3a03f03..0156453 100644 --- a/src/Platform/GridData.h +++ b/src/Platform/GridData.h @@ -7,17 +7,20 @@ namespace platform { using json = nlohmann::json; + const std::string ALL_DATASETS = "all"; class GridData { public: explicit GridData(const std::string& fileName); ~GridData() = default; - std::vector getGrid(); - int getNumCombinations(); - json& getInputGrid(); + std::vector getGrid(const std::string& dataset = ALL_DATASETS); + int getNumCombinations(const std::string& dataset = ALL_DATASETS); + json& getInputGrid(const std::string& dataset = ALL_DATASETS); + std::map& getGridFile() { return grid; } private: + std::string decide_dataset(const std::string& dataset); json generateCombinations(json::iterator index, const json::iterator last, std::vector& output, json currentCombination); int computeNumCombinations(const json& line); - json grid; + std::map grid; }; } /* namespace platform */ #endif /* GRIDDATA_H */ \ No newline at end of file diff --git a/src/Platform/GridSearch.cc b/src/Platform/GridSearch.cc index c056272..809bd8f 100644 --- a/src/Platform/GridSearch.cc +++ b/src/Platform/GridSearch.cc @@ -29,16 +29,10 @@ namespace platform { } GridSearch::GridSearch(struct ConfigGrid& config) : config(config) { - this->config.output_file = config.path + "grid_" + config.model + "_output.json"; - this->config.input_file = config.path + "grid_" + config.model + "_input.json"; - } - std::vector GridSearch::dump() - { - return GridData(config.input_file).getGrid(); } json GridSearch::getResults() { - std::ifstream file(config.output_file); + std::ifstream file(Paths::grid_output(config.model)); if (file.is_open()) { return json::parse(file); } @@ -131,7 +125,7 @@ namespace platform { if (!config.quiet) std::cout << "* Loading previous results" << std::endl; try { - std::ifstream file(config.output_file); + std::ifstream file(Paths::grid_output(config.model)); if (file.is_open()) { results = json::parse(file); } @@ -156,19 +150,18 @@ namespace platform { } // Create model std::cout << "***************** Starting Gridsearch *****************" << std::endl; - std::cout << "input file=" << config.input_file << std::endl; - auto grid = GridData(config.input_file); - auto totalComb = grid.getNumCombinations(); - std::cout << "* Doing " << totalComb << " combinations for each dataset/seed/fold" << std::endl; + std::cout << "input file=" << Paths::grid_input(config.model) << std::endl; + auto grid = GridData(Paths::grid_input(config.model)); // Generate hyperparameters grid & run gridsearch // Check each combination of hyperparameters for each dataset and each seed for (const auto& dataset : datasets_names) { + auto totalComb = grid.getNumCombinations(dataset); if (!config.quiet) std::cout << "- " << setw(20) << left << dataset << " " << right << flush; int num = 0; double bestScore = 0.0; json bestHyperparameters; - auto combinations = grid.getGrid(); + auto combinations = grid.getGrid(dataset); for (const auto& hyperparam_line : combinations) { if (!config.quiet) showProgressComb(++num, totalComb, Colors::CYAN()); @@ -186,7 +179,7 @@ namespace platform { results[dataset]["score"] = bestScore; results[dataset]["hyperparameters"] = bestHyperparameters; results[dataset]["date"] = get_date() + " " + get_time(); - results[dataset]["grid"] = grid.getInputGrid(); + results[dataset]["grid"] = grid.getInputGrid(dataset); // Save partial results save(results); } @@ -196,7 +189,7 @@ namespace platform { } void GridSearch::save(json& results) const { - std::ofstream file(config.output_file); + std::ofstream file(Paths::grid_output(config.model)); file << results.dump(4); } } /* namespace platform */ \ No newline at end of file diff --git a/src/Platform/GridSearch.h b/src/Platform/GridSearch.h index 61bc242..d46ffb3 100644 --- a/src/Platform/GridSearch.h +++ b/src/Platform/GridSearch.h @@ -1,7 +1,7 @@ #ifndef GRIDSEARCH_H #define GRIDSEARCH_H #include -#include +#include #include #include "Datasets.h" #include "HyperParameters.h" @@ -12,9 +12,6 @@ namespace platform { struct ConfigGrid { std::string model; std::string score; - std::string path; - std::string input_file; - std::string output_file; std::string continue_from; bool quiet; bool only; // used with continue_from to only compute that dataset @@ -28,7 +25,6 @@ namespace platform { explicit GridSearch(struct ConfigGrid& config); void go(); ~GridSearch() = default; - std::vector dump(); json getResults(); private: void save(json& results) const; diff --git a/src/Platform/Paths.h b/src/Platform/Paths.h index 3f5d135..6fd61cf 100644 --- a/src/Platform/Paths.h +++ b/src/Platform/Paths.h @@ -26,6 +26,14 @@ namespace platform { } } static std::string excelResults() { return "some_results.xlsx"; } + static std::string grid_input(const std::string& model) + { + return grid() + "grid_" + model + "_input.json"; + } + static std::string grid_output(const std::string& model) + { + return grid() + "grid_" + model + "_output.json"; + } }; } #endif \ No newline at end of file diff --git a/src/Platform/b_grid.cc b/src/Platform/b_grid.cc index 19f1e2c..64452ed 100644 --- a/src/Platform/b_grid.cc +++ b/src/Platform/b_grid.cc @@ -1,5 +1,7 @@ #include #include +#include +#include #include "DotEnv.h" #include "Models.h" #include "modelRegister.h" @@ -8,6 +10,7 @@ #include "Timer.h" #include "Colors.h" +using json = nlohmann::json; void manageArguments(argparse::ArgumentParser& program) { @@ -50,6 +53,72 @@ void manageArguments(argparse::ArgumentParser& program) program.add_argument("-s", "--seeds").nargs(1, 10).help("Random seeds. Set to -1 to have pseudo random").scan<'i', int>().default_value(seed_values); } +void list_dump(std::string& model) +{ + auto data = platform::GridData(platform::Paths::grid_input(model)); + std::cout << Colors::MAGENTA() << "Listing configuration input file (Grid)" << std::endl << std::endl; + int index = 0; + int max_hyper = 15; + int max_dataset = 7; + auto combinations = data.getGridFile(); + for (auto const& item : combinations) { + if (item.first.size() > max_dataset) { + max_dataset = item.first.size(); + } + if (item.second.dump().size() > max_hyper) { + max_hyper = item.second.dump().size(); + } + } + std::cout << Colors::GREEN() << left << " # " << left << setw(max_dataset) << "Dataset" << " #Com. " + << setw(max_hyper) << "Hyperparameters" << std::endl; + std::cout << "=== " << string(max_dataset, '=') << " ===== " << string(max_hyper, '=') << std::endl; + bool odd = true; + for (auto const& item : combinations) { + auto color = odd ? Colors::CYAN() : Colors::BLUE(); + std::cout << color; + auto num_combinations = data.getNumCombinations(item.first); + std::cout << setw(3) << fixed << right << ++index << left << " " << setw(max_dataset) << item.first + << " " << setw(5) << right << num_combinations << " " << setw(max_hyper) << item.second.dump() << std::endl; + odd = !odd; + } + std::cout << Colors::RESET() << std::endl; +} +void list_results(json& results, std::string& model) +{ + std::cout << Colors::MAGENTA() << "Listing computed hyperparameters for model " + << model << std::endl << std::endl; + int spaces = 0; + int hyperparameters_spaces = 0; + for (const auto& item : results.items()) { + auto key = item.key(); + auto value = item.value(); + if (key.size() > spaces) { + spaces = key.size(); + } + if (value["hyperparameters"].dump().size() > hyperparameters_spaces) { + hyperparameters_spaces = value["hyperparameters"].dump().size(); + } + } + std::cout << Colors::GREEN() << " # " << left << setw(spaces) << "Dataset" << " " << setw(19) << "Date" << " " + << setw(8) << "Score" << " " << "Hyperparameters" << std::endl; + std::cout << "=== " << string(spaces, '=') << " " << string(19, '=') << " " << string(8, '=') << " " + << string(hyperparameters_spaces, '=') << std::endl; + bool odd = true; + int index = 0; + for (const auto& item : results.items()) { + auto color = odd ? Colors::CYAN() : Colors::BLUE(); + auto key = item.key(); + auto value = item.value(); + std::cout << color; + std::cout << std::setw(3) << std::right << index++ << " "; + std::cout << left << setw(spaces) << key << " " << value["date"].get() + << " " << setw(8) << setprecision(6) << fixed << right + << value["score"].get() << " " << value["hyperparameters"].dump() << std::endl; + odd = !odd; + } + std::cout << Colors::RESET() << std::endl; +} + int main(int argc, char** argv) { argparse::ArgumentParser program("b_grid"); @@ -87,74 +156,22 @@ int main(int argc, char** argv) */ auto env = platform::DotEnv(); platform::Paths::createPath(platform::Paths::grid()); - config.path = platform::Paths::grid(); auto grid_search = platform::GridSearch(config); platform::Timer timer; timer.start(); if (dump) { - auto combinations = grid_search.dump(); - auto total = combinations.size(); - int spaces = int(log(total) / log(10)) + 1; - std::cout << Colors::MAGENTA() << "There are " << total << " combinations" << std::endl << std::endl; - int index = 0; - int max = 0; - for (auto const& item : combinations) { - if (item.dump().size() > spaces) { - max = item.dump().size(); - } - } - std::cout << Colors::GREEN() << left << setw(spaces) << "#" << left << " " << setw(spaces) - << "Hyperparameters" << std::endl; - std::cout << string(spaces, '=') << " " << string(max, '=') << std::endl; - bool odd = true; - for (auto const& item : combinations) { - auto color = odd ? Colors::CYAN() : Colors::BLUE(); - std::cout << color; - std::cout << setw(spaces) << fixed << right << ++index << left << " " << item.dump() << std::endl; - odd = !odd; - } - std::cout << Colors::RESET() << std::endl; + list_dump(config.model); } else { if (compute) { grid_search.go(); std::cout << "Process took " << timer.getDurationString() << std::endl; } else { - std::cout << Colors::MAGENTA() << "Listing computed hyperparameters for model " - << config.model << std::endl << std::endl; + // List results auto results = grid_search.getResults(); if (results.empty()) { - std::cout << "No results found" << std::endl; + std::cout << "** No results found" << std::endl; } else { - int spaces = 0; - int hyperparameters_spaces = 0; - for (const auto& item : results.items()) { - auto key = item.key(); - auto value = item.value(); - if (key.size() > spaces) { - spaces = key.size(); - } - if (value["hyperparameters"].dump().size() > hyperparameters_spaces) { - hyperparameters_spaces = value["hyperparameters"].dump().size(); - } - } - std::cout << Colors::GREEN() << " # " << left << setw(spaces) << "Dataset" << " " << setw(19) << "Date" << " " - << setw(8) << "Score" << " " << "Hyperparameters" << std::endl; - std::cout << "=== " << string(spaces, '=') << " " << string(19, '=') << " " << string(8, '=') << " " - << string(hyperparameters_spaces, '=') << std::endl; - bool odd = true; - int index = 0; - for (const auto& item : results.items()) { - auto color = odd ? Colors::CYAN() : Colors::BLUE(); - auto key = item.key(); - auto value = item.value(); - std::cout << color; - std::cout << std::setw(3) << std::right << index++ << " "; - std::cout << left << setw(spaces) << key << " " << value["date"].get() - << " " << setw(8) << setprecision(6) << fixed << right - << value["score"].get() << " " << value["hyperparameters"].dump() << std::endl; - odd = !odd; - } - std::cout << Colors::RESET() << std::endl; + list_results(results, config.model); } } }