From 5db4d1189ab786347ad4c8832fc7a2e074c1a15a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana?= Date: Mon, 3 Jul 2023 00:31:15 +0200 Subject: [PATCH] Fix compile errors --- sample/main.cc | 2 +- src/Factor.cc | 12 +++++++----- src/Factor.h | 2 +- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/sample/main.cc b/sample/main.cc index 4aec86f..06a22bb 100644 --- a/sample/main.cc +++ b/sample/main.cc @@ -218,7 +218,7 @@ int main(int argc, char** argv) cout << "Hello, Bayesian Networks!" << endl; showNodesInfo(network, className); showCPDS(network); - //cout << "Score: " << network.score(Xd, y) << endl; + cout << "Score: " << network.score(Xd, y) << endl; cout << "PyTorch version: " << TORCH_VERSION << endl; return 0; } \ No newline at end of file diff --git a/src/Factor.cc b/src/Factor.cc index 4af43d9..247d9ec 100644 --- a/src/Factor.cc +++ b/src/Factor.cc @@ -63,21 +63,23 @@ namespace bayesnet { torch::Tensor newValues = values.sum(0); return new Factor(newVariables, newCardinalities, newValues); } - Factor* Factor::product(vector& factors) + Factor* Factor::product(vector& factors) { vector newVariables; vector newCardinalities; for (auto factor : factors) { - for (auto variable : factor.getVariables()) { + vector variables = factor->getVariables(); + for (auto idx = 0; idx < variables.size(); ++idx) { + string variable = variables[idx]; if (find(newVariables.begin(), newVariables.end(), variable) == newVariables.end()) { newVariables.push_back(variable); - newCardinalities.push_back(factor.getCardinalities()[factor.getVariables().index(variable)]); + newCardinalities.push_back(factor->getCardinalities()[idx]); } } } - torch::Tensor newValues = factors[0].getValues(); + torch::Tensor newValues = factors[0]->getValues(); for (int i = 1; i < factors.size(); i++) { - newValues = newValues.matmul(factors[i].getValues()); + newValues = newValues.matmul(factors[i]->getValues()); } return new Factor(newVariables, newCardinalities, newValues); } diff --git a/src/Factor.h b/src/Factor.h index c4fee52..f98dc48 100644 --- a/src/Factor.h +++ b/src/Factor.h @@ -23,7 +23,7 @@ namespace bayesnet { vector& getCardinalities(); bool contains(string&); torch::Tensor& getValues(); - static Factor* product(vector&); + static Factor* product(vector&); Factor* sumOut(string&); };