Added ExactInference and Factor classes

This commit is contained in:
Ricardo Montañana Gómez 2023-07-02 20:39:13 +02:00
parent ad255625e8
commit 12f0e1e063
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
9 changed files with 172 additions and 44 deletions

View File

@ -1,2 +1,2 @@
add_library(BayesNet Network.cc Node.cc) add_library(BayesNet Network.cc Node.cc ExactInference.cc Factor.cc)
target_link_libraries(BayesNet "${TORCH_LIBRARIES}") target_link_libraries(BayesNet "${TORCH_LIBRARIES}")

48
src/ExactInference.cc Normal file
View File

@ -0,0 +1,48 @@
#include "ExactInference.h"
namespace bayesnet {
ExactInference::ExactInference(Network& net) : network(net), evidence(map<string, int>()), candidates(net.getFeatures()) {}
void ExactInference::setEvidence(const map<string, int>& evidence)
{
this->evidence = evidence;
}
ExactInference::~ExactInference()
{
for (auto& factor : factors) {
delete factor;
}
}
void ExactInference::buildFactors()
{
for (auto node : network.getNodes()) {
factors.push_back(node.second->toFactor());
}
}
string ExactInference::nextCandidate()
{
string result = "";
map<string, Node*> nodes = network.getNodes();
int minFill = INT_MAX;
for (auto candidate : candidates) {
unsigned fill = nodes[candidate]->minFill();
if (fill < minFill) {
minFill = fill;
result = candidate;
}
}
return result;
}
vector<double> ExactInference::variableElimination()
{
vector<double> result;
string candidate;
buildFactors();
// Eliminate evidence
while ((candidate = nextCandidate()) != "") {
// Erase candidate from candidates (Eraseremove idiom)
candidates.erase(remove(candidates.begin(), candidates.end(), candidate), candidates.end());
}
return result;
}
}

27
src/ExactInference.h Normal file
View File

@ -0,0 +1,27 @@
#ifndef EXACTINFERENCE_H
#define EXACTINFERENCE_H
#include "Network.h"
#include "Factor.h"
#include "Node.h"
#include <map>
#include <vector>
#include <string>
using namespace std;
namespace bayesnet {
class ExactInference {
private:
Network network;
map<string, int> evidence;
vector<Factor*> factors;
vector<string> candidates; // variables to be removed
void buildFactors();
string nextCandidate(); // Return the next variable to eliminate using MinFill criterion
public:
ExactInference(Network&);
~ExactInference();
void setEvidence(const map<string, int>&);
vector<double> variableElimination();
};
}
#endif

10
src/Factor.cc Normal file
View File

@ -0,0 +1,10 @@
#include "Factor.h"
#include <vector>
#include <string>
using namespace std;
namespace bayesnet {
Factor::Factor(vector<string>& variables, vector<int>& cardinalities, torch::Tensor& values) : variables(variables), cardinalities(cardinalities), values(values) {}
Factor::~Factor() = default;
}

27
src/Factor.h Normal file
View File

@ -0,0 +1,27 @@
#ifndef FACTOR_H
#define FACTOR_H
#include <torch/torch.h>
#include <vector>
#include <string>
using namespace std;
namespace bayesnet {
class Factor {
private:
vector<string> variables;
vector<int> cardinalities;
torch::Tensor values;
public:
Factor(vector<string>&, vector<int>&, torch::Tensor&);
~Factor();
Factor(const Factor&);
Factor& operator=(const Factor&);
void setVariables(vector<string>&);
void setCardinalities(vector<int>&);
void setValues(torch::Tensor&);
vector<string>& getVariables();
vector<int>& getCardinalities();
torch::Tensor& getValues();
};
}
#endif

View File

