From c460ef46ede70ef6361a22206f186e246eef47eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Thu, 30 Nov 2023 11:01:37 +0100 Subject: [PATCH] Refactor gridsearch method --- src/Platform/GridSearch.cc | 46 ++++++++++++++++++++++---------------- src/Platform/GridSearch.h | 1 + 2 files changed, 28 insertions(+), 19 deletions(-) diff --git a/src/Platform/GridSearch.cc b/src/Platform/GridSearch.cc index 809bd8f..e6b25af 100644 --- a/src/Platform/GridSearch.cc +++ b/src/Platform/GridSearch.cc @@ -109,32 +109,17 @@ namespace platform { } return numItems == 0 ? 0.0 : totalScore / numItems; } - void GridSearch::go() + vector GridSearch::processDatasets(Datasets& datasets) { // Load datasets - auto datasets = Datasets(config.discretize, Paths::datasets()); - // Load previous results - json results; + auto datasets_names = datasets.getNames(); if (config.continue_from != "No") { // Continue previous execution: - // Load previous results & remove datasets already processed + // remove datasets already processed if (std::find(datasets_names.begin(), datasets_names.end(), config.continue_from) == datasets_names.end()) { throw std::invalid_argument("Dataset " + config.continue_from + " not found"); } - if (!config.quiet) - std::cout << "* Loading previous results" << std::endl; - try { - std::ifstream file(Paths::grid_output(config.model)); - if (file.is_open()) { - results = json::parse(file); - } - } - catch (const std::exception& e) { - std::cerr << "* There were no previous results" << std::endl; - std::cerr << "* Initizalizing new results" << std::endl; - results = json(); - } // Remove datasets already processed vector< string >::iterator it = datasets_names.begin(); while (it != datasets_names.end()) { @@ -148,7 +133,30 @@ namespace platform { } } } - // Create model + return datasets_names; + } + + void GridSearch::go() + { + auto datasets = Datasets(config.discretize, Paths::datasets()); + auto datasets_names = processDatasets(datasets); + // Load previous results + json results; + if (config.continue_from != "No") { + if (!config.quiet) + std::cout << "* Loading previous results" << std::endl; + try { + std::ifstream file(Paths::grid_output(config.model)); + if (file.is_open()) { + results = json::parse(file); + } + } + catch (const std::exception& e) { + std::cerr << "* There were no previous results" << std::endl; + std::cerr << "* Initizalizing new results" << std::endl; + results = json(); + } + } std::cout << "***************** Starting Gridsearch *****************" << std::endl; std::cout << "input file=" << Paths::grid_input(config.model) << std::endl; auto grid = GridData(Paths::grid_input(config.model)); diff --git a/src/Platform/GridSearch.h b/src/Platform/GridSearch.h index d46ffb3..e8bfb85 100644 --- a/src/Platform/GridSearch.h +++ b/src/Platform/GridSearch.h @@ -28,6 +28,7 @@ namespace platform { json getResults(); private: void save(json& results) const; + vector processDatasets(Datasets& datasets); double processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters); struct ConfigGrid config; };