Refactor stratified build optimizing loops
This commit is contained in:
24
folding.hpp
24
folding.hpp
@@ -103,22 +103,20 @@ namespace folding {
|
||||
<< ") is less than the number of folds (" << k << ")." << std::endl;
|
||||
faulty = true;
|
||||
}
|
||||
int start = 0;
|
||||
for (auto fold = 0; fold < k; ++fold) {
|
||||
auto it = next(class_indices[label].begin(), num_samples_to_take);
|
||||
move(class_indices[label].begin(), it, back_inserter(stratified_indices[fold]));
|
||||
class_indices[label].erase(class_indices[label].begin(), it);
|
||||
auto it = next(class_indices[label].begin() + start, num_samples_to_take);
|
||||
move(indices.begin() + start, it, back_inserter(stratified_indices[fold]));
|
||||
start += num_samples_to_take;
|
||||
}
|
||||
auto chosen = std::vector<bool>(k, false);
|
||||
while (remainder_samples_to_take > 0) {
|
||||
int fold = (rand() % static_cast<int>(k));
|
||||
if (chosen.at(fold)) {
|
||||
continue;
|
||||
if (remainder_samples_to_take > 0) {
|
||||
auto chosen = std::vector<int>(k);
|
||||
std::iota(chosen.begin(), chosen.end(), 0);
|
||||
std::shuffle(chosen.begin(), chosen.end(), random_seed);
|
||||
chosen.resize(remainder_samples_to_take);
|
||||
for (auto fold : chosen) {
|
||||
stratified_indices[fold].push_back(indices.at(start++));
|
||||
}
|
||||
chosen[fold] = true;
|
||||
auto it = next(class_indices[label].begin(), 1);
|
||||
stratified_indices[fold].push_back(*class_indices[label].begin());
|
||||
class_indices[label].erase(class_indices[label].begin(), it);
|
||||
remainder_samples_to_take--;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user