From c7bcc10dfb0179ae87b862d6b1c9d6e2a6518296 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Sat, 11 May 2024 13:34:50 +0200 Subject: [PATCH] Refactor folding change order public and private methods --- folding.hpp | 72 ++++++++++++++++++++++---------------------- tests/TestFolding.cc | 4 +-- 2 files changed, 38 insertions(+), 38 deletions(-) diff --git a/folding.hpp b/folding.hpp index 749b66a..b4152ac 100644 --- a/folding.hpp +++ b/folding.hpp @@ -7,11 +7,6 @@ namespace folding { const std::string FOLDING_VERSION = "1.1.0"; class Fold { - protected: - int k; - int n; - int seed; - std::mt19937 random_seed; public: inline Fold(int k, int n, int seed = -1) : k(k), n(n), seed(seed) { @@ -23,10 +18,13 @@ namespace folding { virtual ~Fold() = default; std::string version() { return FOLDING_VERSION; } int getNumberOfFolds() { return k; } + protected: + int k; + int n; + int seed; + std::mt19937 random_seed; }; class KFold : public Fold { - private: - std::vector indices; public: inline KFold(int k, int n, int seed = -1) : Fold(k, n, seed), indices(std::vector(n)) { @@ -50,11 +48,42 @@ namespace folding { } return { train, test }; } + private: + std::vector indices; }; class StratifiedKFold : public Fold { + public: + inline StratifiedKFold(int k, const std::vector& y, int seed = -1) : Fold(k, y.size(), seed) + { + this->y = y; + n = y.size(); + build(); + } + inline StratifiedKFold(int k, torch::Tensor& y, int seed = -1) : Fold(k, y.numel(), seed) + { + n = y.numel(); + this->y = std::vector(y.data_ptr(), y.data_ptr() + n); + build(); + } + + inline std::pair, std::vector> getFold(int nFold) override + { + if (nFold >= k || nFold < 0) { + throw std::out_of_range("nFold (" + std::to_string(nFold) + ") must be less than k (" + std::to_string(k) + ")"); + } + std::vector test_indices = stratified_indices[nFold]; + std::vector train_indices; + for (int i = 0; i < k; ++i) { + if (i == nFold) continue; + train_indices.insert(train_indices.end(), stratified_indices[i].begin(), stratified_indices[i].end()); + } + return { train_indices, test_indices }; + } + inline bool isFaulty() { return faulty; } private: std::vector y; std::vector> stratified_indices; + bool faulty = false; // Only true if the number of samples of any class is less than the number of folds. void build() { stratified_indices = std::vector>(k); @@ -99,34 +128,5 @@ namespace folding { } } } - bool faulty = false; // Only true if the number of samples of any class is less than the number of folds. - public: - inline StratifiedKFold(int k, const std::vector& y, int seed = -1) : Fold(k, y.size(), seed) - { - this->y = y; - n = y.size(); - build(); - } - inline StratifiedKFold(int k, torch::Tensor& y, int seed = -1) : Fold(k, y.numel(), seed) - { - n = y.numel(); - this->y = std::vector(y.data_ptr(), y.data_ptr() + n); - build(); - } - - inline std::pair, std::vector> getFold(int nFold) override - { - if (nFold >= k || nFold < 0) { - throw std::out_of_range("nFold (" + std::to_string(nFold) + ") must be less than k (" + std::to_string(k) + ")"); - } - std::vector test_indices = stratified_indices[nFold]; - std::vector train_indices; - for (int i = 0; i < k; ++i) { - if (i == nFold) continue; - train_indices.insert(train_indices.end(), stratified_indices[i].begin(), stratified_indices[i].end()); - } - return { train_indices, test_indices }; - } - inline bool isFaulty() { return faulty; } }; } \ No newline at end of file diff --git a/tests/TestFolding.cc b/tests/TestFolding.cc index 99efe0c..844201c 100644 --- a/tests/TestFolding.cc +++ b/tests/TestFolding.cc @@ -17,7 +17,7 @@ TEST_CASE("Version Test", "[Folding]") TEST_CASE("KFold Test", "[Folding]") { // Initialize a KFold object with k=3,5,7,10 and a seed of 19. - std::string file_name = GENERATE("iris", "diabetes", "glass", "mfeat-fourier"); + std::string file_name = GENERATE("iris", "diabetes", "glass");//, "mfeat-fourier"); auto raw = RawDatasets(file_name, true); INFO("File Name: " << file_name); int nFolds = GENERATE(3, 5, 7, 10); @@ -66,7 +66,7 @@ TEST_CASE("KFold Test", "[Folding]") TEST_CASE("StratifiedKFold Test", "[Folding]") { // Initialize a StratifiedKFold object with k=3, using the y std::vector, and a seed of 17. - std::string file_name = GENERATE("iris", "diabetes", "glass", "mfeat-fourier"); + std::string file_name = GENERATE("iris", "diabetes", "glass");//, "mfeat-fourier"); INFO("File Name: " << file_name); int nFolds = GENERATE(3, 5, 7, 10); INFO("Number of Folds: " << nFolds);