From 4fefe9a1d2d83670fbfa020d64bb86c9009ebb15 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Sun, 26 Nov 2023 16:07:32 +0100 Subject: [PATCH] Add grid input info to grid output --- src/Platform/GridData.cc | 4 ++++ src/Platform/GridData.h | 1 + src/Platform/GridSearch.cc | 6 +++++- 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/Platform/GridData.cc b/src/Platform/GridData.cc index 0150ff3..b2a8149 100644 --- a/src/Platform/GridData.cc +++ b/src/Platform/GridData.cc @@ -52,4 +52,8 @@ namespace platform { } return result; } + json& GridData::getInputGrid() + { + return grid; + } } /* namespace platform */ \ No newline at end of file diff --git a/src/Platform/GridData.h b/src/Platform/GridData.h index b68a54a..3a03f03 100644 --- a/src/Platform/GridData.h +++ b/src/Platform/GridData.h @@ -13,6 +13,7 @@ namespace platform { ~GridData() = default; std::vector getGrid(); int getNumCombinations(); + json& getInputGrid(); private: json generateCombinations(json::iterator index, const json::iterator last, std::vector& output, json currentCombination); int computeNumCombinations(const json& line); diff --git a/src/Platform/GridSearch.cc b/src/Platform/GridSearch.cc index 6816d5d..2546131 100644 --- a/src/Platform/GridSearch.cc +++ b/src/Platform/GridSearch.cc @@ -109,6 +109,8 @@ namespace platform { json results; auto datasets_names = datasets.getNames(); if (config.continue_from != "No") { + // Continue previous execution: + // Load previous results & remove datasets already processed if (std::find(datasets_names.begin(), datasets_names.end(), config.continue_from) == datasets_names.end()) { throw std::invalid_argument("Dataset " + config.continue_from + " not found"); } @@ -146,7 +148,8 @@ namespace platform { int num = 0; double bestScore = 0.0; json bestHyperparameters; - for (const auto& hyperparam_line : grid.getGrid()) { + auto combinations = grid.getGrid(); + for (const auto& hyperparam_line : combinations) { if (!config.quiet) showProgressComb(++num, totalComb, Colors::CYAN()); auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line); @@ -163,6 +166,7 @@ namespace platform { results[dataset]["score"] = bestScore; results[dataset]["hyperparameters"] = bestHyperparameters; results[dataset]["date"] = get_date() + " " + get_time(); + results[dataset]["grid"] = grid.getInputGrid(); // Save partial results save(results); }