First try to link with bayesnet

This commit is contained in:
2023-07-07 00:23:47 +02:00
parent ea473fc604
commit 61e4c176eb
9 changed files with 5800 additions and 2 deletions

View File

@@ -1,2 +1,4 @@
include README.md LICENSE
include bayesclass/FeatureSelect.h
include bayesclass/FeatureSelect.h
include bayesclass/Node.h
include bayesclass/Network.h

5305
bayesclass/BayesNetwork.cpp Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,52 @@
# distutils: language = c++
# cython: language_level = 3
from libcpp.vector cimport vector
from libcpp.string cimport string
from libcpp.pair cimport pair
from libcpp cimport bool
cdef extern from "Node.h" namespace "bayesnet":
cdef cppclass Node:
pass
cdef extern from "Network.h" namespace "bayesnet":
cdef cppclass Network:
Network(float, float) except +
void fit(vector[vector[int]], vector[int], vector[string], string)
vector[int] predict(vector[vector[int]])
vector[vector[float]] predict_proba(vector[vector[int]])
float score(const vector[vector[int]], const vector[int])
void addNode(string, int);
void addEdge(string, string);
vector[string] getFeatures();
int getClassNumStates();
string getClassName();
string version()
cdef class BayesNetwork:
cdef Network *thisptr
def __cinit__(self, maxThreads=0.8, laplaceSmooth=1.0):
self.thisptr = new Network(maxThreads, laplaceSmooth)
def __dealloc__(self):
del self.thisptr
def fit(self, X, y, features, className):
self.thisptr.fit(X, y, features, className)
return self
def predict(self, X):
return self.thisptr.predict(X)
def predict_proba(self, X):
return self.thisptr.predict_proba(X)
def score(self, X, y):
return self.thisptr.score(X, y)
def addNode(self, name, states):
self.thisptr.addNode(name, states)
def addEdge(self, source, destination):
self.thisptr.addEdge(source, destination)
def getFeatures(self):
return self.thisptr.getFeatures()
def getClassName(self):
return self.thisptr.getClassName()
def getClassNumStates(self):
return self.thisptr.getClassNumStates()
def __reduce__(self):
return (BayesNetwork, ())

235
bayesclass/Network.cc Normal file
View File

