Refactor grid input hyperparameter file

This commit is contained in:
Ricardo Montañana Gómez 2023-11-29 18:24:34 +01:00
parent e3f6dc1e0b
commit dee9c674da
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
6 changed files with 120 additions and 87 deletions

View File

@ -4,12 +4,19 @@
namespace platform {
GridData::GridData(const std::string& fileName)
{
json grid_file;
std::ifstream resultData(fileName);
if (resultData.is_open()) {
grid = json::parse(resultData);
grid_file = json::parse(resultData);
} else {
throw std::invalid_argument("Unable to open input file. [" + fileName + "]");
}
for (const auto& item : grid_file.items()) {
auto key = item.key();
auto value = item.value();
grid[key] = value;
}
}
int GridData::computeNumCombinations(const json& line)
{
@ -19,10 +26,11 @@ namespace platform {
}
return numCombinations;
}
int GridData::getNumCombinations()
int GridData::getNumCombinations(const std::string& dataset)
{
int numCombinations = 0;
for (const auto& line : grid) {
auto selected = decide_dataset(dataset);
for (const auto& line : grid.at(selected)) {
numCombinations += computeNumCombinations(line);
}
return numCombinations;
@ -44,16 +52,24 @@ namespace platform {
}
return currentCombination;
}
std::vector<json> GridData::getGrid()
std::vector<json> GridData::getGrid(const std::string& dataset)
{
auto selected = decide_dataset(dataset);
auto result = std::vector<json>();
for (json line : grid) {
for (json line : grid.at(selected)) {
generateCombinations(line.begin(), line.end(), result, json({}));
}
return result;
}
json& GridData::getInputGrid()
json& GridData::getInputGrid(const std::string& dataset)
{
return grid;
auto selected = decide_dataset(dataset);
return grid.at(selected);
}
std::string GridData::decide_dataset(const std::string& dataset)
{
if (grid.find(dataset) != grid.end())
return dataset;
return ALL_DATASETS;
}
} /* namespace platform */

View File

@ -7,17 +7,20 @@
namespace platform {
using json = nlohmann::json;
const std::string ALL_DATASETS = "all";
class GridData {
public:
explicit GridData(const std::string& fileName);
~GridData() = default;
std::vector<json> getGrid();
int getNumCombinations();
json& getInputGrid();
std::vector<json> getGrid(const std::string& dataset = ALL_DATASETS);
int getNumCombinations(const std::string& dataset = ALL_DATASETS);
json& getInputGrid(const std::string& dataset = ALL_DATASETS);
std::map<std::string, json>& getGridFile() { return grid; }
private:
std::string decide_dataset(const std::string& dataset);
json generateCombinations(json::iterator index, const json::iterator last, std::vector<json>& output, json currentCombination);
int computeNumCombinations(const json& line);
json grid;
std::map<std::string, json> grid;
};
} /* namespace platform */
#endif /* GRIDDATA_H */

View File

@ -29,16 +29,10 @@ namespace platform {
}
GridSearch::GridSearch(struct ConfigGrid& config) : config(config)
{
this->config.output_file = config.path + "grid_" + config.model + "_output.json";
this->config.input_file = config.path + "grid_" + config.model + "_input.json";
}
std::vector<json> GridSearch::dump()
{
return GridData(config.input_file).getGrid();
}
json GridSearch::getResults()
{
std::ifstream file(config.output_file);
std::ifstream file(Paths::grid_output(config.model));
if (file.is_open()) {
return json::parse(file);
}
@ -131,7 +125,7 @@ namespace platform {
if (!config.quiet)
std::cout << "* Loading previous results" << std::endl;
try {
std::ifstream file(config.output_file);
std::ifstream file(Paths::grid_output(config.model));
if (file.is_open()) {
results = json::parse(file);
}
@ -156,19 +150,18 @@ namespace platform {
}
// Create model
std::cout << "***************** Starting Gridsearch *****************" << std::endl;
std::cout << "input file=" << config.input_file << std::endl;
auto grid = GridData(config.input_file);
auto totalComb = grid.getNumCombinations();
std::cout << "* Doing " << totalComb << " combinations for each dataset/seed/fold" << std::endl;
std::cout << "input file=" << Paths::grid_input(config.model) << std::endl;
auto grid = GridData(Paths::grid_input(config.model));
// Generate hyperparameters grid & run gridsearch
// Check each combination of hyperparameters for each dataset and each seed
for (const auto& dataset : datasets_names) {
auto totalComb = grid.getNumCombinations(dataset);
if (!config.quiet)
std::cout << "- " << setw(20) << left << dataset << " " << right << flush;
int num = 0;
double bestScore = 0.0;
json bestHyperparameters;
auto combinations = grid.getGrid();
auto combinations = grid.getGrid(dataset);
for (const auto& hyperparam_line : combinations) {
if (!config.quiet)
showProgressComb(++num, totalComb, Colors::CYAN());
@ -186,7 +179,7 @@ namespace platform {
results[dataset]["score"] = bestScore;
results[dataset]["hyperparameters"] = bestHyperparameters;
results[dataset]["date"] = get_date() + " " + get_time();
results[dataset]["grid"] = grid.getInputGrid();
results[dataset]["grid"] = grid.getInputGrid(dataset);
// Save partial results
save(results);
}
@ -196,7 +189,7 @@ namespace platform {
}
void GridSearch::save(json& results) const
{
std::ofstream file(config.output_file);
std::ofstream file(Paths::grid_output(config.model));
file << results.dump(4);
}
} /* namespace platform */

View File

@ -1,7 +1,7 @@
#ifndef GRIDSEARCH_H
#define GRIDSEARCH_H
#include <string>
#include <vector>
#include <map>
#include <nlohmann/json.hpp>
#include "Datasets.h"
#include "HyperParameters.h"
@ -12,9 +12,6 @@ namespace platform {
struct ConfigGrid {
std::string model;
std::string score;
std::string path;
std::string input_file;
std::string output_file;
std::string continue_from;
bool quiet;
bool only; // used with continue_from to only compute that dataset
@ -28,7 +25,6 @@ namespace platform {
explicit GridSearch(struct ConfigGrid& config);
void go();
~GridSearch() = default;
std::vector<json> dump();
json getResults();
private:
void save(json& results) const;

View File

@ -26,6 +26,14 @@ namespace platform {
}
}
static std::string excelResults() { return "some_results.xlsx"; }
static std::string grid_input(const std::string& model)
{
return grid() + "grid_" + model + "_input.json";
}
static std::string grid_output(const std::string& model)
{
return grid() + "grid_" + model + "_output.json";
}
};
}
#endif

View File

@ -1,5 +1,7 @@
#include <iostream>
#include <argparse/argparse.hpp>
#include <map>
#include <nlohmann/json.hpp>
#include "DotEnv.h"
#include "Models.h"
#include "modelRegister.h"
@ -8,6 +10,7 @@
#include "Timer.h"
#include "Colors.h"
using json = nlohmann::json;
void manageArguments(argparse::ArgumentParser& program)
{
@ -50,6 +53,72 @@ void manageArguments(argparse::ArgumentParser& program)
program.add_argument("-s", "--seeds").nargs(1, 10).help("Random seeds. Set to -1 to have pseudo random").scan<'i', int>().default_value(seed_values);
}
void list_dump(std::string& model)
{
auto data = platform::GridData(platform::Paths::grid_input(model));
std::cout << Colors::MAGENTA() << "Listing configuration input file (Grid)" << std::endl << std::endl;
int index = 0;
int max_hyper = 15;
int max_dataset = 7;
auto combinations = data.getGridFile();
for (auto const& item : combinations) {
if (item.first.size() > max_dataset) {
max_dataset = item.first.size();
}
if (item.second.dump().size() > max_hyper) {
max_hyper = item.second.dump().size();
}
}
std::cout << Colors::GREEN() << left << " # " << left << setw(max_dataset) << "Dataset" << " #Com. "
<< setw(max_hyper) << "Hyperparameters" << std::endl;
std::cout << "=== " << string(max_dataset, '=') << " ===== " << string(max_hyper, '=') << std::endl;
bool odd = true;
for (auto const& item : combinations) {
auto color = odd ? Colors::CYAN() : Colors::BLUE();
std::cout << color;
auto num_combinations = data.getNumCombinations(item.first);
std::cout << setw(3) << fixed << right << ++index << left << " " << setw(max_dataset) << item.first
<< " " << setw(5) << right << num_combinations << " " << setw(max_hyper) << item.second.dump() << std::endl;
odd = !odd;
}
std::cout << Colors::RESET() << std::endl;
}
void list_results(json& results, std::string& model)
{
std::cout << Colors::MAGENTA() << "Listing computed hyperparameters for model "
<< model << std::endl << std::endl;
int spaces = 0;
int hyperparameters_spaces = 0;
for (const auto& item : results.items()) {
auto key = item.key();
auto value = item.value();
if (key.size() > spaces) {
spaces = key.size();
}
if (value["hyperparameters"].dump().size() > hyperparameters_spaces) {
hyperparameters_spaces = value["hyperparameters"].dump().size();
}
}
std::cout << Colors::GREEN() << " # " << left << setw(spaces) << "Dataset" << " " << setw(19) << "Date" << " "
<< setw(8) << "Score" << " " << "Hyperparameters" << std::endl;
std::cout << "=== " << string(spaces, '=') << " " << string(19, '=') << " " << string(8, '=') << " "
<< string(hyperparameters_spaces, '=') << std::endl;
bool odd = true;
int index = 0;
for (const auto& item : results.items()) {
auto color = odd ? Colors::CYAN() : Colors::BLUE();
auto key = item.key();
auto value = item.value();
std::cout << color;
std::cout << std::setw(3) << std::right << index++ << " ";
std::cout << left << setw(spaces) << key << " " << value["date"].get<string>()
<< " " << setw(8) << setprecision(6) << fixed << right
<< value["score"].get<double>() << " " << value["hyperparameters"].dump() << std::endl;
odd = !odd;
}
std::cout << Colors::RESET() << std::endl;
}
int main(int argc, char** argv)
{
argparse::ArgumentParser program("b_grid");
@ -87,74 +156,22 @@ int main(int argc, char** argv)
*/
auto env = platform::DotEnv();
platform::Paths::createPath(platform::Paths::grid());
config.path = platform::Paths::grid();
auto grid_search = platform::GridSearch(config);
platform::Timer timer;
timer.start();
if (dump) {
auto combinations = grid_search.dump();
auto total = combinations.size();
int spaces = int(log(total) / log(10)) + 1;
std::cout << Colors::MAGENTA() << "There are " << total << " combinations" << std::endl << std::endl;
int index = 0;
int max = 0;
for (auto const& item : combinations) {
if (item.dump().size() > spaces) {
max = item.dump().size();
}
}
std::cout << Colors::GREEN() << left << setw(spaces) << "#" << left << " " << setw(spaces)
<< "Hyperparameters" << std::endl;
std::cout << string(spaces, '=') << " " << string(max, '=') << std::endl;
bool odd = true;
for (auto const& item : combinations) {
auto color = odd ? Colors::CYAN() : Colors::BLUE();
std::cout << color;
std::cout << setw(spaces) << fixed << right << ++index << left << " " << item.dump() << std::endl;
odd = !odd;
}
std::cout << Colors::RESET() << std::endl;
list_dump(config.model);
} else {
if (compute) {
grid_search.go();
std::cout << "Process took " << timer.getDurationString() << std::endl;
} else {
std::cout << Colors::MAGENTA() << "Listing computed hyperparameters for model "
<< config.model << std::endl << std::endl;
// List results
auto results = grid_search.getResults();
if (results.empty()) {
std::cout << "No results found" << std::endl;
std::cout << "** No results found" << std::endl;
} else {
int spaces = 0;
int hyperparameters_spaces = 0;
for (const auto& item : results.items()) {
auto key = item.key();
auto value = item.value();
if (key.size() > spaces) {
spaces = key.size();
}
if (value["hyperparameters"].dump().size() > hyperparameters_spaces) {
hyperparameters_spaces = value["hyperparameters"].dump().size();
}
}
std::cout << Colors::GREEN() << " # " << left << setw(spaces) << "Dataset" << " " << setw(19) << "Date" << " "
<< setw(8) << "Score" << " " << "Hyperparameters" << std::endl;
std::cout << "=== " << string(spaces, '=') << " " << string(19, '=') << " " << string(8, '=') << " "
<< string(hyperparameters_spaces, '=') << std::endl;
bool odd = true;
int index = 0;
for (const auto& item : results.items()) {
auto color = odd ? Colors::CYAN() : Colors::BLUE();
auto key = item.key();
auto value = item.value();
std::cout << color;
std::cout << std::setw(3) << std::right << index++ << " ";
std::cout << left << setw(spaces) << key << " " << value["date"].get<string>()
<< " " << setw(8) << setprecision(6) << fixed << right
<< value["score"].get<double>() << " " << value["hyperparameters"].dump() << std::endl;
odd = !odd;
}
std::cout << Colors::RESET() << std::endl;
list_results(results, config.model);
}
}
}