Complete SPODE & AODE

This commit is contained in:
2023-07-15 01:59:30 +02:00
parent db6908acd0
commit e311c27d43
7 changed files with 94 additions and 31 deletions

View File

@@ -69,6 +69,20 @@ namespace bayesnet {
auto ypred = torch::tensor(yp, torch::kInt64);
return ypred;
}
vector<int> BaseClassifier::predict(vector<vector<int>>& X)
{
if (!fitted) {
throw logic_error("Classifier has not been fitted");
}
auto m_ = X[0].size();
auto n_ = X.size();
vector<vector<int>> Xd(n_, vector<int>(m_, 0));
for (auto i = 0; i < n_; i++) {
Xd[i] = vector<int>(X[i].begin(), X[i].end());
}
auto yp = model.predict(Xd);
return yp;
}
float BaseClassifier::score(Tensor& X, Tensor& y)
{
if (!fitted) {