Refactor gridsearch method

This commit is contained in:
Ricardo Montañana Gómez 2023-11-30 11:01:37 +01:00
parent dee9c674da
commit c460ef46ed
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
2 changed files with 28 additions and 19 deletions

View File

@ -109,32 +109,17 @@ namespace platform {
} }
return numItems == 0 ? 0.0 : totalScore / numItems; return numItems == 0 ? 0.0 : totalScore / numItems;
} }
void GridSearch::go() vector<std::string> GridSearch::processDatasets(Datasets& datasets)
{ {
// Load datasets // Load datasets
auto datasets = Datasets(config.discretize, Paths::datasets());
// Load previous results
json results;
auto datasets_names = datasets.getNames(); auto datasets_names = datasets.getNames();
if (config.continue_from != "No") { if (config.continue_from != "No") {
// Continue previous execution: // Continue previous execution:
// Load previous results & remove datasets already processed // remove datasets already processed
if (std::find(datasets_names.begin(), datasets_names.end(), config.continue_from) == datasets_names.end()) { if (std::find(datasets_names.begin(), datasets_names.end(), config.continue_from) == datasets_names.end()) {
throw std::invalid_argument("Dataset " + config.continue_from + " not found"); throw std::invalid_argument("Dataset " + config.continue_from + " not found");
} }
if (!config.quiet)
std::cout << "* Loading previous results" << std::endl;
try {
std::ifstream file(Paths::grid_output(config.model));
if (file.is_open()) {
results = json::parse(file);
}
}
catch (const std::exception& e) {
std::cerr << "* There were no previous results" << std::endl;
std::cerr << "* Initizalizing new results" << std::endl;
results = json();
}
// Remove datasets already processed // Remove datasets already processed
vector< string >::iterator it = datasets_names.begin(); vector< string >::iterator it = datasets_names.begin();
while (it != datasets_names.end()) { while (it != datasets_names.end()) {
@ -148,7 +133,30 @@ namespace platform {
} }
} }
} }
// Create model return datasets_names;
}
void GridSearch::go()
{
auto datasets = Datasets(config.discretize, Paths::datasets());
auto datasets_names = processDatasets(datasets);
// Load previous results
json results;
if (config.continue_from != "No") {
if (!config.quiet)
std::cout << "* Loading previous results" << std::endl;
try {
std::ifstream file(Paths::grid_output(config.model));
if (file.is_open()) {
results = json::parse(file);
}
}
catch (const std::exception& e) {
std::cerr << "* There were no previous results" << std::endl;
std::cerr << "* Initizalizing new results" << std::endl;
results = json();
}
}
std::cout << "***************** Starting Gridsearch *****************" << std::endl; std::cout << "***************** Starting Gridsearch *****************" << std::endl;
std::cout << "input file=" << Paths::grid_input(config.model) << std::endl; std::cout << "input file=" << Paths::grid_input(config.model) << std::endl;
auto grid = GridData(Paths::grid_input(config.model)); auto grid = GridData(Paths::grid_input(config.model));

View File

@ -28,6 +28,7 @@ namespace platform {
json getResults(); json getResults();
private: private:
void save(json& results) const; void save(json& results) const;
vector<std::string> processDatasets(Datasets& datasets);
double processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters); double processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters);
struct ConfigGrid config; struct ConfigGrid config;
}; };