Refactor Network and create Metrics class
This commit is contained in:
parent
c7e2042c6e
commit
d1eaab6408
109
README.md
109
README.md
@ -2,111 +2,4 @@
|
||||
|
||||
Bayesian Network Classifier with libtorch from scratch
|
||||
|
||||
## Variable Elimination
|
||||
|
||||
To decide the first variable to eliminate wel use the MinFill criterion, that is
|
||||
the variable that minimizes the number of edges that need to be added to the
|
||||
graph to make it triangulated.
|
||||
This is done by counting the number of edges that need to be added to the graph
|
||||
if the variable is eliminated. The variable with the minimum number of edges is
|
||||
chosen.
|
||||
In pgmpy this is done computing then the length of the combinations of the
|
||||
neighbors taken 2 by 2.
|
||||
|
||||
Once the variable to eliminate is chosen, we need to compute the factors that
|
||||
need to be multiplied to get the new factor.
|
||||
This is done by multiplying all the factors that contain the variable to
|
||||
eliminate and then marginalizing the variable out.
|
||||
|
||||
The new factor is then added to the list of factors and the variable to
|
||||
eliminate is removed from the list of variables.
|
||||
|
||||
The process is repeated until there are no more variables to eliminate.
|
||||
|
||||
## Code for combination
|
||||
|
||||
```cpp
|
||||
// Combinations of length 2
|
||||
vector<string> combinations(vector<string> source)
|
||||
{
|
||||
vector<string> result;
|
||||
for (int i = 0; i < source.size(); ++i) {
|
||||
string temp = source[i];
|
||||
for (int j = i + 1; j < source.size(); ++j) {
|
||||
result.push_back(temp + source[j]);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
```
|
||||
|
||||
## Code for Variable Elimination
|
||||
|
||||
```cpp
|
||||
// Variable Elimination
|
||||
vector<string> variableElimination(vector<string> source, map<string, vector<string>> graph)
|
||||
{
|
||||
vector<string> variables = source;
|
||||
vector<string> factors = source;
|
||||
while (variables.size() > 0) {
|
||||
string variable = minFill(variables, graph);
|
||||
vector<string> neighbors = graph[variable];
|
||||
vector<string> combinations = combinations(neighbors);
|
||||
vector<string> factorsToMultiply;
|
||||
for (int i = 0; i < factors.size(); ++i) {
|
||||
string factor = factors[i];
|
||||
for (int j = 0; j < combinations.size(); ++j) {
|
||||
string combination = combinations[j];
|
||||
if (factor.find(combination) != string::npos) {
|
||||
factorsToMultiply.push_back(factor);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
string newFactor = multiplyFactors(factorsToMultiply);
|
||||
factors.push_back(newFactor);
|
||||
variables.erase(remove(variables.begin(), variables.end(), variable), variables.end());
|
||||
}
|
||||
return factors;
|
||||
}
|
||||
```
|
||||
|
||||
## Network copy constructor
|
||||
|
||||
```cpp
|
||||
// Network copy constructor
|
||||
Network::Network(const Network& network)
|
||||
{
|
||||
this->variables = network.variables;
|
||||
this->factors = network.factors;
|
||||
this->graph = network.graph;
|
||||
}
|
||||
```
|
||||
|
||||
## Code for MinFill
|
||||
|
||||
```cpp
|
||||
// MinFill
|
||||
string minFill(vector<string> source, map<string, vector<string>> graph)
|
||||
{
|
||||
string result;
|
||||
int min = INT_MAX;
|
||||
for (int i = 0; i < source.size(); ++i) {
|
||||
string temp = source[i];
|
||||
int count = 0;
|
||||
vector<string> neighbors = graph[temp];
|
||||
vector<string> combinations = combinations(neighbors);
|
||||
for (int j = 0; j < combinations.size(); ++j) {
|
||||
string combination = combinations[j];
|
||||
if (graph[combination].size() == 0) {
|
||||
count++;
|
||||
}
|
||||
}
|
||||
if (count < min) {
|
||||
min = count;
|
||||
result = temp;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
```
|
||||
## 1. Introduction
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include <getopt.h>
|
||||
#include "ArffFiles.h"
|
||||
#include "Network.h"
|
||||
#include "Metrics.hpp"
|
||||
#include "CPPFImdlp.h"
|
||||
|
||||
|
||||
@ -230,6 +231,7 @@ int main(int argc, char** argv)
|
||||
cout << "BayesNet version: " << network.version() << endl;
|
||||
unsigned int nthreads = std::thread::hardware_concurrency();
|
||||
cout << "Computer has " << nthreads << " cores." << endl;
|
||||
cout << "conditionalEdgeWeight " << endl << network.conditionalEdgeWeight() << endl;
|
||||
auto metrics = bayesnet::Metrics(network.getSamples(), features, className, network.getClassNumStates());
|
||||
cout << "conditionalEdgeWeight " << endl << metrics.conditionalEdgeWeight() << endl;
|
||||
return 0;
|
||||
}
|
@ -1,2 +1,2 @@
|
||||
add_library(BayesNet Network.cc Node.cc)
|
||||
add_library(BayesNet Network.cc Node.cc Metrics.cc)
|
||||
target_link_libraries(BayesNet "${TORCH_LIBRARIES}")
|
102
src/Metrics.cc
Normal file
102
src/Metrics.cc
Normal file
@ -0,0 +1,102 @@
|
||||
#include "Metrics.hpp"
|
||||
using namespace std;
|
||||
namespace bayesnet {
|
||||
Metrics::Metrics(torch::Tensor& samples, vector<string>& features, string& className, int classNumStates)
|
||||
: samples(samples)
|
||||
, features(features)
|
||||
, className(className)
|
||||
, classNumStates(classNumStates)
|
||||
{
|
||||
}
|
||||
vector<pair<string, string>> Metrics::doCombinations(const vector<string>& source)
|
||||
{
|
||||
vector<pair<string, string>> result;
|
||||
for (int i = 0; i < source.size(); ++i) {
|
||||
string temp = source[i];
|
||||
for (int j = i + 1; j < source.size(); ++j) {
|
||||
result.push_back({ temp, source[j] });
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
torch::Tensor Metrics::conditionalEdgeWeight()
|
||||
{
|
||||
auto result = vector<double>();
|
||||
auto source = vector<string>(features);
|
||||
source.push_back(className);
|
||||
auto combinations = doCombinations(source);
|
||||
// Compute class prior
|
||||
auto margin = torch::zeros({ classNumStates });
|
||||
for (int value = 0; value < classNumStates; ++value) {
|
||||
auto mask = samples.index({ "...", -1 }) == value;
|
||||
margin[value] = mask.sum().item<float>() / samples.sizes()[0];
|
||||
}
|
||||
for (auto [first, second] : combinations) {
|
||||
int64_t index_first = find(features.begin(), features.end(), first) - features.begin();
|
||||
int64_t index_second = find(features.begin(), features.end(), second) - features.begin();
|
||||
double accumulated = 0;
|
||||
for (int value = 0; value < classNumStates; ++value) {
|
||||
auto mask = samples.index({ "...", -1 }) == value;
|
||||
auto first_dataset = samples.index({ mask, index_first });
|
||||
auto second_dataset = samples.index({ mask, index_second });
|
||||
auto mi = mutualInformation(first_dataset, second_dataset);
|
||||
auto pb = margin[value].item<float>();
|
||||
accumulated += pb * mi;
|
||||
}
|
||||
result.push_back(accumulated);
|
||||
}
|
||||
long n_vars = source.size();
|
||||
auto matrix = torch::zeros({ n_vars, n_vars });
|
||||
auto indices = torch::triu_indices(n_vars, n_vars, 1);
|
||||
for (auto i = 0; i < result.size(); ++i) {
|
||||
auto x = indices[0][i];
|
||||
auto y = indices[1][i];
|
||||
matrix[x][y] = result[i];
|
||||
matrix[y][x] = result[i];
|
||||
}
|
||||
return matrix;
|
||||
}
|
||||
double Metrics::entropy(torch::Tensor& feature)
|
||||
{
|
||||
torch::Tensor counts = feature.bincount();
|
||||
int totalWeight = counts.sum().item<int>();
|
||||
torch::Tensor probs = counts.to(torch::kFloat) / totalWeight;
|
||||
torch::Tensor logProbs = torch::log(probs);
|
||||
torch::Tensor entropy = -probs * logProbs;
|
||||
return entropy.nansum().item<double>();
|
||||
}
|
||||
// H(Y|X) = sum_{x in X} p(x) H(Y|X=x)
|
||||
double Metrics::conditionalEntropy(torch::Tensor& firstFeature, torch::Tensor& secondFeature)
|
||||
{
|
||||
int numSamples = firstFeature.sizes()[0];
|
||||
torch::Tensor featureCounts = secondFeature.bincount();
|
||||
unordered_map<int, unordered_map<int, double>> jointCounts;
|
||||
double totalWeight = 0;
|
||||
for (auto i = 0; i < numSamples; i++) {
|
||||
jointCounts[secondFeature[i].item<int>()][firstFeature[i].item<int>()] += 1;
|
||||
totalWeight += 1;
|
||||
}
|
||||
if (totalWeight == 0)
|
||||
throw invalid_argument("Total weight should not be zero");
|
||||
double entropyValue = 0;
|
||||
for (int value = 0; value < featureCounts.sizes()[0]; ++value) {
|
||||
double p_f = featureCounts[value].item<double>() / totalWeight;
|
||||
double entropy_f = 0;
|
||||
for (auto& [label, jointCount] : jointCounts[value]) {
|
||||
double p_l_f = jointCount / featureCounts[value].item<double>();
|
||||
if (p_l_f > 0) {
|
||||
entropy_f -= p_l_f * log(p_l_f);
|
||||
} else {
|
||||
entropy_f = 0;
|
||||
}
|
||||
}
|
||||
entropyValue += p_f * entropy_f;
|
||||
}
|
||||
return entropyValue;
|
||||
}
|
||||
// I(X;Y) = H(Y) - H(Y|X)
|
||||
double Metrics::mutualInformation(torch::Tensor& firstFeature, torch::Tensor& secondFeature)
|
||||
{
|
||||
return entropy(firstFeature) - conditionalEntropy(firstFeature, secondFeature);
|
||||
}
|
||||
}
|
23
src/Metrics.hpp
Normal file
23
src/Metrics.hpp
Normal file
@ -0,0 +1,23 @@
|
||||
#ifndef BAYESNET_METRICS_H
|
||||
#define BAYESNET_METRICS_H
|
||||
#include <torch/torch.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
using namespace std;
|
||||
namespace bayesnet {
|
||||
class Metrics {
|
||||
private:
|
||||
torch::Tensor& samples;
|
||||
vector<string>& features;
|
||||
string& className;
|
||||
int classNumStates;
|
||||
vector<pair<string, string>> doCombinations(const vector<string>&);
|
||||
double entropy(torch::Tensor&);
|
||||
double conditionalEntropy(torch::Tensor&, torch::Tensor&);
|
||||
double mutualInformation(torch::Tensor&, torch::Tensor&);
|
||||
public:
|
||||
Metrics(torch::Tensor&, vector<string>&, string&, int);
|
||||
torch::Tensor conditionalEdgeWeight();
|
||||
};
|
||||
}
|
||||
#endif
|
@ -21,6 +21,10 @@ namespace bayesnet {
|
||||
{
|
||||
return maxThreads;
|
||||
}
|
||||
torch::Tensor& Network::getSamples()
|
||||
{
|
||||
return samples;
|
||||
}
|
||||
void Network::addNode(string name, int numStates)
|
||||
{
|
||||
if (nodes.find(name) != nodes.end()) {
|
||||
@ -241,83 +245,5 @@ namespace bayesnet {
|
||||
}
|
||||
return result;
|
||||
}
|
||||
double Network::mutual_info(torch::Tensor& first, torch::Tensor& second)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
torch::Tensor Network::conditionalEdgeWeight()
|
||||
{
|
||||
auto result = vector<double>();
|
||||
auto source = vector<string>(features);
|
||||
source.push_back(className);
|
||||
auto combinations = nodes[className]->combinations(source);
|
||||
auto margin = nodes[className]->getCPT();
|
||||
for (auto [first, second] : combinations) {
|
||||
int64_t index_first = find(features.begin(), features.end(), first) - features.begin();
|
||||
int64_t index_second = find(features.begin(), features.end(), second) - features.begin();
|
||||
double accumulated = 0;
|
||||
for (int value = 0; value < classNumStates; ++value) {
|
||||
auto mask = samples.index({ "...", -1 }) == value;
|
||||
auto first_dataset = samples.index({ mask, index_first });
|
||||
auto second_dataset = samples.index({ mask, index_second });
|
||||
auto mi = mutualInformation(first_dataset, second_dataset);
|
||||
auto pb = margin[value].item<float>();
|
||||
accumulated += pb * mi;
|
||||
}
|
||||
result.push_back(accumulated);
|
||||
}
|
||||
long n_vars = source.size();
|
||||
auto matrix = torch::zeros({ n_vars, n_vars });
|
||||
auto indices = torch::triu_indices(n_vars, n_vars, 1);
|
||||
for (auto i = 0; i < result.size(); ++i) {
|
||||
auto x = indices[0][i];
|
||||
auto y = indices[1][i];
|
||||
matrix[x][y] = result[i];
|
||||
matrix[y][x] = result[i];
|
||||
}
|
||||
return matrix;
|
||||
}
|
||||
double Network::entropy(torch::Tensor& feature)
|
||||
{
|
||||
torch::Tensor counts = feature.bincount();
|
||||
int totalWeight = counts.sum().item<int>();
|
||||
torch::Tensor probs = counts.to(torch::kFloat) / totalWeight;
|
||||
torch::Tensor logProbs = torch::log(probs);
|
||||
torch::Tensor entropy = -probs * logProbs;
|
||||
return entropy.nansum().item<double>();
|
||||
}
|
||||
// H(Y|X) = sum_{x in X} p(x) H(Y|X=x)
|
||||
double Network::conditionalEntropy(torch::Tensor& firstFeature, torch::Tensor& secondFeature)
|
||||
{
|
||||
int numSamples = firstFeature.sizes()[0];
|
||||
torch::Tensor featureCounts = secondFeature.bincount();
|
||||
unordered_map<int, unordered_map<int, double>> jointCounts;
|
||||
double totalWeight = 0;
|
||||
for (auto i = 0; i < numSamples; i++) {
|
||||
jointCounts[secondFeature[i].item<int>()][firstFeature[i].item<int>()] += 1;
|
||||
totalWeight += 1;
|
||||
}
|
||||
if (totalWeight == 0)
|
||||
throw invalid_argument("Total weight should not be zero");
|
||||
double entropyValue = 0;
|
||||
for (int value = 0; value < featureCounts.sizes()[0]; ++value) {
|
||||
double p_f = featureCounts[value].item<double>() / totalWeight;
|
||||
double entropy_f = 0;
|
||||
for (auto& [label, jointCount] : jointCounts[value]) {
|
||||
double p_l_f = jointCount / featureCounts[value].item<double>();
|
||||
if (p_l_f > 0) {
|
||||
entropy_f -= p_l_f * log(p_l_f);
|
||||
} else {
|
||||
entropy_f = 0;
|
||||
}
|
||||
}
|
||||
entropyValue += p_f * entropy_f;
|
||||
}
|
||||
return entropyValue;
|
||||
}
|
||||
// I(X;Y) = H(Y) - H(Y|X)
|
||||
double Network::mutualInformation(torch::Tensor& firstFeature, torch::Tensor& secondFeature)
|
||||
{
|
||||
return entropy(firstFeature) - conditionalEntropy(firstFeature, secondFeature);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -15,6 +15,7 @@ namespace bayesnet {
|
||||
vector<string> features;
|
||||
string className;
|
||||
int laplaceSmoothing;
|
||||
torch::Tensor samples;
|
||||
bool isCyclic(const std::string&, std::unordered_set<std::string>&, std::unordered_set<std::string>&);
|
||||
vector<double> predict_sample(const vector<int>&);
|
||||
vector<double> exactInference(map<string, int>&);
|
||||
@ -24,12 +25,12 @@ namespace bayesnet {
|
||||
double conditionalEntropy(torch::Tensor&, torch::Tensor&);
|
||||
double mutualInformation(torch::Tensor&, torch::Tensor&);
|
||||
public:
|
||||
torch::Tensor samples;
|
||||
Network();
|
||||
Network(float, int);
|
||||
Network(float);
|
||||
Network(Network&);
|
||||
~Network();
|
||||
torch::Tensor& getSamples();
|
||||
float getmaxThreads();
|
||||
void addNode(string, int);
|
||||
void addEdge(const string, const string);
|
||||
|
Loading…
Reference in New Issue
Block a user