Line data Source code
1 : #pragma once
2 : #include <ATen/core/Tensor.h>
3 : #include <c10/core/ScalarType.h>
4 :
5 : namespace at {
6 :
7 : // These functions are defined in ATen/Utils.cpp.
8 : #define TENSOR(T, S) \
9 : TORCH_API Tensor tensor(ArrayRef<T> values, const TensorOptions& options); \
10 : inline Tensor tensor( \
11 : std::initializer_list<T> values, const TensorOptions& options) { \
12 : return at::tensor(ArrayRef<T>(values), options); \
13 : } \
14 : inline Tensor tensor(T value, const TensorOptions& options) { \
15 : return at::tensor(ArrayRef<T>(value), options); \
16 : } \
17 : inline Tensor tensor(ArrayRef<T> values) { \
18 : return at::tensor(std::move(values), at::dtype(k##S)); \
19 : } \
20 : inline Tensor tensor(std::initializer_list<T> values) { \
21 : return at::tensor(ArrayRef<T>(values)); \
22 : } \
23 : inline Tensor tensor(T value) { \
24 : return at::tensor(ArrayRef<T>(value)); \
25 : }
26 181257936 : AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR)
27 : AT_FORALL_COMPLEX_TYPES(TENSOR)
28 : #undef TENSOR
29 :
30 : } // namespace at
|