Implement some needed Factor methods

This commit is contained in:
Ricardo Montañana Gómez 2023-07-02 22:12:44 +02:00
parent 12f0e1e063
commit ec76057e64
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
3 changed files with 97 additions and 0 deletions

View File

@ -41,6 +41,24 @@ namespace bayesnet {
while ((candidate = nextCandidate()) != "") { while ((candidate = nextCandidate()) != "") {
// Erase candidate from candidates (Eraseremove idiom) // Erase candidate from candidates (Eraseremove idiom)
candidates.erase(remove(candidates.begin(), candidates.end(), candidate), candidates.end()); candidates.erase(remove(candidates.begin(), candidates.end(), candidate), candidates.end());
// sum-product variable elimination algorithm as explained in the book probabilistic graphical models
// 1. Multiply all factors containing the variable
vector<Factor*> factorsToMultiply;
for (auto factor : factors) {
if (factor->contains(candidate)) {
factorsToMultiply.push_back(factor);
}
}
Factor* product = Factor::product(factorsToMultiply);
// 2. Sum out the variable
Factor* sum = product->sumOut(candidate);
// 3. Remove factors containing the variable
for (auto factor : factorsToMultiply) {
factors.erase(remove(factors.begin(), factors.end(), factor), factors.end());
delete factor;
}
// 4. Add the resulting factor to the list of factors
factors.push_back(sum);
} }
return result; return result;

View File

@ -7,4 +7,79 @@ using namespace std;
namespace bayesnet { namespace bayesnet {
Factor::Factor(vector<string>& variables, vector<int>& cardinalities, torch::Tensor& values) : variables(variables), cardinalities(cardinalities), values(values) {} Factor::Factor(vector<string>& variables, vector<int>& cardinalities, torch::Tensor& values) : variables(variables), cardinalities(cardinalities), values(values) {}
Factor::~Factor() = default; Factor::~Factor() = default;
Factor::Factor(const Factor& other) : variables(other.variables), cardinalities(other.cardinalities), values(other.values) {}
Factor& Factor::operator=(const Factor& other)
{
if (this != &other) {
variables = other.variables;
cardinalities = other.cardinalities;
values = other.values;
}
return *this;
}
void Factor::setVariables(vector<string>& variables)
{
this->variables = variables;
}
void Factor::setCardinalities(vector<int>& cardinalities)
{
this->cardinalities = cardinalities;
}
void Factor::setValues(torch::Tensor& values)
{
this->values = values;
}
vector<string>& Factor::getVariables()
{
return variables;
}
vector<int>& Factor::getCardinalities()
{
return cardinalities;
}
torch::Tensor& Factor::getValues()
{
return values;
}
bool Factor::contains(string& variable)
{
for (int i = 0; i < variables.size(); i++) {
if (variables[i] == variable) {
return true;
}
}
return false;
}
Factor* Factor::sumOut(string& candidate)
{
vector<string> newVariables;
vector<int> newCardinalities;
for (int i = 0; i < variables.size(); i++) {
if (variables[i] != candidate) {
newVariables.push_back(variables[i]);
newCardinalities.push_back(cardinalities[i]);
}
}
torch::Tensor newValues = values.sum(0);
return new Factor(newVariables, newCardinalities, newValues);
}
Factor* Factor::product(vector<Factor>& factors)
{
vector<string> newVariables;
vector<int> newCardinalities;
for (auto factor : factors) {
for (auto variable : factor.getVariables()) {
if (find(newVariables.begin(), newVariables.end(), variable) == newVariables.end()) {
newVariables.push_back(variable);
newCardinalities.push_back(factor.getCardinalities()[factor.getVariables().index(variable)]);
}
}
}
torch::Tensor newValues = factors[0].getValues();
for (int i = 1; i < factors.size(); i++) {
newValues = newValues.matmul(factors[i].getValues());
}
return new Factor(newVariables, newCardinalities, newValues);
}
} }

View File

@ -21,7 +21,11 @@ namespace bayesnet {
void setValues(torch::Tensor&); void setValues(torch::Tensor&);
vector<string>& getVariables(); vector<string>& getVariables();
vector<int>& getCardinalities(); vector<int>& getCardinalities();
bool contains(string&);
torch::Tensor& getValues(); torch::Tensor& getValues();
static Factor* product(vector<Factor>&);
Factor* sumOut(string&);
}; };
} }
#endif #endif