Complete Folding Test

This commit is contained in:
2023-10-07 01:23:36 +02:00
parent 1287160c47
commit 8c3864f3c8
5 changed files with 26 additions and 24 deletions

View File

@@ -29,10 +29,12 @@ namespace platform {
vector<int> y;
vector<vector<int>> stratified_indices;
void build();
bool faulty = false; // Only true if the number of samples of any class is less than the number of folds.
public:
StratifiedKFold(int k, const vector<int>& y, int seed = -1);
StratifiedKFold(int k, torch::Tensor& y, int seed = -1);
pair<vector<int>, vector<int>> getFold(int nFold) override;
bool isFaulty() { return faulty; }
};
}
#endif