diff --git a/src/common/TensorUtils.hpp b/src/common/TensorUtils.hpp index eb69371..93b258e 100644 --- a/src/common/TensorUtils.hpp +++ b/src/common/TensorUtils.hpp @@ -62,7 +62,7 @@ namespace platform { torch::Tensor tensor = torch::empty({ static_cast(rows), static_cast(cols) }, torch::kInt64); for (size_t i = 0; i < rows; ++i) { for (size_t j = 0; j < cols; ++j) { - tensor.index_put_({ static_cast(i), static_cast(j) }, data[i][j]); + tensor.index_put_({static_cast(i), static_cast(j)}, torch::scalar_tensor(data[i][j])); } } return tensor;