Fix XSPode

This commit is contained in:
2025-03-10 15:55:48 +01:00
parent 86cccb6c7b
commit a26522e62f
4 changed files with 23 additions and 3 deletions

View File

@@ -407,5 +407,24 @@ namespace bayesnet {
}
return result;
}
float XSpode::score(torch::Tensor& X, torch::Tensor& y)
{
torch::Tensor y_pred = predict(X);
return (y_pred == y).sum().item<float>() / y.size(0);
}
float XSpode::score(std::vector<std::vector<int>>& X, std::vector<int>& y)
{
if (!fitted) {
throw std::logic_error(CLASSIFIER_NOT_FITTED);
}
auto y_pred = this->predict(X);
int correct = 0;
for (int i = 0; i < y_pred.size(); ++i) {
if (y_pred[i] == y[i]) {
correct++;
}
}
return (double)correct / y_pred.size();
}
}