Refactor folding change order public and private methods
This commit is contained in:
72
folding.hpp
72
folding.hpp
@@ -7,11 +7,6 @@
|
|||||||
namespace folding {
|
namespace folding {
|
||||||
const std::string FOLDING_VERSION = "1.1.0";
|
const std::string FOLDING_VERSION = "1.1.0";
|
||||||
class Fold {
|
class Fold {
|
||||||
protected:
|
|
||||||
int k;
|
|
||||||
int n;
|
|
||||||
int seed;
|
|
||||||
std::mt19937 random_seed;
|
|
||||||
public:
|
public:
|
||||||
inline Fold(int k, int n, int seed = -1) : k(k), n(n), seed(seed)
|
inline Fold(int k, int n, int seed = -1) : k(k), n(n), seed(seed)
|
||||||
{
|
{
|
||||||
@@ -23,10 +18,13 @@ namespace folding {
|
|||||||
virtual ~Fold() = default;
|
virtual ~Fold() = default;
|
||||||
std::string version() { return FOLDING_VERSION; }
|
std::string version() { return FOLDING_VERSION; }
|
||||||
int getNumberOfFolds() { return k; }
|
int getNumberOfFolds() { return k; }
|
||||||
|
protected:
|
||||||
|
int k;
|
||||||
|
int n;
|
||||||
|
int seed;
|
||||||
|
std::mt19937 random_seed;
|
||||||
};
|
};
|
||||||
class KFold : public Fold {
|
class KFold : public Fold {
|
||||||
private:
|
|
||||||
std::vector<int> indices;
|
|
||||||
public:
|
public:
|
||||||
inline KFold(int k, int n, int seed = -1) : Fold(k, n, seed), indices(std::vector<int>(n))
|
inline KFold(int k, int n, int seed = -1) : Fold(k, n, seed), indices(std::vector<int>(n))
|
||||||
{
|
{
|
||||||
@@ -50,11 +48,42 @@ namespace folding {
|
|||||||
}
|
}
|
||||||
return { train, test };
|
return { train, test };
|
||||||
}
|
}
|
||||||
|
private:
|
||||||
|
std::vector<int> indices;
|
||||||
};
|
};
|
||||||
class StratifiedKFold : public Fold {
|
class StratifiedKFold : public Fold {
|
||||||
|
public:
|
||||||
|
inline StratifiedKFold(int k, const std::vector<int>& y, int seed = -1) : Fold(k, y.size(), seed)
|
||||||
|
{
|
||||||
|
this->y = y;
|
||||||
|
n = y.size();
|
||||||
|
build();
|
||||||
|
}
|
||||||
|
inline StratifiedKFold(int k, torch::Tensor& y, int seed = -1) : Fold(k, y.numel(), seed)
|
||||||
|
{
|
||||||
|
n = y.numel();
|
||||||
|
this->y = std::vector<int>(y.data_ptr<int>(), y.data_ptr<int>() + n);
|
||||||
|
build();
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::pair<std::vector<int>, std::vector<int>> getFold(int nFold) override
|
||||||
|
{
|
||||||
|
if (nFold >= k || nFold < 0) {
|
||||||
|
throw std::out_of_range("nFold (" + std::to_string(nFold) + ") must be less than k (" + std::to_string(k) + ")");
|
||||||
|
}
|
||||||
|
std::vector<int> test_indices = stratified_indices[nFold];
|
||||||
|
std::vector<int> train_indices;
|
||||||
|
for (int i = 0; i < k; ++i) {
|
||||||
|
if (i == nFold) continue;
|
||||||
|
train_indices.insert(train_indices.end(), stratified_indices[i].begin(), stratified_indices[i].end());
|
||||||
|
}
|
||||||
|
return { train_indices, test_indices };
|
||||||
|
}
|
||||||
|
inline bool isFaulty() { return faulty; }
|
||||||
private:
|
private:
|
||||||
std::vector<int> y;
|
std::vector<int> y;
|
||||||
std::vector<std::vector<int>> stratified_indices;
|
std::vector<std::vector<int>> stratified_indices;
|
||||||
|
bool faulty = false; // Only true if the number of samples of any class is less than the number of folds.
|
||||||
void build()
|
void build()
|
||||||
{
|
{
|
||||||
stratified_indices = std::vector<std::vector<int>>(k);
|
stratified_indices = std::vector<std::vector<int>>(k);
|
||||||
@@ -99,34 +128,5 @@ namespace folding {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
bool faulty = false; // Only true if the number of samples of any class is less than the number of folds.
|
|
||||||
public:
|
|
||||||
inline StratifiedKFold(int k, const std::vector<int>& y, int seed = -1) : Fold(k, y.size(), seed)
|
|
||||||
{
|
|
||||||
this->y = y;
|
|
||||||
n = y.size();
|
|
||||||
build();
|
|
||||||
}
|
|
||||||
inline StratifiedKFold(int k, torch::Tensor& y, int seed = -1) : Fold(k, y.numel(), seed)
|
|
||||||
{
|
|
||||||
n = y.numel();
|
|
||||||
this->y = std::vector<int>(y.data_ptr<int>(), y.data_ptr<int>() + n);
|
|
||||||
build();
|
|
||||||
}
|
|
||||||
|
|
||||||
inline std::pair<std::vector<int>, std::vector<int>> getFold(int nFold) override
|
|
||||||
{
|
|
||||||
if (nFold >= k || nFold < 0) {
|
|
||||||
throw std::out_of_range("nFold (" + std::to_string(nFold) + ") must be less than k (" + std::to_string(k) + ")");
|
|
||||||
}
|
|
||||||
std::vector<int> test_indices = stratified_indices[nFold];
|
|
||||||
std::vector<int> train_indices;
|
|
||||||
for (int i = 0; i < k; ++i) {
|
|
||||||
if (i == nFold) continue;
|
|
||||||
train_indices.insert(train_indices.end(), stratified_indices[i].begin(), stratified_indices[i].end());
|
|
||||||
}
|
|
||||||
return { train_indices, test_indices };
|
|
||||||
}
|
|
||||||
inline bool isFaulty() { return faulty; }
|
|
||||||
};
|
};
|
||||||
}
|
}
|
@@ -17,7 +17,7 @@ TEST_CASE("Version Test", "[Folding]")
|
|||||||
TEST_CASE("KFold Test", "[Folding]")
|
TEST_CASE("KFold Test", "[Folding]")
|
||||||
{
|
{
|
||||||
// Initialize a KFold object with k=3,5,7,10 and a seed of 19.
|
// Initialize a KFold object with k=3,5,7,10 and a seed of 19.
|
||||||
std::string file_name = GENERATE("iris", "diabetes", "glass", "mfeat-fourier");
|
std::string file_name = GENERATE("iris", "diabetes", "glass");//, "mfeat-fourier");
|
||||||
auto raw = RawDatasets(file_name, true);
|
auto raw = RawDatasets(file_name, true);
|
||||||
INFO("File Name: " << file_name);
|
INFO("File Name: " << file_name);
|
||||||
int nFolds = GENERATE(3, 5, 7, 10);
|
int nFolds = GENERATE(3, 5, 7, 10);
|
||||||
@@ -66,7 +66,7 @@ TEST_CASE("KFold Test", "[Folding]")
|
|||||||
TEST_CASE("StratifiedKFold Test", "[Folding]")
|
TEST_CASE("StratifiedKFold Test", "[Folding]")
|
||||||
{
|
{
|
||||||
// Initialize a StratifiedKFold object with k=3, using the y std::vector, and a seed of 17.
|
// Initialize a StratifiedKFold object with k=3, using the y std::vector, and a seed of 17.
|
||||||
std::string file_name = GENERATE("iris", "diabetes", "glass", "mfeat-fourier");
|
std::string file_name = GENERATE("iris", "diabetes", "glass");//, "mfeat-fourier");
|
||||||
INFO("File Name: " << file_name);
|
INFO("File Name: " << file_name);
|
||||||
int nFolds = GENERATE(3, 5, 7, 10);
|
int nFolds = GENERATE(3, 5, 7, 10);
|
||||||
INFO("Number of Folds: " << nFolds);
|
INFO("Number of Folds: " << nFolds);
|
||||||
|
Reference in New Issue
Block a user