Begin Stratified KFold

This commit is contained in:
Ricardo Montañana Gómez 2023-07-21 21:49:02 +02:00
parent a2622a4fb6
commit f6e154bc6e
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
6 changed files with 150 additions and 61 deletions

View File

@ -4,4 +4,5 @@ include_directories(${BayesNet_SOURCE_DIR}/lib/Files)
include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp)
include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include)
add_executable(main Experiment.cc Folding.cc platformUtils.cc)
add_executable(testx testx.cpp Folding.cc)
target_link_libraries(main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}")

View File

@ -1,21 +1,23 @@
#include "Folding.h"
#include <algorithm>
#include <map>
#include <random>
using namespace std;
KFold::KFold(int k, int n, int seed)
KFold::KFold(int k, int n, int seed) : k(k), n(n), seed(seed)
{
this->k = k;
this->n = n;
indices = vector<int>(n);
iota(begin(indices), end(indices), 0); // fill with 0, 1, ..., n - 1
shuffle(indices.begin(), indices.end(), default_random_engine(seed));
random_device rd;
default_random_engine random_seed(seed == -1 ? rd() : seed);
shuffle(indices.begin(), indices.end(), random_seed);
}
pair<vector<int>, vector<int>> KFold::getFold(int nFold)
{
if (nFold >= k || nFold < 0) {
throw invalid_argument("nFold (" + to_string(nFold) + ") must be less than k (" + to_string(k) + ")");
throw out_of_range("nFold (" + to_string(nFold) + ") must be less than k (" + to_string(k) + ")");
}
int nTest = n / k;
auto train = vector<int>();
@ -28,4 +30,60 @@ pair<vector<int>, vector<int>> KFold::getFold(int nFold)
}
}
return { train, test };
}
StratifiedKFold::StratifiedKFold(int k, const vector<int>& y, int seed) :
k(k), seed(seed)
{
// n = y.size();
// map<int, vector<int>> class_to_indices;
// for (int i = 0; i < n; ++i) {
// class_to_indices[y[i]].push_back(i);
// }
// random_device rd;
// default_random_engine random_seed(seed == -1 ? rd() : seed);
// for (auto& [cls, indices] : class_to_indices) {
// shuffle(indices.begin(), indices.end(), random_seed);
// int fold_size = n / k;
// for (int i = 0; i < k; ++i) {
// int start = i * fold_size;
// int end = (i == k - 1) ? indices.size() : (i + 1) * fold_size;
// stratified_indices.emplace_back(indices.begin() + start, indices.begin() + end);
// }
// }
n = y.size();
stratified_indices.resize(k);
vector<int> class_counts(*max_element(y.begin(), y.end()) + 1, 0);
for (auto i = 0; i < n; ++i) {
class_counts[y[i]]++;
}
vector<int> class_starts(class_counts.size());
partial_sum(class_counts.begin(), class_counts.end() - 1, class_starts.begin() + 1);
vector<int> indices(n);
for (auto i = 0; i < n; ++i) {
int label = y[i];
stratified_indices[class_starts[label]] = i;
class_starts[label]++;
}
int fold_size = n / k;
int remainder = n % k;
int start = 0;
for (auto i = 0; i < k; ++i) {
int fold_length = fold_size + (i < remainder ? 1 : 0);
stratified_indices[i].resize(fold_length);
copy(indices.begin() + start, indices.begin() + start + fold_length, stratified_indices[i].begin());
start += fold_length;
}
}
pair<vector<int>, vector<int>> StratifiedKFold::getFold(int nFold)
{
if (nFold >= k || nFold < 0) {
throw out_of_range("nFold (" + to_string(nFold) + ") must be less than k (" + to_string(k) + ")");
}
vector<int> test_indices = stratified_indices[nFold];
vector<int> train_indices;
for (int i = 0; i < k; ++i) {
if (i == nFold) continue;
train_indices.insert(train_indices.end(), stratified_indices[i].begin(), stratified_indices[i].end());
}
return { train_indices, test_indices };
}

View File

