Mark inline methods to support multiple Translation Units
This commit is contained in:
14
folding.hpp
14
folding.hpp
@@ -12,7 +12,7 @@ namespace folding {
|
||||
int seed;
|
||||
std::default_random_engine random_seed;
|
||||
public:
|
||||
Fold(int k, int n, int seed = -1) : k(k), n(n), seed(seed)
|
||||
inline Fold(int k, int n, int seed = -1) : k(k), n(n), seed(seed)
|
||||
{
|
||||
std::random_device rd;
|
||||
random_seed = std::default_random_engine(seed == -1 ? rd() : seed);
|
||||
@@ -26,12 +26,12 @@ namespace folding {
|
||||
private:
|
||||
std::vector<int> indices;
|
||||
public:
|
||||
KFold(int k, int n, int seed = -1) : Fold(k, n, seed), indices(std::vector<int>(n))
|
||||
inline KFold(int k, int n, int seed = -1) : Fold(k, n, seed), indices(std::vector<int>(n))
|
||||
{
|
||||
std::iota(begin(indices), end(indices), 0); // fill with 0, 1, ..., n - 1
|
||||
shuffle(indices.begin(), indices.end(), random_seed);
|
||||
}
|
||||
std::pair<std::vector<int>, std::vector<int>> getFold(int nFold) override
|
||||
inline std::pair<std::vector<int>, std::vector<int>> getFold(int nFold) override
|
||||
{
|
||||
if (nFold >= k || nFold < 0) {
|
||||
throw std::out_of_range("nFold (" + std::to_string(nFold) + ") must be less than k (" + std::to_string(k) + ")");
|
||||
@@ -100,20 +100,20 @@ namespace folding {
|
||||
}
|
||||
bool faulty = false; // Only true if the number of samples of any class is less than the number of folds.
|
||||
public:
|
||||
StratifiedKFold(int k, const std::vector<int>& y, int seed = -1) : Fold(k, y.size(), seed)
|
||||
inline StratifiedKFold(int k, const std::vector<int>& y, int seed = -1) : Fold(k, y.size(), seed)
|
||||
{
|
||||
this->y = y;
|
||||
n = y.size();
|
||||
build();
|
||||
}
|
||||
StratifiedKFold(int k, torch::Tensor& y, int seed = -1) : Fold(k, y.numel(), seed)
|
||||
inline StratifiedKFold(int k, torch::Tensor& y, int seed = -1) : Fold(k, y.numel(), seed)
|
||||
{
|
||||
n = y.numel();
|
||||
this->y = std::vector<int>(y.data_ptr<int>(), y.data_ptr<int>() + n);
|
||||
build();
|
||||
}
|
||||
|
||||
std::pair<std::vector<int>, std::vector<int>> getFold(int nFold) override
|
||||
inline std::pair<std::vector<int>, std::vector<int>> getFold(int nFold) override
|
||||
{
|
||||
if (nFold >= k || nFold < 0) {
|
||||
throw std::out_of_range("nFold (" + std::to_string(nFold) + ") must be less than k (" + std::to_string(k) + ")");
|
||||
@@ -126,6 +126,6 @@ namespace folding {
|
||||
}
|
||||
return { train_indices, test_indices };
|
||||
}
|
||||
bool isFaulty() { return faulty; }
|
||||
inline bool isFaulty() { return faulty; }
|
||||
};
|
||||
}
|
Reference in New Issue
Block a user