BayesNet/bayesnet/utils/BayesMetrics.cc

169 lines
7.5 KiB
C++
Raw Permalink Normal View History

2024-04-11 16:02:49 +00:00
// ***************************************************************
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
// SPDX-FileType: SOURCE
// SPDX-License-Identifier: MIT
// ***************************************************************
#include "Mst.h"
2024-03-08 21:20:54 +00:00
#include "BayesMetrics.h"
namespace bayesnet {
2023-10-04 23:14:16 +00:00
//samples is n+1xm tensor used to fit the model
2023-11-08 17:45:35 +00:00
Metrics::Metrics(const torch::Tensor& samples, const std::vector<std::string>& features, const std::string& className, const int classNumStates)
: samples(samples)
, features(features)
, className(className)
, classNumStates(classNumStates)
{
2023-07-11 23:05:24 +00:00
}
2024-04-07 22:13:59 +00:00
//samples is n+1xm std::vector used to fit the model
2023-11-08 17:45:35 +00:00
Metrics::Metrics(const std::vector<std::vector<int>>& vsamples, const std::vector<int>& labels, const std::vector<std::string>& features, const std::string& className, const int classNumStates)
2023-07-11 23:05:24 +00:00
: features(features)
, className(className)
, classNumStates(classNumStates)
2024-04-07 22:13:59 +00:00
, samples(torch::zeros({ static_cast<int>(vsamples.size() + 1), static_cast<int>(vsamples[0].size()) }, torch::kInt32))
2023-07-11 23:05:24 +00:00
{
2023-07-12 01:07:10 +00:00
for (int i = 0; i < vsamples.size(); ++i) {
2023-08-03 18:22:33 +00:00
samples.index_put_({ i, "..." }, torch::tensor(vsamples[i], torch::kInt32));
2023-07-12 01:07:10 +00:00
}
2023-08-03 18:22:33 +00:00
samples.index_put_({ -1, "..." }, torch::tensor(labels, torch::kInt32));
}
2023-11-08 17:45:35 +00:00
std::vector<int> Metrics::SelectKBestWeighted(const torch::Tensor& weights, bool ascending, unsigned k)
{
// Return the K Best features
auto n = features.size();
if (k == 0) {
k = n;
}
// compute scores
scoresKBest.clear();
featuresKBest.clear();
auto label = samples.index({ -1, "..." });
for (int i = 0; i < n; ++i) {
scoresKBest.push_back(mutualInformation(label, samples.index({ i, "..." }), weights));
featuresKBest.push_back(i);
}
// sort & reduce scores and features
if (ascending) {
sort(featuresKBest.begin(), featuresKBest.end(), [&](int i, int j)
{ return scoresKBest[i] < scoresKBest[j]; });
sort(scoresKBest.begin(), scoresKBest.end(), std::less<double>());
if (k < n) {
for (int i = 0; i < n - k; ++i) {
featuresKBest.erase(featuresKBest.begin());
scoresKBest.erase(scoresKBest.begin());
}
}
} else {
sort(featuresKBest.begin(), featuresKBest.end(), [&](int i, int j)
{ return scoresKBest[i] > scoresKBest[j]; });
sort(scoresKBest.begin(), scoresKBest.end(), std::greater<double>());
featuresKBest.resize(k);
scoresKBest.resize(k);
}
return featuresKBest;
}
2023-11-08 17:45:35 +00:00
std::vector<double> Metrics::getScoresKBest() const
{
return scoresKBest;
}
2023-10-13 10:29:25 +00:00
2023-08-13 10:56:06 +00:00
torch::Tensor Metrics::conditionalEdge(const torch::Tensor& weights)
{
2023-11-08 17:45:35 +00:00
auto result = std::vector<double>();
auto source = std::vector<std::string>(features);
source.push_back(className);
2023-10-13 10:29:25 +00:00
auto combinations = doCombinations(source);
// Compute class prior
2023-08-16 10:32:51 +00:00
auto margin = torch::zeros({ classNumStates }, torch::kFloat);
for (int value = 0; value < classNumStates; ++value) {
2023-08-03 18:22:33 +00:00
auto mask = samples.index({ -1, "..." }) == value;
2023-08-16 10:32:51 +00:00
margin[value] = mask.sum().item<double>() / samples.size(1);
}
for (auto [first, second] : combinations) {
2023-07-25 23:39:01 +00:00
int index_first = find(features.begin(), features.end(), first) - features.begin();
int index_second = find(features.begin(), features.end(), second) - features.begin();
double accumulated = 0;
for (int value = 0; value < classNumStates; ++value) {
2023-08-03 18:22:33 +00:00
auto mask = samples.index({ -1, "..." }) == value;
auto first_dataset = samples.index({ index_first, mask });
auto second_dataset = samples.index({ index_second, mask });
2023-08-15 13:04:56 +00:00
auto weights_dataset = weights.index({ mask });
auto mi = mutualInformation(first_dataset, second_dataset, weights_dataset);
2023-08-16 10:32:51 +00:00
auto pb = margin[value].item<double>();
accumulated += pb * mi;
}
result.push_back(accumulated);
}
long n_vars = source.size();
auto matrix = torch::zeros({ n_vars, n_vars });
auto indices = torch::triu_indices(n_vars, n_vars, 1);
for (auto i = 0; i < result.size(); ++i) {
auto x = indices[0][i];
auto y = indices[1][i];
matrix[x][y] = result[i];
matrix[y][x] = result[i];
}
2023-07-13 01:44:33 +00:00
return matrix;
}
2023-08-03 18:22:33 +00:00
// To use in Python
2023-11-08 17:45:35 +00:00
std::vector<float> Metrics::conditionalEdgeWeights(std::vector<float>& weights_)
2023-07-13 01:44:33 +00:00
{
2023-08-13 10:56:06 +00:00
const torch::Tensor weights = torch::tensor(weights_);
auto matrix = conditionalEdge(weights);
2023-07-11 23:05:24 +00:00
std::vector<float> v(matrix.data_ptr<float>(), matrix.data_ptr<float>() + matrix.numel());
return v;
}
2023-08-13 10:56:06 +00:00
double Metrics::entropy(const torch::Tensor& feature, const torch::Tensor& weights)
{
2023-08-13 10:56:06 +00:00
torch::Tensor counts = feature.bincount(weights);
2023-08-16 10:32:51 +00:00
double totalWeight = counts.sum().item<double>();
torch::Tensor probs = counts.to(torch::kFloat) / totalWeight;
torch::Tensor logProbs = torch::log(probs);
torch::Tensor entropy = -probs * logProbs;
return entropy.nansum().item<double>();
}
// H(Y|X) = sum_{x in X} p(x) H(Y|X=x)
2023-08-13 10:56:06 +00:00
double Metrics::conditionalEntropy(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& weights)
{
int numSamples = firstFeature.sizes()[0];
2023-08-13 10:56:06 +00:00
torch::Tensor featureCounts = secondFeature.bincount(weights);
2023-11-08 17:45:35 +00:00
std::unordered_map<int, std::unordered_map<int, double>> jointCounts;
double totalWeight = 0;
for (auto i = 0; i < numSamples; i++) {
2023-08-16 10:32:51 +00:00
jointCounts[secondFeature[i].item<int>()][firstFeature[i].item<int>()] += weights[i].item<double>();
2023-08-13 10:56:06 +00:00
totalWeight += weights[i].item<float>();
}
if (totalWeight == 0)
return 0;
double entropyValue = 0;
for (int value = 0; value < featureCounts.sizes()[0]; ++value) {
double p_f = featureCounts[value].item<double>() / totalWeight;
double entropy_f = 0;
for (auto& [label, jointCount] : jointCounts[value]) {
double p_l_f = jointCount / featureCounts[value].item<double>();
if (p_l_f > 0) {
entropy_f -= p_l_f * log(p_l_f);
} else {
entropy_f = 0;
}
}
entropyValue += p_f * entropy_f;
}
return entropyValue;
}
// I(X;Y) = H(Y) - H(Y|X)
2023-08-13 10:56:06 +00:00
double Metrics::mutualInformation(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& weights)
{
2023-08-13 10:56:06 +00:00
return entropy(firstFeature, weights) - conditionalEntropy(firstFeature, secondFeature, weights);
}
2023-07-13 23:05:49 +00:00
/*
Compute the maximum spanning tree considering the weights as distances
and the indices of the weights as nodes of this square matrix using
Kruskal algorithm
*/
2023-11-08 17:45:35 +00:00
std::vector<std::pair<int, int>> Metrics::maximumSpanningTree(const std::vector<std::string>& features, const torch::Tensor& weights, const int root)
{
auto mst = MST(features, weights, root);
return mst.maximumSpanningTree();
}
}