From 2121ba9b986749381057a7cfb76e52f7db0d6b1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Fri, 24 Nov 2023 09:57:29 +0100 Subject: [PATCH] Refactor input grid parameters to json file --- src/Platform/GridData.cc | 40 ++++++++++++-------------------------- src/Platform/GridData.h | 8 ++++---- src/Platform/GridSearch.cc | 7 +++++-- src/Platform/GridSearch.h | 1 - 4 files changed, 21 insertions(+), 35 deletions(-) diff --git a/src/Platform/GridData.cc b/src/Platform/GridData.cc index 5935d73..0150ff3 100644 --- a/src/Platform/GridData.cc +++ b/src/Platform/GridData.cc @@ -1,31 +1,15 @@ #include "GridData.h" -#include +#include namespace platform { - GridData::GridData() + GridData::GridData(const std::string& fileName) { - auto boostaode = R"( - [ - { - "convergence": [true, false], - "ascending": [true, false], - "repeatSparent": [true, false], - "select_features": ["CFS", "FCBF"], - "tolerance": [0, 3, 5], - "threshold": [1e-7] - }, - { - "convergence": [true, false], - "ascending": [true, false], - "repeatSparent": [true, false], - "select_features": ["IWSS"], - "tolerance": [0, 3, 5], - "threshold": [0.5] - - } - ] - )"_json; - grid["BoostAODE"] = boostaode; + std::ifstream resultData(fileName); + if (resultData.is_open()) { + grid = json::parse(resultData); + } else { + throw std::invalid_argument("Unable to open input file. [" + fileName + "]"); + } } int GridData::computeNumCombinations(const json& line) { @@ -35,10 +19,10 @@ namespace platform { } return numCombinations; } - int GridData::getNumCombinations(const std::string& model) + int GridData::getNumCombinations() { int numCombinations = 0; - for (const auto& line : grid.at(model)) { + for (const auto& line : grid) { numCombinations += computeNumCombinations(line); } return numCombinations; @@ -60,10 +44,10 @@ namespace platform { } return currentCombination; } - std::vector GridData::getGrid(const std::string& model) + std::vector GridData::getGrid() { auto result = std::vector(); - for (json line : grid.at(model)) { + for (json line : grid) { generateCombinations(line.begin(), line.end(), result, json({})); } return result; diff --git a/src/Platform/GridData.h b/src/Platform/GridData.h index 87ab74c..b68a54a 100644 --- a/src/Platform/GridData.h +++ b/src/Platform/GridData.h @@ -9,14 +9,14 @@ namespace platform { using json = nlohmann::json; class GridData { public: - GridData(); + explicit GridData(const std::string& fileName); ~GridData() = default; - std::vector getGrid(const std::string& model); - int getNumCombinations(const std::string& model); + std::vector getGrid(); + int getNumCombinations(); private: json generateCombinations(json::iterator index, const json::iterator last, std::vector& output, json currentCombination); int computeNumCombinations(const json& line); - std::map grid; + json grid; }; } /* namespace platform */ #endif /* GRIDDATA_H */ \ No newline at end of file diff --git a/src/Platform/GridSearch.cc b/src/Platform/GridSearch.cc index 239cf50..d0b84ed 100644 --- a/src/Platform/GridSearch.cc +++ b/src/Platform/GridSearch.cc @@ -10,6 +10,7 @@ namespace platform { GridSearch::GridSearch(struct ConfigGrid& config) : config(config) { this->config.output_file = config.path + "grid_" + config.model + "_output.json"; + this->config.input_file = config.path + "grid_" + config.model + "_input.json"; } void showProgressComb(const int num, const int total, const std::string& color) { @@ -83,7 +84,9 @@ namespace platform { auto datasets = Datasets(config.discretize, Paths::datasets()); // Create model std::cout << "***************** Starting Gridsearch *****************" << std::endl; - auto totalComb = grid.getNumCombinations(config.model); + std::cout << "input file=" << config.input_file << std::endl; + auto grid = GridData(config.input_file); + auto totalComb = grid.getNumCombinations(); std::cout << "* Doing " << totalComb << " combinations for each dataset/seed/fold" << std::endl; // Generate hyperparameters grid & run gridsearch // Check each combination of hyperparameters for each dataset and each seed @@ -92,7 +95,7 @@ namespace platform { int num = 0; double bestScore = 0.0; json bestHyperparameters; - for (const auto& hyperparam_line : grid.getGrid(config.model)) { + for (const auto& hyperparam_line : grid.getGrid()) { showProgressComb(++num, totalComb, Colors::CYAN()); auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line); double score = processFile(dataset, datasets, hyperparameters); diff --git a/src/Platform/GridSearch.h b/src/Platform/GridSearch.h index 6bf9f1a..81f06b5 100644 --- a/src/Platform/GridSearch.h +++ b/src/Platform/GridSearch.h @@ -30,7 +30,6 @@ namespace platform { double processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters); json results; struct ConfigGrid config; - GridData grid; }; } /* namespace platform */ #endif /* GRIDSEARCH_H */ \ No newline at end of file