Added ExactInference and Factor classes

This commit is contained in:
Ricardo Montañana Gómez 2023-07-02 20:39:13 +02:00
parent ad255625e8
commit 12f0e1e063
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
9 changed files with 172 additions and 44 deletions

View File

@ -1,2 +1,2 @@
add_library(BayesNet Network.cc Node.cc)
add_library(BayesNet Network.cc Node.cc ExactInference.cc Factor.cc)
target_link_libraries(BayesNet "${TORCH_LIBRARIES}")

48
src/ExactInference.cc Normal file
View File

@ -0,0 +1,48 @@
#include "ExactInference.h"
namespace bayesnet {
ExactInference::ExactInference(Network& net) : network(net), evidence(map<string, int>()), candidates(net.getFeatures()) {}
void ExactInference::setEvidence(const map<string, int>& evidence)
{
this->evidence = evidence;
}
ExactInference::~ExactInference()
{
for (auto& factor : factors) {
delete factor;
}
}
void ExactInference::buildFactors()
{
for (auto node : network.getNodes()) {
factors.push_back(node.second->toFactor());
}
}
string ExactInference::nextCandidate()
{
string result = "";
map<string, Node*> nodes = network.getNodes();
int minFill = INT_MAX;
for (auto candidate : candidates) {
unsigned fill = nodes[candidate]->minFill();
if (fill < minFill) {
minFill = fill;
result = candidate;
}
}
return result;
}
vector<double> ExactInference::variableElimination()
{
vector<double> result;
string candidate;
buildFactors();
// Eliminate evidence
while ((candidate = nextCandidate()) != "") {
// Erase candidate from candidates (Eraseremove idiom)
candidates.erase(remove(candidates.begin(), candidates.end(), candidate), candidates.end());
}
return result;
}
}

27
src/ExactInference.h Normal file
View File

@ -0,0 +1,27 @@
#ifndef EXACTINFERENCE_H
#define EXACTINFERENCE_H
#include "Network.h"
#include "Factor.h"
#include "Node.h"
#include <map>
#include <vector>
#include <string>
using namespace std;
namespace bayesnet {
class ExactInference {
private:
Network network;
map<string, int> evidence;
vector<Factor*> factors;
vector<string> candidates; // variables to be removed
void buildFactors();
string nextCandidate(); // Return the next variable to eliminate using MinFill criterion
public:
ExactInference(Network&);
~ExactInference();
void setEvidence(const map<string, int>&);
vector<double> variableElimination();
};
}
#endif

10
src/Factor.cc Normal file
View File

@ -0,0 +1,10 @@
#include "Factor.h"
#include <vector>
#include <string>
using namespace std;
namespace bayesnet {
Factor::Factor(vector<string>& variables, vector<int>& cardinalities, torch::Tensor& values) : variables(variables), cardinalities(cardinalities), values(values) {}
Factor::~Factor() = default;
}

27
src/Factor.h Normal file
View File

@ -0,0 +1,27 @@
#ifndef FACTOR_H
#define FACTOR_H
#include <torch/torch.h>
#include <vector>
#include <string>
using namespace std;
namespace bayesnet {
class Factor {
private:
vector<string> variables;
vector<int> cardinalities;
torch::Tensor values;
public:
Factor(vector<string>&, vector<int>&, torch::Tensor&);
~Factor();
Factor(const Factor&);
Factor& operator=(const Factor&);
void setVariables(vector<string>&);
void setCardinalities(vector<int>&);
void setValues(torch::Tensor&);
vector<string>& getVariables();
vector<int>& getCardinalities();
torch::Tensor& getValues();
};
}
#endif

View File

