Complete Stratified K Fold

This commit is contained in:
Ricardo Montañana Gómez 2023-07-22 11:23:35 +02:00
parent f6e154bc6e
commit 41cceece20
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
5 changed files with 88 additions and 86 deletions

View File

@ -98,7 +98,8 @@
"queue": "cpp", "queue": "cpp",
"typeindex": "cpp", "typeindex": "cpp",
"shared_mutex": "cpp", "shared_mutex": "cpp",
"*.ipp": "cpp" "*.ipp": "cpp",
"cassert": "cpp"
}, },
"cmake.configureOnOpen": false, "cmake.configureOnOpen": false,
"C_Cpp.default.configurationProvider": "ms-vscode.cmake-tools" "C_Cpp.default.configurationProvider": "ms-vscode.cmake-tools"

View File

@ -34,44 +34,44 @@ pair<vector<int>, vector<int>> KFold::getFold(int nFold)
StratifiedKFold::StratifiedKFold(int k, const vector<int>& y, int seed) : StratifiedKFold::StratifiedKFold(int k, const vector<int>& y, int seed) :
k(k), seed(seed) k(k), seed(seed)
{ {
// n = y.size();
// map<int, vector<int>> class_to_indices;
// for (int i = 0; i < n; ++i) {
// class_to_indices[y[i]].push_back(i);
// }
// random_device rd;
// default_random_engine random_seed(seed == -1 ? rd() : seed);
// for (auto& [cls, indices] : class_to_indices) {
// shuffle(indices.begin(), indices.end(), random_seed);
// int fold_size = n / k;
// for (int i = 0; i < k; ++i) {
// int start = i * fold_size;
// int end = (i == k - 1) ? indices.size() : (i + 1) * fold_size;
// stratified_indices.emplace_back(indices.begin() + start, indices.begin() + end);
// }
// }
n = y.size(); n = y.size();
stratified_indices.resize(k); stratified_indices = vector<vector<int>>(k);
int fold_size = n / k;
int remainder = n % k;
// Compute class counts and indices
auto class_indices = map<int, vector<int>>();
vector<int> class_counts(*max_element(y.begin(), y.end()) + 1, 0); vector<int> class_counts(*max_element(y.begin(), y.end()) + 1, 0);
for (auto i = 0; i < n; ++i) { for (auto i = 0; i < n; ++i) {
class_counts[y[i]]++; class_counts[y[i]]++;
class_indices[y[i]].push_back(i);
} }
vector<int> class_starts(class_counts.size()); // Shuffle class indices
partial_sum(class_counts.begin(), class_counts.end() - 1, class_starts.begin() + 1); random_device rd;
vector<int> indices(n); default_random_engine random_seed(seed == -1 ? rd() : seed);
for (auto i = 0; i < n; ++i) { for (auto& [cls, indices] : class_indices) {
int label = y[i]; shuffle(indices.begin(), indices.end(), random_seed);
stratified_indices[class_starts[label]] = i;
class_starts[label]++;
} }
int fold_size = n / k; // Assign indices to folds
int remainder = n % k; for (auto label = 0; label < class_counts.size(); ++label) {
int start = 0; auto num_samples_to_take = class_counts[label] / k;
for (auto i = 0; i < k; ++i) { if (num_samples_to_take == 0)
int fold_length = fold_size + (i < remainder ? 1 : 0); continue;
stratified_indices[i].resize(fold_length); auto remainder_samples_to_take = class_counts[label] % k;
copy(indices.begin() + start, indices.begin() + start + fold_length, stratified_indices[i].begin()); for (auto fold = 0; fold < k; ++fold) {
start += fold_length; auto it = next(class_indices[label].begin(), num_samples_to_take);
move(class_indices[label].begin(), it, back_inserter(stratified_indices[fold])); // ##
class_indices[label].erase(class_indices[label].begin(), it);
}
while (remainder_samples_to_take > 0) {
int fold = (rand() % static_cast<int>(k));
if (stratified_indices[fold].size() == fold_size) {
continue;
}
auto it = next(class_indices[label].begin(), 1);
stratified_indices[fold].push_back(*class_indices[label].begin());
class_indices[label].erase(class_indices[label].begin(), it);
remainder_samples_to_take--;
}
} }
} }
pair<vector<int>, vector<int>> StratifiedKFold::getFold(int nFold) pair<vector<int>, vector<int>> StratifiedKFold::getFold(int nFold)

View File

