method defiinition inside class

This commit is contained in:
2024-01-08 10:46:16 +01:00
parent a3a2977996
commit ce0f0fa91c

View File

@@ -12,7 +12,12 @@ namespace folding {
int seed; int seed;
std::default_random_engine random_seed; std::default_random_engine random_seed;
public: public:
Fold(int k, int n, int seed = -1); Fold(int k, int n, int seed = -1) : k(k), n(n), seed(seed)
{
std::random_device rd;
random_seed = std::default_random_engine(seed == -1 ? rd() : seed);
std::srand(seed == -1 ? time(0) : seed);
}
virtual std::pair<std::vector<int>, std::vector<int>> getFold(int nFold) = 0; virtual std::pair<std::vector<int>, std::vector<int>> getFold(int nFold) = 0;
virtual ~Fold() = default; virtual ~Fold() = default;
int getNumberOfFolds() { return k; } int getNumberOfFolds() { return k; }
@@ -21,33 +26,12 @@ namespace folding {
private: private:
std::vector<int> indices; std::vector<int> indices;
public: public:
KFold(int k, int n, int seed = -1); KFold(int k, int n, int seed = -1) : Fold(k, n, seed), indices(std::vector<int>(n))
std::pair<std::vector<int>, std::vector<int>> getFold(int nFold) override;
};
class StratifiedKFold : public Fold {
private:
std::vector<int> y;
std::vector<std::vector<int>> stratified_indices;
void build();
bool faulty = false; // Only true if the number of samples of any class is less than the number of folds.
public:
StratifiedKFold(int k, const std::vector<int>& y, int seed = -1);
StratifiedKFold(int k, torch::Tensor& y, int seed = -1);
std::pair<std::vector<int>, std::vector<int>> getFold(int nFold) override;
bool isFaulty() { return faulty; }
};
Fold::Fold(int k, int n, int seed) : k(k), n(n), seed(seed)
{
std::random_device rd;
random_seed = std::default_random_engine(seed == -1 ? rd() : seed);
std::srand(seed == -1 ? time(0) : seed);
}
KFold::KFold(int k, int n, int seed) : Fold(k, n, seed), indices(std::vector<int>(n))
{ {
std::iota(begin(indices), end(indices), 0); // fill with 0, 1, ..., n - 1 std::iota(begin(indices), end(indices), 0); // fill with 0, 1, ..., n - 1
shuffle(indices.begin(), indices.end(), random_seed); shuffle(indices.begin(), indices.end(), random_seed);
} }
std::pair<std::vector<int>, std::vector<int>> KFold::getFold(int nFold) std::pair<std::vector<int>, std::vector<int>> getFold(int nFold) override
{ {
if (nFold >= k || nFold < 0) { if (nFold >= k || nFold < 0) {
throw std::out_of_range("nFold (" + std::to_string(nFold) + ") must be less than k (" + std::to_string(k) + ")"); throw std::out_of_range("nFold (" + std::to_string(nFold) + ") must be less than k (" + std::to_string(k) + ")");
@@ -64,20 +48,12 @@ namespace folding {
} }
return { train, test }; return { train, test };
} }
StratifiedKFold::StratifiedKFold(int k, torch::Tensor& y, int seed) : Fold(k, y.numel(), seed) };
{ class StratifiedKFold : public Fold {
n = y.numel(); private:
this->y = std::vector<int>(y.data_ptr<int>(), y.data_ptr<int>() + n); std::vector<int> y;
build(); std::vector<std::vector<int>> stratified_indices;
} void build()
StratifiedKFold::StratifiedKFold(int k, const std::vector<int>& y, int seed)
: Fold(k, y.size(), seed)
{
this->y = y;
n = y.size();
build();
}
void StratifiedKFold::build()
{ {
stratified_indices = std::vector<std::vector<int>>(k); stratified_indices = std::vector<std::vector<int>>(k);
int fold_size = n / k; int fold_size = n / k;
@@ -122,7 +98,22 @@ namespace folding {
} }
} }
} }
std::pair<std::vector<int>, std::vector<int>> StratifiedKFold::getFold(int nFold) bool faulty = false; // Only true if the number of samples of any class is less than the number of folds.
public:
StratifiedKFold(int k, const std::vector<int>& y, int seed = -1) : Fold(k, y.size(), seed)
{
this->y = y;
n = y.size();
build();
}
StratifiedKFold(int k, torch::Tensor& y, int seed = -1) : Fold(k, y.numel(), seed)
{
n = y.numel();
this->y = std::vector<int>(y.data_ptr<int>(), y.data_ptr<int>() + n);
build();
}
std::pair<std::vector<int>, std::vector<int>> getFold(int nFold) override
{ {
if (nFold >= k || nFold < 0) { if (nFold >= k || nFold < 0) {
throw std::out_of_range("nFold (" + std::to_string(nFold) + ") must be less than k (" + std::to_string(k) + ")"); throw std::out_of_range("nFold (" + std::to_string(nFold) + ") must be less than k (" + std::to_string(k) + ")");
@@ -135,4 +126,6 @@ namespace folding {
} }
return { train_indices, test_indices }; return { train_indices, test_indices };
} }
bool isFaulty() { return faulty; }
};
} }