Begin Test Folding

This commit is contained in:
2023-10-06 17:08:54 +02:00
parent b9e0028e9d
commit 17e079edd5
10 changed files with 250 additions and 55 deletions

View File

@@ -47,6 +47,7 @@ namespace platform {
{
stratified_indices = vector<vector<int>>(k);
int fold_size = n / k;
cout << "Fold SIZE: " << fold_size << endl;
// Compute class counts and indices
auto class_indices = map<int, vector<int>>();
vector<int> class_counts(*max_element(y.begin(), y.end()) + 1, 0);
@@ -64,16 +65,20 @@ namespace platform {
if (num_samples_to_take == 0)
continue;
auto remainder_samples_to_take = class_counts[label] % k;
cout << "Remainder samples to take: " << remainder_samples_to_take << endl;
for (auto fold = 0; fold < k; ++fold) {
auto it = next(class_indices[label].begin(), num_samples_to_take);
move(class_indices[label].begin(), it, back_inserter(stratified_indices[fold])); // ##
class_indices[label].erase(class_indices[label].begin(), it);
}
auto chosen = vector<bool>(k, false);
while (remainder_samples_to_take > 0) {
int fold = (rand() % static_cast<int>(k));
if (stratified_indices[fold].size() == fold_size + 1) {
if (chosen.at(fold)) {
continue;
}
chosen[k] = true;
cout << "One goes to fold " << fold << " that had " << stratified_indices[fold].size() << " elements before" << endl;
auto it = next(class_indices[label].begin(), 1);
stratified_indices[fold].push_back(*class_indices[label].begin());
class_indices[label].erase(class_indices[label].begin(), it);