Mark inline methods to support multiple Translation Units

This commit is contained in:
2024-01-08 17:41:54 +01:00
parent ce0f0fa91c
commit 37316a54e0

View File

@@ -12,7 +12,7 @@ namespace folding {
int seed; int seed;
std::default_random_engine random_seed; std::default_random_engine random_seed;
public: public:
Fold(int k, int n, int seed = -1) : k(k), n(n), seed(seed) inline Fold(int k, int n, int seed = -1) : k(k), n(n), seed(seed)
{ {
std::random_device rd; std::random_device rd;
random_seed = std::default_random_engine(seed == -1 ? rd() : seed); random_seed = std::default_random_engine(seed == -1 ? rd() : seed);
@@ -26,12 +26,12 @@ namespace folding {
private: private:
std::vector<int> indices; std::vector<int> indices;
public: public:
KFold(int k, int n, int seed = -1) : Fold(k, n, seed), indices(std::vector<int>(n)) inline KFold(int k, int n, int seed = -1) : Fold(k, n, seed), indices(std::vector<int>(n))
{ {
std::iota(begin(indices), end(indices), 0); // fill with 0, 1, ..., n - 1 std::iota(begin(indices), end(indices), 0); // fill with 0, 1, ..., n - 1
shuffle(indices.begin(), indices.end(), random_seed); shuffle(indices.begin(), indices.end(), random_seed);
} }
std::pair<std::vector<int>, std::vector<int>> getFold(int nFold) override inline std::pair<std::vector<int>, std::vector<int>> getFold(int nFold) override
{ {
if (nFold >= k || nFold < 0) { if (nFold >= k || nFold < 0) {
throw std::out_of_range("nFold (" + std::to_string(nFold) + ") must be less than k (" + std::to_string(k) + ")"); throw std::out_of_range("nFold (" + std::to_string(nFold) + ") must be less than k (" + std::to_string(k) + ")");
@@ -100,20 +100,20 @@ namespace folding {
} }
bool faulty = false; // Only true if the number of samples of any class is less than the number of folds. bool faulty = false; // Only true if the number of samples of any class is less than the number of folds.
public: public:
StratifiedKFold(int k, const std::vector<int>& y, int seed = -1) : Fold(k, y.size(), seed) inline StratifiedKFold(int k, const std::vector<int>& y, int seed = -1) : Fold(k, y.size(), seed)
{ {
this->y = y; this->y = y;
n = y.size(); n = y.size();
build(); build();
} }
StratifiedKFold(int k, torch::Tensor& y, int seed = -1) : Fold(k, y.numel(), seed) inline StratifiedKFold(int k, torch::Tensor& y, int seed = -1) : Fold(k, y.numel(), seed)
{ {
n = y.numel(); n = y.numel();
this->y = std::vector<int>(y.data_ptr<int>(), y.data_ptr<int>() + n); this->y = std::vector<int>(y.data_ptr<int>(), y.data_ptr<int>() + n);
build(); build();
} }
std::pair<std::vector<int>, std::vector<int>> getFold(int nFold) override inline std::pair<std::vector<int>, std::vector<int>> getFold(int nFold) override
{ {
if (nFold >= k || nFold < 0) { if (nFold >= k || nFold < 0) {
throw std::out_of_range("nFold (" + std::to_string(nFold) + ") must be less than k (" + std::to_string(k) + ")"); throw std::out_of_range("nFold (" + std::to_string(nFold) + ") must be less than k (" + std::to_string(k) + ")");
@@ -126,6 +126,6 @@ namespace folding {
} }
return { train_indices, test_indices }; return { train_indices, test_indices };
} }
bool isFaulty() { return faulty; } inline bool isFaulty() { return faulty; }
}; };
} }