Refactor stratified build optimizing loops

This commit is contained in:
2024-05-11 14:10:00 +02:00
parent 3fac7b95f8
commit 38bc00bb05
2 changed files with 26 additions and 18 deletions

View File

@@ -103,22 +103,20 @@ namespace folding {
<< ") is less than the number of folds (" << k << ")." << std::endl;
faulty = true;
}
int start = 0;
for (auto fold = 0; fold < k; ++fold) {
auto it = next(class_indices[label].begin(), num_samples_to_take);
move(class_indices[label].begin(), it, back_inserter(stratified_indices[fold]));
class_indices[label].erase(class_indices[label].begin(), it);
auto it = next(class_indices[label].begin() + start, num_samples_to_take);
move(indices.begin() + start, it, back_inserter(stratified_indices[fold]));
start += num_samples_to_take;
}
auto chosen = std::vector<bool>(k, false);
while (remainder_samples_to_take > 0) {
int fold = (rand() % static_cast<int>(k));
if (chosen.at(fold)) {
continue;
if (remainder_samples_to_take > 0) {
auto chosen = std::vector<int>(k);
std::iota(chosen.begin(), chosen.end(), 0);
std::shuffle(chosen.begin(), chosen.end(), random_seed);
chosen.resize(remainder_samples_to_take);
for (auto fold : chosen) {
stratified_indices[fold].push_back(indices.at(start++));
}
chosen[fold] = true;
auto it = next(class_indices[label].begin(), 1);
stratified_indices[fold].push_back(*class_indices[label].begin());
class_indices[label].erase(class_indices[label].begin(), it);
remainder_samples_to_take--;
}
}
}