Complete Stratified K Fold

This commit is contained in:
2023-07-22 11:23:35 +02:00
parent f6e154bc6e
commit 41cceece20
5 changed files with 88 additions and 86 deletions

View File

@@ -34,44 +34,44 @@ pair<vector<int>, vector<int>> KFold::getFold(int nFold)
StratifiedKFold::StratifiedKFold(int k, const vector<int>& y, int seed) :
k(k), seed(seed)
{
// n = y.size();
// map<int, vector<int>> class_to_indices;
// for (int i = 0; i < n; ++i) {
// class_to_indices[y[i]].push_back(i);
// }
// random_device rd;
// default_random_engine random_seed(seed == -1 ? rd() : seed);
// for (auto& [cls, indices] : class_to_indices) {
// shuffle(indices.begin(), indices.end(), random_seed);
// int fold_size = n / k;
// for (int i = 0; i < k; ++i) {
// int start = i * fold_size;
// int end = (i == k - 1) ? indices.size() : (i + 1) * fold_size;
// stratified_indices.emplace_back(indices.begin() + start, indices.begin() + end);
// }
// }
n = y.size();
stratified_indices.resize(k);
stratified_indices = vector<vector<int>>(k);
int fold_size = n / k;
int remainder = n % k;
// Compute class counts and indices
auto class_indices = map<int, vector<int>>();
vector<int> class_counts(*max_element(y.begin(), y.end()) + 1, 0);
for (auto i = 0; i < n; ++i) {
class_counts[y[i]]++;
class_indices[y[i]].push_back(i);
}
vector<int> class_starts(class_counts.size());
partial_sum(class_counts.begin(), class_counts.end() - 1, class_starts.begin() + 1);
vector<int> indices(n);
for (auto i = 0; i < n; ++i) {
int label = y[i];
stratified_indices[class_starts[label]] = i;
class_starts[label]++;
// Shuffle class indices
random_device rd;
default_random_engine random_seed(seed == -1 ? rd() : seed);
for (auto& [cls, indices] : class_indices) {
shuffle(indices.begin(), indices.end(), random_seed);
}
int fold_size = n / k;
int remainder = n % k;
int start = 0;
for (auto i = 0; i < k; ++i) {
int fold_length = fold_size + (i < remainder ? 1 : 0);
stratified_indices[i].resize(fold_length);
copy(indices.begin() + start, indices.begin() + start + fold_length, stratified_indices[i].begin());
start += fold_length;
// Assign indices to folds
for (auto label = 0; label < class_counts.size(); ++label) {
auto num_samples_to_take = class_counts[label] / k;
if (num_samples_to_take == 0)
continue;
auto remainder_samples_to_take = class_counts[label] % k;
for (auto fold = 0; fold < k; ++fold) {
auto it = next(class_indices[label].begin(), num_samples_to_take);
move(class_indices[label].begin(), it, back_inserter(stratified_indices[fold])); // ##
class_indices[label].erase(class_indices[label].begin(), it);
}
while (remainder_samples_to_take > 0) {
int fold = (rand() % static_cast<int>(k));
if (stratified_indices[fold].size() == fold_size) {
continue;
}
auto it = next(class_indices[label].begin(), 1);
stratified_indices[fold].push_back(*class_indices[label].begin());
class_indices[label].erase(class_indices[label].begin(), it);
remainder_samples_to_take--;
}
}
}
pair<vector<int>, vector<int>> StratifiedKFold::getFold(int nFold)