Refactor input grid parameters to json file
This commit is contained in:
@@ -10,6 +10,7 @@ namespace platform {
|
||||
GridSearch::GridSearch(struct ConfigGrid& config) : config(config)
|
||||
{
|
||||
this->config.output_file = config.path + "grid_" + config.model + "_output.json";
|
||||
this->config.input_file = config.path + "grid_" + config.model + "_input.json";
|
||||
}
|
||||
void showProgressComb(const int num, const int total, const std::string& color)
|
||||
{
|
||||
@@ -83,7 +84,9 @@ namespace platform {
|
||||
auto datasets = Datasets(config.discretize, Paths::datasets());
|
||||
// Create model
|
||||
std::cout << "***************** Starting Gridsearch *****************" << std::endl;
|
||||
auto totalComb = grid.getNumCombinations(config.model);
|
||||
std::cout << "input file=" << config.input_file << std::endl;
|
||||
auto grid = GridData(config.input_file);
|
||||
auto totalComb = grid.getNumCombinations();
|
||||
std::cout << "* Doing " << totalComb << " combinations for each dataset/seed/fold" << std::endl;
|
||||
// Generate hyperparameters grid & run gridsearch
|
||||
// Check each combination of hyperparameters for each dataset and each seed
|
||||
@@ -92,7 +95,7 @@ namespace platform {
|
||||
int num = 0;
|
||||
double bestScore = 0.0;
|
||||
json bestHyperparameters;
|
||||
for (const auto& hyperparam_line : grid.getGrid(config.model)) {
|
||||
for (const auto& hyperparam_line : grid.getGrid()) {
|
||||
showProgressComb(++num, totalComb, Colors::CYAN());
|
||||
auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line);
|
||||
double score = processFile(dataset, datasets, hyperparameters);
|
||||
|
Reference in New Issue
Block a user