Add grid input info to grid output
This commit is contained in:
parent
7c12dd25e5
commit
4fefe9a1d2
@ -52,4 +52,8 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
json& GridData::getInputGrid()
|
||||||
|
{
|
||||||
|
return grid;
|
||||||
|
}
|
||||||
} /* namespace platform */
|
} /* namespace platform */
|
@ -13,6 +13,7 @@ namespace platform {
|
|||||||
~GridData() = default;
|
~GridData() = default;
|
||||||
std::vector<json> getGrid();
|
std::vector<json> getGrid();
|
||||||
int getNumCombinations();
|
int getNumCombinations();
|
||||||
|
json& getInputGrid();
|
||||||
private:
|
private:
|
||||||
json generateCombinations(json::iterator index, const json::iterator last, std::vector<json>& output, json currentCombination);
|
json generateCombinations(json::iterator index, const json::iterator last, std::vector<json>& output, json currentCombination);
|
||||||
int computeNumCombinations(const json& line);
|
int computeNumCombinations(const json& line);
|
||||||
|
@ -109,6 +109,8 @@ namespace platform {
|
|||||||
json results;
|
json results;
|
||||||
auto datasets_names = datasets.getNames();
|
auto datasets_names = datasets.getNames();
|
||||||
if (config.continue_from != "No") {
|
if (config.continue_from != "No") {
|
||||||
|
// Continue previous execution:
|
||||||
|
// Load previous results & remove datasets already processed
|
||||||
if (std::find(datasets_names.begin(), datasets_names.end(), config.continue_from) == datasets_names.end()) {
|
if (std::find(datasets_names.begin(), datasets_names.end(), config.continue_from) == datasets_names.end()) {
|
||||||
throw std::invalid_argument("Dataset " + config.continue_from + " not found");
|
throw std::invalid_argument("Dataset " + config.continue_from + " not found");
|
||||||
}
|
}
|
||||||
@ -146,7 +148,8 @@ namespace platform {
|
|||||||
int num = 0;
|
int num = 0;
|
||||||
double bestScore = 0.0;
|
double bestScore = 0.0;
|
||||||
json bestHyperparameters;
|
json bestHyperparameters;
|
||||||
for (const auto& hyperparam_line : grid.getGrid()) {
|
auto combinations = grid.getGrid();
|
||||||
|
for (const auto& hyperparam_line : combinations) {
|
||||||
if (!config.quiet)
|
if (!config.quiet)
|
||||||
showProgressComb(++num, totalComb, Colors::CYAN());
|
showProgressComb(++num, totalComb, Colors::CYAN());
|
||||||
auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line);
|
auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line);
|
||||||
@ -163,6 +166,7 @@ namespace platform {
|
|||||||
results[dataset]["score"] = bestScore;
|
results[dataset]["score"] = bestScore;
|
||||||
results[dataset]["hyperparameters"] = bestHyperparameters;
|
results[dataset]["hyperparameters"] = bestHyperparameters;
|
||||||
results[dataset]["date"] = get_date() + " " + get_time();
|
results[dataset]["date"] = get_date() + " " + get_time();
|
||||||
|
results[dataset]["grid"] = grid.getInputGrid();
|
||||||
// Save partial results
|
// Save partial results
|
||||||
save(results);
|
save(results);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user