Fix compile errors
This commit is contained in:
parent
ec76057e64
commit
5db4d1189a
@ -218,7 +218,7 @@ int main(int argc, char** argv)
|
|||||||
cout << "Hello, Bayesian Networks!" << endl;
|
cout << "Hello, Bayesian Networks!" << endl;
|
||||||
showNodesInfo(network, className);
|
showNodesInfo(network, className);
|
||||||
showCPDS(network);
|
showCPDS(network);
|
||||||
//cout << "Score: " << network.score(Xd, y) << endl;
|
cout << "Score: " << network.score(Xd, y) << endl;
|
||||||
cout << "PyTorch version: " << TORCH_VERSION << endl;
|
cout << "PyTorch version: " << TORCH_VERSION << endl;
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
@ -63,21 +63,23 @@ namespace bayesnet {
|
|||||||
torch::Tensor newValues = values.sum(0);
|
torch::Tensor newValues = values.sum(0);
|
||||||
return new Factor(newVariables, newCardinalities, newValues);
|
return new Factor(newVariables, newCardinalities, newValues);
|
||||||
}
|
}
|
||||||
Factor* Factor::product(vector<Factor>& factors)
|
Factor* Factor::product(vector<Factor*>& factors)
|
||||||
{
|
{
|
||||||
vector<string> newVariables;
|
vector<string> newVariables;
|
||||||
vector<int> newCardinalities;
|
vector<int> newCardinalities;
|
||||||
for (auto factor : factors) {
|
for (auto factor : factors) {
|
||||||
for (auto variable : factor.getVariables()) {
|
vector<string> variables = factor->getVariables();
|
||||||
|
for (auto idx = 0; idx < variables.size(); ++idx) {
|
||||||
|
string variable = variables[idx];
|
||||||
if (find(newVariables.begin(), newVariables.end(), variable) == newVariables.end()) {
|
if (find(newVariables.begin(), newVariables.end(), variable) == newVariables.end()) {
|
||||||
newVariables.push_back(variable);
|
newVariables.push_back(variable);
|
||||||
newCardinalities.push_back(factor.getCardinalities()[factor.getVariables().index(variable)]);
|
newCardinalities.push_back(factor->getCardinalities()[idx]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
torch::Tensor newValues = factors[0].getValues();
|
torch::Tensor newValues = factors[0]->getValues();
|
||||||
for (int i = 1; i < factors.size(); i++) {
|
for (int i = 1; i < factors.size(); i++) {
|
||||||
newValues = newValues.matmul(factors[i].getValues());
|
newValues = newValues.matmul(factors[i]->getValues());
|
||||||
}
|
}
|
||||||
return new Factor(newVariables, newCardinalities, newValues);
|
return new Factor(newVariables, newCardinalities, newValues);
|
||||||
}
|
}
|
||||||
|
@ -23,7 +23,7 @@ namespace bayesnet {
|
|||||||
vector<int>& getCardinalities();
|
vector<int>& getCardinalities();
|
||||||
bool contains(string&);
|
bool contains(string&);
|
||||||
torch::Tensor& getValues();
|
torch::Tensor& getValues();
|
||||||
static Factor* product(vector<Factor>&);
|
static Factor* product(vector<Factor*>&);
|
||||||
Factor* sumOut(string&);
|
Factor* sumOut(string&);
|
||||||
|
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user