Refactor stratified build removing uneeded structures
This commit is contained in:
20
folding.hpp
20
folding.hpp
@@ -87,31 +87,25 @@ namespace folding {
|
|||||||
void build()
|
void build()
|
||||||
{
|
{
|
||||||
stratified_indices = std::vector<std::vector<int>>(k);
|
stratified_indices = std::vector<std::vector<int>>(k);
|
||||||
int fold_size = n / k;
|
|
||||||
|
|
||||||
// Compute class counts and indices
|
// Compute class counts and indices
|
||||||
auto class_indices = std::map<int, std::vector<int>>();
|
auto class_indices = std::map<int, std::vector<int>>();
|
||||||
std::vector<int> class_counts(*max_element(y.begin(), y.end()) + 1, 0);
|
|
||||||
for (auto i = 0; i < n; ++i) {
|
for (auto i = 0; i < n; ++i) {
|
||||||
class_counts[y[i]]++;
|
|
||||||
class_indices[y[i]].push_back(i);
|
class_indices[y[i]].push_back(i);
|
||||||
}
|
}
|
||||||
// Shuffle class indices
|
|
||||||
for (auto& [cls, indices] : class_indices) {
|
|
||||||
shuffle(indices.begin(), indices.end(), random_seed);
|
|
||||||
}
|
|
||||||
// Assign indices to folds
|
// Assign indices to folds
|
||||||
for (auto label = 0; label < class_counts.size(); ++label) {
|
for (auto& [label, indices] : class_indices) {
|
||||||
auto num_samples_to_take = class_counts.at(label) / k;
|
shuffle(indices.begin(), indices.end(), random_seed);
|
||||||
|
int num_samples = indices.size();
|
||||||
|
int num_samples_to_take = num_samples / k;
|
||||||
|
int remainder_samples_to_take = num_samples % k;
|
||||||
if (num_samples_to_take == 0) {
|
if (num_samples_to_take == 0) {
|
||||||
std::cerr << "Warning! The number of samples in class " << label << " (" << class_counts.at(label)
|
std::cerr << "Warning! The number of samples in class " << label << " (" << num_samples
|
||||||
<< ") is less than the number of folds (" << k << ")." << std::endl;
|
<< ") is less than the number of folds (" << k << ")." << std::endl;
|
||||||
faulty = true;
|
faulty = true;
|
||||||
}
|
}
|
||||||
auto remainder_samples_to_take = class_counts[label] % k;
|
|
||||||
for (auto fold = 0; fold < k; ++fold) {
|
for (auto fold = 0; fold < k; ++fold) {
|
||||||
auto it = next(class_indices[label].begin(), num_samples_to_take);
|
auto it = next(class_indices[label].begin(), num_samples_to_take);
|
||||||
move(class_indices[label].begin(), it, back_inserter(stratified_indices[fold])); // ##
|
move(class_indices[label].begin(), it, back_inserter(stratified_indices[fold]));
|
||||||
class_indices[label].erase(class_indices[label].begin(), it);
|
class_indices[label].erase(class_indices[label].begin(), it);
|
||||||
}
|
}
|
||||||
auto chosen = std::vector<bool>(k, false);
|
auto chosen = std::vector<bool>(k, false);
|
||||||
|
@@ -17,7 +17,7 @@ TEST_CASE("Version Test", "[Folding]")
|
|||||||
TEST_CASE("KFold Test", "[Folding]")
|
TEST_CASE("KFold Test", "[Folding]")
|
||||||
{
|
{
|
||||||
// Initialize a KFold object with k=3,5,7,10 and a seed of 19.
|
// Initialize a KFold object with k=3,5,7,10 and a seed of 19.
|
||||||
std::string file_name = GENERATE("iris", "diabetes", "glass");//, "mfeat-fourier");
|
std::string file_name = GENERATE("iris", "diabetes", "glass"); //, "mfeat-fourier");
|
||||||
auto raw = RawDatasets(file_name, true);
|
auto raw = RawDatasets(file_name, true);
|
||||||
INFO("File Name: " << file_name);
|
INFO("File Name: " << file_name);
|
||||||
int nFolds = GENERATE(3, 5, 7, 10);
|
int nFolds = GENERATE(3, 5, 7, 10);
|
||||||
@@ -66,7 +66,7 @@ TEST_CASE("KFold Test", "[Folding]")
|
|||||||
TEST_CASE("StratifiedKFold Test", "[Folding]")
|
TEST_CASE("StratifiedKFold Test", "[Folding]")
|
||||||
{
|
{
|
||||||
// Initialize a StratifiedKFold object with k=3, using the y std::vector, and a seed of 17.
|
// Initialize a StratifiedKFold object with k=3, using the y std::vector, and a seed of 17.
|
||||||
std::string file_name = GENERATE("iris", "diabetes", "glass");//, "mfeat-fourier");
|
std::string file_name = GENERATE("iris", "diabetes", "glass"); //, "mfeat-fourier");
|
||||||
INFO("File Name: " << file_name);
|
INFO("File Name: " << file_name);
|
||||||
int nFolds = GENERATE(3, 5, 7, 10);
|
int nFolds = GENERATE(3, 5, 7, 10);
|
||||||
INFO("Number of Folds: " << nFolds);
|
INFO("Number of Folds: " << nFolds);
|
||||||
|
Reference in New Issue
Block a user