Refactor folding change order public and private methods

This commit is contained in:
2024-05-11 13:34:50 +02:00
parent 306a9e1fc8
commit c7bcc10dfb
2 changed files with 38 additions and 38 deletions

View File

@@ -7,11 +7,6 @@
namespace folding {
const std::string FOLDING_VERSION = "1.1.0";
class Fold {
protected:
int k;
int n;
int seed;
std::mt19937 random_seed;
public:
inline Fold(int k, int n, int seed = -1) : k(k), n(n), seed(seed)
{
@@ -23,10 +18,13 @@ namespace folding {
virtual ~Fold() = default;
std::string version() { return FOLDING_VERSION; }
int getNumberOfFolds() { return k; }
protected:
int k;
int n;
int seed;
std::mt19937 random_seed;
};
class KFold : public Fold {
private:
std::vector<int> indices;
public:
inline KFold(int k, int n, int seed = -1) : Fold(k, n, seed), indices(std::vector<int>(n))
{
@@ -50,11 +48,42 @@ namespace folding {
}
return { train, test };
}
private:
std::vector<int> indices;
};
class StratifiedKFold : public Fold {
public:
inline StratifiedKFold(int k, const std::vector<int>& y, int seed = -1) : Fold(k, y.size(), seed)
{
this->y = y;
n = y.size();
build();
}
inline StratifiedKFold(int k, torch::Tensor& y, int seed = -1) : Fold(k, y.numel(), seed)
{
n = y.numel();
this->y = std::vector<int>(y.data_ptr<int>(), y.data_ptr<int>() + n);
build();
}
inline std::pair<std::vector<int>, std::vector<int>> getFold(int nFold) override
{
if (nFold >= k || nFold < 0) {
throw std::out_of_range("nFold (" + std::to_string(nFold) + ") must be less than k (" + std::to_string(k) + ")");
}
std::vector<int> test_indices = stratified_indices[nFold];
std::vector<int> train_indices;
for (int i = 0; i < k; ++i) {
if (i == nFold) continue;
train_indices.insert(train_indices.end(), stratified_indices[i].begin(), stratified_indices[i].end());
}
return { train_indices, test_indices };
}
inline bool isFaulty() { return faulty; }
private:
std::vector<int> y;
std::vector<std::vector<int>> stratified_indices;
bool faulty = false; // Only true if the number of samples of any class is less than the number of folds.
void build()
{
stratified_indices = std::vector<std::vector<int>>(k);
@@ -99,34 +128,5 @@ namespace folding {
}
}
}
bool faulty = false; // Only true if the number of samples of any class is less than the number of folds.
public:
inline StratifiedKFold(int k, const std::vector<int>& y, int seed = -1) : Fold(k, y.size(), seed)
{
this->y = y;
n = y.size();
build();
}
inline StratifiedKFold(int k, torch::Tensor& y, int seed = -1) : Fold(k, y.numel(), seed)
{
n = y.numel();
this->y = std::vector<int>(y.data_ptr<int>(), y.data_ptr<int>() + n);
build();
}
inline std::pair<std::vector<int>, std::vector<int>> getFold(int nFold) override
{
if (nFold >= k || nFold < 0) {
throw std::out_of_range("nFold (" + std::to_string(nFold) + ") must be less than k (" + std::to_string(k) + ")");
}
std::vector<int> test_indices = stratified_indices[nFold];
std::vector<int> train_indices;
for (int i = 0; i < k; ++i) {
if (i == nFold) continue;
train_indices.insert(train_indices.end(), stratified_indices[i].begin(), stratified_indices[i].end());
}
return { train_indices, test_indices };
}
inline bool isFaulty() { return faulty; }
};
}