Implement some needed Factor methods
This commit is contained in:
parent
12f0e1e063
commit
ec76057e64
@ -41,6 +41,24 @@ namespace bayesnet {
|
|||||||
while ((candidate = nextCandidate()) != "") {
|
while ((candidate = nextCandidate()) != "") {
|
||||||
// Erase candidate from candidates (Erase–remove idiom)
|
// Erase candidate from candidates (Erase–remove idiom)
|
||||||
candidates.erase(remove(candidates.begin(), candidates.end(), candidate), candidates.end());
|
candidates.erase(remove(candidates.begin(), candidates.end(), candidate), candidates.end());
|
||||||
|
// sum-product variable elimination algorithm as explained in the book probabilistic graphical models
|
||||||
|
// 1. Multiply all factors containing the variable
|
||||||
|
vector<Factor*> factorsToMultiply;
|
||||||
|
for (auto factor : factors) {
|
||||||
|
if (factor->contains(candidate)) {
|
||||||
|
factorsToMultiply.push_back(factor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Factor* product = Factor::product(factorsToMultiply);
|
||||||
|
// 2. Sum out the variable
|
||||||
|
Factor* sum = product->sumOut(candidate);
|
||||||
|
// 3. Remove factors containing the variable
|
||||||
|
for (auto factor : factorsToMultiply) {
|
||||||
|
factors.erase(remove(factors.begin(), factors.end(), factor), factors.end());
|
||||||
|
delete factor;
|
||||||
|
}
|
||||||
|
// 4. Add the resulting factor to the list of factors
|
||||||
|
factors.push_back(sum);
|
||||||
|
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
|
@ -7,4 +7,79 @@ using namespace std;
|
|||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
Factor::Factor(vector<string>& variables, vector<int>& cardinalities, torch::Tensor& values) : variables(variables), cardinalities(cardinalities), values(values) {}
|
Factor::Factor(vector<string>& variables, vector<int>& cardinalities, torch::Tensor& values) : variables(variables), cardinalities(cardinalities), values(values) {}
|
||||||
Factor::~Factor() = default;
|
Factor::~Factor() = default;
|
||||||
|
Factor::Factor(const Factor& other) : variables(other.variables), cardinalities(other.cardinalities), values(other.values) {}
|
||||||
|
Factor& Factor::operator=(const Factor& other)
|
||||||
|
{
|
||||||
|
if (this != &other) {
|
||||||
|
variables = other.variables;
|
||||||
|
cardinalities = other.cardinalities;
|
||||||
|
values = other.values;
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
void Factor::setVariables(vector<string>& variables)
|
||||||
|
{
|
||||||
|
this->variables = variables;
|
||||||
|
}
|
||||||
|
void Factor::setCardinalities(vector<int>& cardinalities)
|
||||||
|
{
|
||||||
|
this->cardinalities = cardinalities;
|
||||||
|
}
|
||||||
|
void Factor::setValues(torch::Tensor& values)
|
||||||
|
{
|
||||||
|
this->values = values;
|
||||||
|
}
|
||||||
|
vector<string>& Factor::getVariables()
|
||||||
|
{
|
||||||
|
return variables;
|
||||||
|
}
|
||||||
|
vector<int>& Factor::getCardinalities()
|
||||||
|
{
|
||||||
|
return cardinalities;
|
||||||
|
}
|
||||||
|
torch::Tensor& Factor::getValues()
|
||||||
|
{
|
||||||
|
return values;
|
||||||
|
}
|
||||||
|
bool Factor::contains(string& variable)
|
||||||
|
{
|
||||||
|
for (int i = 0; i < variables.size(); i++) {
|
||||||
|
if (variables[i] == variable) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
Factor* Factor::sumOut(string& candidate)
|
||||||
|
{
|
||||||
|
vector<string> newVariables;
|
||||||
|
vector<int> newCardinalities;
|
||||||
|
for (int i = 0; i < variables.size(); i++) {
|
||||||
|
if (variables[i] != candidate) {
|
||||||
|
newVariables.push_back(variables[i]);
|
||||||
|
newCardinalities.push_back(cardinalities[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
torch::Tensor newValues = values.sum(0);
|
||||||
|
return new Factor(newVariables, newCardinalities, newValues);
|
||||||
|
}
|
||||||
|
Factor* Factor::product(vector<Factor>& factors)
|
||||||
|
{
|
||||||
|
vector<string> newVariables;
|
||||||
|
vector<int> newCardinalities;
|
||||||
|
for (auto factor : factors) {
|
||||||
|
for (auto variable : factor.getVariables()) {
|
||||||
|
if (find(newVariables.begin(), newVariables.end(), variable) == newVariables.end()) {
|
||||||
|
newVariables.push_back(variable);
|
||||||
|
newCardinalities.push_back(factor.getCardinalities()[factor.getVariables().index(variable)]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
torch::Tensor newValues = factors[0].getValues();
|
||||||
|
for (int i = 1; i < factors.size(); i++) {
|
||||||
|
newValues = newValues.matmul(factors[i].getValues());
|
||||||
|
}
|
||||||
|
return new Factor(newVariables, newCardinalities, newValues);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
@ -21,7 +21,11 @@ namespace bayesnet {
|
|||||||
void setValues(torch::Tensor&);
|
void setValues(torch::Tensor&);
|
||||||
vector<string>& getVariables();
|
vector<string>& getVariables();
|
||||||
vector<int>& getCardinalities();
|
vector<int>& getCardinalities();
|
||||||
|
bool contains(string&);
|
||||||
torch::Tensor& getValues();
|
torch::Tensor& getValues();
|
||||||
|
static Factor* product(vector<Factor>&);
|
||||||
|
Factor* sumOut(string&);
|
||||||
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
Loading…
Reference in New Issue
Block a user