@ -6,8 +6,8 @@ class KFold {
private: private:
int k; int k;
int n; int n;
vector<int> indices;
int seed; int seed;
vector<int> indices;
public: public:
KFold(int k, int n, int seed = -1); KFold(int k, int n, int seed = -1);
pair<vector<int>, vector<int>> getFold(int nFold); pair<vector<int>, vector<int>> getFold(int nFold);
@ -16,8 +16,8 @@ class StratifiedKFold {
private: private:
int k; int k;
int n; int n;
int seed;
vector<vector<int>> stratified_indices; vector<vector<int>> stratified_indices;
unsigned seed;
public: public:
StratifiedKFold(int k, const vector<int>& y, int seed = -1); StratifiedKFold(int k, const vector<int>& y, int seed = -1);
pair<vector<int>, vector<int>> getFold(int nFold); pair<vector<int>, vector<int>> getFold(int nFold);

BIN
src/Platform/m Executable file

Binary file not shown.

View File

@ -4,71 +4,72 @@
using namespace std; using namespace std;
class A { class A {
private: private:
int a; int a;
public: public:
A(int a) : a(a) {} A(int a) : a(a) {}
int getA() { return a; } int getA() { return a; }
}; };
class B : public A { class B : public A {
private: private:
int b; int b;
public: public:
B(int a, int b) : A(a), b(b) {} B(int a, int b) : A(a), b(b) {}
int getB() { return b; } int getB() { return b; }
}; };
class C : public A { class C : public A {
private: private:
int c; int c;
public: public:
C(int a, int c) : A(a), c(c) {} C(int a, int c) : A(a), c(c) {}
int getC() { return c; } int getC() { return c; }
}; };
string counts(vector<int> y, vector<int> indices) string counts(vector<int> y, vector<int> indices)
{ {
auto result = map<int, int>(); auto result = map<int, int>();
for (auto i = 0; i < indices.size(); ++i) { for (auto i = 0; i < indices.size(); ++i) {
result[y[indices[i]]]++; result[y[indices[i]]]++;
} }
string final_result = ""; string final_result = "";
for (auto i = 0; i < result.size(); ++i) for (auto i = 0; i < result.size(); ++i)
final_result += to_string(i) + " -> " + to_string(result[i]) + " // "; final_result += to_string(i) + " -> " + to_string(result[i]) + " // ";
final_result += "\n"; final_result += "\n";
return final_result; return final_result;
} }
int main() int main()
{ {
auto y = vector<int>(150); auto y = vector<int>(153);
fill(y.begin(), y.begin() + 50, 0); fill(y.begin(), y.begin() + 50, 0);
fill(y.begin() + 50, y.begin() + 100, 1); fill(y.begin() + 50, y.begin() + 103, 1);
fill(y.begin() + 100, y.end(), 2); fill(y.begin() + 103, y.end(), 2);
//auto fold = KFold(5, 150); //auto fold = KFold(5, 150);
auto fold = StratifiedKFold(5, y, 0); auto fold = StratifiedKFold(5, y, -1);
for (int i = 0; i < 5; ++i) { for (int i = 0; i < 5; ++i) {
cout << "Fold: " << i << endl; cout << "Fold: " << i << endl;
auto [train, test] = fold.getFold(i); auto [train, test] = fold.getFold(i);
cout << "Train: "; cout << "Train: ";
cout << "(" << train.size() << "): "; cout << "(" << train.size() << "): ";
for (auto j = 0; j < static_cast<int>(train.size()); j++) for (auto j = 0; j < static_cast<int>(train.size()); j++)
cout << train[j] << ", "; cout << train[j] << ", ";
cout << endl; cout << endl;
cout << "Train Statistics : " << counts(y, train); cout << "Train Statistics : " << counts(y, train);
cout << "-------------------------------------------------------------------------------" << endl; cout << "-------------------------------------------------------------------------------" << endl;
cout << "Test: "; cout << "Test: ";
cout << "(" << test.size() << "): "; cout << "(" << test.size() << "): ";
for (auto j = 0; j < static_cast<int>(test.size()); j++) for (auto j = 0; j < static_cast<int>(test.size()); j++)
cout << test[j] << ", "; cout << test[j] << ", ";
cout << endl; cout << endl;
cout << "Test Statistics: " << counts(y, test); cout << "Test Statistics: " << counts(y, test);
cout << "==============================================================================" << endl; cout << "==============================================================================" << endl;
// cout << "Vector poly" << endl; // cout << "Vector poly" << endl;
// auto some = vector<A>(); // auto some = vector<A>();
// auto cx = C(5, 4); // auto cx = C(5, 4);
// auto bx = B(7, 6); // auto bx = B(7, 6);
// some.push_back(cx); // some.push_back(cx);
// some.push_back(bx); // some.push_back(bx);
// for (auto& obj : some) { // for (auto& obj : some) {
// cout << "Obj :" << obj.getA() << endl; // cout << "Obj :" << obj.getA() << endl;
// } // }
} }
} }