Refactor grid input hyperparameter file

This commit is contained in:
Ricardo Montañana Gómez 2023-11-29 18:24:34 +01:00
parent e3f6dc1e0b
commit dee9c674da
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
6 changed files with 120 additions and 87 deletions

View File

@ -4,12 +4,19 @@
namespace platform { namespace platform {
GridData::GridData(const std::string& fileName) GridData::GridData(const std::string& fileName)
{ {
json grid_file;
std::ifstream resultData(fileName); std::ifstream resultData(fileName);
if (resultData.is_open()) { if (resultData.is_open()) {
grid = json::parse(resultData); grid_file = json::parse(resultData);
} else { } else {
throw std::invalid_argument("Unable to open input file. [" + fileName + "]"); throw std::invalid_argument("Unable to open input file. [" + fileName + "]");
} }
for (const auto& item : grid_file.items()) {
auto key = item.key();
auto value = item.value();
grid[key] = value;
}
} }
int GridData::computeNumCombinations(const json& line) int GridData::computeNumCombinations(const json& line)
{ {
@ -19,10 +26,11 @@ namespace platform {
} }
return numCombinations; return numCombinations;
} }
int GridData::getNumCombinations() int GridData::getNumCombinations(const std::string& dataset)
{ {
int numCombinations = 0; int numCombinations = 0;
for (const auto& line : grid) { auto selected = decide_dataset(dataset);
for (const auto& line : grid.at(selected)) {
numCombinations += computeNumCombinations(line); numCombinations += computeNumCombinations(line);
} }
return numCombinations; return numCombinations;
@ -44,16 +52,24 @@ namespace platform {
} }
return currentCombination; return currentCombination;
} }
std::vector<json> GridData::getGrid() std::vector<json> GridData::getGrid(const std::string& dataset)
{ {
auto selected = decide_dataset(dataset);
auto result = std::vector<json>(); auto result = std::vector<json>();
for (json line : grid) { for (json line : grid.at(selected)) {
generateCombinations(line.begin(), line.end(), result, json({})); generateCombinations(line.begin(), line.end(), result, json({}));
} }
return result; return result;
} }
json& GridData::getInputGrid() json& GridData::getInputGrid(const std::string& dataset)
{ {
return grid; auto selected = decide_dataset(dataset);
return grid.at(selected);
}
std::string GridData::decide_dataset(const std::string& dataset)
{
if (grid.find(dataset) != grid.end())
return dataset;
return ALL_DATASETS;
} }
} /* namespace platform */ } /* namespace platform */

View File

@ -7,17 +7,20 @@
namespace platform { namespace platform {
using json = nlohmann::json; using json = nlohmann::json;
const std::string ALL_DATASETS = "all";
class GridData { class GridData {
public: public:
explicit GridData(const std::string& fileName); explicit GridData(const std::string& fileName);
~GridData() = default; ~GridData() = default;
std::vector<json> getGrid(); std::vector<json> getGrid(const std::string& dataset = ALL_DATASETS);
int getNumCombinations(); int getNumCombinations(const std::string& dataset = ALL_DATASETS);
json& getInputGrid(); json& getInputGrid(const std::string& dataset = ALL_DATASETS);
std::map<std::string, json>& getGridFile() { return grid; }
private: private:
std::string decide_dataset(const std::string& dataset);
json generateCombinations(json::iterator index, const json::iterator last, std::vector<json>& output, json currentCombination); json generateCombinations(json::iterator index, const json::iterator last, std::vector<json>& output, json currentCombination);
int computeNumCombinations(const json& line); int computeNumCombinations(const json& line);
json grid; std::map<std::string, json> grid;
}; };
} /* namespace platform */ } /* namespace platform */
#endif /* GRIDDATA_H */ #endif /* GRIDDATA_H */

View File

