Refactor grid input hyperparameter file

This commit is contained in:
2023-11-29 18:24:34 +01:00
parent e3f6dc1e0b
commit dee9c674da
6 changed files with 120 additions and 87 deletions

View File

@@ -4,12 +4,19 @@
namespace platform {
GridData::GridData(const std::string& fileName)
{
json grid_file;
std::ifstream resultData(fileName);
if (resultData.is_open()) {
grid = json::parse(resultData);
grid_file = json::parse(resultData);
} else {
throw std::invalid_argument("Unable to open input file. [" + fileName + "]");
}
for (const auto& item : grid_file.items()) {
auto key = item.key();
auto value = item.value();
grid[key] = value;
}
}
int GridData::computeNumCombinations(const json& line)
{
@@ -19,10 +26,11 @@ namespace platform {
}
return numCombinations;
}
int GridData::getNumCombinations()
int GridData::getNumCombinations(const std::string& dataset)
{
int numCombinations = 0;
for (const auto& line : grid) {
auto selected = decide_dataset(dataset);
for (const auto& line : grid.at(selected)) {
numCombinations += computeNumCombinations(line);
}
return numCombinations;
@@ -44,16 +52,24 @@ namespace platform {
}
return currentCombination;
}
std::vector<json> GridData::getGrid()
std::vector<json> GridData::getGrid(const std::string& dataset)
{
auto selected = decide_dataset(dataset);
auto result = std::vector<json>();
for (json line : grid) {
for (json line : grid.at(selected)) {
generateCombinations(line.begin(), line.end(), result, json({}));
}
return result;
}
json& GridData::getInputGrid()
json& GridData::getInputGrid(const std::string& dataset)
{
return grid;
auto selected = decide_dataset(dataset);
return grid.at(selected);
}
std::string GridData::decide_dataset(const std::string& dataset)
{
if (grid.find(dataset) != grid.end())
return dataset;
return ALL_DATASETS;
}
} /* namespace platform */