Add features used to selectKPairs

This commit is contained in:
2024-05-16 14:18:45 +02:00
parent cccaa6e0af
commit 677ec5613d
4 changed files with 183 additions and 101 deletions

View File

@@ -30,7 +30,7 @@ namespace bayesnet {
}
samples.index_put_({ -1, "..." }, torch::tensor(labels, torch::kInt32));
}
std::vector<std::pair<int, int>> Metrics::SelectKPairs(const torch::Tensor& weights, bool ascending, unsigned k)
std::vector<std::pair<int, int>> Metrics::SelectKPairs(const torch::Tensor& weights, std::vector<int>& featuresExcluded, bool ascending, unsigned k)
{
// Return the K Best features
auto n = features.size();
@@ -39,7 +39,13 @@ namespace bayesnet {
pairsKBest.clear();
auto labels = samples.index({ -1, "..." });
for (int i = 0; i < n - 1; ++i) {
if (std::find(featuresExcluded.begin(), featuresExcluded.end(), i) != featuresExcluded.end()) {
continue;
}
for (int j = i + 1; j < n; ++j) {
if (std::find(featuresExcluded.begin(), featuresExcluded.end(), j) != featuresExcluded.end()) {
continue;
}
auto key = std::make_pair(i, j);
auto value = conditionalMutualInformation(samples.index({ i, "..." }), samples.index({ j, "..." }), labels, weights);
scoresKPairs.push_back({ key, value });
@@ -57,9 +63,10 @@ namespace bayesnet {
for (auto& [pairs, score] : scoresKPairs) {
pairsKBest.push_back(pairs);
}
if (k != 0) {
if (k != 0 && k < pairsKBest.size()) {
if (ascending) {
for (int i = 0; i < n - k; ++i) {
int limit = pairsKBest.size() - k;
for (int i = 0; i < limit; i++) {
pairsKBest.erase(pairsKBest.begin());
scoresKPairs.erase(scoresKPairs.begin());
}