BayesNet/bayesnet/utils/BayesMetrics.h

62 lines
3.4 KiB
C
Raw Normal View History

2024-04-11 16:02:49 +00:00
// ***************************************************************
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
// SPDX-FileType: SOURCE
// SPDX-License-Identifier: MIT
// ***************************************************************
#ifndef BAYESNET_METRICS_H
#define BAYESNET_METRICS_H
#include <vector>
#include <string>
2024-03-08 21:20:54 +00:00
#include <torch/torch.h>
namespace bayesnet {
class Metrics {
2024-04-07 22:13:59 +00:00
public:
Metrics() = default;
Metrics(const torch::Tensor& samples, const std::vector<std::string>& features, const std::string& className, const int classNumStates);
Metrics(const std::vector<std::vector<int>>& vsamples, const std::vector<int>& labels, const std::vector<std::string>& features, const std::string& className, const int classNumStates);
std::vector<int> SelectKBestWeighted(const torch::Tensor& weights, bool ascending = false, unsigned k = 0);
2024-05-16 12:18:45 +00:00
std::vector<std::pair<int, int>> SelectKPairs(const torch::Tensor& weights, std::vector<int>& featuresExcluded, bool ascending = false, unsigned k = 0);
2024-04-07 22:13:59 +00:00
std::vector<double> getScoresKBest() const;
2024-05-16 11:46:38 +00:00
std::vector<std::pair<std::pair<int, int>, double>> getScoresKPairs() const;
2024-04-07 22:13:59 +00:00
double mutualInformation(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& weights);
double conditionalMutualInformation(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& labels, const torch::Tensor& weights);
2024-04-07 22:13:59 +00:00
torch::Tensor conditionalEdge(const torch::Tensor& weights);
std::vector<std::pair<int, int>> maximumSpanningTree(const std::vector<std::string>& features, const torch::Tensor& weights, const int root);
// Measured in nats (natural logarithm (log) base e)
// Elements of Information Theory, 2nd Edition, Thomas M. Cover, Joy A. Thomas p. 14
double entropy(const torch::Tensor& feature, const torch::Tensor& weights);
double conditionalEntropy(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& labels, const torch::Tensor& weights);
2023-10-11 19:17:26 +00:00
protected:
2023-11-08 17:45:35 +00:00
torch::Tensor samples; // n+1xm torch::Tensor used to fit the model where samples[-1] is the y std::vector
std::string className;
std::vector<std::string> features;
2023-10-11 19:17:26 +00:00
template <class T>
2023-11-08 17:45:35 +00:00
std::vector<std::pair<T, T>> doCombinations(const std::vector<T>& source)
2023-10-13 10:29:25 +00:00
{
2023-11-08 17:45:35 +00:00
std::vector<std::pair<T, T>> result;
2024-05-16 11:46:38 +00:00
for (int i = 0; i < source.size() - 1; ++i) {
2023-10-13 10:29:25 +00:00
T temp = source[i];
for (int j = i + 1; j < source.size(); ++j) {
result.push_back({ temp, source[j] });
}
}
return result;
}
2024-05-16 11:46:38 +00:00
template <class T>
2023-11-08 17:45:35 +00:00
T pop_first(std::vector<T>& v)
{
T temp = v[0];
v.erase(v.begin());
return temp;
}
2024-04-07 22:13:59 +00:00
private:
int classNumStates = 0;
std::vector<double> scoresKBest;
std::vector<int> featuresKBest; // sorted indices of the features
2024-05-16 09:17:21 +00:00
std::vector<std::pair<int, int>> pairsKBest; // sorted indices of the pairs
2024-05-16 11:46:38 +00:00
std::vector<std::pair<std::pair<int, int>, double>> scoresKPairs;
2024-04-07 22:13:59 +00:00
double conditionalEntropy(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& weights);
};
}
#endif