Fix compile errors

This commit is contained in:
Ricardo Montañana Gómez 2023-07-03 00:31:15 +02:00
parent ec76057e64
commit 5db4d1189a
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
3 changed files with 9 additions and 7 deletions

View File

@ -218,7 +218,7 @@ int main(int argc, char** argv)
cout << "Hello, Bayesian Networks!" << endl; cout << "Hello, Bayesian Networks!" << endl;
showNodesInfo(network, className); showNodesInfo(network, className);
showCPDS(network); showCPDS(network);
//cout << "Score: " << network.score(Xd, y) << endl; cout << "Score: " << network.score(Xd, y) << endl;
cout << "PyTorch version: " << TORCH_VERSION << endl; cout << "PyTorch version: " << TORCH_VERSION << endl;
return 0; return 0;
} }

View File

@ -63,21 +63,23 @@ namespace bayesnet {
torch::Tensor newValues = values.sum(0); torch::Tensor newValues = values.sum(0);
return new Factor(newVariables, newCardinalities, newValues); return new Factor(newVariables, newCardinalities, newValues);
} }
Factor* Factor::product(vector<Factor>& factors) Factor* Factor::product(vector<Factor*>& factors)
{ {
vector<string> newVariables; vector<string> newVariables;
vector<int> newCardinalities; vector<int> newCardinalities;
for (auto factor : factors) { for (auto factor : factors) {
for (auto variable : factor.getVariables()) { vector<string> variables = factor->getVariables();
for (auto idx = 0; idx < variables.size(); ++idx) {
string variable = variables[idx];
if (find(newVariables.begin(), newVariables.end(), variable) == newVariables.end()) { if (find(newVariables.begin(), newVariables.end(), variable) == newVariables.end()) {
newVariables.push_back(variable); newVariables.push_back(variable);
newCardinalities.push_back(factor.getCardinalities()[factor.getVariables().index(variable)]); newCardinalities.push_back(factor->getCardinalities()[idx]);
} }
} }
} }
torch::Tensor newValues = factors[0].getValues(); torch::Tensor newValues = factors[0]->getValues();
for (int i = 1; i < factors.size(); i++) { for (int i = 1; i < factors.size(); i++) {
newValues = newValues.matmul(factors[i].getValues()); newValues = newValues.matmul(factors[i]->getValues());
} }
return new Factor(newVariables, newCardinalities, newValues); return new Factor(newVariables, newCardinalities, newValues);
} }

View File

@ -23,7 +23,7 @@ namespace bayesnet {
vector<int>& getCardinalities(); vector<int>& getCardinalities();
bool contains(string&); bool contains(string&);
torch::Tensor& getValues(); torch::Tensor& getValues();
static Factor* product(vector<Factor>&); static Factor* product(vector<Factor*>&);
Factor* sumOut(string&); Factor* sumOut(string&);
}; };