From 37316a54e0d558555ae02ae95c8bb083ec063874 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Mon, 8 Jan 2024 17:41:54 +0100 Subject: [PATCH] Mark inline methods to support multiple Translation Units --- folding.hpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/folding.hpp b/folding.hpp index 0e49b25..9c8f264 100644 --- a/folding.hpp +++ b/folding.hpp @@ -12,7 +12,7 @@ namespace folding { int seed; std::default_random_engine random_seed; public: - Fold(int k, int n, int seed = -1) : k(k), n(n), seed(seed) + inline Fold(int k, int n, int seed = -1) : k(k), n(n), seed(seed) { std::random_device rd; random_seed = std::default_random_engine(seed == -1 ? rd() : seed); @@ -26,12 +26,12 @@ namespace folding { private: std::vector indices; public: - KFold(int k, int n, int seed = -1) : Fold(k, n, seed), indices(std::vector(n)) + inline KFold(int k, int n, int seed = -1) : Fold(k, n, seed), indices(std::vector(n)) { std::iota(begin(indices), end(indices), 0); // fill with 0, 1, ..., n - 1 shuffle(indices.begin(), indices.end(), random_seed); } - std::pair, std::vector> getFold(int nFold) override + inline std::pair, std::vector> getFold(int nFold) override { if (nFold >= k || nFold < 0) { throw std::out_of_range("nFold (" + std::to_string(nFold) + ") must be less than k (" + std::to_string(k) + ")"); @@ -100,20 +100,20 @@ namespace folding { } bool faulty = false; // Only true if the number of samples of any class is less than the number of folds. public: - StratifiedKFold(int k, const std::vector& y, int seed = -1) : Fold(k, y.size(), seed) + inline StratifiedKFold(int k, const std::vector& y, int seed = -1) : Fold(k, y.size(), seed) { this->y = y; n = y.size(); build(); } - StratifiedKFold(int k, torch::Tensor& y, int seed = -1) : Fold(k, y.numel(), seed) + inline StratifiedKFold(int k, torch::Tensor& y, int seed = -1) : Fold(k, y.numel(), seed) { n = y.numel(); this->y = std::vector(y.data_ptr(), y.data_ptr() + n); build(); } - std::pair, std::vector> getFold(int nFold) override + inline std::pair, std::vector> getFold(int nFold) override { if (nFold >= k || nFold < 0) { throw std::out_of_range("nFold (" + std::to_string(nFold) + ") must be less than k (" + std::to_string(k) + ")"); @@ -126,6 +126,6 @@ namespace folding { } return { train_indices, test_indices }; } - bool isFaulty() { return faulty; } + inline bool isFaulty() { return faulty; } }; } \ No newline at end of file