Refactor grid input hyperparameter file
This commit is contained in:
@@ -29,16 +29,10 @@ namespace platform {
|
||||
}
|
||||
GridSearch::GridSearch(struct ConfigGrid& config) : config(config)
|
||||
{
|
||||
this->config.output_file = config.path + "grid_" + config.model + "_output.json";
|
||||
this->config.input_file = config.path + "grid_" + config.model + "_input.json";
|
||||
}
|
||||
std::vector<json> GridSearch::dump()
|
||||
{
|
||||
return GridData(config.input_file).getGrid();
|
||||
}
|
||||
json GridSearch::getResults()
|
||||
{
|
||||
std::ifstream file(config.output_file);
|
||||
std::ifstream file(Paths::grid_output(config.model));
|
||||
if (file.is_open()) {
|
||||
return json::parse(file);
|
||||
}
|
||||
@@ -131,7 +125,7 @@ namespace platform {
|
||||
if (!config.quiet)
|
||||
std::cout << "* Loading previous results" << std::endl;
|
||||
try {
|
||||
std::ifstream file(config.output_file);
|
||||
std::ifstream file(Paths::grid_output(config.model));
|
||||
if (file.is_open()) {
|
||||
results = json::parse(file);
|
||||
}
|
||||
@@ -156,19 +150,18 @@ namespace platform {
|
||||
}
|
||||
// Create model
|
||||
std::cout << "***************** Starting Gridsearch *****************" << std::endl;
|
||||
std::cout << "input file=" << config.input_file << std::endl;
|
||||
auto grid = GridData(config.input_file);
|
||||
auto totalComb = grid.getNumCombinations();
|
||||
std::cout << "* Doing " << totalComb << " combinations for each dataset/seed/fold" << std::endl;
|
||||
std::cout << "input file=" << Paths::grid_input(config.model) << std::endl;
|
||||
auto grid = GridData(Paths::grid_input(config.model));
|
||||
// Generate hyperparameters grid & run gridsearch
|
||||
// Check each combination of hyperparameters for each dataset and each seed
|
||||
for (const auto& dataset : datasets_names) {
|
||||
auto totalComb = grid.getNumCombinations(dataset);
|
||||
if (!config.quiet)
|
||||
std::cout << "- " << setw(20) << left << dataset << " " << right << flush;
|
||||
int num = 0;
|
||||
double bestScore = 0.0;
|
||||
json bestHyperparameters;
|
||||
auto combinations = grid.getGrid();
|
||||
auto combinations = grid.getGrid(dataset);
|
||||
for (const auto& hyperparam_line : combinations) {
|
||||
if (!config.quiet)
|
||||
showProgressComb(++num, totalComb, Colors::CYAN());
|
||||
@@ -186,7 +179,7 @@ namespace platform {
|
||||
results[dataset]["score"] = bestScore;
|
||||
results[dataset]["hyperparameters"] = bestHyperparameters;
|
||||
results[dataset]["date"] = get_date() + " " + get_time();
|
||||
results[dataset]["grid"] = grid.getInputGrid();
|
||||
results[dataset]["grid"] = grid.getInputGrid(dataset);
|
||||
// Save partial results
|
||||
save(results);
|
||||
}
|
||||
@@ -196,7 +189,7 @@ namespace platform {
|
||||
}
|
||||
void GridSearch::save(json& results) const
|
||||
{
|
||||
std::ofstream file(config.output_file);
|
||||
std::ofstream file(Paths::grid_output(config.model));
|
||||
file << results.dump(4);
|
||||
}
|
||||
} /* namespace platform */
|
Reference in New Issue
Block a user