Refactor input grid parameters to json file
This commit is contained in:
parent
8b7b59d42b
commit
2121ba9b98
@ -1,31 +1,15 @@
|
|||||||
#include "GridData.h"
|
#include "GridData.h"
|
||||||
#include <iostream>
|
#include <fstream>
|
||||||
|
|
||||||
namespace platform {
|
namespace platform {
|
||||||
GridData::GridData()
|
GridData::GridData(const std::string& fileName)
|
||||||
{
|
{
|
||||||
auto boostaode = R"(
|
std::ifstream resultData(fileName);
|
||||||
[
|
if (resultData.is_open()) {
|
||||||
{
|
grid = json::parse(resultData);
|
||||||
"convergence": [true, false],
|
} else {
|
||||||
"ascending": [true, false],
|
throw std::invalid_argument("Unable to open input file. [" + fileName + "]");
|
||||||
"repeatSparent": [true, false],
|
}
|
||||||
"select_features": ["CFS", "FCBF"],
|
|
||||||
"tolerance": [0, 3, 5],
|
|
||||||
"threshold": [1e-7]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"convergence": [true, false],
|
|
||||||
"ascending": [true, false],
|
|
||||||
"repeatSparent": [true, false],
|
|
||||||
"select_features": ["IWSS"],
|
|
||||||
"tolerance": [0, 3, 5],
|
|
||||||
"threshold": [0.5]
|
|
||||||
|
|
||||||
}
|
|
||||||
]
|
|
||||||
)"_json;
|
|
||||||
grid["BoostAODE"] = boostaode;
|
|
||||||
}
|
}
|
||||||
int GridData::computeNumCombinations(const json& line)
|
int GridData::computeNumCombinations(const json& line)
|
||||||
{
|
{
|
||||||
@ -35,10 +19,10 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
return numCombinations;
|
return numCombinations;
|
||||||
}
|
}
|
||||||
int GridData::getNumCombinations(const std::string& model)
|
int GridData::getNumCombinations()
|
||||||
{
|
{
|
||||||
int numCombinations = 0;
|
int numCombinations = 0;
|
||||||
for (const auto& line : grid.at(model)) {
|
for (const auto& line : grid) {
|
||||||
numCombinations += computeNumCombinations(line);
|
numCombinations += computeNumCombinations(line);
|
||||||
}
|
}
|
||||||
return numCombinations;
|
return numCombinations;
|
||||||
@ -60,10 +44,10 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
return currentCombination;
|
return currentCombination;
|
||||||
}
|
}
|
||||||
std::vector<json> GridData::getGrid(const std::string& model)
|
std::vector<json> GridData::getGrid()
|
||||||
{
|
{
|
||||||
auto result = std::vector<json>();
|
auto result = std::vector<json>();
|
||||||
for (json line : grid.at(model)) {
|
for (json line : grid) {
|
||||||
generateCombinations(line.begin(), line.end(), result, json({}));
|
generateCombinations(line.begin(), line.end(), result, json({}));
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
|
@ -9,14 +9,14 @@ namespace platform {
|
|||||||
using json = nlohmann::json;
|
using json = nlohmann::json;
|
||||||
class GridData {
|
class GridData {
|
||||||
public:
|
public:
|
||||||
GridData();
|
explicit GridData(const std::string& fileName);
|
||||||
~GridData() = default;
|
~GridData() = default;
|
||||||
std::vector<json> getGrid(const std::string& model);
|
std::vector<json> getGrid();
|
||||||
int getNumCombinations(const std::string& model);
|
int getNumCombinations();
|
||||||
private:
|
private:
|
||||||
json generateCombinations(json::iterator index, const json::iterator last, std::vector<json>& output, json currentCombination);
|
json generateCombinations(json::iterator index, const json::iterator last, std::vector<json>& output, json currentCombination);
|
||||||
int computeNumCombinations(const json& line);
|
int computeNumCombinations(const json& line);
|
||||||
std::map<std::string, json> grid;
|
json grid;
|
||||||
};
|
};
|
||||||
} /* namespace platform */
|
} /* namespace platform */
|
||||||
#endif /* GRIDDATA_H */
|
#endif /* GRIDDATA_H */
|
@ -10,6 +10,7 @@ namespace platform {
|
|||||||
GridSearch::GridSearch(struct ConfigGrid& config) : config(config)
|
GridSearch::GridSearch(struct ConfigGrid& config) : config(config)
|
||||||
{
|
{
|
||||||
this->config.output_file = config.path + "grid_" + config.model + "_output.json";
|
this->config.output_file = config.path + "grid_" + config.model + "_output.json";
|
||||||
|
this->config.input_file = config.path + "grid_" + config.model + "_input.json";
|
||||||
}
|
}
|
||||||
void showProgressComb(const int num, const int total, const std::string& color)
|
void showProgressComb(const int num, const int total, const std::string& color)
|
||||||
{
|
{
|
||||||
@ -83,7 +84,9 @@ namespace platform {
|
|||||||
auto datasets = Datasets(config.discretize, Paths::datasets());
|
auto datasets = Datasets(config.discretize, Paths::datasets());
|
||||||
// Create model
|
// Create model
|
||||||
std::cout << "***************** Starting Gridsearch *****************" << std::endl;
|
std::cout << "***************** Starting Gridsearch *****************" << std::endl;
|
||||||
auto totalComb = grid.getNumCombinations(config.model);
|
std::cout << "input file=" << config.input_file << std::endl;
|
||||||
|
auto grid = GridData(config.input_file);
|
||||||
|
auto totalComb = grid.getNumCombinations();
|
||||||
std::cout << "* Doing " << totalComb << " combinations for each dataset/seed/fold" << std::endl;
|
std::cout << "* Doing " << totalComb << " combinations for each dataset/seed/fold" << std::endl;
|
||||||
// Generate hyperparameters grid & run gridsearch
|
// Generate hyperparameters grid & run gridsearch
|
||||||
// Check each combination of hyperparameters for each dataset and each seed
|
// Check each combination of hyperparameters for each dataset and each seed
|
||||||
@ -92,7 +95,7 @@ namespace platform {
|
|||||||
int num = 0;
|
int num = 0;
|
||||||
double bestScore = 0.0;
|
double bestScore = 0.0;
|
||||||
json bestHyperparameters;
|
json bestHyperparameters;
|
||||||
for (const auto& hyperparam_line : grid.getGrid(config.model)) {
|
for (const auto& hyperparam_line : grid.getGrid()) {
|
||||||
showProgressComb(++num, totalComb, Colors::CYAN());
|
showProgressComb(++num, totalComb, Colors::CYAN());
|
||||||
auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line);
|
auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line);
|
||||||
double score = processFile(dataset, datasets, hyperparameters);
|
double score = processFile(dataset, datasets, hyperparameters);
|
||||||
|
@ -30,7 +30,6 @@ namespace platform {
|
|||||||
double processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters);
|
double processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters);
|
||||||
json results;
|
json results;
|
||||||
struct ConfigGrid config;
|
struct ConfigGrid config;
|
||||||
GridData grid;
|
|
||||||
};
|
};
|
||||||
} /* namespace platform */
|
} /* namespace platform */
|
||||||
#endif /* GRIDSEARCH_H */
|
#endif /* GRIDSEARCH_H */
|
Loading…
Reference in New Issue
Block a user