@ -1,4 +1,5 @@
#include "Network.h" #include "Network.h"
#include "ExactInference.h"
namespace bayesnet { namespace bayesnet {
Network::Network() : laplaceSmoothing(1), root(nullptr), features(vector<string>()), className("") {} Network::Network() : laplaceSmoothing(1), root(nullptr), features(vector<string>()), className("") {}
Network::Network(int smoothing) : laplaceSmoothing(smoothing), root(nullptr), features(vector<string>()), className("") {} Network::Network(int smoothing) : laplaceSmoothing(smoothing), root(nullptr), features(vector<string>()), className("") {}
@ -26,6 +27,10 @@ namespace bayesnet {
root = nodes[name]; root = nodes[name];
} }
} }
vector<string> Network::getFeatures()
{
return features;
}
void Network::setRoot(string name) void Network::setRoot(string name)
{ {
if (nodes.find(name) == nodes.end()) { if (nodes.find(name) == nodes.end()) {
@ -120,37 +125,7 @@ namespace bayesnet {
node->setCPT(cpt); node->setCPT(cpt);
} }
} }
// pair<int, double> Network::predict_sample(const vector<int>& sample)
// {
// // For each possible class, calculate the posterior probability
// Node* classNode = nodes[className];
// int numClassStates = classNode->getNumStates();
// vector<double> classProbabilities(numClassStates, 0.0);
// for (int classState = 0; classState < numClassStates; ++classState) {
// // Start with the prior probability of the class
// classProbabilities[classState] = classNode->getCPT()[classState].item<double>();
// // Multiply by the likelihood of each feature given the class
// for (auto& pair : nodes) {
// if (pair.first != className) {
// Node* node = pair.second;
// int featureValue = featureValues[pair.first];
// // We use the class as the parent state to index into the CPT
// classProbabilities[classState] *= node->getCPT()[classState][featureValue].item<double>();
// }
// }
// }
// // Find the class with the maximum posterior probability
// auto maxElem = max_element(classProbabilities.begin(), classProbabilities.end());
// int predictedClass = distance(classProbabilities.begin(), maxElem);
// double maxProbability = *maxElem;
// return make_pair(predictedClass, maxProbability);
// }
vector<int> Network::predict(const vector<vector<int>>& samples) vector<int> Network::predict(const vector<vector<int>>& samples)
{ {
vector<int> predictions; vector<int> predictions;
@ -195,15 +170,13 @@ namespace bayesnet {
throw invalid_argument("Sample size (" + to_string(sample.size()) + throw invalid_argument("Sample size (" + to_string(sample.size()) +
") does not match the number of features (" + to_string(features.size()) + ")"); ") does not match the number of features (" + to_string(features.size()) + ")");
} }
// Map the feature values to their corresponding nodes auto inference = ExactInference(*this);
map<string, int> featureValues; map<string, int> evidence;
for (int i = 0; i < features.size(); ++i) { for (int i = 0; i < sample.size(); ++i) {
featureValues[features[i]] = sample[i]; evidence[features[i]] = sample[i];
} }
inference.setEvidence(evidence);
// For each possible class, calculate the posterior probability vector<double> classProbabilities = inference.variableElimination();
Network network = *this;
vector<double> classProbabilities = eliminateVariables(network, featureValues);
// Normalize the probabilities to sum to 1 // Normalize the probabilities to sum to 1
double sum = accumulate(classProbabilities.begin(), classProbabilities.end(), 0.0); double sum = accumulate(classProbabilities.begin(), classProbabilities.end(), 0.0);
@ -217,8 +190,4 @@ namespace bayesnet {
return make_pair(predictedClass, maxProbability); return make_pair(predictedClass, maxProbability);
} }
vector<double> eliminateVariables(network, featureValues)
{
}
} }

View File

@ -16,7 +16,6 @@ namespace bayesnet {
int laplaceSmoothing; int laplaceSmoothing;
bool isCyclic(const std::string&, std::unordered_set<std::string>&, std::unordered_set<std::string>&); bool isCyclic(const std::string&, std::unordered_set<std::string>&, std::unordered_set<std::string>&);
pair<int, double> predict_sample(const vector<int>&); pair<int, double> predict_sample(const vector<int>&);
vector<double> eliminateVariables(Network&, const map<string, int>&);
public: public:
Network(); Network();
Network(int); Network(int);
@ -25,6 +24,7 @@ namespace bayesnet {
void addNode(string, int); void addNode(string, int);
void addEdge(const string, const string); void addEdge(const string, const string);
map<string, Node*>& getNodes(); map<string, Node*>& getNodes();
vector<string> getFeatures();
void fit(const vector<vector<int>>&, const vector<int>&, const vector<string>&, const string&); void fit(const vector<vector<int>>&, const vector<int>&, const vector<string>&, const string&);
void estimateParameters(); void estimateParameters();
void setRoot(string); void setRoot(string);

View File

@ -53,4 +53,47 @@ namespace bayesnet {
{ {
this->cpt = cpt; this->cpt = cpt;
} }
/*
The MinFill criterion is a heuristic for variable elimination.
The variable that minimizes the number of edges that need to be added to the graph to make it triangulated.
This is done by counting the number of edges that need to be added to the graph if the variable is eliminated.
The variable with the minimum number of edges is chosen.
Here this is done computing the length of the combinations of the node neighbors taken 2 by 2.
*/
unsigned Node::minFill()
{
set<string> neighbors;
for (auto child : children) {
neighbors.emplace(child->getName());
}
for (auto parent : parents) {
neighbors.emplace(parent->getName());
}
return combinations(neighbors).size();
}
vector<string> Node::combinations(const set<string>& neighbors)
{
vector<string> source(neighbors.begin(), neighbors.end());
vector<string> result;
for (int i = 0; i < source.size(); ++i) {
string temp = source[i];
for (int j = i + 1; j < source.size(); ++j) {
result.push_back(temp + source[j]);
}
}
return result;
}
Factor* Node::toFactor()
{
vector<string> variables;
vector<int> cardinalities;
variables.push_back(name);
cardinalities.push_back(numStates);
for (auto parent : parents) {
variables.push_back(parent->getName());
cardinalities.push_back(parent->getNumStates());
}
return new Factor(variables, cardinalities, cpt);
}
} }

View File

@ -1,6 +1,7 @@
#ifndef NODE_H #ifndef NODE_H
#define NODE_H #define NODE_H
#include <torch/torch.h> #include <torch/torch.h>
#include "Factor.h"
#include <vector> #include <vector>
#include <string> #include <string>
namespace bayesnet { namespace bayesnet {
@ -15,6 +16,7 @@ namespace bayesnet {
torch::Tensor cpTable; torch::Tensor cpTable;
int numStates; int numStates;
torch::Tensor cpt; torch::Tensor cpt;
vector<string> combinations(const set<string>&);
public: public:
Node(const std::string&, int); Node(const std::string&, int);
void addParent(Node*); void addParent(Node*);
@ -28,7 +30,9 @@ namespace bayesnet {
void setCPT(const torch::Tensor&); void setCPT(const torch::Tensor&);
int getNumStates() const; int getNumStates() const;
void setNumStates(int); void setNumStates(int);
unsigned minFill();
int getId() const { return id; } int getId() const { return id; }
Factor* toFactor();
}; };
} }
#endif #endif