Complete selectKPairs method & test

This commit is contained in:
2024-05-16 13:46:38 +02:00
parent 2e3e0e0fc2
commit cccaa6e0af
5 changed files with 96 additions and 35 deletions

View File

@@ -18,6 +18,7 @@ namespace bayesnet {
std::vector<int> SelectKBestWeighted(const torch::Tensor& weights, bool ascending = false, unsigned k = 0);
std::vector<std::pair<int, int>> SelectKPairs(const torch::Tensor& weights, bool ascending = false, unsigned k = 0);
std::vector<double> getScoresKBest() const;
std::vector<std::pair<std::pair<int, int>, double>> getScoresKPairs() const;
double mutualInformation(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& weights);
double conditionalMutualInformation(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& labels, const torch::Tensor& weights);
torch::Tensor conditionalEdge(const torch::Tensor& weights);
@@ -34,7 +35,7 @@ namespace bayesnet {
std::vector<std::pair<T, T>> doCombinations(const std::vector<T>& source)
{
std::vector<std::pair<T, T>> result;
for (int i = 0; i < source.size(); ++i) {
for (int i = 0; i < source.size() - 1; ++i) {
T temp = source[i];
for (int j = i + 1; j < source.size(); ++j) {
result.push_back({ temp, source[j] });
@@ -42,7 +43,7 @@ namespace bayesnet {
}
return result;
}
template <class T>
template <class T>
T pop_first(std::vector<T>& v)
{
T temp = v[0];
@@ -54,7 +55,7 @@ namespace bayesnet {
std::vector<double> scoresKBest;
std::vector<int> featuresKBest; // sorted indices of the features
std::vector<std::pair<int, int>> pairsKBest; // sorted indices of the pairs
std::map<std::pair<int, int>, double> scoresKPairs;
std::vector<std::pair<std::pair<int, int>, double>> scoresKPairs;
double conditionalEntropy(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& weights);
};
}