diff --git a/folding.hpp b/folding.hpp index 0cf0fc7..0e49b25 100644 --- a/folding.hpp +++ b/folding.hpp @@ -12,7 +12,12 @@ namespace folding { int seed; std::default_random_engine random_seed; public: - Fold(int k, int n, int seed = -1); + Fold(int k, int n, int seed = -1) : k(k), n(n), seed(seed) + { + std::random_device rd; + random_seed = std::default_random_engine(seed == -1 ? rd() : seed); + std::srand(seed == -1 ? time(0) : seed); + } virtual std::pair, std::vector> getFold(int nFold) = 0; virtual ~Fold() = default; int getNumberOfFolds() { return k; } @@ -21,118 +26,106 @@ namespace folding { private: std::vector indices; public: - KFold(int k, int n, int seed = -1); - std::pair, std::vector> getFold(int nFold) override; + KFold(int k, int n, int seed = -1) : Fold(k, n, seed), indices(std::vector(n)) + { + std::iota(begin(indices), end(indices), 0); // fill with 0, 1, ..., n - 1 + shuffle(indices.begin(), indices.end(), random_seed); + } + std::pair, std::vector> getFold(int nFold) override + { + if (nFold >= k || nFold < 0) { + throw std::out_of_range("nFold (" + std::to_string(nFold) + ") must be less than k (" + std::to_string(k) + ")"); + } + int nTest = n / k; + auto train = std::vector(); + auto test = std::vector(); + for (int i = 0; i < n; i++) { + if (i >= nTest * nFold && i < nTest * (nFold + 1)) { + test.push_back(indices[i]); + } else { + train.push_back(indices[i]); + } + } + return { train, test }; + } }; class StratifiedKFold : public Fold { private: std::vector y; std::vector> stratified_indices; - void build(); - bool faulty = false; // Only true if the number of samples of any class is less than the number of folds. - public: - StratifiedKFold(int k, const std::vector& y, int seed = -1); - StratifiedKFold(int k, torch::Tensor& y, int seed = -1); - std::pair, std::vector> getFold(int nFold) override; - bool isFaulty() { return faulty; } - }; - Fold::Fold(int k, int n, int seed) : k(k), n(n), seed(seed) - { - std::random_device rd; - random_seed = std::default_random_engine(seed == -1 ? rd() : seed); - std::srand(seed == -1 ? time(0) : seed); - } - KFold::KFold(int k, int n, int seed) : Fold(k, n, seed), indices(std::vector(n)) - { - std::iota(begin(indices), end(indices), 0); // fill with 0, 1, ..., n - 1 - shuffle(indices.begin(), indices.end(), random_seed); - } - std::pair, std::vector> KFold::getFold(int nFold) - { - if (nFold >= k || nFold < 0) { - throw std::out_of_range("nFold (" + std::to_string(nFold) + ") must be less than k (" + std::to_string(k) + ")"); - } - int nTest = n / k; - auto train = std::vector(); - auto test = std::vector(); - for (int i = 0; i < n; i++) { - if (i >= nTest * nFold && i < nTest * (nFold + 1)) { - test.push_back(indices[i]); - } else { - train.push_back(indices[i]); - } - } - return { train, test }; - } - StratifiedKFold::StratifiedKFold(int k, torch::Tensor& y, int seed) : Fold(k, y.numel(), seed) - { - n = y.numel(); - this->y = std::vector(y.data_ptr(), y.data_ptr() + n); - build(); - } - StratifiedKFold::StratifiedKFold(int k, const std::vector& y, int seed) - : Fold(k, y.size(), seed) - { - this->y = y; - n = y.size(); - build(); - } - void StratifiedKFold::build() - { - stratified_indices = std::vector>(k); - int fold_size = n / k; + void build() + { + stratified_indices = std::vector>(k); + int fold_size = n / k; - // Compute class counts and indices - auto class_indices = std::map>(); - std::vector class_counts(*max_element(y.begin(), y.end()) + 1, 0); - for (auto i = 0; i < n; ++i) { - class_counts[y[i]]++; - class_indices[y[i]].push_back(i); - } - // Shuffle class indices - for (auto& [cls, indices] : class_indices) { - shuffle(indices.begin(), indices.end(), random_seed); - } - // Assign indices to folds - for (auto label = 0; label < class_counts.size(); ++label) { - auto num_samples_to_take = class_counts.at(label) / k; - if (num_samples_to_take == 0) { - std::cerr << "Warning! The number of samples in class " << label << " (" << class_counts.at(label) - << ") is less than the number of folds (" << k << ")." << std::endl; - faulty = true; - continue; + // Compute class counts and indices + auto class_indices = std::map>(); + std::vector class_counts(*max_element(y.begin(), y.end()) + 1, 0); + for (auto i = 0; i < n; ++i) { + class_counts[y[i]]++; + class_indices[y[i]].push_back(i); } - auto remainder_samples_to_take = class_counts[label] % k; - for (auto fold = 0; fold < k; ++fold) { - auto it = next(class_indices[label].begin(), num_samples_to_take); - move(class_indices[label].begin(), it, back_inserter(stratified_indices[fold])); // ## - class_indices[label].erase(class_indices[label].begin(), it); + // Shuffle class indices + for (auto& [cls, indices] : class_indices) { + shuffle(indices.begin(), indices.end(), random_seed); } - auto chosen = std::vector(k, false); - while (remainder_samples_to_take > 0) { - int fold = (rand() % static_cast(k)); - if (chosen.at(fold)) { + // Assign indices to folds + for (auto label = 0; label < class_counts.size(); ++label) { + auto num_samples_to_take = class_counts.at(label) / k; + if (num_samples_to_take == 0) { + std::cerr << "Warning! The number of samples in class " << label << " (" << class_counts.at(label) + << ") is less than the number of folds (" << k << ")." << std::endl; + faulty = true; continue; } - chosen[fold] = true; - auto it = next(class_indices[label].begin(), 1); - stratified_indices[fold].push_back(*class_indices[label].begin()); - class_indices[label].erase(class_indices[label].begin(), it); - remainder_samples_to_take--; + auto remainder_samples_to_take = class_counts[label] % k; + for (auto fold = 0; fold < k; ++fold) { + auto it = next(class_indices[label].begin(), num_samples_to_take); + move(class_indices[label].begin(), it, back_inserter(stratified_indices[fold])); // ## + class_indices[label].erase(class_indices[label].begin(), it); + } + auto chosen = std::vector(k, false); + while (remainder_samples_to_take > 0) { + int fold = (rand() % static_cast(k)); + if (chosen.at(fold)) { + continue; + } + chosen[fold] = true; + auto it = next(class_indices[label].begin(), 1); + stratified_indices[fold].push_back(*class_indices[label].begin()); + class_indices[label].erase(class_indices[label].begin(), it); + remainder_samples_to_take--; + } } } - } - std::pair, std::vector> StratifiedKFold::getFold(int nFold) - { - if (nFold >= k || nFold < 0) { - throw std::out_of_range("nFold (" + std::to_string(nFold) + ") must be less than k (" + std::to_string(k) + ")"); + bool faulty = false; // Only true if the number of samples of any class is less than the number of folds. + public: + StratifiedKFold(int k, const std::vector& y, int seed = -1) : Fold(k, y.size(), seed) + { + this->y = y; + n = y.size(); + build(); } - std::vector test_indices = stratified_indices[nFold]; - std::vector train_indices; - for (int i = 0; i < k; ++i) { - if (i == nFold) continue; - train_indices.insert(train_indices.end(), stratified_indices[i].begin(), stratified_indices[i].end()); + StratifiedKFold(int k, torch::Tensor& y, int seed = -1) : Fold(k, y.numel(), seed) + { + n = y.numel(); + this->y = std::vector(y.data_ptr(), y.data_ptr() + n); + build(); } - return { train_indices, test_indices }; - } + + std::pair, std::vector> getFold(int nFold) override + { + if (nFold >= k || nFold < 0) { + throw std::out_of_range("nFold (" + std::to_string(nFold) + ") must be less than k (" + std::to_string(k) + ")"); + } + std::vector test_indices = stratified_indices[nFold]; + std::vector train_indices; + for (int i = 0; i < k; ++i) { + if (i == nFold) continue; + train_indices.insert(train_indices.end(), stratified_indices[i].begin(), stratified_indices[i].end()); + } + return { train_indices, test_indices }; + } + bool isFaulty() { return faulty; } + }; } \ No newline at end of file