diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index e2b3ce9..c27057c 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,2 +1,2 @@ -add_library(BayesNet Network.cc Node.cc) +add_library(BayesNet Network.cc Node.cc ExactInference.cc Factor.cc) target_link_libraries(BayesNet "${TORCH_LIBRARIES}") \ No newline at end of file diff --git a/src/ExactInference.cc b/src/ExactInference.cc new file mode 100644 index 0000000..bb8731b --- /dev/null +++ b/src/ExactInference.cc @@ -0,0 +1,48 @@ +#include "ExactInference.h" + +namespace bayesnet { + ExactInference::ExactInference(Network& net) : network(net), evidence(map()), candidates(net.getFeatures()) {} + void ExactInference::setEvidence(const map& evidence) + { + this->evidence = evidence; + } + ExactInference::~ExactInference() + { + for (auto& factor : factors) { + delete factor; + } + } + void ExactInference::buildFactors() + { + for (auto node : network.getNodes()) { + factors.push_back(node.second->toFactor()); + } + } + string ExactInference::nextCandidate() + { + string result = ""; + map nodes = network.getNodes(); + int minFill = INT_MAX; + for (auto candidate : candidates) { + unsigned fill = nodes[candidate]->minFill(); + if (fill < minFill) { + minFill = fill; + result = candidate; + } + } + return result; + } + vector ExactInference::variableElimination() + { + vector result; + string candidate; + buildFactors(); + // Eliminate evidence + while ((candidate = nextCandidate()) != "") { + // Erase candidate from candidates (Erase–remove idiom) + candidates.erase(remove(candidates.begin(), candidates.end(), candidate), candidates.end()); + + } + return result; + } +} \ No newline at end of file diff --git a/src/ExactInference.h b/src/ExactInference.h new file mode 100644 index 0000000..87cdb6f --- /dev/null +++ b/src/ExactInference.h @@ -0,0 +1,27 @@ +#ifndef EXACTINFERENCE_H +#define EXACTINFERENCE_H +#include "Network.h" +#include "Factor.h" +#include "Node.h" +#include +#include +#include +using namespace std; + +namespace bayesnet { + class ExactInference { + private: + Network network; + map evidence; + vector factors; + vector candidates; // variables to be removed + void buildFactors(); + string nextCandidate(); // Return the next variable to eliminate using MinFill criterion + public: + ExactInference(Network&); + ~ExactInference(); + void setEvidence(const map&); + vector variableElimination(); + }; +} +#endif \ No newline at end of file diff --git a/src/Factor.cc b/src/Factor.cc new file mode 100644 index 0000000..1e1b781 --- /dev/null +++ b/src/Factor.cc @@ -0,0 +1,10 @@ +#include "Factor.h" +#include +#include + +using namespace std; + +namespace bayesnet { + Factor::Factor(vector& variables, vector& cardinalities, torch::Tensor& values) : variables(variables), cardinalities(cardinalities), values(values) {} + Factor::~Factor() = default; +} \ No newline at end of file diff --git a/src/Factor.h b/src/Factor.h new file mode 100644 index 0000000..26cbbab --- /dev/null +++ b/src/Factor.h @@ -0,0 +1,27 @@ +#ifndef FACTOR_H +#define FACTOR_H +#include +#include +#include +using namespace std; + +namespace bayesnet { + class Factor { + private: + vector variables; + vector cardinalities; + torch::Tensor values; + public: + Factor(vector&, vector&, torch::Tensor&); + ~Factor(); + Factor(const Factor&); + Factor& operator=(const Factor&); + void setVariables(vector&); + void setCardinalities(vector&); + void setValues(torch::Tensor&); + vector& getVariables(); + vector& getCardinalities(); + torch::Tensor& getValues(); + }; +} +#endif \ No newline at end of file diff --git a/src/Network.cc b/src/Network.cc index 7df9445..30956a1 100644 --- a/src/Network.cc +++ b/src/Network.cc @@ -1,4 +1,5 @@ #include "Network.h" +#include "ExactInference.h" namespace bayesnet { Network::Network() : laplaceSmoothing(1), root(nullptr), features(vector()), className("") {} Network::Network(int smoothing) : laplaceSmoothing(smoothing), root(nullptr), features(vector()), className("") {} @@ -26,6 +27,10 @@ namespace bayesnet { root = nodes[name]; } } + vector Network::getFeatures() + { + return features; + } void Network::setRoot(string name) { if (nodes.find(name) == nodes.end()) { @@ -120,37 +125,7 @@ namespace bayesnet { node->setCPT(cpt); } } - // pair Network::predict_sample(const vector& sample) - // { - - // // For each possible class, calculate the posterior probability - // Node* classNode = nodes[className]; - // int numClassStates = classNode->getNumStates(); - // vector classProbabilities(numClassStates, 0.0); - // for (int classState = 0; classState < numClassStates; ++classState) { - // // Start with the prior probability of the class - // classProbabilities[classState] = classNode->getCPT()[classState].item(); - - // // Multiply by the likelihood of each feature given the class - // for (auto& pair : nodes) { - // if (pair.first != className) { - // Node* node = pair.second; - // int featureValue = featureValues[pair.first]; - - // // We use the class as the parent state to index into the CPT - // classProbabilities[classState] *= node->getCPT()[classState][featureValue].item(); - // } - // } - // } - - // // Find the class with the maximum posterior probability - // auto maxElem = max_element(classProbabilities.begin(), classProbabilities.end()); - // int predictedClass = distance(classProbabilities.begin(), maxElem); - // double maxProbability = *maxElem; - - // return make_pair(predictedClass, maxProbability); - // } vector Network::predict(const vector>& samples) { vector predictions; @@ -195,15 +170,13 @@ namespace bayesnet { throw invalid_argument("Sample size (" + to_string(sample.size()) + ") does not match the number of features (" + to_string(features.size()) + ")"); } - // Map the feature values to their corresponding nodes - map featureValues; - for (int i = 0; i < features.size(); ++i) { - featureValues[features[i]] = sample[i]; + auto inference = ExactInference(*this); + map evidence; + for (int i = 0; i < sample.size(); ++i) { + evidence[features[i]] = sample[i]; } - - // For each possible class, calculate the posterior probability - Network network = *this; - vector classProbabilities = eliminateVariables(network, featureValues); + inference.setEvidence(evidence); + vector classProbabilities = inference.variableElimination(); // Normalize the probabilities to sum to 1 double sum = accumulate(classProbabilities.begin(), classProbabilities.end(), 0.0); @@ -217,8 +190,4 @@ namespace bayesnet { return make_pair(predictedClass, maxProbability); } - vector eliminateVariables(network, featureValues) - { - - } } diff --git a/src/Network.h b/src/Network.h index 568876c..aa0285b 100644 --- a/src/Network.h +++ b/src/Network.h @@ -16,7 +16,6 @@ namespace bayesnet { int laplaceSmoothing; bool isCyclic(const std::string&, std::unordered_set&, std::unordered_set&); pair predict_sample(const vector&); - vector eliminateVariables(Network&, const map&); public: Network(); Network(int); @@ -25,6 +24,7 @@ namespace bayesnet { void addNode(string, int); void addEdge(const string, const string); map& getNodes(); + vector getFeatures(); void fit(const vector>&, const vector&, const vector&, const string&); void estimateParameters(); void setRoot(string); diff --git a/src/Node.cc b/src/Node.cc index 657e827..ab2e193 100644 --- a/src/Node.cc +++ b/src/Node.cc @@ -53,4 +53,47 @@ namespace bayesnet { { this->cpt = cpt; } + /* + The MinFill criterion is a heuristic for variable elimination. + The variable that minimizes the number of edges that need to be added to the graph to make it triangulated. + This is done by counting the number of edges that need to be added to the graph if the variable is eliminated. + The variable with the minimum number of edges is chosen. + Here this is done computing the length of the combinations of the node neighbors taken 2 by 2. + */ + unsigned Node::minFill() + { + set neighbors; + for (auto child : children) { + neighbors.emplace(child->getName()); + } + for (auto parent : parents) { + neighbors.emplace(parent->getName()); + } + return combinations(neighbors).size(); + } + vector Node::combinations(const set& neighbors) + { + vector source(neighbors.begin(), neighbors.end()); + vector result; + for (int i = 0; i < source.size(); ++i) { + string temp = source[i]; + for (int j = i + 1; j < source.size(); ++j) { + result.push_back(temp + source[j]); + } + } + return result; + } + Factor* Node::toFactor() + { + vector variables; + vector cardinalities; + variables.push_back(name); + cardinalities.push_back(numStates); + for (auto parent : parents) { + variables.push_back(parent->getName()); + cardinalities.push_back(parent->getNumStates()); + } + return new Factor(variables, cardinalities, cpt); + + } } \ No newline at end of file diff --git a/src/Node.h b/src/Node.h index 8bc1590..2dd2d1b 100644 --- a/src/Node.h +++ b/src/Node.h @@ -1,6 +1,7 @@ #ifndef NODE_H #define NODE_H #include +#include "Factor.h" #include #include namespace bayesnet { @@ -15,6 +16,7 @@ namespace bayesnet { torch::Tensor cpTable; int numStates; torch::Tensor cpt; + vector combinations(const set&); public: Node(const std::string&, int); void addParent(Node*); @@ -28,7 +30,9 @@ namespace bayesnet { void setCPT(const torch::Tensor&); int getNumStates() const; void setNumStates(int); + unsigned minFill(); int getId() const { return id; } + Factor* toFactor(); }; } #endif \ No newline at end of file