Line data Source code
1 : #pragma once
2 :
3 : #include <ATen/core/Tensor.h>
4 : #include <c10/core/Scalar.h>
5 :
6 : #ifndef AT_PER_OPERATOR_HEADERS
7 : #include <ATen/Functions.h>
8 : #else
9 : #include <ATen/ops/empty_like.h>
10 : #endif
11 :
12 : #include <stdexcept>
13 : #include <string>
14 :
15 : namespace at {
16 :
17 : #define AT_FORALL_BINARY_OPS(_) \
18 : _(+, x.add(y), y.add(x)) \
19 : _(*, x.mul(y), y.mul(x)) \
20 : _(-, \
21 : x.sub(y), \
22 : ::at::empty_like(y, at::MemoryFormat::Preserve).fill_(x).sub_(y)) \
23 : _(/, \
24 : x.div(y), \
25 : ::at::empty_like(y, at::MemoryFormat::Preserve).fill_(x).div_(y)) \
26 : _(%, \
27 : x.remainder(y), \
28 : ::at::empty_like(y, at::MemoryFormat::Preserve).fill_(x).remainder_(y)) \
29 : _(&, x.bitwise_and(y), y.bitwise_and(x)) \
30 : _(|, x.bitwise_or(y), y.bitwise_or(x)) \
31 : _(^, x.bitwise_xor(y), y.bitwise_xor(x)) \
32 : _(<, x.lt(y), y.gt(x)) \
33 : _(<=, x.le(y), y.ge(x)) \
34 : _(>, x.gt(y), y.lt(x)) \
35 : _(>=, x.ge(y), y.le(x)) \
36 : _(==, x.eq(y), y.eq(x)) \
37 : _(!=, x.ne(y), y.ne(x))
38 :
39 : #define DEFINE_OPERATOR(op, body, reverse_scalar_body) \
40 : static inline Tensor operator op(const Tensor& x, const Tensor& y) { \
41 : return body; \
42 : } \
43 : static inline Tensor operator op(const Tensor& x, const Scalar& y) { \
44 : return body; \
45 : } \
46 : static inline Tensor operator op(const Scalar& x, const Tensor& y) { \
47 : return reverse_scalar_body; \
48 : }
49 :
50 5196418 : AT_FORALL_BINARY_OPS(DEFINE_OPERATOR)
51 : #undef DEFINE_OPERATOR
52 : #undef AT_FORALL_BINARY_OPS
53 :
54 : } // namespace at
|