Refactor gridsearch method
This commit is contained in:
parent
dee9c674da
commit
c460ef46ed
@ -109,32 +109,17 @@ namespace platform {
|
||||
}
|
||||
return numItems == 0 ? 0.0 : totalScore / numItems;
|
||||
}
|
||||
void GridSearch::go()
|
||||
vector<std::string> GridSearch::processDatasets(Datasets& datasets)
|
||||
{
|
||||
// Load datasets
|
||||
auto datasets = Datasets(config.discretize, Paths::datasets());
|
||||
// Load previous results
|
||||
json results;
|
||||
|
||||
auto datasets_names = datasets.getNames();
|
||||
if (config.continue_from != "No") {
|
||||
// Continue previous execution:
|
||||
// Load previous results & remove datasets already processed
|
||||
// remove datasets already processed
|
||||
if (std::find(datasets_names.begin(), datasets_names.end(), config.continue_from) == datasets_names.end()) {
|
||||
throw std::invalid_argument("Dataset " + config.continue_from + " not found");
|
||||
}
|
||||
if (!config.quiet)
|
||||
std::cout << "* Loading previous results" << std::endl;
|
||||
try {
|
||||
std::ifstream file(Paths::grid_output(config.model));
|
||||
if (file.is_open()) {
|
||||
results = json::parse(file);
|
||||
}
|
||||
}
|
||||
catch (const std::exception& e) {
|
||||
std::cerr << "* There were no previous results" << std::endl;
|
||||
std::cerr << "* Initizalizing new results" << std::endl;
|
||||
results = json();
|
||||
}
|
||||
// Remove datasets already processed
|
||||
vector< string >::iterator it = datasets_names.begin();
|
||||
while (it != datasets_names.end()) {
|
||||
@ -148,7 +133,30 @@ namespace platform {
|
||||
}
|
||||
}
|
||||
}
|
||||
// Create model
|
||||
return datasets_names;
|
||||
}
|
||||
|
||||
void GridSearch::go()
|
||||
{
|
||||
auto datasets = Datasets(config.discretize, Paths::datasets());
|
||||
auto datasets_names = processDatasets(datasets);
|
||||
// Load previous results
|
||||
json results;
|
||||
if (config.continue_from != "No") {
|
||||
if (!config.quiet)
|
||||
std::cout << "* Loading previous results" << std::endl;
|
||||
try {
|
||||
std::ifstream file(Paths::grid_output(config.model));
|
||||
if (file.is_open()) {
|
||||
results = json::parse(file);
|
||||
}
|
||||
}
|
||||
catch (const std::exception& e) {
|
||||
std::cerr << "* There were no previous results" << std::endl;
|
||||
std::cerr << "* Initizalizing new results" << std::endl;
|
||||
results = json();
|
||||
}
|
||||
}
|
||||
std::cout << "***************** Starting Gridsearch *****************" << std::endl;
|
||||
std::cout << "input file=" << Paths::grid_input(config.model) << std::endl;
|
||||
auto grid = GridData(Paths::grid_input(config.model));
|
||||
|
@ -28,6 +28,7 @@ namespace platform {
|
||||
json getResults();
|
||||
private:
|
||||
void save(json& results) const;
|
||||
vector<std::string> processDatasets(Datasets& datasets);
|
||||
double processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters);
|
||||
struct ConfigGrid config;
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user