Begin gridsearch implementation

This commit is contained in:
2023-11-22 12:22:30 +01:00
parent b657762c0c
commit fb347ed5b9
7 changed files with 110 additions and 106 deletions

View File

@@ -30,37 +30,41 @@ namespace platform {
int GridData::computeNumCombinations(const json& line)
{
int numCombinations = 1;
for (const auto& item : line) {
for (const auto& hyperparam : item.items()) {
numCombinations *= item.size();
}
for (const auto& item : line.items()) {
numCombinations *= item.value().size();
}
return numCombinations;
}
std::vector<json> GridData::doCombination(const std::string& model)
int GridData::getNumCombinations(const std::string& model)
{
int numTotal = 0;
for (const auto& item : grid[model]) {
numTotal += computeNumCombinations(item);
int numCombinations = 0;
for (const auto& line : grid.at(model)) {
numCombinations += computeNumCombinations(line);
}
auto result = std::vector<json>(numTotal);
int base = 0;
for (const auto& item : grid[model]) {
int numCombinations = computeNumCombinations(item);
int line = 0;
for (const auto& hyperparam : item.items()) {
int numValues = hyperparam.value().size();
for (const auto& value : hyperparam.value()) {
for (int i = 0; i < numCombinations / numValues; i++) {
result[base + line++][hyperparam.key()] = value;
//std::cout << "line=" << base + line << " " << hyperparam.key() << "=" << value << std::endl;
}
}
}
base += numCombinations;
return numCombinations;
}
json GridData::generateCombinations(json::iterator index, const json::iterator last, std::vector<json>& output, json currentCombination)
{
if (index == last) {
// If we reached the end of input, store the current combination
output.push_back(currentCombination);
return currentCombination;
}
for (const auto& item : result) {
std::cout << item.dump() << std::endl;
const auto& key = index.key();
const auto& values = index.value();
for (const auto& value : values) {
auto combination = currentCombination;
combination[key] = value;
json::iterator nextIndex = index;
generateCombinations(++nextIndex, last, output, combination);
}
return currentCombination;
}
std::vector<json> GridData::getGrid(const std::string& model)
{
auto result = std::vector<json>();
for (json line : grid.at(model)) {
generateCombinations(line.begin(), line.end(), result, json({}));
}
return result;
}