Refactor gridsearch method

This commit is contained in:
Ricardo Montañana Gómez 2023-11-30 11:01:37 +01:00
parent dee9c674da
commit c460ef46ed
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
2 changed files with 28 additions and 19 deletions

View File

@ -109,32 +109,17 @@ namespace platform {
}
return numItems == 0 ? 0.0 : totalScore / numItems;
}
void GridSearch::go()
vector<std::string> GridSearch::processDatasets(Datasets& datasets)
{
// Load datasets
auto datasets = Datasets(config.discretize, Paths::datasets());
// Load previous results
json results;
auto datasets_names = datasets.getNames();
if (config.continue_from != "No") {
// Continue previous execution:
// Load previous results & remove datasets already processed
// remove datasets already processed
if (std::find(datasets_names.begin(), datasets_names.end(), config.continue_from) == datasets_names.end()) {
throw std::invalid_argument("Dataset " + config.continue_from + " not found");
}
if (!config.quiet)
std::cout << "* Loading previous results" << std::endl;
try {
std::ifstream file(Paths::grid_output(config.model));
if (file.is_open()) {
results = json::parse(file);
}
}
catch (const std::exception& e) {
std::cerr << "* There were no previous results" << std::endl;
std::cerr << "* Initizalizing new results" << std::endl;
results = json();
}
// Remove datasets already processed
vector< string >::iterator it = datasets_names.begin();
while (it != datasets_names.end()) {
@ -148,7 +133,30 @@ namespace platform {
}
}
}
// Create model
return datasets_names;
}
void GridSearch::go()
{
auto datasets = Datasets(config.discretize, Paths::datasets());
auto datasets_names = processDatasets(datasets);
// Load previous results
json results;
if (config.continue_from != "No") {
if (!config.quiet)
std::cout << "* Loading previous results" << std::endl;
try {
std::ifstream file(Paths::grid_output(config.model));
if (file.is_open()) {
results = json::parse(file);
}
}
catch (const std::exception& e) {
std::cerr << "* There were no previous results" << std::endl;
std::cerr << "* Initizalizing new results" << std::endl;
results = json();
}
}
std::cout << "***************** Starting Gridsearch *****************" << std::endl;
std::cout << "input file=" << Paths::grid_input(config.model) << std::endl;
auto grid = GridData(Paths::grid_input(config.model));

View File

@ -28,6 +28,7 @@ namespace platform {
json getResults();
private:
void save(json& results) const;
vector<std::string> processDatasets(Datasets& datasets);
double processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters);
struct ConfigGrid config;
};