Fix XSpode

This commit is contained in:
2025-03-10 22:18:50 +01:00
parent 619276a5ea
commit 3d8be79b37
2 changed files with 8 additions and 18 deletions

View File

@@ -390,11 +390,7 @@ namespace bayesnet {
{
auto X_ = TensorUtils::to_matrix(X);
auto result_v = predict(X_);
torch::Tensor result;
for (int i = 0; i < result_v.size(); ++i) {
result.index_put_({ i, "..." }, torch::tensor(result_v[i], torch::kInt32));
}
return result;
return torch::tensor(result_v, torch::kInt32);
}
torch::Tensor XSpode::predict_proba(torch::Tensor& X)
{