LCOV - code coverage report
Current view: top level - libtorch/include/ATen - TensorIndexing.h (source / functions) Coverage Total Hit
Test: coverage.info Lines: 96.0 % 25 24
Test Date: 2024-04-30 13:17:26 Functions: 100.0 % 7 7

            Line data    Source code
       1              : #pragma once
       2              : 
       3              : #include <ATen/ExpandUtils.h>
       4              : #include <ATen/ScalarOps.h>
       5              : #include <ATen/core/Tensor.h>
       6              : #include <ATen/core/TensorBody.h>
       7              : #include <c10/core/SymInt.h>
       8              : #include <c10/util/Optional.h>
       9              : #include <c10/util/irange.h>
      10              : 
      11              : #ifndef AT_PER_OPERATOR_HEADERS
      12              : #include <ATen/Functions.h>
      13              : #include <ATen/NativeFunctions.h>
      14              : #else
      15              : #include <ATen/ops/alias.h>
      16              : #include <ATen/ops/empty.h>
      17              : #include <ATen/ops/scalar_tensor.h>
      18              : #include <ATen/ops/zeros.h>
      19              : #endif
      20              : 
      21              : #include <ATen/core/List.h>
      22              : 
      23              : #include <utility>
      24              : 
      25              : namespace at {
      26              : namespace indexing {
      27              : 
      28              : const int64_t INDEX_MIN = c10::SymInt::min_representable_int();
      29              : const int64_t INDEX_MAX = -(INDEX_MIN + 1);
      30              : 
      31              : enum class TensorIndexType { None, Ellipsis, SymInt, Boolean, Slice, Tensor };
      32              : 
      33              : constexpr c10::nullopt_t None = c10::nullopt;
      34              : 
      35              : struct TORCH_API EllipsisIndexType final {
      36              :   EllipsisIndexType() = default;
      37              : };
      38              : TORCH_API extern const EllipsisIndexType Ellipsis;
      39              : 
      40              : struct TORCH_API Slice final {
      41              :  public:
      42     36975204 :   Slice(
      43              :       c10::optional<c10::SymInt> start_index = c10::nullopt,
      44              :       c10::optional<c10::SymInt> stop_index = c10::nullopt,
      45     36975204 :       c10::optional<c10::SymInt> step_index = c10::nullopt) {
      46     36975204 :     if (!step_index.has_value()) {
      47     36975204 :       step_ = c10::SymInt(1);
      48              :     } else {
      49            0 :       step_ = std::move(step_index).value();
      50              :     }
      51              : 
      52     36975204 :     TORCH_CHECK_VALUE(step_ != 0, "slice step cannot be zero");
      53              : 
      54     36975204 :     if (!start_index.has_value()) {
      55     36974694 :       start_ = c10::SymInt(step_ < 0 ? INDEX_MAX : 0);
      56              :     } else {
      57          510 :       start_ = std::move(start_index).value();
      58              :     }
      59              : 
      60     36975204 :     if (!stop_index.has_value()) {
      61     36974694 :       stop_ = c10::SymInt(step_ < 0 ? INDEX_MIN : INDEX_MAX);
      62              :     } else {
      63          510 :       stop_ = std::move(stop_index).value();
      64              :     }
      65     36975204 :   }
      66              : 
      67              :   inline c10::SymInt start() const {
      68              :     return start_;
      69              :   }
      70              : 
      71              :   inline c10::SymInt stop() const {
      72              :     return stop_;
      73              :   }
      74              : 
      75              :   inline c10::SymInt step() const {
      76              :     return step_;
      77              :   }
      78              : 
      79              :  private:
      80              :   c10::SymInt start_;
      81              :   c10::SymInt stop_;
      82              :   c10::SymInt step_;
      83              : };
      84              : 
      85              : TORCH_API std::ostream& operator<<(std::ostream& stream, const Slice& slice);
      86              : 
      87              : // `at::indexing::TensorIndex` is used for converting C++ tensor indices such as
      88              : // `{None, "...", Ellipsis, 0, true, Slice(1, None, 2), torch::tensor({1, 2})}`
      89              : // into its equivalent `std::vector<TensorIndex>`, so that further tensor
      90              : // indexing operations can be performed using the supplied indices.
      91              : //
      92              : // There is one-to-one correspondence between Python and C++ tensor index types:
      93              : // Python                  | C++
      94              : // -----------------------------------------------------
      95              : // `None`                  | `at::indexing::None`
      96              : // `Ellipsis`              | `at::indexing::Ellipsis`
      97              : // `...`                   | `"..."`
      98              : // `123`                   | `123`
      99              : // `True` / `False`        | `true` / `false`
     100              : // `:`                     | `Slice()` / `Slice(None, None)`
     101              : // `::`                    | `Slice()` / `Slice(None, None, None)`
     102              : // `1:`                    | `Slice(1, None)`
     103              : // `1::`                   | `Slice(1, None, None)`
     104              : // `:3`                    | `Slice(None, 3)`
     105              : // `:3:`                   | `Slice(None, 3, None)`
     106              : // `::2`                   | `Slice(None, None, 2)`
     107              : // `1:3`                   | `Slice(1, 3)`
     108              : // `1::2`                  | `Slice(1, None, 2)`
     109              : // `:3:2`                  | `Slice(None, 3, 2)`
     110              : // `1:3:2`                 | `Slice(1, 3, 2)`
     111              : // `torch.tensor([1, 2])`) | `torch::tensor({1, 2})`
     112              : struct TORCH_API TensorIndex final {
     113              :   // Case 1: `at::indexing::None`
     114              :   TensorIndex(c10::nullopt_t) : type_(TensorIndexType::None) {}
     115              : 
     116              :   // Case 2: "..." / `at::indexing::Ellipsis`
     117       776042 :   TensorIndex(at::indexing::EllipsisIndexType)
     118       776042 :       : type_(TensorIndexType::Ellipsis) {}
     119       776042 :   TensorIndex(const char* str) : TensorIndex(at::indexing::Ellipsis) {
     120       776042 :     TORCH_CHECK_VALUE(
     121              :         strcmp(str, "...") == 0,
     122              :         "Expected \"...\" to represent an ellipsis index, but got \"",
     123              :         str,
     124              :         "\"");
     125       776042 :   }
     126              : 
     127              :   // Case 3: (Sym) Integer value
     128     36171516 :   TensorIndex(SymInt integer)
     129     36171516 :       : integer_(std::move(integer)), type_(TensorIndexType::SymInt) {}
     130              :   TensorIndex(int64_t integer) : TensorIndex(SymInt(integer)) {}
     131     36171516 :   TensorIndex(int integer) : TensorIndex(SymInt(integer)) {}
     132              : 
     133              :   // Case 4: Boolean value
     134              :   template <
     135              :       class T,
     136              :       class = typename std::enable_if<std::is_same<bool, T>::value>::type>
     137              :   TensorIndex(T boolean) : boolean_(boolean), type_(TensorIndexType::Boolean) {}
     138              : 
     139              :   // Case 5: Slice represented in `at::indexing::Slice` form
     140          510 :   TensorIndex(Slice slice)
     141          510 :       : slice_(std::move(slice)), type_(TensorIndexType::Slice) {}
     142              : 
     143              :   // Case 6: Tensor value
     144        27136 :   TensorIndex(Tensor tensor)
     145        27136 :       : tensor_(std::move(tensor)), type_(TensorIndexType::Tensor) {}
     146              : 
     147              :   inline bool is_none() const {
     148              :     return type_ == TensorIndexType::None;
     149              :   }
     150              : 
     151              :   inline bool is_ellipsis() const {
     152              :     return type_ == TensorIndexType::Ellipsis;
     153              :   }
     154              : 
     155              :   inline bool is_integer() const {
     156              :     return type_ == TensorIndexType::SymInt;
     157              :   }
     158              : 
     159              :   inline SymInt integer() const {
     160              :     return integer_;
     161              :   }
     162              : 
     163              :   inline bool is_boolean() const {
     164              :     return type_ == TensorIndexType::Boolean;
     165              :   }
     166              : 
     167              :   inline bool boolean() const {
     168              :     return boolean_;
     169              :   }
     170              : 
     171              :   inline bool is_slice() const {
     172              :     return type_ == TensorIndexType::Slice;
     173              :   }
     174              : 
     175              :   inline const Slice& slice() const {
     176              :     return slice_;
     177              :   }
     178              : 
     179              :   inline bool is_tensor() const {
     180              :     return type_ == TensorIndexType::Tensor;
     181              :   }
     182              : 
     183              :   inline const Tensor& tensor() const {
     184              :     return tensor_;
     185              :   }
     186              : 
     187              :  private:
     188              :   SymInt integer_ = 0;
     189              :   bool boolean_ = false;
     190              :   Slice slice_;
     191              :   Tensor tensor_;
     192              :   TensorIndexType type_;
     193              : };
     194              : 
     195              : TORCH_API std::ostream& operator<<(
     196              :     std::ostream& stream,
     197              :     const TensorIndex& tensor_index);
     198              : TORCH_API std::ostream& operator<<(
     199              :     std::ostream& stream,
     200              :     const std::vector<TensorIndex>& tensor_indices);
     201              : 
     202              : namespace impl {
     203              : static inline Tensor applySlice(
     204              :     const Tensor& self,
     205              :     int64_t dim,
     206              :     c10::SymInt start,
     207              :     c10::SymInt stop,
     208              :     c10::SymInt step,
     209              :     bool disable_slice_optimization,
     210              :     const at::Device& self_device,
     211              :     const c10::optional<SymIntArrayRef>& self_sizes) {
     212              :   // TODO: implement negative step
     213              :   TORCH_CHECK_VALUE(step > 0, "step must be greater than zero");
     214              : 
     215              :   // See NOTE [nested tensor size for indexing]
     216              :   if (self_sizes.has_value()) {
     217              :     // Skip this optimization if we are tracing, as the trace may be polymorphic
     218              :     // over the shape of the `self` tensor, and we still want to record
     219              :     // the slice.
     220              :     SymInt length = (self_device == at::kCPU || self_device == at::kCUDA)
     221              :         ? (*self_sizes)[dim]
     222              :         : self.sym_size(dim);
     223              :     if (!disable_slice_optimization && start == 0 && length == stop &&
     224              :         step == 1) {
     225              :       return self;
     226              :     }
     227              :   }
     228              :   return self.slice_symint(dim, start, stop, std::move(step));
     229              : }
     230              : 
     231              : static inline Tensor applySelect(
     232              :     const Tensor& self,
     233              :     int64_t dim,
     234              :     SymInt index,
     235              :     int64_t real_dim,
     236              :     const at::Device& /*self_device*/,
     237              :     const c10::optional<SymIntArrayRef>& self_sizes) {
     238              :   // See NOTE [nested tensor size for indexing]
     239              :   if (self_sizes.has_value()) {
     240              :     auto maybe_index = index.maybe_as_int();
     241              :     if (maybe_index.has_value()) {
     242              :       TORCH_CHECK_INDEX(
     243              :           !(maybe_index.value() == 0 && dim == 0 && self_sizes->empty()),
     244              :           "invalid index of a 0-dim tensor. ",
     245              :           "Use `tensor.item()` in Python or `tensor.item<T>()` in C++ to convert a 0-dim tensor to a number");
     246              :     }
     247              : 
     248              :     auto size = (*self_sizes)[dim];
     249              :     TORCH_CHECK_INDEX(
     250              :         size >= -index && size > index,
     251              :         "index ",
     252              :         index,
     253              :         " is out of bounds for dimension ",
     254              :         real_dim,
     255              :         " with size ",
     256              :         size);
     257              :   }
     258              : 
     259              :   // if the index is negative, do not normalize it because that would fix the
     260              :   // index on the current tensor size in the tracer. aten::select also works on
     261              :   // negative indices
     262              :   return self.select_symint(dim, index);
     263              : }
     264              : 
     265              : static inline Tensor boolToIndexingTensorCPUOrCUDA(
     266              :     const Tensor& self,
     267              :     bool value) {
     268              :   // booleans add a dimension of size 1. true indexes this dimension as if 0:,
     269              :   // false as empty.
     270              :   if (value) {
     271              :     return at::empty({1}, {}, self.options().dtype(kLong)).fill_(0.);
     272              :   } else {
     273              :     return at::empty({0}, {}, self.options().dtype(kLong));
     274              :   }
     275              : }
     276              : 
     277              : static inline Tensor boolToIndexingTensorNonNativeDeviceType(
     278              :     const Tensor& self,
     279              :     bool value) {
     280              :   // booleans add a dimension of size 1. true indexes this dimension as if 0:,
     281              :   // false as empty.
     282              :   if (value) {
     283              :     return at::zeros({1}, {}, self.options().dtype(kLong));
     284              :   } else {
     285              :     return at::empty({0}, {}, self.options().dtype(kLong));
     286              :   }
     287              : }
     288              : 
     289              : static inline Tensor boolToIndexingTensor(
     290              :     const Tensor& self,
     291              :     bool value,
     292              :     const at::Device& self_device) {
     293              :   if (self_device == at::kCPU || self_device == at::kCUDA) {
     294              :     return boolToIndexingTensorCPUOrCUDA(self, value);
     295              :   } else {
     296              :     return boolToIndexingTensorNonNativeDeviceType(self, value);
     297              :   }
     298              : }
     299              : 
     300              : static inline Tensor scalarToTensorNonNativeDeviceType(
     301              :     const Scalar& v,
     302              :     const TensorOptions& options) {
     303              :   return at::scalar_tensor(v, options);
     304              : }
     305              : 
     306              : static inline void recordTensorIndex(
     307              :     const Tensor& tensor,
     308              :     std::vector<Tensor>& outIndices,
     309              :     int64_t* dim_ptr) {
     310              :   // TODO: check scalarType
     311              :   outIndices.resize(*dim_ptr + 1);
     312              :   outIndices[*dim_ptr] = tensor;
     313              :   (*dim_ptr)++;
     314              : };
     315              : 
     316              : static inline c10::List<c10::optional<Tensor>> typeConvertIndices(
     317              :     const Tensor& /*self*/,
     318              :     std::vector<Tensor>&& indices) {
     319              :   c10::List<c10::optional<Tensor>> converted_inds;
     320              :   converted_inds.reserve(indices.size());
     321              :   for (const auto& i : indices) {
     322              :     converted_inds.push_back(std::move(i));
     323              :   }
     324              :   return converted_inds;
     325              : }
     326              : 
     327              : // NOTE: Why do we mirror instead of replace the `count_specified_dimensions`
     328              : // function in torch/csrc/autograd/python_variable_indexing.cpp? It's because
     329              : // `count_specified_dimensions` is on the hot path of Python tensor multi-dim
     330              : // indexing (i.e. it's called by `applySlicing` which is called by
     331              : // `THPVariable_getitem` / `THPVariable_setitem` when handling indexing of more
     332              : // than one dimension). If we were to merge the Python/C++
     333              : // `count_specified_dimensions` function, on the Python side we would have to
     334              : // construct a `std::vector` container to be consumed by the C++
     335              : // `count_specified_dimensions` function, which adds 100s of nanoseconds
     336              : // overhead and is undesirable.
     337              : static inline int64_t count_specified_dimensions(
     338              :     const ArrayRef<TensorIndex>& indices) {
     339              :   // Count the number of indexed dimensions (everything but ellipsis and None)
     340              :   int64_t count = 0;
     341              :   for (auto& obj : indices) {
     342              :     if (obj.is_tensor()) {
     343              :       auto& tensor = obj.tensor();
     344              :       if (tensor.scalar_type() == kByte || tensor.scalar_type() == kBool) {
     345              :         count += tensor.dim();
     346              :       } else {
     347              :         count++;
     348              :       }
     349              :     } else if (!obj.is_none() && !obj.is_ellipsis() && !obj.is_boolean()) {
     350              :       count++;
     351              :     }
     352              :   }
     353              :   return count;
     354              : }
     355              : } // namespace impl
     356              : 
     357              : // NOTE: Many functions below are only for consumption from Python indexing
     358              : // implementation, they include:
     359              : //
     360              : // - `Tensor scalarToTensor(...)`
     361              : // - `IntArrayRef slicePrefix1sSize(...)`
     362              : // - `void copy_to(...)`
     363              : // - `Tensor handleDimInMultiDimIndexing(...)`
     364              : // - `Tensor dispatch_index(...)`
     365              : // - `Tensor dispatch_index_put_(...)`
     366              : // - `Tensor get_item(...)`
     367              : // - `void set_item(...)`
     368              : //
     369              : // The rest of the functions are in `at::indexing::impl` namespace, signifying
     370              : // that they shouldn't be used from Python indexing implementation.
     371              : static inline Tensor scalarToTensor(
     372              :     const Scalar& v,
     373              :     const TensorOptions& options,
     374              :     const at::Device& self_device) {
     375              :   if (self_device == at::kCPU) {
     376              :     return at::detail::scalar_tensor_static(
     377              :         v, options.dtype_opt()->toScalarType(), self_device);
     378              :   } else {
     379              :     return impl::scalarToTensorNonNativeDeviceType(v, options);
     380              :   }
     381              : }
     382              : 
     383              : // To match numpy semantics:
     384              : // As a special case for backwards compatibility,
     385              : // strip away unit dimensions from the left of 'src'
     386              : static inline SymIntArrayRef slicePrefix1sSize(const SymIntArrayRef& sizes) {
     387              :   size_t first_non1_src = sizes.size();
     388              :   for (const auto i : c10::irange(sizes.size())) {
     389              :     // Unbacked SymInt has different behavior, but this is sound because
     390              :     // failing to slice will only ever cause an error, not divergent
     391              :     // behavior
     392              :     if (!sizes[i].has_hint() || sizes[i] != 1) {
     393              :       first_non1_src = i;
     394              :       break;
     395              :     }
     396              :   }
     397              : 
     398              :   return sizes.slice(first_non1_src);
     399              : }
     400              : 
     401              : static inline void copy_to(const Tensor& dst, const Tensor& src) {
     402              :   if (dst.sym_sizes().equals(src.sym_sizes())) {
     403              :     // A shortcut to avoid generating hard-coded constant sizes during tracing.
     404              :     // This is not a perfect solution: when src & dst have different shapes,
     405              :     // constants will still appear. Users can workaround that case by
     406              :     // dst[index..] = src.reshape(..)
     407              :     dst.copy_(src);
     408              :     return;
     409              :   } else if (src.dim() == 0 && src.device().type() == at::kCPU) {
     410              :     dst.fill_(src);
     411              :     return;
     412              :   }
     413              :   auto src_view = src.view_symint(slicePrefix1sSize(src.sym_sizes()));
     414              :   c10::MaybeOwned<Tensor> b_src = expand_inplace(dst, src_view, "setitem");
     415              :   dst.copy_(*b_src);
     416              : }
     417              : 
     418              : // See NOTE [ Setting `disable_slice_optimization` when calling C++ tensor
     419              : // indexing functions from Python ]
     420              : static inline Tensor handleDimInMultiDimIndexing(
     421              :     const Tensor& prev_dim_result,
     422              :     const Tensor& original_tensor,
     423              :     const TensorIndex& index,
     424              :     int64_t* dim_ptr,
     425              :     int64_t* specified_dims_ptr,
     426              :     int64_t real_dim,
     427              :     std::vector<Tensor>& outIndices,
     428              :     bool disable_slice_optimization,
     429              :     const at::Device& original_tensor_device,
     430              :     const c10::optional<SymIntArrayRef>& prev_dim_result_sizes) {
     431              :   if (index.is_integer()) {
     432              :     return impl::applySelect(
     433              :         prev_dim_result,
     434              :         *dim_ptr,
     435              :         index.integer(),
     436              :         real_dim,
     437              :         original_tensor_device,
     438              :         prev_dim_result_sizes);
     439              :   } else if (index.is_slice()) {
     440              :     Tensor result = impl::applySlice(
     441              :         prev_dim_result,
     442              :         *dim_ptr,
     443              :         index.slice().start(),
     444              :         index.slice().stop(),
     445              :         index.slice().step(),
     446              :         /*disable_slice_optimization=*/disable_slice_optimization,
     447              :         original_tensor_device,
     448              :         prev_dim_result_sizes);
     449              :     (*dim_ptr)++;
     450              :     return result;
     451              :   } else if (index.is_ellipsis()) {
     452              :     (*dim_ptr) += original_tensor.dim() - (*specified_dims_ptr);
     453              :     return prev_dim_result;
     454              :   } else if (index.is_none()) {
     455              :     Tensor result = prev_dim_result.unsqueeze(*dim_ptr);
     456              :     (*dim_ptr)++;
     457              :     return result;
     458              :   } else if (index.is_boolean()) {
     459              :     Tensor result = prev_dim_result.unsqueeze(*dim_ptr);
     460              :     impl::recordTensorIndex(
     461              :         impl::boolToIndexingTensor(
     462              :             result, index.boolean(), original_tensor_device),
     463              :         outIndices,
     464              :         dim_ptr);
     465              :     return result;
     466              :   } else if (index.is_tensor()) {
     467              :     Tensor result = prev_dim_result;
     468              :     const Tensor& tensor = index.tensor();
     469              :     auto scalar_type = tensor.scalar_type();
     470              :     if (tensor.dim() == 0 &&
     471              :         at::isIntegralType(scalar_type, /*includeBool=*/true)) {
     472              :       if (scalar_type != at::kByte && scalar_type != at::kBool) {
     473              :         result = impl::applySelect(
     474              :             result,
     475              :             *dim_ptr,
     476              :             tensor.item<int64_t>(),
     477              :             real_dim,
     478              :             original_tensor_device,
     479              :             prev_dim_result_sizes);
     480              :       } else {
     481              :         result = result.unsqueeze(*dim_ptr);
     482              :         if (scalar_type == at::kBool) {
     483              :           impl::recordTensorIndex(
     484              :               impl::boolToIndexingTensor(
     485              :                   result, tensor.item<bool>() != 0, original_tensor_device),
     486              :               outIndices,
     487              :               dim_ptr);
     488              :         } else {
     489              :           impl::recordTensorIndex(
     490              :               impl::boolToIndexingTensor(
     491              :                   result, tensor.item<uint8_t>() != 0, original_tensor_device),
     492              :               outIndices,
     493              :               dim_ptr);
     494              :         }
     495              :       }
     496              :     } else {
     497              :       impl::recordTensorIndex(tensor, outIndices, dim_ptr);
     498              :     }
     499              :     return result;
     500              :   } else {
     501              :     TORCH_INTERNAL_ASSERT(false, "Invalid TensorIndex type");
     502              :   }
     503              : }
     504              : 
     505              : namespace impl {
     506              : // This mirrors `applySlicing` in
     507              : // torch/csrc/autograd/python_variable_indexing.cpp
     508              : static inline Tensor applySlicing(
     509              :     const Tensor& self,
     510              :     const ArrayRef<TensorIndex>& indices,
     511              :     std::vector<Tensor>& outIndices,
     512              :     bool disable_slice_optimization,
     513              :     const at::Device& self_device,
     514              :     const c10::optional<SymIntArrayRef>& self_sizes) {
     515              :   int64_t dim = 0;
     516              :   int64_t specified_dims = impl::count_specified_dimensions(indices);
     517              : 
     518              :   // See NOTE [nested tensor size for indexing]
     519              :   if (self_sizes.has_value()) {
     520              :     TORCH_CHECK_INDEX(
     521              :         specified_dims <= (int64_t)self_sizes->size(),
     522              :         "too many indices for tensor of dimension ",
     523              :         (int)self_sizes->size());
     524              :   }
     525              : 
     526              :   Tensor result = self;
     527              :   for (const auto i : c10::irange(indices.size())) {
     528              :     auto& obj = indices[i];
     529              :     // See NOTE [nested tensor size for indexing]
     530              :     c10::optional<SymIntArrayRef> result_sizes = result.is_nested()
     531              :         ? c10::optional<SymIntArrayRef>(c10::nullopt)
     532              :         : c10::optional<SymIntArrayRef>(result.sym_sizes());
     533              :     result = handleDimInMultiDimIndexing(
     534              :         /*prev_dim_result=*/result,
     535              :         /*original_tensor=*/self,
     536              :         /*index=*/obj,
     537              :         /*dim=*/&dim,
     538              :         /*specified_dims=*/&specified_dims,
     539              :         /*real_dim=*/i,
     540              :         /*outIndices=*/outIndices,
     541              :         /*disable_slice_optimization=*/disable_slice_optimization,
     542              :         /*original_tensor_device=*/self_device,
     543              :         /*prev_dim_result_sizes=*/result_sizes);
     544              :   }
     545              :   return result;
     546              : }
     547              : } // namespace impl
     548              : 
     549              : static inline Tensor dispatch_index(
     550              :     const Tensor& self,
     551              :     std::vector<Tensor>&& indices) {
     552              :   return self.index(impl::typeConvertIndices(self, std::move(indices)));
     553              : }
     554              : 
     555              : static inline Tensor dispatch_index_put_(
     556              :     Tensor& self,
     557              :     std::vector<Tensor>&& indices,
     558              :     const Tensor& value) {
     559              :   return self.index_put_(
     560              :       impl::typeConvertIndices(self, std::move(indices)), value);
     561              : }
     562              : 
     563              : // NOTE [ Setting `disable_slice_optimization` when calling C++ tensor indexing
     564              : // functions from Python ]
     565              : //
     566              : // Question: When should we set `disable_slice_optimization` to `true` when
     567              : // calling C++ tensor indexing functions from Python indexing code?
     568              : //
     569              : // Answer: What "slice optimization" means: when we have a slicing expression
     570              : // like `x[0:5, 0]`, where the sliced tensor was of size 5 in dimension 0, we
     571              : // would skip dispatching the actual slice call as an optimization. However,
     572              : // here are the cases where we DON'T want this optimization:
     573              : //
     574              : // 1. When we are doing 1-D slicing (e.g. `tensor[:]`).
     575              : //    Reason: we always return a shallow copy for expressions such as
     576              : //    `tensor[:]` / `tensor[...]` / `tensor[:, :]`. (Note that for `tensor[:,
     577              : //    :]`, we return an alias of `tensor` by doing the following:
     578              : //    ```
     579              : //    Tensor sliced = impl::applySlicing(self, indices, tensorIndices,
     580              : //    disable_slice_optimization, self_device, self_sizes); if
     581              : //    (tensorIndices.empty()) {
     582              : //      if (sliced.is_same(self)) {
     583              : //        // ensure we return a shallow copy for things like x[...]
     584              : //        sliced = at::alias(sliced);
     585              : //      }
     586              : //      return sliced;
     587              : //    }
     588              : //    ```)
     589              : // 2. When we are doing JIT tracing.
     590              : //    Reason: JIT tracing needs the `self.slice(...)` call to properly trace the
     591              : //    slice operation.
     592              : 
     593              : // This mirrors `THPVariable_getitem` in
     594              : // torch/csrc/autograd/python_variable_indexing.cpp See NOTE [ Setting
     595              : // `disable_slice_optimization` when calling C++ tensor indexing functions from
     596              : // Python ]
     597              : static inline Tensor get_item(
     598              :     const Tensor& self,
     599              :     const ArrayRef<TensorIndex>& indices,
     600              :     bool disable_slice_optimization = false) {
     601              :   at::Device self_device = self.device();
     602              :   // NOTE [nested tensor size for indexing]
     603              :   // nested tensor does not have a size (yet) so for now we represent its size
     604              :   // as null may need to be changed after we reach a better solution for nested
     605              :   // tensor size
     606              :   c10::optional<SymIntArrayRef> self_sizes = self.is_nested()
     607              :       ? c10::optional<SymIntArrayRef>(c10::nullopt)
     608              :       : c10::optional<SymIntArrayRef>(self.sym_sizes());
     609              : 
     610              :   // handle simple types: integers, slices, none, ellipsis, bool
     611              :   if (indices.size() == 1) {
     612              :     const TensorIndex& index = indices[0];
     613              :     if (index.is_integer()) {
     614              :       return impl::applySelect(
     615              :           self, 0, index.integer(), 0, self_device, self_sizes);
     616              :     } else if (index.is_slice()) {
     617              :       return impl::applySlice(
     618              :           self,
     619              :           0,
     620              :           index.slice().start(),
     621              :           index.slice().stop(),
     622              :           index.slice().step(),
     623              :           /*disable_slice_optimization=*/true,
     624              :           self_device,
     625              :           self_sizes);
     626              :     } else if (index.is_none()) {
     627              :       return self.unsqueeze(0);
     628              :     } else if (index.is_ellipsis()) {
     629              :       return at::alias(self);
     630              :     } else if (index.is_boolean()) {
     631              :       Tensor result = self.unsqueeze(0);
     632              :       return dispatch_index(
     633              :           result,
     634              :           std::vector<Tensor>{impl::boolToIndexingTensor(
     635              :               result, index.boolean(), self_device)});
     636              :     }
     637              :   }
     638              : 
     639              :   std::vector<Tensor> tensorIndices;
     640              :   Tensor sliced = impl::applySlicing(
     641              :       self,
     642              :       indices,
     643              :       tensorIndices,
     644              :       disable_slice_optimization,
     645              :       self_device,
     646              :       self_sizes);
     647              :   if (tensorIndices.empty()) {
     648              :     if (sliced.is_same(self)) {
     649              :       // ensure we return a shallow copy for things like x[...]
     650              :       sliced = at::alias(sliced);
     651              :     }
     652              :     return sliced;
     653              :   }
     654              : 
     655              :   // indexing by tensors ("advanced" indexing)
     656              :   return dispatch_index(sliced, std::move(tensorIndices));
     657              : }
     658              : 
     659              : // This mirrors `THPVariable_setitem` in
     660              : // torch/csrc/autograd/python_variable_indexing.cpp for "the assigned value is a
     661              : // Tensor" case See NOTE [ Setting `disable_slice_optimization` when calling C++
     662              : // tensor indexing functions from Python ]
     663              : static inline void set_item(
     664              :     const Tensor& self,
     665              :     const ArrayRef<TensorIndex>& indices,
     666              :     const Tensor& value,
     667              :     bool disable_slice_optimization = false) {
     668              :   at::Device self_device = self.device();
     669              :   SymIntArrayRef self_sizes = self.sym_sizes();
     670              : 
     671              :   // handle simple types: integers, slices, ellipsis, bool
     672              :   if (indices.size() == 1) {
     673              :     const TensorIndex& index = indices[0];
     674              :     if (index.is_boolean() && !index.boolean()) {
     675              :       // do nothing for false (technically we should check the size, but we
     676              :       // don't have real 0-sized shapes.
     677              :       return;
     678              :     } else if (index.is_ellipsis()) {
     679              :       copy_to(self, value);
     680              :       return;
     681              :     } else if (index.is_none() || (index.is_boolean() && index.boolean())) {
     682              :       copy_to(self.unsqueeze(0), value);
     683              :       return;
     684              :     } else if (index.is_integer()) {
     685              :       copy_to(
     686              :           impl::applySelect(
     687              :               self, 0, index.integer(), 0, self_device, self_sizes),
     688              :           value);
     689              :       return;
     690              :     } else if (index.is_slice()) {
     691              :       copy_to(
     692              :           impl::applySlice(
     693              :               self,
     694              :               0,
     695              :               index.slice().start(),
     696              :               index.slice().stop(),
     697              :               index.slice().step(),
     698              :               /*disable_slice_optimization=*/disable_slice_optimization,
     699              :               self_device,
     700              :               self_sizes),
     701              :           value);
     702              :       return;
     703              :     }
     704              :   }
     705              : 
     706              :   std::vector<Tensor> tensorIndices;
     707              :   Tensor sliced = impl::applySlicing(
     708              :       self,
     709              :       indices,
     710              :       tensorIndices,
     711              :       disable_slice_optimization,
     712              :       self_device,
     713              :       self_sizes);
     714              :   if (tensorIndices.empty()) {
     715              :     copy_to(sliced, value);
     716              :     return;
     717              :   }
     718              : 
     719              :   SymIntArrayRef valueSizes = value.sym_sizes();
     720              :   SymIntArrayRef slicedValueSizes = slicePrefix1sSize(valueSizes);
     721              :   Tensor valuesSliced;
     722              :   if (!valueSizes.equals(slicedValueSizes)) {
     723              :     valuesSliced = value.view_symint(slicedValueSizes);
     724              :   } else {
     725              :     valuesSliced = value;
     726              :   }
     727              :   dispatch_index_put_(sliced, std::move(tensorIndices), valuesSliced);
     728              :   return;
     729              : }
     730              : 
     731              : } // namespace indexing
     732              : } // namespace at
        

Generated by: LCOV version 2.0-1