Refactor grid input hyperparameter file

This commit is contained in:
2023-11-29 18:24:34 +01:00
parent e3f6dc1e0b
commit dee9c674da
6 changed files with 120 additions and 87 deletions

View File

@@ -29,16 +29,10 @@ namespace platform {
}
GridSearch::GridSearch(struct ConfigGrid& config) : config(config)
{
this->config.output_file = config.path + "grid_" + config.model + "_output.json";
this->config.input_file = config.path + "grid_" + config.model + "_input.json";
}
std::vector<json> GridSearch::dump()
{
return GridData(config.input_file).getGrid();
}
json GridSearch::getResults()
{
std::ifstream file(config.output_file);
std::ifstream file(Paths::grid_output(config.model));
if (file.is_open()) {
return json::parse(file);
}
@@ -131,7 +125,7 @@ namespace platform {
if (!config.quiet)
std::cout << "* Loading previous results" << std::endl;
try {
std::ifstream file(config.output_file);
std::ifstream file(Paths::grid_output(config.model));
if (file.is_open()) {
results = json::parse(file);
}
@@ -156,19 +150,18 @@ namespace platform {
}
// Create model
std::cout << "***************** Starting Gridsearch *****************" << std::endl;
std::cout << "input file=" << config.input_file << std::endl;
auto grid = GridData(config.input_file);
auto totalComb = grid.getNumCombinations();
std::cout << "* Doing " << totalComb << " combinations for each dataset/seed/fold" << std::endl;
std::cout << "input file=" << Paths::grid_input(config.model) << std::endl;
auto grid = GridData(Paths::grid_input(config.model));
// Generate hyperparameters grid & run gridsearch
// Check each combination of hyperparameters for each dataset and each seed
for (const auto& dataset : datasets_names) {
auto totalComb = grid.getNumCombinations(dataset);
if (!config.quiet)
std::cout << "- " << setw(20) << left << dataset << " " << right << flush;
int num = 0;
double bestScore = 0.0;
json bestHyperparameters;
auto combinations = grid.getGrid();
auto combinations = grid.getGrid(dataset);
for (const auto& hyperparam_line : combinations) {
if (!config.quiet)
showProgressComb(++num, totalComb, Colors::CYAN());
@@ -186,7 +179,7 @@ namespace platform {
results[dataset]["score"] = bestScore;
results[dataset]["hyperparameters"] = bestHyperparameters;
results[dataset]["date"] = get_date() + " " + get_time();
results[dataset]["grid"] = grid.getInputGrid();
results[dataset]["grid"] = grid.getInputGrid(dataset);
// Save partial results
save(results);
}
@@ -196,7 +189,7 @@ namespace platform {
}
void GridSearch::save(json& results) const
{
std::ofstream file(config.output_file);
std::ofstream file(Paths::grid_output(config.model));
file << results.dump(4);
}
} /* namespace platform */