Refactor Network and create Metrics class

This commit is contained in:
2023-07-11 22:23:49 +02:00
parent c7e2042c6e
commit d1eaab6408
7 changed files with 137 additions and 190 deletions

View File

@@ -1,2 +1,2 @@
add_library(BayesNet Network.cc Node.cc)
add_library(BayesNet Network.cc Node.cc Metrics.cc)
target_link_libraries(BayesNet "${TORCH_LIBRARIES}")

102
src/Metrics.cc Normal file
View File

@@ -0,0 +1,102 @@
#include "Metrics.hpp"
using namespace std;
namespace bayesnet {
Metrics::Metrics(torch::Tensor& samples, vector<string>& features, string& className, int classNumStates)
: samples(samples)
, features(features)
, className(className)
, classNumStates(classNumStates)
{
}
vector<pair<string, string>> Metrics::doCombinations(const vector<string>& source)
{
vector<pair<string, string>> result;
for (int i = 0; i < source.size(); ++i) {
string temp = source[i];
for (int j = i + 1; j < source.size(); ++j) {
result.push_back({ temp, source[j] });
}
}
return result;
}
torch::Tensor Metrics::conditionalEdgeWeight()
{
auto result = vector<double>();
auto source = vector<string>(features);
source.push_back(className);
auto combinations = doCombinations(source);
// Compute class prior
auto margin = torch::zeros({ classNumStates });
for (int value = 0; value < classNumStates; ++value) {
auto mask = samples.index({ "...", -1 }) == value;
margin[value] = mask.sum().item<float>() / samples.sizes()[0];
}
for (auto [first, second] : combinations) {
int64_t index_first = find(features.begin(), features.end(), first) - features.begin();
int64_t index_second = find(features.begin(), features.end(), second) - features.begin();
double accumulated = 0;
for (int value = 0; value < classNumStates; ++value) {
auto mask = samples.index({ "...", -1 }) == value;
auto first_dataset = samples.index({ mask, index_first });
auto second_dataset = samples.index({ mask, index_second });
auto mi = mutualInformation(first_dataset, second_dataset);
auto pb = margin[value].item<float>();
accumulated += pb * mi;
}
result.push_back(accumulated);
}
long n_vars = source.size();
auto matrix = torch::zeros({ n_vars, n_vars });
auto indices = torch::triu_indices(n_vars, n_vars, 1);
for (auto i = 0; i < result.size(); ++i) {
auto x = indices[0][i];
auto y = indices[1][i];
matrix[x][y] = result[i];
matrix[y][x] = result[i];
}
return matrix;
}
double Metrics::entropy(torch::Tensor& feature)
{
torch::Tensor counts = feature.bincount();
int totalWeight = counts.sum().item<int>();
torch::Tensor probs = counts.to(torch::kFloat) / totalWeight;
torch::Tensor logProbs = torch::log(probs);
torch::Tensor entropy = -probs * logProbs;
return entropy.nansum().item<double>();
}
// H(Y|X) = sum_{x in X} p(x) H(Y|X=x)
double Metrics::conditionalEntropy(torch::Tensor& firstFeature, torch::Tensor& secondFeature)
{
int numSamples = firstFeature.sizes()[0];
torch::Tensor featureCounts = secondFeature.bincount();
unordered_map<int, unordered_map<int, double>> jointCounts;
double totalWeight = 0;
for (auto i = 0; i < numSamples; i++) {
jointCounts[secondFeature[i].item<int>()][firstFeature[i].item<int>()] += 1;
totalWeight += 1;
}
if (totalWeight == 0)
throw invalid_argument("Total weight should not be zero");
double entropyValue = 0;
for (int value = 0; value < featureCounts.sizes()[0]; ++value) {
double p_f = featureCounts[value].item<double>() / totalWeight;
double entropy_f = 0;
for (auto& [label, jointCount] : jointCounts[value]) {
double p_l_f = jointCount / featureCounts[value].item<double>();
if (p_l_f > 0) {
entropy_f -= p_l_f * log(p_l_f);
} else {
entropy_f = 0;
}
}
entropyValue += p_f * entropy_f;
}
return entropyValue;
}
// I(X;Y) = H(Y) - H(Y|X)
double Metrics::mutualInformation(torch::Tensor& firstFeature, torch::Tensor& secondFeature)
{
return entropy(firstFeature) - conditionalEntropy(firstFeature, secondFeature);
}
}

23
src/Metrics.hpp Normal file
View File

