Refactor input grid parameters to json file

This commit is contained in:
2023-11-24 09:57:29 +01:00
parent 8b7b59d42b
commit 2121ba9b98
4 changed files with 21 additions and 35 deletions

View File

@@ -10,6 +10,7 @@ namespace platform {
GridSearch::GridSearch(struct ConfigGrid& config) : config(config)
{
this->config.output_file = config.path + "grid_" + config.model + "_output.json";
this->config.input_file = config.path + "grid_" + config.model + "_input.json";
}
void showProgressComb(const int num, const int total, const std::string& color)
{
@@ -83,7 +84,9 @@ namespace platform {
auto datasets = Datasets(config.discretize, Paths::datasets());
// Create model
std::cout << "***************** Starting Gridsearch *****************" << std::endl;
auto totalComb = grid.getNumCombinations(config.model);
std::cout << "input file=" << config.input_file << std::endl;
auto grid = GridData(config.input_file);
auto totalComb = grid.getNumCombinations();
std::cout << "* Doing " << totalComb << " combinations for each dataset/seed/fold" << std::endl;
// Generate hyperparameters grid & run gridsearch
// Check each combination of hyperparameters for each dataset and each seed
@@ -92,7 +95,7 @@ namespace platform {
int num = 0;
double bestScore = 0.0;
json bestHyperparameters;
for (const auto& hyperparam_line : grid.getGrid(config.model)) {
for (const auto& hyperparam_line : grid.getGrid()) {
showProgressComb(++num, totalComb, Colors::CYAN());
auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line);
double score = processFile(dataset, datasets, hyperparameters);