Complete Folding Test
This commit is contained in:
parent
1287160c47
commit
8c3864f3c8
9
Makefile
9
Makefile
@ -15,10 +15,7 @@ define ClearTests
|
|||||||
rm -f $(f_debug)/tests/$$t ; \
|
rm -f $(f_debug)/tests/$$t ; \
|
||||||
fi ; \
|
fi ; \
|
||||||
done
|
done
|
||||||
$(eval nfiles=$(find . -name "*.gcda" -print))
|
@find . -name "*.gcda" -print0 | xargs -0 rm 2>/dev/null ;
|
||||||
@if test "${nfiles}" != "" ; then \
|
|
||||||
find . -name "*.gcda" -print0 | xargs -0 rm 2>/dev/null ;\
|
|
||||||
fi ;
|
|
||||||
endef
|
endef
|
||||||
|
|
||||||
|
|
||||||
@ -106,8 +103,8 @@ testb: ## Run BayesNet tests (opt="-s") to verbose output the tests, (opt="-c='T
|
|||||||
coverage: ## Run tests and generate coverage report (build/index.html)
|
coverage: ## Run tests and generate coverage report (build/index.html)
|
||||||
@echo ">>> Building tests with coverage...";
|
@echo ">>> Building tests with coverage...";
|
||||||
@$(MAKE) test
|
@$(MAKE) test
|
||||||
@cd $(f_debug) ;
|
@cd $(f_debug) ; \
|
||||||
@gcovr --config ../gcovr.cfg ;
|
gcovr --config ../gcovr.cfg tests ;
|
||||||
@echo ">>> Done";
|
@echo ">>> Done";
|
||||||
|
|
||||||
|
|
||||||
|
@ -47,7 +47,7 @@ namespace platform {
|
|||||||
{
|
{
|
||||||
stratified_indices = vector<vector<int>>(k);
|
stratified_indices = vector<vector<int>>(k);
|
||||||
int fold_size = n / k;
|
int fold_size = n / k;
|
||||||
cout << "Fold SIZE: " << fold_size << endl;
|
|
||||||
// Compute class counts and indices
|
// Compute class counts and indices
|
||||||
auto class_indices = map<int, vector<int>>();
|
auto class_indices = map<int, vector<int>>();
|
||||||
vector<int> class_counts(*max_element(y.begin(), y.end()) + 1, 0);
|
vector<int> class_counts(*max_element(y.begin(), y.end()) + 1, 0);
|
||||||
@ -61,11 +61,14 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
// Assign indices to folds
|
// Assign indices to folds
|
||||||
for (auto label = 0; label < class_counts.size(); ++label) {
|
for (auto label = 0; label < class_counts.size(); ++label) {
|
||||||
auto num_samples_to_take = class_counts[label] / k;
|
auto num_samples_to_take = class_counts.at(label) / k;
|
||||||
if (num_samples_to_take == 0)
|
if (num_samples_to_take == 0) {
|
||||||
|
cerr << "Warning! The number of samples in class " << label << " (" << class_counts.at(label)
|
||||||
|
<< ") is less than the number of folds (" << k << ")." << endl;
|
||||||
|
faulty = true;
|
||||||
continue;
|
continue;
|
||||||
|
}
|
||||||
auto remainder_samples_to_take = class_counts[label] % k;
|
auto remainder_samples_to_take = class_counts[label] % k;
|
||||||
cout << "Remainder samples to take: " << remainder_samples_to_take << endl;
|
|
||||||
for (auto fold = 0; fold < k; ++fold) {
|
for (auto fold = 0; fold < k; ++fold) {
|
||||||
auto it = next(class_indices[label].begin(), num_samples_to_take);
|
auto it = next(class_indices[label].begin(), num_samples_to_take);
|
||||||
move(class_indices[label].begin(), it, back_inserter(stratified_indices[fold])); // ##
|
move(class_indices[label].begin(), it, back_inserter(stratified_indices[fold])); // ##
|
||||||
@ -74,12 +77,10 @@ namespace platform {
|
|||||||
auto chosen = vector<bool>(k, false);
|
auto chosen = vector<bool>(k, false);
|
||||||
while (remainder_samples_to_take > 0) {
|
while (remainder_samples_to_take > 0) {
|
||||||
int fold = (rand() % static_cast<int>(k));
|
int fold = (rand() % static_cast<int>(k));
|
||||||
cout << "-candidate: " << fold << endl;
|
|
||||||
if (chosen.at(fold)) {
|
if (chosen.at(fold)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
chosen[fold] = true;
|
chosen[fold] = true;
|
||||||
cout << "One goes to fold " << fold << " that had " << stratified_indices[fold].size() << " elements before" << endl;
|
|
||||||
auto it = next(class_indices[label].begin(), 1);
|
auto it = next(class_indices[label].begin(), 1);
|
||||||
stratified_indices[fold].push_back(*class_indices[label].begin());
|
stratified_indices[fold].push_back(*class_indices[label].begin());
|
||||||
class_indices[label].erase(class_indices[label].begin(), it);
|
class_indices[label].erase(class_indices[label].begin(), it);
|
||||||
|
@ -29,10 +29,12 @@ namespace platform {
|
|||||||
vector<int> y;
|
vector<int> y;
|
||||||
vector<vector<int>> stratified_indices;
|
vector<vector<int>> stratified_indices;
|
||||||
void build();
|
void build();
|
||||||
|
bool faulty = false; // Only true if the number of samples of any class is less than the number of folds.
|
||||||
public:
|
public:
|
||||||
StratifiedKFold(int k, const vector<int>& y, int seed = -1);
|
StratifiedKFold(int k, const vector<int>& y, int seed = -1);
|
||||||
StratifiedKFold(int k, torch::Tensor& y, int seed = -1);
|
StratifiedKFold(int k, torch::Tensor& y, int seed = -1);
|
||||||
pair<vector<int>, vector<int>> getFold(int nFold) override;
|
pair<vector<int>, vector<int>> getFold(int nFold) override;
|
||||||
|
bool isFaulty() { return faulty; }
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
@ -22,7 +22,8 @@ TEST_CASE("Metrics Test", "[BayesNet]")
|
|||||||
{"diabetes", 0.0345470614}
|
{"diabetes", 0.0345470614}
|
||||||
};
|
};
|
||||||
map<string, vector<pair<int, int>>> resultsMST = {
|
map<string, vector<pair<int, int>>> resultsMST = {
|
||||||
{"glass", {{0,6}, {0,5}, {0,3}, {6,2}, {6,7}, {5,1}, {5,8}, {5,4}}},
|
//{"glass", {{0,6}, {0,5}, {0,3}, {6,2}, {6,7}, {5,1}, {5,8}, {5,4}}},
|
||||||
|
{"glass", {{0,6}, {0,5}, {0,3}, {5,1}, {5,8}, {5,4}, {6,2}, {6,7}}},
|
||||||
{"iris", {{0,1},{0,2},{1,3}}},
|
{"iris", {{0,1},{0,2},{1,3}}},
|
||||||
{"ecoli", {{0,1}, {0,2}, {1,5}, {1,3}, {5,6}, {5,4}}},
|
{"ecoli", {{0,1}, {0,2}, {1,5}, {1,3}, {5,6}, {5,4}}},
|
||||||
{"diabetes", {{0,7}, {0,2}, {0,6}, {2,3}, {3,4}, {3,5}, {4,1}}}
|
{"diabetes", {{0,7}, {0,2}, {0,6}, {2,3}, {3,4}, {3,5}, {4,1}}}
|
||||||
|
@ -66,27 +66,28 @@ TEST_CASE("StratifiedKFold Test", "[Platform][StratifiedKFold]")
|
|||||||
auto [train_indicesv, test_indicesv] = stratified_kfoldv.getFold(fold);
|
auto [train_indicesv, test_indicesv] = stratified_kfoldv.getFold(fold);
|
||||||
REQUIRE(train_indicest == train_indicesv);
|
REQUIRE(train_indicest == train_indicesv);
|
||||||
REQUIRE(test_indicest == test_indicesv);
|
REQUIRE(test_indicest == test_indicesv);
|
||||||
bool result = train_indicest.size() == number || train_indicest.size() == number + 1;
|
// In the worst case scenario, the number of samples in the training set is number + raw.classNumStates
|
||||||
REQUIRE(result);
|
// because in that fold can come one remainder sample from each class.
|
||||||
REQUIRE(train_indicest.size() + test_indicest.size() == raw.nSamples);
|
REQUIRE(train_indicest.size() <= number + raw.classNumStates);
|
||||||
|
// If the number of samples in any class is less than the number of folds, then the fold is faulty.
|
||||||
|
// and the number of samples in the training set + test set will be less than nSamples
|
||||||
|
if (!stratified_kfoldt.isFaulty()) {
|
||||||
|
REQUIRE(train_indicest.size() + test_indicest.size() == raw.nSamples);
|
||||||
|
} else {
|
||||||
|
REQUIRE(train_indicest.size() + test_indicest.size() <= raw.nSamples);
|
||||||
|
}
|
||||||
auto train_t = torch::tensor(train_indicest);
|
auto train_t = torch::tensor(train_indicest);
|
||||||
auto ytrain = raw.yt.index({ train_t });
|
auto ytrain = raw.yt.index({ train_t });
|
||||||
cout << "dataset=" << file_name << endl;
|
|
||||||
cout << "nSamples=" << raw.nSamples << endl;;
|
|
||||||
cout << "number=" << number << endl;
|
|
||||||
cout << "train_indices.size()=" << train_indicest.size() << endl;
|
|
||||||
cout << "test_indices.size()=" << test_indicest.size() << endl;
|
|
||||||
cout << "Class Name = " << raw.classNamet << endl;
|
|
||||||
// Check that the class labels have been equally assign to each fold
|
// Check that the class labels have been equally assign to each fold
|
||||||
for (const auto& idx : train_indicest) {
|
for (const auto& idx : train_indicest) {
|
||||||
counts[fold][ytrain[idx].item<int>()]++;
|
counts[fold][raw.yt[idx].item<int>()]++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Test the fold counting of every class
|
// Test the fold counting of every class
|
||||||
for (int fold = 0; fold < nFolds; ++fold) {
|
for (int fold = 0; fold < nFolds; ++fold) {
|
||||||
for (int j = 1; j < nFolds - 1; ++j) {
|
for (int j = 1; j < nFolds - 1; ++j) {
|
||||||
for (int k = 0; k < raw.classNumStates; ++k) {
|
for (int k = 0; k < raw.classNumStates; ++k) {
|
||||||
REQUIRE(abs(counts.at(fold).at(k) - counts.at(fold).at(j)) <= 1);
|
REQUIRE(abs(counts.at(fold).at(k) - counts.at(j).at(k)) <= 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user