Fix compile errors

This commit is contained in:
Ricardo Montañana Gómez 2023-07-03 00:31:15 +02:00
parent ec76057e64
commit 5db4d1189a
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
3 changed files with 9 additions and 7 deletions

View File

@ -218,7 +218,7 @@ int main(int argc, char** argv)
cout << "Hello, Bayesian Networks!" << endl;
showNodesInfo(network, className);
showCPDS(network);
//cout << "Score: " << network.score(Xd, y) << endl;
cout << "Score: " << network.score(Xd, y) << endl;
cout << "PyTorch version: " << TORCH_VERSION << endl;
return 0;
}

View File

@ -63,21 +63,23 @@ namespace bayesnet {
torch::Tensor newValues = values.sum(0);
return new Factor(newVariables, newCardinalities, newValues);
}
Factor* Factor::product(vector<Factor>& factors)
Factor* Factor::product(vector<Factor*>& factors)
{
vector<string> newVariables;
vector<int> newCardinalities;
for (auto factor : factors) {
for (auto variable : factor.getVariables()) {
vector<string> variables = factor->getVariables();
for (auto idx = 0; idx < variables.size(); ++idx) {
string variable = variables[idx];
if (find(newVariables.begin(), newVariables.end(), variable) == newVariables.end()) {
newVariables.push_back(variable);
newCardinalities.push_back(factor.getCardinalities()[factor.getVariables().index(variable)]);
newCardinalities.push_back(factor->getCardinalities()[idx]);
}
}
}
torch::Tensor newValues = factors[0].getValues();
torch::Tensor newValues = factors[0]->getValues();
for (int i = 1; i < factors.size(); i++) {
newValues = newValues.matmul(factors[i].getValues());
newValues = newValues.matmul(factors[i]->getValues());
}
return new Factor(newVariables, newCardinalities, newValues);
}

View File

@ -23,7 +23,7 @@ namespace bayesnet {
vector<int>& getCardinalities();
bool contains(string&);
torch::Tensor& getValues();
static Factor* product(vector<Factor>&);
static Factor* product(vector<Factor*>&);
Factor* sumOut(string&);
};