Inference working
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
#include "Network.h"
|
||||
#include "ExactInference.h"
|
||||
namespace bayesnet {
|
||||
Network::Network() : laplaceSmoothing(1), root(nullptr), features(vector<string>()), className("") {}
|
||||
Network::Network(int smoothing) : laplaceSmoothing(smoothing), root(nullptr), features(vector<string>()), className("") {}
|
||||
Network::Network(Network& other) : laplaceSmoothing(other.laplaceSmoothing), root(other.root), features(other.features), className(other.className)
|
||||
Network::Network() : laplaceSmoothing(1), root(nullptr), features(vector<string>()), className(""), classNumStates(0) {}
|
||||
Network::Network(int smoothing) : laplaceSmoothing(smoothing), root(nullptr), features(vector<string>()), className(""), classNumStates(0) {}
|
||||
Network::Network(Network& other) : laplaceSmoothing(other.laplaceSmoothing), root(other.root), features(other.features), className(other.className), classNumStates(other.getClassNumStates())
|
||||
{
|
||||
for (auto& pair : other.nodes) {
|
||||
nodes[pair.first] = new Node(*pair.second);
|
||||
@@ -31,6 +31,14 @@ namespace bayesnet {
|
||||
{
|
||||
return features;
|
||||
}
|
||||
int Network::getClassNumStates()
|
||||
{
|
||||
return classNumStates;
|
||||
}
|
||||
string Network::getClassName()
|
||||
{
|
||||
return className;
|
||||
}
|
||||
void Network::setRoot(string name)
|
||||
{
|
||||
if (nodes.find(name) == nodes.end()) {
|
||||
@@ -93,6 +101,7 @@ namespace bayesnet {
|
||||
this->dataset[featureNames[i]] = dataset[i];
|
||||
}
|
||||
this->dataset[className] = labels;
|
||||
this->classNumStates = *max_element(labels.begin(), labels.end()) + 1;
|
||||
estimateParameters();
|
||||
}
|
||||
|
||||
@@ -100,29 +109,7 @@ namespace bayesnet {
|
||||
{
|
||||
auto dimensions = vector<int64_t>();
|
||||
for (auto [name, node] : nodes) {
|
||||
// Get dimensions of the CPT
|
||||
dimensions.clear();
|
||||
dimensions.push_back(node->getNumStates());
|
||||
for (auto father : node->getParents()) {
|
||||
dimensions.push_back(father->getNumStates());
|
||||
}
|
||||
auto length = dimensions.size();
|
||||
// Create a tensor of zeros with the dimensions of the CPT
|
||||
torch::Tensor cpt = torch::zeros(dimensions, torch::kFloat) + laplaceSmoothing;
|
||||
// Fill table with counts
|
||||
for (int n_sample = 0; n_sample < dataset[name].size(); ++n_sample) {
|
||||
torch::List<c10::optional<torch::Tensor>> coordinates;
|
||||
coordinates.push_back(torch::tensor(dataset[name][n_sample]));
|
||||
for (auto father : node->getParents()) {
|
||||
coordinates.push_back(torch::tensor(dataset[father->getName()][n_sample]));
|
||||
}
|
||||
// Increment the count of the corresponding coordinate
|
||||
cpt.index_put_({ coordinates }, cpt.index({ coordinates }) + 1);
|
||||
}
|
||||
// Normalize the counts
|
||||
cpt = cpt / cpt.sum(0);
|
||||
// store thre resulting cpt in the node
|
||||
node->setCPT(cpt);
|
||||
node->computeCPT(dataset, laplaceSmoothing);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -175,14 +162,8 @@ namespace bayesnet {
|
||||
for (int i = 0; i < sample.size(); ++i) {
|
||||
evidence[features[i]] = sample[i];
|
||||
}
|
||||
inference.setEvidence(evidence);
|
||||
vector<double> classProbabilities = inference.variableElimination();
|
||||
vector<double> classProbabilities = inference.variableElimination(evidence);
|
||||
|
||||
// Normalize the probabilities to sum to 1
|
||||
double sum = accumulate(classProbabilities.begin(), classProbabilities.end(), 0.0);
|
||||
for (double& prob : classProbabilities) {
|
||||
prob /= sum;
|
||||
}
|
||||
// Find the class with the maximum posterior probability
|
||||
auto maxElem = max_element(classProbabilities.begin(), classProbabilities.end());
|
||||
int predictedClass = distance(classProbabilities.begin(), maxElem);
|
||||
|
Reference in New Issue
Block a user