Mark inline methods to support multiple Translation Units

This commit is contained in:
2024-01-08 17:41:54 +01:00
parent ce0f0fa91c
commit 37316a54e0

View File

@@ -12,7 +12,7 @@ namespace folding {
int seed;
std::default_random_engine random_seed;
public:
Fold(int k, int n, int seed = -1) : k(k), n(n), seed(seed)
inline Fold(int k, int n, int seed = -1) : k(k), n(n), seed(seed)
{
std::random_device rd;
random_seed = std::default_random_engine(seed == -1 ? rd() : seed);
@@ -26,12 +26,12 @@ namespace folding {
private:
std::vector<int> indices;
public:
KFold(int k, int n, int seed = -1) : Fold(k, n, seed), indices(std::vector<int>(n))
inline KFold(int k, int n, int seed = -1) : Fold(k, n, seed), indices(std::vector<int>(n))
{
std::iota(begin(indices), end(indices), 0); // fill with 0, 1, ..., n - 1
shuffle(indices.begin(), indices.end(), random_seed);
}
std::pair<std::vector<int>, std::vector<int>> getFold(int nFold) override
inline std::pair<std::vector<int>, std::vector<int>> getFold(int nFold) override
{
if (nFold >= k || nFold < 0) {
throw std::out_of_range("nFold (" + std::to_string(nFold) + ") must be less than k (" + std::to_string(k) + ")");
@@ -100,20 +100,20 @@ namespace folding {
}
bool faulty = false; // Only true if the number of samples of any class is less than the number of folds.
public:
StratifiedKFold(int k, const std::vector<int>& y, int seed = -1) : Fold(k, y.size(), seed)
inline StratifiedKFold(int k, const std::vector<int>& y, int seed = -1) : Fold(k, y.size(), seed)
{
this->y = y;
n = y.size();
build();
}
StratifiedKFold(int k, torch::Tensor& y, int seed = -1) : Fold(k, y.numel(), seed)
inline StratifiedKFold(int k, torch::Tensor& y, int seed = -1) : Fold(k, y.numel(), seed)
{
n = y.numel();
this->y = std::vector<int>(y.data_ptr<int>(), y.data_ptr<int>() + n);
build();
}
std::pair<std::vector<int>, std::vector<int>> getFold(int nFold) override
inline std::pair<std::vector<int>, std::vector<int>> getFold(int nFold) override
{
if (nFold >= k || nFold < 0) {
throw std::out_of_range("nFold (" + std::to_string(nFold) + ") must be less than k (" + std::to_string(k) + ")");
@@ -126,6 +126,6 @@ namespace folding {
}
return { train_indices, test_indices };
}
bool isFaulty() { return faulty; }
inline bool isFaulty() { return faulty; }
};
}