Refactor grid input hyperparameter file
This commit is contained in:
parent
e3f6dc1e0b
commit
dee9c674da
@ -4,12 +4,19 @@
|
|||||||
namespace platform {
|
namespace platform {
|
||||||
GridData::GridData(const std::string& fileName)
|
GridData::GridData(const std::string& fileName)
|
||||||
{
|
{
|
||||||
|
json grid_file;
|
||||||
std::ifstream resultData(fileName);
|
std::ifstream resultData(fileName);
|
||||||
if (resultData.is_open()) {
|
if (resultData.is_open()) {
|
||||||
grid = json::parse(resultData);
|
grid_file = json::parse(resultData);
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument("Unable to open input file. [" + fileName + "]");
|
throw std::invalid_argument("Unable to open input file. [" + fileName + "]");
|
||||||
}
|
}
|
||||||
|
for (const auto& item : grid_file.items()) {
|
||||||
|
auto key = item.key();
|
||||||
|
auto value = item.value();
|
||||||
|
grid[key] = value;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
int GridData::computeNumCombinations(const json& line)
|
int GridData::computeNumCombinations(const json& line)
|
||||||
{
|
{
|
||||||
@ -19,10 +26,11 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
return numCombinations;
|
return numCombinations;
|
||||||
}
|
}
|
||||||
int GridData::getNumCombinations()
|
int GridData::getNumCombinations(const std::string& dataset)
|
||||||
{
|
{
|
||||||
int numCombinations = 0;
|
int numCombinations = 0;
|
||||||
for (const auto& line : grid) {
|
auto selected = decide_dataset(dataset);
|
||||||
|
for (const auto& line : grid.at(selected)) {
|
||||||
numCombinations += computeNumCombinations(line);
|
numCombinations += computeNumCombinations(line);
|
||||||
}
|
}
|
||||||
return numCombinations;
|
return numCombinations;
|
||||||
@ -44,16 +52,24 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
return currentCombination;
|
return currentCombination;
|
||||||
}
|
}
|
||||||
std::vector<json> GridData::getGrid()
|
std::vector<json> GridData::getGrid(const std::string& dataset)
|
||||||
{
|
{
|
||||||
|
auto selected = decide_dataset(dataset);
|
||||||
auto result = std::vector<json>();
|
auto result = std::vector<json>();
|
||||||
for (json line : grid) {
|
for (json line : grid.at(selected)) {
|
||||||
generateCombinations(line.begin(), line.end(), result, json({}));
|
generateCombinations(line.begin(), line.end(), result, json({}));
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
json& GridData::getInputGrid()
|
json& GridData::getInputGrid(const std::string& dataset)
|
||||||
{
|
{
|
||||||
return grid;
|
auto selected = decide_dataset(dataset);
|
||||||
|
return grid.at(selected);
|
||||||
|
}
|
||||||
|
std::string GridData::decide_dataset(const std::string& dataset)
|
||||||
|
{
|
||||||
|
if (grid.find(dataset) != grid.end())
|
||||||
|
return dataset;
|
||||||
|
return ALL_DATASETS;
|
||||||
}
|
}
|
||||||
} /* namespace platform */
|
} /* namespace platform */
|
@ -7,17 +7,20 @@
|
|||||||
|
|
||||||
namespace platform {
|
namespace platform {
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::json;
|
||||||
|
const std::string ALL_DATASETS = "all";
|
||||||
class GridData {
|
class GridData {
|
||||||
public:
|
public:
|
||||||
explicit GridData(const std::string& fileName);
|
explicit GridData(const std::string& fileName);
|
||||||
~GridData() = default;
|
~GridData() = default;
|
||||||
std::vector<json> getGrid();
|
std::vector<json> getGrid(const std::string& dataset = ALL_DATASETS);
|
||||||
int getNumCombinations();
|
int getNumCombinations(const std::string& dataset = ALL_DATASETS);
|
||||||
json& getInputGrid();
|
json& getInputGrid(const std::string& dataset = ALL_DATASETS);
|
||||||
|
std::map<std::string, json>& getGridFile() { return grid; }
|
||||||
private:
|
private:
|
||||||
|
std::string decide_dataset(const std::string& dataset);
|
||||||
json generateCombinations(json::iterator index, const json::iterator last, std::vector<json>& output, json currentCombination);
|
json generateCombinations(json::iterator index, const json::iterator last, std::vector<json>& output, json currentCombination);
|
||||||
int computeNumCombinations(const json& line);
|
int computeNumCombinations(const json& line);
|
||||||
json grid;
|
std::map<std::string, json> grid;
|
||||||
};
|
};
|
||||||
} /* namespace platform */
|
} /* namespace platform */
|
||||||
#endif /* GRIDDATA_H */
|
#endif /* GRIDDATA_H */
|
@ -29,16 +29,10 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
GridSearch::GridSearch(struct ConfigGrid& config) : config(config)
|
GridSearch::GridSearch(struct ConfigGrid& config) : config(config)
|
||||||
{
|
{
|
||||||
this->config.output_file = config.path + "grid_" + config.model + "_output.json";
|
|
||||||
this->config.input_file = config.path + "grid_" + config.model + "_input.json";
|
|
||||||
}
|
|
||||||
std::vector<json> GridSearch::dump()
|
|
||||||
{
|
|
||||||
return GridData(config.input_file).getGrid();
|
|
||||||
}
|
}
|
||||||
json GridSearch::getResults()
|
json GridSearch::getResults()
|
||||||
{
|
{
|
||||||
std::ifstream file(config.output_file);
|
std::ifstream file(Paths::grid_output(config.model));
|
||||||
if (file.is_open()) {
|
if (file.is_open()) {
|
||||||
return json::parse(file);
|
return json::parse(file);
|
||||||
}
|
}
|
||||||
@ -131,7 +125,7 @@ namespace platform {
|
|||||||
if (!config.quiet)
|
if (!config.quiet)
|
||||||
std::cout << "* Loading previous results" << std::endl;
|
std::cout << "* Loading previous results" << std::endl;
|
||||||
try {
|
try {
|
||||||
std::ifstream file(config.output_file);
|
std::ifstream file(Paths::grid_output(config.model));
|
||||||
if (file.is_open()) {
|
if (file.is_open()) {
|
||||||
results = json::parse(file);
|
results = json::parse(file);
|
||||||
}
|
}
|
||||||
@ -156,19 +150,18 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
// Create model
|
// Create model
|
||||||
std::cout << "***************** Starting Gridsearch *****************" << std::endl;
|
std::cout << "***************** Starting Gridsearch *****************" << std::endl;
|
||||||
std::cout << "input file=" << config.input_file << std::endl;
|
std::cout << "input file=" << Paths::grid_input(config.model) << std::endl;
|
||||||
auto grid = GridData(config.input_file);
|
auto grid = GridData(Paths::grid_input(config.model));
|
||||||
auto totalComb = grid.getNumCombinations();
|
|
||||||
std::cout << "* Doing " << totalComb << " combinations for each dataset/seed/fold" << std::endl;
|
|
||||||
// Generate hyperparameters grid & run gridsearch
|
// Generate hyperparameters grid & run gridsearch
|
||||||
// Check each combination of hyperparameters for each dataset and each seed
|
// Check each combination of hyperparameters for each dataset and each seed
|
||||||
for (const auto& dataset : datasets_names) {
|
for (const auto& dataset : datasets_names) {
|
||||||
|
auto totalComb = grid.getNumCombinations(dataset);
|
||||||
if (!config.quiet)
|
if (!config.quiet)
|
||||||
std::cout << "- " << setw(20) << left << dataset << " " << right << flush;
|
std::cout << "- " << setw(20) << left << dataset << " " << right << flush;
|
||||||
int num = 0;
|
int num = 0;
|
||||||
double bestScore = 0.0;
|
double bestScore = 0.0;
|
||||||
json bestHyperparameters;
|
json bestHyperparameters;
|
||||||
auto combinations = grid.getGrid();
|
auto combinations = grid.getGrid(dataset);
|
||||||
for (const auto& hyperparam_line : combinations) {
|
for (const auto& hyperparam_line : combinations) {
|
||||||
if (!config.quiet)
|
if (!config.quiet)
|
||||||
showProgressComb(++num, totalComb, Colors::CYAN());
|
showProgressComb(++num, totalComb, Colors::CYAN());
|
||||||
@ -186,7 +179,7 @@ namespace platform {
|
|||||||
results[dataset]["score"] = bestScore;
|
results[dataset]["score"] = bestScore;
|
||||||
results[dataset]["hyperparameters"] = bestHyperparameters;
|
results[dataset]["hyperparameters"] = bestHyperparameters;
|
||||||
results[dataset]["date"] = get_date() + " " + get_time();
|
results[dataset]["date"] = get_date() + " " + get_time();
|
||||||
results[dataset]["grid"] = grid.getInputGrid();
|
results[dataset]["grid"] = grid.getInputGrid(dataset);
|
||||||
// Save partial results
|
// Save partial results
|
||||||
save(results);
|
save(results);
|
||||||
}
|
}
|
||||||
@ -196,7 +189,7 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
void GridSearch::save(json& results) const
|
void GridSearch::save(json& results) const
|
||||||
{
|
{
|
||||||
std::ofstream file(config.output_file);
|
std::ofstream file(Paths::grid_output(config.model));
|
||||||
file << results.dump(4);
|
file << results.dump(4);
|
||||||
}
|
}
|
||||||
} /* namespace platform */
|
} /* namespace platform */
|
@ -1,7 +1,7 @@
|
|||||||
#ifndef GRIDSEARCH_H
|
#ifndef GRIDSEARCH_H
|
||||||
#define GRIDSEARCH_H
|
#define GRIDSEARCH_H
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <map>
|
||||||
#include <nlohmann/json.hpp>
|
#include <nlohmann/json.hpp>
|
||||||
#include "Datasets.h"
|
#include "Datasets.h"
|
||||||
#include "HyperParameters.h"
|
#include "HyperParameters.h"
|
||||||
@ -12,9 +12,6 @@ namespace platform {
|
|||||||
struct ConfigGrid {
|
struct ConfigGrid {
|
||||||
std::string model;
|
std::string model;
|
||||||
std::string score;
|
std::string score;
|
||||||
std::string path;
|
|
||||||
std::string input_file;
|
|
||||||
std::string output_file;
|
|
||||||
std::string continue_from;
|
std::string continue_from;
|
||||||
bool quiet;
|
bool quiet;
|
||||||
bool only; // used with continue_from to only compute that dataset
|
bool only; // used with continue_from to only compute that dataset
|
||||||
@ -28,7 +25,6 @@ namespace platform {
|
|||||||
explicit GridSearch(struct ConfigGrid& config);
|
explicit GridSearch(struct ConfigGrid& config);
|
||||||
void go();
|
void go();
|
||||||
~GridSearch() = default;
|
~GridSearch() = default;
|
||||||
std::vector<json> dump();
|
|
||||||
json getResults();
|
json getResults();
|
||||||
private:
|
private:
|
||||||
void save(json& results) const;
|
void save(json& results) const;
|
||||||
|
@ -26,6 +26,14 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
static std::string excelResults() { return "some_results.xlsx"; }
|
static std::string excelResults() { return "some_results.xlsx"; }
|
||||||
|
static std::string grid_input(const std::string& model)
|
||||||
|
{
|
||||||
|
return grid() + "grid_" + model + "_input.json";
|
||||||
|
}
|
||||||
|
static std::string grid_output(const std::string& model)
|
||||||
|
{
|
||||||
|
return grid() + "grid_" + model + "_output.json";
|
||||||
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
@ -1,5 +1,7 @@
|
|||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <argparse/argparse.hpp>
|
#include <argparse/argparse.hpp>
|
||||||
|
#include <map>
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
#include "DotEnv.h"
|
#include "DotEnv.h"
|
||||||
#include "Models.h"
|
#include "Models.h"
|
||||||
#include "modelRegister.h"
|
#include "modelRegister.h"
|
||||||
@ -8,6 +10,7 @@
|
|||||||
#include "Timer.h"
|
#include "Timer.h"
|
||||||
#include "Colors.h"
|
#include "Colors.h"
|
||||||
|
|
||||||
|
using json = nlohmann::json;
|
||||||
|
|
||||||
void manageArguments(argparse::ArgumentParser& program)
|
void manageArguments(argparse::ArgumentParser& program)
|
||||||
{
|
{
|
||||||
@ -50,6 +53,72 @@ void manageArguments(argparse::ArgumentParser& program)
|
|||||||
program.add_argument("-s", "--seeds").nargs(1, 10).help("Random seeds. Set to -1 to have pseudo random").scan<'i', int>().default_value(seed_values);
|
program.add_argument("-s", "--seeds").nargs(1, 10).help("Random seeds. Set to -1 to have pseudo random").scan<'i', int>().default_value(seed_values);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void list_dump(std::string& model)
|
||||||
|
{
|
||||||
|
auto data = platform::GridData(platform::Paths::grid_input(model));
|
||||||
|
std::cout << Colors::MAGENTA() << "Listing configuration input file (Grid)" << std::endl << std::endl;
|
||||||
|
int index = 0;
|
||||||
|
int max_hyper = 15;
|
||||||
|
int max_dataset = 7;
|
||||||
|
auto combinations = data.getGridFile();
|
||||||
|
for (auto const& item : combinations) {
|
||||||
|
if (item.first.size() > max_dataset) {
|
||||||
|
max_dataset = item.first.size();
|
||||||
|
}
|
||||||
|
if (item.second.dump().size() > max_hyper) {
|
||||||
|
max_hyper = item.second.dump().size();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::cout << Colors::GREEN() << left << " # " << left << setw(max_dataset) << "Dataset" << " #Com. "
|
||||||
|
<< setw(max_hyper) << "Hyperparameters" << std::endl;
|
||||||
|
std::cout << "=== " << string(max_dataset, '=') << " ===== " << string(max_hyper, '=') << std::endl;
|
||||||
|
bool odd = true;
|
||||||
|
for (auto const& item : combinations) {
|
||||||
|
auto color = odd ? Colors::CYAN() : Colors::BLUE();
|
||||||
|
std::cout << color;
|
||||||
|
auto num_combinations = data.getNumCombinations(item.first);
|
||||||
|
std::cout << setw(3) << fixed << right << ++index << left << " " << setw(max_dataset) << item.first
|
||||||
|
<< " " << setw(5) << right << num_combinations << " " << setw(max_hyper) << item.second.dump() << std::endl;
|
||||||
|
odd = !odd;
|
||||||
|
}
|
||||||
|
std::cout << Colors::RESET() << std::endl;
|
||||||
|
}
|
||||||
|
void list_results(json& results, std::string& model)
|
||||||
|
{
|
||||||
|
std::cout << Colors::MAGENTA() << "Listing computed hyperparameters for model "
|
||||||
|
<< model << std::endl << std::endl;
|
||||||
|
int spaces = 0;
|
||||||
|
int hyperparameters_spaces = 0;
|
||||||
|
for (const auto& item : results.items()) {
|
||||||
|
auto key = item.key();
|
||||||
|
auto value = item.value();
|
||||||
|
if (key.size() > spaces) {
|
||||||
|
spaces = key.size();
|
||||||
|
}
|
||||||
|
if (value["hyperparameters"].dump().size() > hyperparameters_spaces) {
|
||||||
|
hyperparameters_spaces = value["hyperparameters"].dump().size();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::cout << Colors::GREEN() << " # " << left << setw(spaces) << "Dataset" << " " << setw(19) << "Date" << " "
|
||||||
|
<< setw(8) << "Score" << " " << "Hyperparameters" << std::endl;
|
||||||
|
std::cout << "=== " << string(spaces, '=') << " " << string(19, '=') << " " << string(8, '=') << " "
|
||||||
|
<< string(hyperparameters_spaces, '=') << std::endl;
|
||||||
|
bool odd = true;
|
||||||
|
int index = 0;
|
||||||
|
for (const auto& item : results.items()) {
|
||||||
|
auto color = odd ? Colors::CYAN() : Colors::BLUE();
|
||||||
|
auto key = item.key();
|
||||||
|
auto value = item.value();
|
||||||
|
std::cout << color;
|
||||||
|
std::cout << std::setw(3) << std::right << index++ << " ";
|
||||||
|
std::cout << left << setw(spaces) << key << " " << value["date"].get<string>()
|
||||||
|
<< " " << setw(8) << setprecision(6) << fixed << right
|
||||||
|
<< value["score"].get<double>() << " " << value["hyperparameters"].dump() << std::endl;
|
||||||
|
odd = !odd;
|
||||||
|
}
|
||||||
|
std::cout << Colors::RESET() << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
int main(int argc, char** argv)
|
int main(int argc, char** argv)
|
||||||
{
|
{
|
||||||
argparse::ArgumentParser program("b_grid");
|
argparse::ArgumentParser program("b_grid");
|
||||||
@ -87,74 +156,22 @@ int main(int argc, char** argv)
|
|||||||
*/
|
*/
|
||||||
auto env = platform::DotEnv();
|
auto env = platform::DotEnv();
|
||||||
platform::Paths::createPath(platform::Paths::grid());
|
platform::Paths::createPath(platform::Paths::grid());
|
||||||
config.path = platform::Paths::grid();
|
|
||||||
auto grid_search = platform::GridSearch(config);
|
auto grid_search = platform::GridSearch(config);
|
||||||
platform::Timer timer;
|
platform::Timer timer;
|
||||||
timer.start();
|
timer.start();
|
||||||
if (dump) {
|
if (dump) {
|
||||||
auto combinations = grid_search.dump();
|
list_dump(config.model);
|
||||||
auto total = combinations.size();
|
|
||||||
int spaces = int(log(total) / log(10)) + 1;
|
|
||||||
std::cout << Colors::MAGENTA() << "There are " << total << " combinations" << std::endl << std::endl;
|
|
||||||
int index = 0;
|
|
||||||
int max = 0;
|
|
||||||
for (auto const& item : combinations) {
|
|
||||||
if (item.dump().size() > spaces) {
|
|
||||||
max = item.dump().size();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
std::cout << Colors::GREEN() << left << setw(spaces) << "#" << left << " " << setw(spaces)
|
|
||||||
<< "Hyperparameters" << std::endl;
|
|
||||||
std::cout << string(spaces, '=') << " " << string(max, '=') << std::endl;
|
|
||||||
bool odd = true;
|
|
||||||
for (auto const& item : combinations) {
|
|
||||||
auto color = odd ? Colors::CYAN() : Colors::BLUE();
|
|
||||||
std::cout << color;
|
|
||||||
std::cout << setw(spaces) << fixed << right << ++index << left << " " << item.dump() << std::endl;
|
|
||||||
odd = !odd;
|
|
||||||
}
|
|
||||||
std::cout << Colors::RESET() << std::endl;
|
|
||||||
} else {
|
} else {
|
||||||
if (compute) {
|
if (compute) {
|
||||||
grid_search.go();
|
grid_search.go();
|
||||||
std::cout << "Process took " << timer.getDurationString() << std::endl;
|
std::cout << "Process took " << timer.getDurationString() << std::endl;
|
||||||
} else {
|
} else {
|
||||||
std::cout << Colors::MAGENTA() << "Listing computed hyperparameters for model "
|
// List results
|
||||||
<< config.model << std::endl << std::endl;
|
|
||||||
auto results = grid_search.getResults();
|
auto results = grid_search.getResults();
|
||||||
if (results.empty()) {
|
if (results.empty()) {
|
||||||
std::cout << "No results found" << std::endl;
|
std::cout << "** No results found" << std::endl;
|
||||||
} else {
|
} else {
|
||||||
int spaces = 0;
|
list_results(results, config.model);
|
||||||
int hyperparameters_spaces = 0;
|
|
||||||
for (const auto& item : results.items()) {
|
|
||||||
auto key = item.key();
|
|
||||||
auto value = item.value();
|
|
||||||
if (key.size() > spaces) {
|
|
||||||
spaces = key.size();
|
|
||||||
}
|
|
||||||
if (value["hyperparameters"].dump().size() > hyperparameters_spaces) {
|
|
||||||
hyperparameters_spaces = value["hyperparameters"].dump().size();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
std::cout << Colors::GREEN() << " # " << left << setw(spaces) << "Dataset" << " " << setw(19) << "Date" << " "
|
|
||||||
<< setw(8) << "Score" << " " << "Hyperparameters" << std::endl;
|
|
||||||
std::cout << "=== " << string(spaces, '=') << " " << string(19, '=') << " " << string(8, '=') << " "
|
|
||||||
<< string(hyperparameters_spaces, '=') << std::endl;
|
|
||||||
bool odd = true;
|
|
||||||
int index = 0;
|
|
||||||
for (const auto& item : results.items()) {
|
|
||||||
auto color = odd ? Colors::CYAN() : Colors::BLUE();
|
|
||||||
auto key = item.key();
|
|
||||||
auto value = item.value();
|
|
||||||
std::cout << color;
|
|
||||||
std::cout << std::setw(3) << std::right << index++ << " ";
|
|
||||||
std::cout << left << setw(spaces) << key << " " << value["date"].get<string>()
|
|
||||||
<< " " << setw(8) << setprecision(6) << fixed << right
|
|
||||||
<< value["score"].get<double>() << " " << value["hyperparameters"].dump() << std::endl;
|
|
||||||
odd = !odd;
|
|
||||||
}
|
|
||||||
std::cout << Colors::RESET() << std::endl;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user