Fix compile errors
This commit is contained in:
parent
ec76057e64
commit
5db4d1189a
@ -218,7 +218,7 @@ int main(int argc, char** argv)
|
||||
cout << "Hello, Bayesian Networks!" << endl;
|
||||
showNodesInfo(network, className);
|
||||
showCPDS(network);
|
||||
//cout << "Score: " << network.score(Xd, y) << endl;
|
||||
cout << "Score: " << network.score(Xd, y) << endl;
|
||||
cout << "PyTorch version: " << TORCH_VERSION << endl;
|
||||
return 0;
|
||||
}
|
@ -63,21 +63,23 @@ namespace bayesnet {
|
||||
torch::Tensor newValues = values.sum(0);
|
||||
return new Factor(newVariables, newCardinalities, newValues);
|
||||
}
|
||||
Factor* Factor::product(vector<Factor>& factors)
|
||||
Factor* Factor::product(vector<Factor*>& factors)
|
||||
{
|
||||
vector<string> newVariables;
|
||||
vector<int> newCardinalities;
|
||||
for (auto factor : factors) {
|
||||
for (auto variable : factor.getVariables()) {
|
||||
vector<string> variables = factor->getVariables();
|
||||
for (auto idx = 0; idx < variables.size(); ++idx) {
|
||||
string variable = variables[idx];
|
||||
if (find(newVariables.begin(), newVariables.end(), variable) == newVariables.end()) {
|
||||
newVariables.push_back(variable);
|
||||
newCardinalities.push_back(factor.getCardinalities()[factor.getVariables().index(variable)]);
|
||||
newCardinalities.push_back(factor->getCardinalities()[idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
torch::Tensor newValues = factors[0].getValues();
|
||||
torch::Tensor newValues = factors[0]->getValues();
|
||||
for (int i = 1; i < factors.size(); i++) {
|
||||
newValues = newValues.matmul(factors[i].getValues());
|
||||
newValues = newValues.matmul(factors[i]->getValues());
|
||||
}
|
||||
return new Factor(newVariables, newCardinalities, newValues);
|
||||
}
|
||||
|
@ -23,7 +23,7 @@ namespace bayesnet {
|
||||
vector<int>& getCardinalities();
|
||||
bool contains(string&);
|
||||
torch::Tensor& getValues();
|
||||
static Factor* product(vector<Factor>&);
|
||||
static Factor* product(vector<Factor*>&);
|
||||
Factor* sumOut(string&);
|
||||
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user