@@ -0,0 +1,235 @@
#include <thread>
#include <mutex>
#include "Network.h"
namespace bayesnet {
Network::Network() : laplaceSmoothing(1), features(vector<string>()), className(""), classNumStates(0), maxThreads(0.8) {}
Network::Network(float maxT) : laplaceSmoothing(1), features(vector<string>()), className(""), classNumStates(0), maxThreads(maxT) {}
Network::Network(float maxT, int smoothing) : laplaceSmoothing(smoothing), features(vector<string>()), className(""), classNumStates(0), maxThreads(maxT) {}
Network::Network(Network& other) : laplaceSmoothing(other.laplaceSmoothing), features(other.features), className(other.className), classNumStates(other.getClassNumStates()), maxThreads(other.getmaxThreads())
{
for (auto& pair : other.nodes) {
nodes[pair.first] = new Node(*pair.second);
}
}
Network::~Network()
{
for (auto& pair : nodes) {
delete pair.second;
}
}
float Network::getmaxThreads()
{
return maxThreads;
}
void Network::addNode(string name, int numStates)
{
if (nodes.find(name) != nodes.end()) {
// if node exists update its number of states
nodes[name]->setNumStates(numStates);
return;
}
nodes[name] = new Node(name, numStates);
}
vector<string> Network::getFeatures()
{
return features;
}
int Network::getClassNumStates()
{
return classNumStates;
}
string Network::getClassName()
{
return className;
}
bool Network::isCyclic(const string& nodeId, unordered_set<string>& visited, unordered_set<string>& recStack)
{
if (visited.find(nodeId) == visited.end()) // if node hasn't been visited yet
{
visited.insert(nodeId);
recStack.insert(nodeId);
for (Node* child : nodes[nodeId]->getChildren()) {
if (visited.find(child->getName()) == visited.end() && isCyclic(child->getName(), visited, recStack))
return true;
else if (recStack.find(child->getName()) != recStack.end())
return true;
}
}
recStack.erase(nodeId); // remove node from recursion stack before function ends
return false;
}
void Network::addEdge(const string parent, const string child)
{
if (nodes.find(parent) == nodes.end()) {
throw invalid_argument("Parent node " + parent + " does not exist");
}
if (nodes.find(child) == nodes.end()) {
throw invalid_argument("Child node " + child + " does not exist");
}
// Temporarily add edge to check for cycles
nodes[parent]->addChild(nodes[child]);
nodes[child]->addParent(nodes[parent]);
unordered_set<string> visited;
unordered_set<string> recStack;
if (isCyclic(nodes[child]->getName(), visited, recStack)) // if adding this edge forms a cycle
{
// remove problematic edge
nodes[parent]->removeChild(nodes[child]);
nodes[child]->removeParent(nodes[parent]);
throw invalid_argument("Adding this edge forms a cycle in the graph.");
}
}
map<string, Node*>& Network::getNodes()
{
return nodes;
}
void Network::fit(const vector<vector<int>>& input_data, const vector<int>& labels, const vector<string>& featureNames, const string& className)
{
features = featureNames;
this->className = className;
dataset.clear();
// Build dataset
for (int i = 0; i < featureNames.size(); ++i) {
dataset[featureNames[i]] = input_data[i];
}
dataset[className] = labels;
classNumStates = *max_element(labels.begin(), labels.end()) + 1;
int maxThreadsRunning = static_cast<int>(std::thread::hardware_concurrency() * maxThreads);
if (maxThreadsRunning < 1) {
maxThreadsRunning = 1;
}
vector<thread> threads;
mutex mtx;
condition_variable cv;
int activeThreads = 0;
int nextNodeIndex = 0;
while (nextNodeIndex < nodes.size()) {
unique_lock<mutex> lock(mtx);
cv.wait(lock, [&activeThreads, &maxThreadsRunning]() { return activeThreads < maxThreadsRunning; });
if (nextNodeIndex >= nodes.size()) {
break; // No more work remaining
}
threads.emplace_back([this, &nextNodeIndex, &mtx, &cv, &activeThreads]() {
while (true) {
unique_lock<mutex> lock(mtx);
if (nextNodeIndex >= nodes.size()) {
break; // No more work remaining
}
auto& pair = *std::next(nodes.begin(), nextNodeIndex);
++nextNodeIndex;
lock.unlock();
pair.second->computeCPT(dataset, laplaceSmoothing);
lock.lock();
nodes[pair.first] = pair.second;
lock.unlock();
}
lock_guard<mutex> lock(mtx);
--activeThreads;
cv.notify_one();
});
++activeThreads;
}
for (auto& thread : threads) {
thread.join();
}
}
vector<int> Network::predict(const vector<vector<int>>& samples)
{
vector<int> predictions;
vector<int> sample;
for (int row = 0; row < samples[0].size(); ++row) {
sample.clear();
for (int col = 0; col < samples.size(); ++col) {
sample.push_back(samples[col][row]);
}
predictions.push_back(predict_sample(sample).first);
}
return predictions;
}
vector<vector<float>> Network::predict_proba(const vector<vector<int>>& samples)
{
vector<pair<int, double>> predictions;
vector<int> sample;
for (int row = 0; row < samples[0].size(); ++row) {
sample.clear();
for (int col = 0; col < samples.size(); ++col) {
sample.push_back(samples[col][row]);
}
predictions.push_back(predict_sample(sample.second));
}
return predictions;
}
double Network::score(const vector<vector<int>>& samples, const vector<int>& labels)
{
vector<int> y_pred = predict(samples);
int correct = 0;
for (int i = 0; i < y_pred.size(); ++i) {
if (y_pred[i] == labels[i]) {
correct++;
}
}
return (double)correct / y_pred.size();
}
pair<int, double> Network::predict_sample(const vector<int>& sample)
{
// Ensure the sample size is equal to the number of features
if (sample.size() != features.size()) {
throw invalid_argument("Sample size (" + to_string(sample.size()) +
") does not match the number of features (" + to_string(features.size()) + ")");
}
map<string, int> evidence;
for (int i = 0; i < sample.size(); ++i) {
evidence[features[i]] = sample[i];
}
vector<double> classProbabilities = exactInference(evidence);
// Find the class with the maximum posterior probability
auto maxElem = max_element(classProbabilities.begin(), classProbabilities.end());
int predictedClass = distance(classProbabilities.begin(), maxElem);
double maxProbability = *maxElem;
return make_pair(predictedClass, maxProbability);
}
double Network::computeFactor(map<string, int>& completeEvidence)
{
double result = 1.0;
for (auto node : getNodes()) {
result *= node.second->getFactorValue(completeEvidence);
}
return result;
}
vector<double> Network::exactInference(map<string, int>& evidence)
{
vector<double> result(classNumStates, 0.0);
vector<thread> threads;
mutex mtx;
for (int i = 0; i < classNumStates; ++i) {
threads.emplace_back([this, &result, &evidence, i, &mtx]() {
auto completeEvidence = map<string, int>(evidence);
completeEvidence[getClassName()] = i;
double factor = computeFactor(completeEvidence);
lock_guard<mutex> lock(mtx);
result[i] = factor;
});
}
for (auto& thread : threads) {
thread.join();
}
// Normalize result
double sum = accumulate(result.begin(), result.end(), 0.0);
for (double& value : result) {
value /= sum;
}
return result;
}
}

