Add grid input info to grid output

This commit is contained in:
Ricardo Montañana Gómez 2023-11-26 16:07:32 +01:00
parent 7c12dd25e5
commit 4fefe9a1d2
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
3 changed files with 10 additions and 1 deletions

View File

@ -52,4 +52,8 @@ namespace platform {
}
return result;
}
json& GridData::getInputGrid()
{
return grid;
}
} /* namespace platform */

View File

@ -13,6 +13,7 @@ namespace platform {
~GridData() = default;
std::vector<json> getGrid();
int getNumCombinations();
json& getInputGrid();
private:
json generateCombinations(json::iterator index, const json::iterator last, std::vector<json>& output, json currentCombination);
int computeNumCombinations(const json& line);

View File

@ -109,6 +109,8 @@ namespace platform {
json results;
auto datasets_names = datasets.getNames();
if (config.continue_from != "No") {
// Continue previous execution:
// Load previous results & remove datasets already processed
if (std::find(datasets_names.begin(), datasets_names.end(), config.continue_from) == datasets_names.end()) {
throw std::invalid_argument("Dataset " + config.continue_from + " not found");
}
@ -146,7 +148,8 @@ namespace platform {
int num = 0;
double bestScore = 0.0;
json bestHyperparameters;
for (const auto& hyperparam_line : grid.getGrid()) {
auto combinations = grid.getGrid();
for (const auto& hyperparam_line : combinations) {
if (!config.quiet)
showProgressComb(++num, totalComb, Colors::CYAN());
auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line);
@ -163,6 +166,7 @@ namespace platform {
results[dataset]["score"] = bestScore;
results[dataset]["hyperparameters"] = bestHyperparameters;
results[dataset]["date"] = get_date() + " " + get_time();
results[dataset]["grid"] = grid.getInputGrid();
// Save partial results
save(results);
}