Fix input data dimensions

This commit is contained in:
2025-07-04 11:10:25 +02:00
parent 37a765e6b0
commit 57c6693842
3 changed files with 41 additions and 32 deletions

View File

@@ -21,25 +21,27 @@ namespace pywrap {
}
// Ensure tensor is contiguous and in the expected format
X = X.contiguous();
auto X_copy = X.contiguous();
if (X.dtype() != torch::kFloat32) {
if (X_copy.dtype() != torch::kFloat32) {
throw std::runtime_error("tensor2numpy: Expected float32 tensor");
}
int64_t m = X.size(0);
int64_t n = X.size(1);
// Transpose from [features, samples] to [samples, features] for Python classifiers
X_copy = X_copy.transpose(0, 1);
int64_t m = X_copy.size(0);
int64_t n = X_copy.size(1);
// Calculate correct strides in bytes
int64_t element_size = X.element_size();
int64_t stride0 = X.stride(0) * element_size;
int64_t stride1 = X.stride(1) * element_size;
int64_t element_size = X_copy.element_size();
int64_t stride0 = X_copy.stride(0) * element_size;
int64_t stride1 = X_copy.stride(1) * element_size;
auto Xn = np::from_data(X.data_ptr(), np::dtype::get_builtin<float>(),
auto Xn = np::from_data(X_copy.data_ptr(), np::dtype::get_builtin<float>(),
bp::make_tuple(m, n),
bp::make_tuple(stride0, stride1),
bp::object());
// Don't transpose - tensor is already in correct [samples, features] format
return Xn;
}
np::ndarray tensorInt2numpy(torch::Tensor& X)
@@ -50,25 +52,27 @@ namespace pywrap {
}
// Ensure tensor is contiguous and in the expected format
X = X.contiguous();
auto X_copy = X.contiguous();
if (X.dtype() != torch::kInt32) {
if (X_copy.dtype() != torch::kInt32) {
throw std::runtime_error("tensorInt2numpy: Expected int32 tensor");
}
int64_t m = X.size(0);
int64_t n = X.size(1);
// Transpose from [features, samples] to [samples, features] for Python classifiers
X_copy = X_copy.transpose(0, 1);
int64_t m = X_copy.size(0);
int64_t n = X_copy.size(1);
// Calculate correct strides in bytes
int64_t element_size = X.element_size();
int64_t stride0 = X.stride(0) * element_size;
int64_t stride1 = X.stride(1) * element_size;
int64_t element_size = X_copy.element_size();
int64_t stride0 = X_copy.stride(0) * element_size;
int64_t stride1 = X_copy.stride(1) * element_size;
auto Xn = np::from_data(X.data_ptr(), np::dtype::get_builtin<int>(),
auto Xn = np::from_data(X_copy.data_ptr(), np::dtype::get_builtin<int>(),
bp::make_tuple(m, n),
bp::make_tuple(stride0, stride1),
bp::object());
// Don't transpose - tensor is already in correct [samples, features] format
return Xn;
}
std::pair<np::ndarray, np::ndarray> tensors2numpy(torch::Tensor& X, torch::Tensor& y)
@@ -78,10 +82,11 @@ namespace pywrap {
throw std::runtime_error("tensors2numpy: Expected 1D y tensor, got " + std::to_string(y.dim()) + "D");
}
// Validate dimensions match
if (X.size(0) != y.size(0)) {
// Validate dimensions match (X is [features, samples], y is [samples])
// X.size(1) is samples, y.size(0) is samples
if (X.size(1) != y.size(0)) {
throw std::runtime_error("tensors2numpy: X and y dimension mismatch: X[" +
std::to_string(X.size(0)) + "], y[" + std::to_string(y.size(0)) + "]");
std::to_string(X.size(1)) + "], y[" + std::to_string(y.size(0)) + "]");
}
// Ensure y tensor is contiguous