Fix xgboost error in predict/predict_proba

This commit is contained in:
2025-04-12 17:48:23 +02:00
parent 761f57be6c
commit 830265d91b
4 changed files with 55 additions and 30 deletions

View File

@@ -93,11 +93,19 @@ namespace pywrap {
PyErr_Print();
throw std::runtime_error("Error creating object for predict in " + module + " and class " + className);
}
int* data = reinterpret_cast<int*>(prediction.get_data());
std::vector<int> vPrediction(data, data + prediction.shape(0));
auto resultTensor = torch::tensor(vPrediction, torch::kInt32);
Py_XDECREF(incoming);
return resultTensor;
if (xgboost) {
long* data = reinterpret_cast<long*>(prediction.get_data());
std::vector<int> vPrediction(data, data + prediction.shape(0));
auto resultTensor = torch::tensor(vPrediction, torch::kInt32);
Py_XDECREF(incoming);
return resultTensor;
} else {
int* data = reinterpret_cast<int*>(prediction.get_data());
std::vector<int> vPrediction(data, data + prediction.shape(0));
auto resultTensor = torch::tensor(vPrediction, torch::kInt32);
Py_XDECREF(incoming);
return resultTensor;
}
}
torch::Tensor PyClassifier::predict_proba(torch::Tensor& X)
{
@@ -118,11 +126,19 @@ namespace pywrap {
PyErr_Print();
throw std::runtime_error("Error creating object for predict_proba in " + module + " and class " + className);
}
double* data = reinterpret_cast<double*>(prediction.get_data());
std::vector<double> vPrediction(data, data + prediction.shape(0) * prediction.shape(1));
auto resultTensor = torch::tensor(vPrediction, torch::kFloat64).reshape({ prediction.shape(0), prediction.shape(1) });
Py_XDECREF(incoming);
return resultTensor;
if (xgboost) {
float* data = reinterpret_cast<float*>(prediction.get_data());
std::vector<float> vPrediction(data, data + prediction.shape(0) * prediction.shape(1));
auto resultTensor = torch::tensor(vPrediction, torch::kFloat64).reshape({ prediction.shape(0), prediction.shape(1) });
Py_XDECREF(incoming);
return resultTensor;
} else {
double* data = reinterpret_cast<double*>(prediction.get_data());
std::vector<double> vPrediction(data, data + prediction.shape(0) * prediction.shape(1));
auto resultTensor = torch::tensor(vPrediction, torch::kFloat64).reshape({ prediction.shape(0), prediction.shape(1) });
Py_XDECREF(incoming);
return resultTensor;
}
}
float PyClassifier::score(torch::Tensor& X, torch::Tensor& y)
{
@@ -135,4 +151,4 @@ namespace pywrap {
{
this->hyperparameters = hyperparameters;
}
} /* namespace pywrap */
} /* namespace pywrap */