Inference working

This commit is contained in:
2023-07-05 18:38:54 +02:00
parent 5db4d1189a
commit ba08b8dd3d
12 changed files with 114 additions and 250 deletions

View File

@@ -1,9 +1,9 @@
#include "Network.h"
#include "ExactInference.h"
namespace bayesnet {
Network::Network() : laplaceSmoothing(1), root(nullptr), features(vector<string>()), className("") {}
Network::Network(int smoothing) : laplaceSmoothing(smoothing), root(nullptr), features(vector<string>()), className("") {}
Network::Network(Network& other) : laplaceSmoothing(other.laplaceSmoothing), root(other.root), features(other.features), className(other.className)
Network::Network() : laplaceSmoothing(1), root(nullptr), features(vector<string>()), className(""), classNumStates(0) {}
Network::Network(int smoothing) : laplaceSmoothing(smoothing), root(nullptr), features(vector<string>()), className(""), classNumStates(0) {}
Network::Network(Network& other) : laplaceSmoothing(other.laplaceSmoothing), root(other.root), features(other.features), className(other.className), classNumStates(other.getClassNumStates())
{
for (auto& pair : other.nodes) {
nodes[pair.first] = new Node(*pair.second);
@@ -31,6 +31,14 @@ namespace bayesnet {
{
return features;
}
int Network::getClassNumStates()
{
return classNumStates;
}
string Network::getClassName()
{
return className;
}
void Network::setRoot(string name)
{
if (nodes.find(name) == nodes.end()) {
@@ -93,6 +101,7 @@ namespace bayesnet {
this->dataset[featureNames[i]] = dataset[i];
}
this->dataset[className] = labels;
this->classNumStates = *max_element(labels.begin(), labels.end()) + 1;
estimateParameters();
}
@@ -100,29 +109,7 @@ namespace bayesnet {
{
auto dimensions = vector<int64_t>();
for (auto [name, node] : nodes) {
// Get dimensions of the CPT
dimensions.clear();
dimensions.push_back(node->getNumStates());
for (auto father : node->getParents()) {
dimensions.push_back(father->getNumStates());
}
auto length = dimensions.size();
// Create a tensor of zeros with the dimensions of the CPT
torch::Tensor cpt = torch::zeros(dimensions, torch::kFloat) + laplaceSmoothing;
// Fill table with counts
for (int n_sample = 0; n_sample < dataset[name].size(); ++n_sample) {
torch::List<c10::optional<torch::Tensor>> coordinates;
coordinates.push_back(torch::tensor(dataset[name][n_sample]));
for (auto father : node->getParents()) {
coordinates.push_back(torch::tensor(dataset[father->getName()][n_sample]));
}
// Increment the count of the corresponding coordinate
cpt.index_put_({ coordinates }, cpt.index({ coordinates }) + 1);
}
// Normalize the counts
cpt = cpt / cpt.sum(0);
// store thre resulting cpt in the node
node->setCPT(cpt);
node->computeCPT(dataset, laplaceSmoothing);
}
}
@@ -175,14 +162,8 @@ namespace bayesnet {
for (int i = 0; i < sample.size(); ++i) {
evidence[features[i]] = sample[i];
}
inference.setEvidence(evidence);
vector<double> classProbabilities = inference.variableElimination();
vector<double> classProbabilities = inference.variableElimination(evidence);
// Normalize the probabilities to sum to 1
double sum = accumulate(classProbabilities.begin(), classProbabilities.end(), 0.0);
for (double& prob : classProbabilities) {
prob /= sum;
}
// Find the class with the maximum posterior probability
auto maxElem = max_element(classProbabilities.begin(), classProbabilities.end());
int predictedClass = distance(classProbabilities.begin(), maxElem);