diff --git a/folding.hpp b/folding.hpp index b4152ac..1e312cc 100644 --- a/folding.hpp +++ b/folding.hpp @@ -87,31 +87,25 @@ namespace folding { void build() { stratified_indices = std::vector>(k); - int fold_size = n / k; - // Compute class counts and indices auto class_indices = std::map>(); - std::vector class_counts(*max_element(y.begin(), y.end()) + 1, 0); for (auto i = 0; i < n; ++i) { - class_counts[y[i]]++; class_indices[y[i]].push_back(i); } - // Shuffle class indices - for (auto& [cls, indices] : class_indices) { - shuffle(indices.begin(), indices.end(), random_seed); - } // Assign indices to folds - for (auto label = 0; label < class_counts.size(); ++label) { - auto num_samples_to_take = class_counts.at(label) / k; + for (auto& [label, indices] : class_indices) { + shuffle(indices.begin(), indices.end(), random_seed); + int num_samples = indices.size(); + int num_samples_to_take = num_samples / k; + int remainder_samples_to_take = num_samples % k; if (num_samples_to_take == 0) { - std::cerr << "Warning! The number of samples in class " << label << " (" << class_counts.at(label) + std::cerr << "Warning! The number of samples in class " << label << " (" << num_samples << ") is less than the number of folds (" << k << ")." << std::endl; faulty = true; } - auto remainder_samples_to_take = class_counts[label] % k; for (auto fold = 0; fold < k; ++fold) { auto it = next(class_indices[label].begin(), num_samples_to_take); - move(class_indices[label].begin(), it, back_inserter(stratified_indices[fold])); // ## + move(class_indices[label].begin(), it, back_inserter(stratified_indices[fold])); class_indices[label].erase(class_indices[label].begin(), it); } auto chosen = std::vector(k, false); diff --git a/tests/TestFolding.cc b/tests/TestFolding.cc index 844201c..1f009ff 100644 --- a/tests/TestFolding.cc +++ b/tests/TestFolding.cc @@ -17,7 +17,7 @@ TEST_CASE("Version Test", "[Folding]") TEST_CASE("KFold Test", "[Folding]") { // Initialize a KFold object with k=3,5,7,10 and a seed of 19. - std::string file_name = GENERATE("iris", "diabetes", "glass");//, "mfeat-fourier"); + std::string file_name = GENERATE("iris", "diabetes", "glass"); //, "mfeat-fourier"); auto raw = RawDatasets(file_name, true); INFO("File Name: " << file_name); int nFolds = GENERATE(3, 5, 7, 10); @@ -66,7 +66,7 @@ TEST_CASE("KFold Test", "[Folding]") TEST_CASE("StratifiedKFold Test", "[Folding]") { // Initialize a StratifiedKFold object with k=3, using the y std::vector, and a seed of 17. - std::string file_name = GENERATE("iris", "diabetes", "glass");//, "mfeat-fourier"); + std::string file_name = GENERATE("iris", "diabetes", "glass"); //, "mfeat-fourier"); INFO("File Name: " << file_name); int nFolds = GENERATE(3, 5, 7, 10); INFO("Number of Folds: " << nFolds);