Refactor input grid parameters to json file

This commit is contained in:
Ricardo Montañana Gómez 2023-11-24 09:57:29 +01:00
parent 8b7b59d42b
commit 2121ba9b98
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
4 changed files with 21 additions and 35 deletions

View File

@ -1,31 +1,15 @@
#include "GridData.h" #include "GridData.h"
#include <iostream> #include <fstream>
namespace platform { namespace platform {
GridData::GridData() GridData::GridData(const std::string& fileName)
{ {
auto boostaode = R"( std::ifstream resultData(fileName);
[ if (resultData.is_open()) {
{ grid = json::parse(resultData);
"convergence": [true, false], } else {
"ascending": [true, false], throw std::invalid_argument("Unable to open input file. [" + fileName + "]");
"repeatSparent": [true, false], }
"select_features": ["CFS", "FCBF"],
"tolerance": [0, 3, 5],
"threshold": [1e-7]
},
{
"convergence": [true, false],
"ascending": [true, false],
"repeatSparent": [true, false],
"select_features": ["IWSS"],
"tolerance": [0, 3, 5],
"threshold": [0.5]
}
]
)"_json;
grid["BoostAODE"] = boostaode;
} }
int GridData::computeNumCombinations(const json& line) int GridData::computeNumCombinations(const json& line)
{ {
@ -35,10 +19,10 @@ namespace platform {
} }
return numCombinations; return numCombinations;
} }
int GridData::getNumCombinations(const std::string& model) int GridData::getNumCombinations()
{ {
int numCombinations = 0; int numCombinations = 0;
for (const auto& line : grid.at(model)) { for (const auto& line : grid) {
numCombinations += computeNumCombinations(line); numCombinations += computeNumCombinations(line);
} }
return numCombinations; return numCombinations;
@ -60,10 +44,10 @@ namespace platform {
} }
return currentCombination; return currentCombination;
} }
std::vector<json> GridData::getGrid(const std::string& model) std::vector<json> GridData::getGrid()
{ {
auto result = std::vector<json>(); auto result = std::vector<json>();
for (json line : grid.at(model)) { for (json line : grid) {
generateCombinations(line.begin(), line.end(), result, json({})); generateCombinations(line.begin(), line.end(), result, json({}));
} }
return result; return result;

View File

@ -9,14 +9,14 @@ namespace platform {
using json = nlohmann::json; using json = nlohmann::json;
class GridData { class GridData {
public: public:
GridData(); explicit GridData(const std::string& fileName);
~GridData() = default; ~GridData() = default;
std::vector<json> getGrid(const std::string& model); std::vector<json> getGrid();
int getNumCombinations(const std::string& model); int getNumCombinations();
private: private:
json generateCombinations(json::iterator index, const json::iterator last, std::vector<json>& output, json currentCombination); json generateCombinations(json::iterator index, const json::iterator last, std::vector<json>& output, json currentCombination);
int computeNumCombinations(const json& line); int computeNumCombinations(const json& line);
std::map<std::string, json> grid; json grid;
}; };
} /* namespace platform */ } /* namespace platform */
#endif /* GRIDDATA_H */ #endif /* GRIDDATA_H */

View File

@ -10,6 +10,7 @@ namespace platform {
GridSearch::GridSearch(struct ConfigGrid& config) : config(config) GridSearch::GridSearch(struct ConfigGrid& config) : config(config)
{ {
this->config.output_file = config.path + "grid_" + config.model + "_output.json"; this->config.output_file = config.path + "grid_" + config.model + "_output.json";
this->config.input_file = config.path + "grid_" + config.model + "_input.json";
} }
void showProgressComb(const int num, const int total, const std::string& color) void showProgressComb(const int num, const int total, const std::string& color)
{ {
@ -83,7 +84,9 @@ namespace platform {
auto datasets = Datasets(config.discretize, Paths::datasets()); auto datasets = Datasets(config.discretize, Paths::datasets());
// Create model // Create model
std::cout << "***************** Starting Gridsearch *****************" << std::endl; std::cout << "***************** Starting Gridsearch *****************" << std::endl;
auto totalComb = grid.getNumCombinations(config.model); std::cout << "input file=" << config.input_file << std::endl;
auto grid = GridData(config.input_file);
auto totalComb = grid.getNumCombinations();
std::cout << "* Doing " << totalComb << " combinations for each dataset/seed/fold" << std::endl; std::cout << "* Doing " << totalComb << " combinations for each dataset/seed/fold" << std::endl;
// Generate hyperparameters grid & run gridsearch // Generate hyperparameters grid & run gridsearch
// Check each combination of hyperparameters for each dataset and each seed // Check each combination of hyperparameters for each dataset and each seed
@ -92,7 +95,7 @@ namespace platform {
int num = 0; int num = 0;
double bestScore = 0.0; double bestScore = 0.0;
json bestHyperparameters; json bestHyperparameters;
for (const auto& hyperparam_line : grid.getGrid(config.model)) { for (const auto& hyperparam_line : grid.getGrid()) {
showProgressComb(++num, totalComb, Colors::CYAN()); showProgressComb(++num, totalComb, Colors::CYAN());
auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line); auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line);
double score = processFile(dataset, datasets, hyperparameters); double score = processFile(dataset, datasets, hyperparameters);

View File

@ -30,7 +30,6 @@ namespace platform {
double processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters); double processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters);
json results; json results;
struct ConfigGrid config; struct ConfigGrid config;
GridData grid;
}; };
} /* namespace platform */ } /* namespace platform */
#endif /* GRIDSEARCH_H */ #endif /* GRIDSEARCH_H */