@ -1,4 +1,5 @@
#include "Network.h"
#include "ExactInference.h"
namespace bayesnet {
Network::Network() : laplaceSmoothing(1), root(nullptr), features(vector<string>()), className("") {}
Network::Network(int smoothing) : laplaceSmoothing(smoothing), root(nullptr), features(vector<string>()), className("") {}
@ -26,6 +27,10 @@ namespace bayesnet {
root = nodes[name];
}
}
vector<string> Network::getFeatures()
{
return features;
}
void Network::setRoot(string name)
{
if (nodes.find(name) == nodes.end()) {
@ -120,37 +125,7 @@ namespace bayesnet {
node->setCPT(cpt);
}
}
// pair<int, double> Network::predict_sample(const vector<int>& sample)
// {
// // For each possible class, calculate the posterior probability
// Node* classNode = nodes[className];
// int numClassStates = classNode->getNumStates();
// vector<double> classProbabilities(numClassStates, 0.0);
// for (int classState = 0; classState < numClassStates; ++classState) {
// // Start with the prior probability of the class
// classProbabilities[classState] = classNode->getCPT()[classState].item<double>();
// // Multiply by the likelihood of each feature given the class
// for (auto& pair : nodes) {
// if (pair.first != className) {
// Node* node = pair.second;
// int featureValue = featureValues[pair.first];
// // We use the class as the parent state to index into the CPT
// classProbabilities[classState] *= node->getCPT()[classState][featureValue].item<double>();
// }
// }
// }
// // Find the class with the maximum posterior probability
// auto maxElem = max_element(classProbabilities.begin(), classProbabilities.end());
// int predictedClass = distance(classProbabilities.begin(), maxElem);
// double maxProbability = *maxElem;
// return make_pair(predictedClass, maxProbability);
// }
vector<int> Network::predict(const vector<vector<int>>& samples)
{
vector<int> predictions;
@ -195,15 +170,13 @@ namespace bayesnet {
throw invalid_argument("Sample size (" + to_string(sample.size()) +
") does not match the number of features (" + to_string(features.size()) + ")");
}
// Map the feature values to their corresponding nodes
map<string, int> featureValues;
for (int i = 0; i < features.size(); ++i) {
featureValues[features[i]] = sample[i];
auto inference = ExactInference(*this);
map<string, int> evidence;
for (int i = 0; i < sample.size(); ++i) {
evidence[features[i]] = sample[i];
}
// For each possible class, calculate the posterior probability
Network network = *this;
vector<double> classProbabilities = eliminateVariables(network, featureValues);
inference.setEvidence(evidence);
vector<double> classProbabilities = inference.variableElimination();
// Normalize the probabilities to sum to 1
double sum = accumulate(classProbabilities.begin(), classProbabilities.end(), 0.0);
@ -217,8 +190,4 @@ namespace bayesnet {
return make_pair(predictedClass, maxProbability);
}
vector<double> eliminateVariables(network, featureValues)
{
}
}

View File

@ -16,7 +16,6 @@ namespace bayesnet {
int laplaceSmoothing;
bool isCyclic(const std::string&, std::unordered_set<std::string>&, std::unordered_set<std::string>&);
pair<int, double> predict_sample(const vector<int>&);
vector<double> eliminateVariables(Network&, const map<string, int>&);
public:
Network();
Network(int);
@ -25,6 +24,7 @@ namespace bayesnet {
void addNode(string, int);
void addEdge(const string, const string);
map<string, Node*>& getNodes();
vector<string> getFeatures();
void fit(const vector<vector<int>>&, const vector<int>&, const vector<string>&, const string&);
void estimateParameters();
void setRoot(string);

View File

@ -53,4 +53,47 @@ namespace bayesnet {
{
this->cpt = cpt;
}
/*
The MinFill criterion is a heuristic for variable elimination.
The variable that minimizes the number of edges that need to be added to the graph to make it triangulated.
This is done by counting the number of edges that need to be added to the graph if the variable is eliminated.
The variable with the minimum number of edges is chosen.
Here this is done computing the length of the combinations of the node neighbors taken 2 by 2.
*/
unsigned Node::minFill()
{
set<string> neighbors;
for (auto child : children) {
neighbors.emplace(child->getName());
}
for (auto parent : parents) {
neighbors.emplace(parent->getName());
}
return combinations(neighbors).size();
}
vector<string> Node::combinations(const set<string>& neighbors)
{
vector<string> source(neighbors.begin(), neighbors.end());
vector<string> result;
for (int i = 0; i < source.size(); ++i) {
string temp = source[i];
for (int j = i + 1; j < source.size(); ++j) {
result.push_back(temp + source[j]);
}
}
return result;
}
Factor* Node::toFactor()
{
vector<string> variables;
vector<int> cardinalities;
variables.push_back(name);
cardinalities.push_back(numStates);
for (auto parent : parents) {
variables.push_back(parent->getName());
cardinalities.push_back(parent->getNumStates());
}
return new Factor(variables, cardinalities, cpt);
}
}

View File

@ -1,6 +1,7 @@
#ifndef NODE_H
#define NODE_H
#include <torch/torch.h>
#include "Factor.h"
#include <vector>
#include <string>
namespace bayesnet {
@ -15,6 +16,7 @@ namespace bayesnet {
torch::Tensor cpTable;
int numStates;
torch::Tensor cpt;
vector<string> combinations(const set<string>&);
public:
Node(const std::string&, int);
void addParent(Node*);
@ -28,7 +30,9 @@ namespace bayesnet {
void setCPT(const torch::Tensor&);
int getNumStates() const;
void setNumStates(int);
unsigned minFill();
int getId() const { return id; }
Factor* toFactor();
};
}
#endif