@ -7,12 +7,19 @@ private:
int k;
int n;
vector<int> indices;
int seed;
public:
KFold(int k, int n, int seed);
pair<vector<int>, vector<int>> getFold(int);
KFold(int k, int n, int seed = -1);
pair<vector<int>, vector<int>> getFold(int nFold);
};
class KStratifiedFold {
class StratifiedKFold {
private:
int k;
int n;
vector<vector<int>> stratified_indices;
unsigned seed;
public:
StratifiedKFold(int k, const vector<int>& y, int seed = -1);
pair<vector<int>, vector<int>> getFold(int nFold);
};
#endif

Binary file not shown.

View File

@ -1,51 +0,0 @@
#include "Folding.h"
#include <iostream>
using namespace std;
class A {
private:
int a;
public:
A(int a) : a(a) {}
int getA() { return a; }
};
class B : public A {
private:
int b;
public:
B(int a, int b) : A(a), b(b) {}
int getB() { return b; }
};
class C : public A {
private:
int c;
public:
C(int a, int c) : A(a), c(c) {}
int getC() { return c; }
};
int main()
{
auto fold = KFold(5, 100, 1);
for (int i = 0; i < 5; ++i) {
cout << "Fold: " << i << endl;
auto [train, test] = fold.getFold(i);
cout << "Train: ";
cout << "(" << train.size() << "): ";
for (auto j = 0; j < static_cast<int>(train.size()); j++)
cout << train[j] << ", ";
cout << endl;
cout << "Test: ";
cout << "(" << train.size() << "): ";
for (auto j = 0; j < static_cast<int>(test.size()); j++)
cout << test[j] << ", ";
cout << endl;
cout << "Vector poly" << endl;
auto some = vector<A>();
auto cx = C(5, 4);
auto bx = B(7, 6);
some.push_back(cx);
some.push_back(bx);
for (auto& obj : some) {
cout << "Obj :" << obj.getA() << endl;
}
}
}

74
src/Platform/testx.cpp Normal file
View File

@ -0,0 +1,74 @@
#include "Folding.h"
#include <map>
#include <iostream>
using namespace std;
class A {
private:
int a;
public:
A(int a) : a(a) {}
int getA() { return a; }
};
class B : public A {
private:
int b;
public:
B(int a, int b) : A(a), b(b) {}
int getB() { return b; }
};
class C : public A {
private:
int c;
public:
C(int a, int c) : A(a), c(c) {}
int getC() { return c; }
};
string counts(vector<int> y, vector<int> indices)
{
auto result = map<int, int>();
for (auto i = 0; i < indices.size(); ++i) {
result[y[indices[i]]]++;
}
string final_result = "";
for (auto i = 0; i < result.size(); ++i)
final_result += to_string(i) + " -> " + to_string(result[i]) + " // ";
final_result += "\n";
return final_result;
}
int main()
{
auto y = vector<int>(150);
fill(y.begin(), y.begin() + 50, 0);
fill(y.begin() + 50, y.begin() + 100, 1);
fill(y.begin() + 100, y.end(), 2);
//auto fold = KFold(5, 150);
auto fold = StratifiedKFold(5, y, 0);
for (int i = 0; i < 5; ++i) {
cout << "Fold: " << i << endl;
auto [train, test] = fold.getFold(i);
cout << "Train: ";
cout << "(" << train.size() << "): ";
for (auto j = 0; j < static_cast<int>(train.size()); j++)
cout << train[j] << ", ";
cout << endl;
cout << "Train Statistics : " << counts(y, train);
cout << "-------------------------------------------------------------------------------" << endl;
cout << "Test: ";
cout << "(" << test.size() << "): ";
for (auto j = 0; j < static_cast<int>(test.size()); j++)
cout << test[j] << ", ";
cout << endl;
cout << "Test Statistics: " << counts(y, test);
cout << "==============================================================================" << endl;
// cout << "Vector poly" << endl;
// auto some = vector<A>();
// auto cx = C(5, 4);
// auto bx = B(7, 6);
// some.push_back(cx);
// some.push_back(bx);
// for (auto& obj : some) {
// cout << "Obj :" << obj.getA() << endl;
// }
}
}