Complete Folding Test

This commit is contained in:
2023-10-07 01:23:36 +02:00
parent 1287160c47
commit 8c3864f3c8
5 changed files with 26 additions and 24 deletions

View File

@@ -47,7 +47,7 @@ namespace platform {
{
stratified_indices = vector<vector<int>>(k);
int fold_size = n / k;
cout << "Fold SIZE: " << fold_size << endl;
// Compute class counts and indices
auto class_indices = map<int, vector<int>>();
vector<int> class_counts(*max_element(y.begin(), y.end()) + 1, 0);
@@ -61,11 +61,14 @@ namespace platform {
}
// Assign indices to folds
for (auto label = 0; label < class_counts.size(); ++label) {
auto num_samples_to_take = class_counts[label] / k;
if (num_samples_to_take == 0)
auto num_samples_to_take = class_counts.at(label) / k;
if (num_samples_to_take == 0) {
cerr << "Warning! The number of samples in class " << label << " (" << class_counts.at(label)
<< ") is less than the number of folds (" << k << ")." << endl;
faulty = true;
continue;
}
auto remainder_samples_to_take = class_counts[label] % k;
cout << "Remainder samples to take: " << remainder_samples_to_take << endl;
for (auto fold = 0; fold < k; ++fold) {
auto it = next(class_indices[label].begin(), num_samples_to_take);
move(class_indices[label].begin(), it, back_inserter(stratified_indices[fold])); // ##
@@ -74,12 +77,10 @@ namespace platform {
auto chosen = vector<bool>(k, false);
while (remainder_samples_to_take > 0) {
int fold = (rand() % static_cast<int>(k));
cout << "-candidate: " << fold << endl;
if (chosen.at(fold)) {
continue;
}
chosen[fold] = true;
cout << "One goes to fold " << fold << " that had " << stratified_indices[fold].size() << " elements before" << endl;
auto it = next(class_indices[label].begin(), 1);
stratified_indices[fold].push_back(*class_indices[label].begin());
class_indices[label].erase(class_indices[label].begin(), it);

View File

@@ -29,10 +29,12 @@ namespace platform {
vector<int> y;
vector<vector<int>> stratified_indices;
void build();
bool faulty = false; // Only true if the number of samples of any class is less than the number of folds.
public:
StratifiedKFold(int k, const vector<int>& y, int seed = -1);
StratifiedKFold(int k, torch::Tensor& y, int seed = -1);
pair<vector<int>, vector<int>> getFold(int nFold) override;
bool isFaulty() { return faulty; }
};
}
#endif