#ifndef FOLDING_H #define FOLDING_H #include #include #include using namespace std; namespace platform { class Fold { protected: int k; int n; int seed; default_random_engine random_seed; public: Fold(int k, int n, int seed = -1); virtual pair, vector> getFold(int nFold) = 0; virtual ~Fold() = default; int getNumberOfFolds() { return k; } }; class KFold : public Fold { private: vector indices; public: KFold(int k, int n, int seed = -1); pair, vector> getFold(int nFold) override; }; class StratifiedKFold : public Fold { private: vector y; vector> stratified_indices; void build(); bool faulty = false; // Only true if the number of samples of any class is less than the number of folds. public: StratifiedKFold(int k, const vector& y, int seed = -1); StratifiedKFold(int k, torch::Tensor& y, int seed = -1); pair, vector> getFold(int nFold) override; bool isFaulty() { return faulty; } }; } #endif