41
bayesclass/Network.h Normal file
View File

@@ -0,0 +1,41 @@
#ifndef NETWORK_H
#define NETWORK_H
#include "Node.h"
#include <map>
#include <vector>
namespace bayesnet {
class Network {
private:
map<string, Node*> nodes;
map<string, vector<int>> dataset;
float maxThreads;
int classNumStates;
vector<string> features;
string className;
int laplaceSmoothing;
bool isCyclic(const std::string&, std::unordered_set<std::string>&, std::unordered_set<std::string>&);
pair<int, double> predict_sample(const vector<int>&);
vector<double> exactInference(map<string, int>&);
double computeFactor(map<string, int>&);
public:
Network();
Network(float, int);
Network(float);
Network(Network&);
~Network();
float getmaxThreads();
void addNode(string, int);
void addEdge(const string, const string);
map<string, Node*>& getNodes();
vector<string> getFeatures();
int getClassNumStates();
string getClassName();
void fit(const vector<vector<int>>&, const vector<int>&, const vector<string>&, const string&);
vector<int> predict(const vector<vector<int>>&);
vector<vector<float>> predict_proba(const vector<vector<int>>&);
double score(const vector<vector<int>>&, const vector<int>&);
};
}
#endif

114
bayesclass/Node.cc Normal file
View File