@ -29,16 +29,10 @@ namespace platform {
} }
GridSearch::GridSearch(struct ConfigGrid& config) : config(config) GridSearch::GridSearch(struct ConfigGrid& config) : config(config)
{ {
this->config.output_file = config.path + "grid_" + config.model + "_output.json";
this->config.input_file = config.path + "grid_" + config.model + "_input.json";
}
std::vector<json> GridSearch::dump()
{
return GridData(config.input_file).getGrid();
} }
json GridSearch::getResults() json GridSearch::getResults()
{ {
std::ifstream file(config.output_file); std::ifstream file(Paths::grid_output(config.model));
if (file.is_open()) { if (file.is_open()) {
return json::parse(file); return json::parse(file);
} }
@ -131,7 +125,7 @@ namespace platform {
if (!config.quiet) if (!config.quiet)
std::cout << "* Loading previous results" << std::endl; std::cout << "* Loading previous results" << std::endl;
try { try {
std::ifstream file(config.output_file); std::ifstream file(Paths::grid_output(config.model));
if (file.is_open()) { if (file.is_open()) {
results = json::parse(file); results = json::parse(file);
} }
@ -156,19 +150,18 @@ namespace platform {
} }
// Create model // Create model
std::cout << "***************** Starting Gridsearch *****************" << std::endl; std::cout << "***************** Starting Gridsearch *****************" << std::endl;
std::cout << "input file=" << config.input_file << std::endl; std::cout << "input file=" << Paths::grid_input(config.model) << std::endl;
auto grid = GridData(config.input_file); auto grid = GridData(Paths::grid_input(config.model));
auto totalComb = grid.getNumCombinations();
std::cout << "* Doing " << totalComb << " combinations for each dataset/seed/fold" << std::endl;
// Generate hyperparameters grid & run gridsearch // Generate hyperparameters grid & run gridsearch
// Check each combination of hyperparameters for each dataset and each seed // Check each combination of hyperparameters for each dataset and each seed
for (const auto& dataset : datasets_names) { for (const auto& dataset : datasets_names) {
auto totalComb = grid.getNumCombinations(dataset);
if (!config.quiet) if (!config.quiet)
std::cout << "- " << setw(20) << left << dataset << " " << right << flush; std::cout << "- " << setw(20) << left << dataset << " " << right << flush;
int num = 0; int num = 0;
double bestScore = 0.0; double bestScore = 0.0;
json bestHyperparameters; json bestHyperparameters;
auto combinations = grid.getGrid(); auto combinations = grid.getGrid(dataset);
for (const auto& hyperparam_line : combinations) { for (const auto& hyperparam_line : combinations) {
if (!config.quiet) if (!config.quiet)
showProgressComb(++num, totalComb, Colors::CYAN()); showProgressComb(++num, totalComb, Colors::CYAN());
@ -186,7 +179,7 @@ namespace platform {
results[dataset]["score"] = bestScore; results[dataset]["score"] = bestScore;
results[dataset]["hyperparameters"] = bestHyperparameters; results[dataset]["hyperparameters"] = bestHyperparameters;
results[dataset]["date"] = get_date() + " " + get_time(); results[dataset]["date"] = get_date() + " " + get_time();
results[dataset]["grid"] = grid.getInputGrid(); results[dataset]["grid"] = grid.getInputGrid(dataset);
// Save partial results // Save partial results
save(results); save(results);
} }
@ -196,7 +189,7 @@ namespace platform {
} }
void GridSearch::save(json& results) const void GridSearch::save(json& results) const
{ {
std::ofstream file(config.output_file); std::ofstream file(Paths::grid_output(config.model));
file << results.dump(4); file << results.dump(4);
} }
} /* namespace platform */ } /* namespace platform */

View File

@ -1,7 +1,7 @@
#ifndef GRIDSEARCH_H #ifndef GRIDSEARCH_H
#define GRIDSEARCH_H #define GRIDSEARCH_H
#include <string> #include <string>
#include <vector> #include <map>
#include <nlohmann/json.hpp> #include <nlohmann/json.hpp>
#include "Datasets.h" #include "Datasets.h"
#include "HyperParameters.h" #include "HyperParameters.h"
@ -12,9 +12,6 @@ namespace platform {
struct ConfigGrid { struct ConfigGrid {
std::string model; std::string model;
std::string score; std::string score;
std::string path;
std::string input_file;
std::string output_file;
std::string continue_from; std::string continue_from;
bool quiet; bool quiet;
bool only; // used with continue_from to only compute that dataset bool only; // used with continue_from to only compute that dataset
@ -28,7 +25,6 @@ namespace platform {
explicit GridSearch(struct ConfigGrid& config); explicit GridSearch(struct ConfigGrid& config);
void go(); void go();
~GridSearch() = default; ~GridSearch() = default;
std::vector<json> dump();
json getResults(); json getResults();
private: private:
void save(json& results) const; void save(json& results) const;

View File

@ -26,6 +26,14 @@ namespace platform {
} }
} }
static std::string excelResults() { return "some_results.xlsx"; } static std::string excelResults() { return "some_results.xlsx"; }
static std::string grid_input(const std::string& model)
{
return grid() + "grid_" + model + "_input.json";
}
static std::string grid_output(const std::string& model)
{
return grid() + "grid_" + model + "_output.json";
}
}; };
} }
#endif #endif

View File

@ -1,5 +1,7 @@
#include <iostream> #include <iostream>
#include <argparse/argparse.hpp> #include <argparse/argparse.hpp>
#include <map>
#include <nlohmann/json.hpp>
#include "DotEnv.h" #include "DotEnv.h"
#include "Models.h" #include "Models.h"
#include "modelRegister.h" #include "modelRegister.h"
@ -8,6 +10,7 @@
#include "Timer.h" #include "Timer.h"
#include "Colors.h" #include "Colors.h"
using json = nlohmann::json;
void manageArguments(argparse::ArgumentParser& program) void manageArguments(argparse::ArgumentParser& program)
{ {
@ -50,6 +53,72 @@ void manageArguments(argparse::ArgumentParser& program)
program.add_argument("-s", "--seeds").nargs(1, 10).help("Random seeds. Set to -1 to have pseudo random").scan<'i', int>().default_value(seed_values); program.add_argument("-s", "--seeds").nargs(1, 10).help("Random seeds. Set to -1 to have pseudo random").scan<'i', int>().default_value(seed_values);
} }
void list_dump(std::string& model)
{
auto data = platform::GridData(platform::Paths::grid_input(model));
std::cout << Colors::MAGENTA() << "Listing configuration input file (Grid)" << std::endl << std::endl;
int index = 0;
int max_hyper = 15;
int max_dataset = 7;
auto combinations = data.getGridFile();
for (auto const& item : combinations) {
if (item.first.size() > max_dataset) {
max_dataset = item.first.size();
}
if (item.second.dump().size() > max_hyper) {
max_hyper = item.second.dump().size();
}
}
std::cout << Colors::GREEN() << left << " # " << left << setw(max_dataset) << "Dataset" << " #Com. "
<< setw(max_hyper) << "Hyperparameters" << std::endl;
std::cout << "=== " << string(max_dataset, '=') << " ===== " << string(max_hyper, '=') << std::endl;
bool odd = true;
for (auto const& item : combinations) {
auto color = odd ? Colors::CYAN() : Colors::BLUE();
std::cout << color;
auto num_combinations = data.getNumCombinations(item.first);
std::cout << setw(3) << fixed << right << ++index << left << " " << setw(max_dataset) << item.first
<< " " << setw(5) << right << num_combinations << " " << setw(max_hyper) << item.second.dump() << std::endl;
odd = !odd;
}
std::cout << Colors::RESET() << std::endl;
}
void list_results(json& results, std::string& model)
{
std::cout << Colors::MAGENTA() << "Listing computed hyperparameters for model "
<< model << std::endl << std::endl;
int spaces = 0;
int hyperparameters_spaces = 0;
for (const auto& item : results.items()) {
auto key = item.key();
auto value = item.value();
if (key.size() > spaces) {
spaces = key.size();
}
if (value["hyperparameters"].dump().size() > hyperparameters_spaces) {
hyperparameters_spaces = value["hyperparameters"].dump().size();
}
}
std::cout << Colors::GREEN() << " # " << left << setw(spaces) << "Dataset" << " " << setw(19) << "Date" << " "
<< setw(8) << "Score" << " " << "Hyperparameters" << std::endl;
std::cout << "=== " << string(spaces, '=') << " " << string(19, '=') << " " << string(8, '=') << " "
<< string(hyperparameters_spaces, '=') << std::endl;
bool odd = true;
int index = 0;
for (const auto& item : results.items()) {
auto color = odd ? Colors::CYAN() : Colors::BLUE();
auto key = item.key();
auto value = item.value();
std::cout << color;
std::cout << std::setw(3) << std::right << index++ << " ";
std::cout << left << setw(spaces) << key << " " << value["date"].get<string>()
<< " " << setw(8) << setprecision(6) << fixed << right
<< value["score"].get<double>() << " " << value["hyperparameters"].dump() << std::endl;
odd = !odd;
}
std::cout << Colors::RESET() << std::endl;
}
int main(int argc, char** argv) int main(int argc, char** argv)
{ {
argparse::ArgumentParser program("b_grid"); argparse::ArgumentParser program("b_grid");
@ -87,74 +156,22 @@ int main(int argc, char** argv)
*/ */
auto env = platform::DotEnv(); auto env = platform::DotEnv();
platform::Paths::createPath(platform::Paths::grid()); platform::Paths::createPath(platform::Paths::grid());
config.path = platform::Paths::grid();
auto grid_search = platform::GridSearch(config); auto grid_search = platform::GridSearch(config);
platform::Timer timer; platform::Timer timer;
timer.start(); timer.start();
if (dump) { if (dump) {
auto combinations = grid_search.dump(); list_dump(config.model);
auto total = combinations.size();
int spaces = int(log(total) / log(10)) + 1;
std::cout << Colors::MAGENTA() << "There are " << total << " combinations" << std::endl << std::endl;
int index = 0;
int max = 0;
for (auto const& item : combinations) {
if (item.dump().size() > spaces) {
max = item.dump().size();
}
}
std::cout << Colors::GREEN() << left << setw(spaces) << "#" << left << " " << setw(spaces)
<< "Hyperparameters" << std::endl;
std::cout << string(spaces, '=') << " " << string(max, '=') << std::endl;
bool odd = true;
for (auto const& item : combinations) {
auto color = odd ? Colors::CYAN() : Colors::BLUE();
std::cout << color;
std::cout << setw(spaces) << fixed << right << ++index << left << " " << item.dump() << std::endl;
odd = !odd;
}
std::cout << Colors::RESET() << std::endl;
} else { } else {
if (compute) { if (compute) {
grid_search.go(); grid_search.go();
std::cout << "Process took " << timer.getDurationString() << std::endl; std::cout << "Process took " << timer.getDurationString() << std::endl;
} else { } else {
std::cout << Colors::MAGENTA() << "Listing computed hyperparameters for model " // List results
<< config.model << std::endl << std::endl;
auto results = grid_search.getResults(); auto results = grid_search.getResults();
if (results.empty()) { if (results.empty()) {
std::cout << "No results found" << std::endl; std::cout << "** No results found" << std::endl;
} else { } else {
int spaces = 0; list_results(results, config.model);
int hyperparameters_spaces = 0;
for (const auto& item : results.items()) {
auto key = item.key();
auto value = item.value();
if (key.size() > spaces) {
spaces = key.size();
}
if (value["hyperparameters"].dump().size() > hyperparameters_spaces) {
hyperparameters_spaces = value["hyperparameters"].dump().size();
}
}
std::cout << Colors::GREEN() << " # " << left << setw(spaces) << "Dataset" << " " << setw(19) << "Date" << " "
<< setw(8) << "Score" << " " << "Hyperparameters" << std::endl;
std::cout << "=== " << string(spaces, '=') << " " << string(19, '=') << " " << string(8, '=') << " "
<< string(hyperparameters_spaces, '=') << std::endl;
bool odd = true;
int index = 0;
for (const auto& item : results.items()) {
auto color = odd ? Colors::CYAN() : Colors::BLUE();
auto key = item.key();
auto value = item.value();
std::cout << color;
std::cout << std::setw(3) << std::right << index++ << " ";
std::cout << left << setw(spaces) << key << " " << value["date"].get<string>()
<< " " << setw(8) << setprecision(6) << fixed << right
<< value["score"].get<double>() << " " << value["hyperparameters"].dump() << std::endl;
odd = !odd;
}
std::cout << Colors::RESET() << std::endl;
} }
} }
} }