@@ -0,0 +1,23 @@
#ifndef BAYESNET_METRICS_H
#define BAYESNET_METRICS_H
#include <torch/torch.h>
#include <vector>
#include <string>
using namespace std;
namespace bayesnet {
class Metrics {
private:
torch::Tensor& samples;
vector<string>& features;
string& className;
int classNumStates;
vector<pair<string, string>> doCombinations(const vector<string>&);
double entropy(torch::Tensor&);
double conditionalEntropy(torch::Tensor&, torch::Tensor&);
double mutualInformation(torch::Tensor&, torch::Tensor&);
public:
Metrics(torch::Tensor&, vector<string>&, string&, int);
torch::Tensor conditionalEdgeWeight();
};
}
#endif

View File

@@ -21,6 +21,10 @@ namespace bayesnet {
{
return maxThreads;
}
torch::Tensor& Network::getSamples()
{
return samples;
}
void Network::addNode(string name, int numStates)
{
if (nodes.find(name) != nodes.end()) {
@@ -241,83 +245,5 @@ namespace bayesnet {
}
return result;
}
double Network::mutual_info(torch::Tensor& first, torch::Tensor& second)
{
return 1;
}
torch::Tensor Network::conditionalEdgeWeight()
{
auto result = vector<double>();
auto source = vector<string>(features);
source.push_back(className);
auto combinations = nodes[className]->combinations(source);
auto margin = nodes[className]->getCPT();
for (auto [first, second] : combinations) {
int64_t index_first = find(features.begin(), features.end(), first) - features.begin();
int64_t index_second = find(features.begin(), features.end(), second) - features.begin();
double accumulated = 0;
for (int value = 0; value < classNumStates; ++value) {
auto mask = samples.index({ "...", -1 }) == value;
auto first_dataset = samples.index({ mask, index_first });
auto second_dataset = samples.index({ mask, index_second });
auto mi = mutualInformation(first_dataset, second_dataset);
auto pb = margin[value].item<float>();
accumulated += pb * mi;
}
result.push_back(accumulated);
}
long n_vars = source.size();
auto matrix = torch::zeros({ n_vars, n_vars });
auto indices = torch::triu_indices(n_vars, n_vars, 1);
for (auto i = 0; i < result.size(); ++i) {
auto x = indices[0][i];
auto y = indices[1][i];
matrix[x][y] = result[i];
matrix[y][x] = result[i];
}
return matrix;
}
double Network::entropy(torch::Tensor& feature)
{
torch::Tensor counts = feature.bincount();
int totalWeight = counts.sum().item<int>();
torch::Tensor probs = counts.to(torch::kFloat) / totalWeight;
torch::Tensor logProbs = torch::log(probs);
torch::Tensor entropy = -probs * logProbs;
return entropy.nansum().item<double>();
}
// H(Y|X) = sum_{x in X} p(x) H(Y|X=x)
double Network::conditionalEntropy(torch::Tensor& firstFeature, torch::Tensor& secondFeature)
{
int numSamples = firstFeature.sizes()[0];
torch::Tensor featureCounts = secondFeature.bincount();
unordered_map<int, unordered_map<int, double>> jointCounts;
double totalWeight = 0;
for (auto i = 0; i < numSamples; i++) {
jointCounts[secondFeature[i].item<int>()][firstFeature[i].item<int>()] += 1;
totalWeight += 1;
}
if (totalWeight == 0)
throw invalid_argument("Total weight should not be zero");
double entropyValue = 0;
for (int value = 0; value < featureCounts.sizes()[0]; ++value) {
double p_f = featureCounts[value].item<double>() / totalWeight;
double entropy_f = 0;
for (auto& [label, jointCount] : jointCounts[value]) {
double p_l_f = jointCount / featureCounts[value].item<double>();
if (p_l_f > 0) {
entropy_f -= p_l_f * log(p_l_f);
} else {
entropy_f = 0;
}
}
entropyValue += p_f * entropy_f;
}
return entropyValue;
}
// I(X;Y) = H(Y) - H(Y|X)
double Network::mutualInformation(torch::Tensor& firstFeature, torch::Tensor& secondFeature)
{
return entropy(firstFeature) - conditionalEntropy(firstFeature, secondFeature);
}
}

View File

@@ -15,6 +15,7 @@ namespace bayesnet {
vector<string> features;
string className;
int laplaceSmoothing;
torch::Tensor samples;
bool isCyclic(const std::string&, std::unordered_set<std::string>&, std::unordered_set<std::string>&);
vector<double> predict_sample(const vector<int>&);
vector<double> exactInference(map<string, int>&);
@@ -24,12 +25,12 @@ namespace bayesnet {
double conditionalEntropy(torch::Tensor&, torch::Tensor&);
double mutualInformation(torch::Tensor&, torch::Tensor&);
public:
torch::Tensor samples;
Network();
Network(float, int);
Network(float);
Network(Network&);
~Network();
torch::Tensor& getSamples();
float getmaxThreads();
void addNode(string, int);
void addEdge(const string, const string);