@@ -0,0 +1,114 @@
#include "Node.h"
namespace bayesnet {
Node::Node(const std::string& name, int numStates)
: name(name), numStates(numStates), cpTable(torch::Tensor()), parents(vector<Node*>()), children(vector<Node*>())
{
}
string Node::getName() const
{
return name;
}
void Node::addParent(Node* parent)
{
parents.push_back(parent);
}
void Node::removeParent(Node* parent)
{
parents.erase(std::remove(parents.begin(), parents.end(), parent), parents.end());
}
void Node::removeChild(Node* child)
{
children.erase(std::remove(children.begin(), children.end(), child), children.end());
}
void Node::addChild(Node* child)
{
children.push_back(child);
}
vector<Node*>& Node::getParents()
{
return parents;
}
vector<Node*>& Node::getChildren()
{
return children;
}
int Node::getNumStates() const
{
return numStates;
}
void Node::setNumStates(int numStates)
{
this->numStates = numStates;
}
torch::Tensor& Node::getCPT()
{
return cpTable;
}
/*
The MinFill criterion is a heuristic for variable elimination.
The variable that minimizes the number of edges that need to be added to the graph to make it triangulated.
This is done by counting the number of edges that need to be added to the graph if the variable is eliminated.
The variable with the minimum number of edges is chosen.
Here this is done computing the length of the combinations of the node neighbors taken 2 by 2.
*/
unsigned Node::minFill()
{
set<string> neighbors;
for (auto child : children) {
neighbors.emplace(child->getName());
}
for (auto parent : parents) {
neighbors.emplace(parent->getName());
}
return combinations(neighbors).size();
}
vector<string> Node::combinations(const set<string>& neighbors)
{
vector<string> source(neighbors.begin(), neighbors.end());
vector<string> result;
for (int i = 0; i < source.size(); ++i) {
string temp = source[i];
for (int j = i + 1; j < source.size(); ++j) {
result.push_back(temp + source[j]);
}
}
return result;
}
void Node::computeCPT(map<string, vector<int>>& dataset, const int laplaceSmoothing)
{
// Get dimensions of the CPT
dimensions.push_back(numStates);
for (auto father : getParents()) {
dimensions.push_back(father->getNumStates());
}
auto length = dimensions.size();
// Create a tensor of zeros with the dimensions of the CPT
cpTable = torch::zeros(dimensions, torch::kFloat) + laplaceSmoothing;
// Fill table with counts
for (int n_sample = 0; n_sample < dataset[name].size(); ++n_sample) {
torch::List<c10::optional<torch::Tensor>> coordinates;
coordinates.push_back(torch::tensor(dataset[name][n_sample]));
for (auto father : getParents()) {
coordinates.push_back(torch::tensor(dataset[father->getName()][n_sample]));
}
// Increment the count of the corresponding coordinate
cpTable.index_put_({ coordinates }, cpTable.index({ coordinates }) + 1);
}
// Normalize the counts
cpTable = cpTable / cpTable.sum(0);
}
float Node::getFactorValue(map<string, int>& evidence)
{
torch::List<c10::optional<torch::Tensor>> coordinates;
// following predetermined order of indices in the cpTable (see Node.h)
coordinates.push_back(torch::tensor(evidence[name]));
for (auto parent : getParents()) {
coordinates.push_back(torch::tensor(evidence[parent->getName()]));
}
return cpTable.index({ coordinates }).item<float>();
}
}

35
bayesclass/Node.h Normal file
View File

@@ -0,0 +1,35 @@
#ifndef NODE_H
#define NODE_H
#include <torch/torch.h>
//#include <torch/extension.h>
#include <vector>
#include <string>
namespace bayesnet {
using namespace std;
class Node {
private:
string name;
vector<Node*> parents;
vector<Node*> children;
int numStates; // number of states of the variable
torch::Tensor cpTable; // Order of indices is 0-> node variable, 1-> 1st parent, 2-> 2nd parent, ...
vector<int64_t> dimensions; // dimensions of the cpTable
vector<string> combinations(const set<string>&);
public:
Node(const std::string&, int);
void addParent(Node*);
void addChild(Node*);
void removeParent(Node*);
void removeChild(Node*);
string getName() const;
vector<Node*>& getParents();
vector<Node*>& getChildren();
torch::Tensor& getCPT();
void computeCPT(map<string, vector<int>>&, const int);
int getNumStates() const;
void setNumStates(int);
unsigned minFill();
float getFactorValue(map<string, int>&);
};
}
#endif

View File

@@ -1,5 +1,5 @@
[build-system]
requires = ["setuptools", "setuptools-scm", "cython", "wheel"]
requires = ["setuptools", "setuptools-scm", "cython", "wheel", "torch"]
build-backend = "setuptools.build_meta"
[tool.setuptools]

View File

@@ -5,6 +5,7 @@
"""
from setuptools import Extension, setup
from torch.utils import cpp_extension
setup(
ext_modules=[
@@ -20,5 +21,18 @@ setup(
"-std=c++17",
],
),
Extension(
name="bayesclass.cppBayesNetwork",
sources=[
"bayesclass/BayesNetwork.pyx",
"bayesclass/Network.cc",
"bayesclass/Node.cc",
],
include_dirs=cpp_extension.include_paths(),
language="c++",
extra_compile_args=[
"-std=c++17",
],
),
]
)