Add tests for the quiet parameter and fix initialization mistake
This commit is contained in:
10
folding.hpp
10
folding.hpp
@@ -11,7 +11,7 @@
|
||||
#include <random>
|
||||
#include <vector>
|
||||
namespace folding {
|
||||
const std::string FOLDING_VERSION = "1.1.0";
|
||||
const std::string FOLDING_VERSION = "1.1.1";
|
||||
class Fold {
|
||||
public:
|
||||
inline Fold(int k, int n, int seed = -1) : k(k), n(n), seed(seed)
|
||||
@@ -63,12 +63,14 @@ namespace folding {
|
||||
{
|
||||
this->y = y;
|
||||
n = y.size();
|
||||
this->quiet = quiet;
|
||||
build();
|
||||
}
|
||||
inline StratifiedKFold(int k, torch::Tensor& y, int seed = -1) : Fold(k, y.numel(), seed)
|
||||
inline StratifiedKFold(int k, torch::Tensor& y, int seed = -1, bool quiet = true) : Fold(k, y.numel(), seed)
|
||||
{
|
||||
n = y.numel();
|
||||
this->y = std::vector<int>(y.data_ptr<int>(), y.data_ptr<int>() + n);
|
||||
this->quiet = quiet;
|
||||
build();
|
||||
}
|
||||
|
||||
@@ -90,6 +92,7 @@ namespace folding {
|
||||
std::vector<int> y;
|
||||
std::vector<std::vector<int>> stratified_indices;
|
||||
bool faulty = false; // Only true if the number of samples of any class is less than the number of folds.
|
||||
bool quiet = true; // Enable or disable warning messages
|
||||
void build()
|
||||
{
|
||||
stratified_indices = std::vector<std::vector<int>>(k);
|
||||
@@ -105,7 +108,8 @@ namespace folding {
|
||||
int num_samples_to_take = num_samples / k;
|
||||
int remainder_samples_to_take = num_samples % k;
|
||||
if (num_samples_to_take == 0) {
|
||||
std::cerr << "Warning! The number of samples in class " << label << " (" << num_samples
|
||||
if (!quiet)
|
||||
std::cerr << "Warning! The number of samples in class " << label << " (" << num_samples
|
||||
<< ") is less than the number of folds (" << k << ")." << std::endl;
|
||||
faulty = true;
|
||||
}
|
||||
|
Reference in New Issue
Block a user