diff --git a/src/ExactInference.cc b/src/ExactInference.cc index bb8731b..0040447 100644 --- a/src/ExactInference.cc +++ b/src/ExactInference.cc @@ -41,6 +41,24 @@ namespace bayesnet { while ((candidate = nextCandidate()) != "") { // Erase candidate from candidates (Erase–remove idiom) candidates.erase(remove(candidates.begin(), candidates.end(), candidate), candidates.end()); + // sum-product variable elimination algorithm as explained in the book probabilistic graphical models + // 1. Multiply all factors containing the variable + vector factorsToMultiply; + for (auto factor : factors) { + if (factor->contains(candidate)) { + factorsToMultiply.push_back(factor); + } + } + Factor* product = Factor::product(factorsToMultiply); + // 2. Sum out the variable + Factor* sum = product->sumOut(candidate); + // 3. Remove factors containing the variable + for (auto factor : factorsToMultiply) { + factors.erase(remove(factors.begin(), factors.end(), factor), factors.end()); + delete factor; + } + // 4. Add the resulting factor to the list of factors + factors.push_back(sum); } return result; diff --git a/src/Factor.cc b/src/Factor.cc index 1e1b781..4af43d9 100644 --- a/src/Factor.cc +++ b/src/Factor.cc @@ -7,4 +7,79 @@ using namespace std; namespace bayesnet { Factor::Factor(vector& variables, vector& cardinalities, torch::Tensor& values) : variables(variables), cardinalities(cardinalities), values(values) {} Factor::~Factor() = default; + Factor::Factor(const Factor& other) : variables(other.variables), cardinalities(other.cardinalities), values(other.values) {} + Factor& Factor::operator=(const Factor& other) + { + if (this != &other) { + variables = other.variables; + cardinalities = other.cardinalities; + values = other.values; + } + return *this; + } + void Factor::setVariables(vector& variables) + { + this->variables = variables; + } + void Factor::setCardinalities(vector& cardinalities) + { + this->cardinalities = cardinalities; + } + void Factor::setValues(torch::Tensor& values) + { + this->values = values; + } + vector& Factor::getVariables() + { + return variables; + } + vector& Factor::getCardinalities() + { + return cardinalities; + } + torch::Tensor& Factor::getValues() + { + return values; + } + bool Factor::contains(string& variable) + { + for (int i = 0; i < variables.size(); i++) { + if (variables[i] == variable) { + return true; + } + } + return false; + } + Factor* Factor::sumOut(string& candidate) + { + vector newVariables; + vector newCardinalities; + for (int i = 0; i < variables.size(); i++) { + if (variables[i] != candidate) { + newVariables.push_back(variables[i]); + newCardinalities.push_back(cardinalities[i]); + } + } + torch::Tensor newValues = values.sum(0); + return new Factor(newVariables, newCardinalities, newValues); + } + Factor* Factor::product(vector& factors) + { + vector newVariables; + vector newCardinalities; + for (auto factor : factors) { + for (auto variable : factor.getVariables()) { + if (find(newVariables.begin(), newVariables.end(), variable) == newVariables.end()) { + newVariables.push_back(variable); + newCardinalities.push_back(factor.getCardinalities()[factor.getVariables().index(variable)]); + } + } + } + torch::Tensor newValues = factors[0].getValues(); + for (int i = 1; i < factors.size(); i++) { + newValues = newValues.matmul(factors[i].getValues()); + } + return new Factor(newVariables, newCardinalities, newValues); + } + } \ No newline at end of file diff --git a/src/Factor.h b/src/Factor.h index 26cbbab..c4fee52 100644 --- a/src/Factor.h +++ b/src/Factor.h @@ -21,7 +21,11 @@ namespace bayesnet { void setValues(torch::Tensor&); vector& getVariables(); vector& getCardinalities(); + bool contains(string&); torch::Tensor& getValues(); + static Factor* product(vector&); + Factor* sumOut(string&); + }; } #endif \ No newline at end of file