Line data Source code
1 : #pragma once
2 :
3 : #include <ATen/ExpandUtils.h>
4 : #include <ATen/ScalarOps.h>
5 : #include <ATen/core/Tensor.h>
6 : #include <ATen/core/TensorBody.h>
7 : #include <c10/core/SymInt.h>
8 : #include <c10/util/Optional.h>
9 : #include <c10/util/irange.h>
10 :
11 : #ifndef AT_PER_OPERATOR_HEADERS
12 : #include <ATen/Functions.h>
13 : #include <ATen/NativeFunctions.h>
14 : #else
15 : #include <ATen/ops/alias.h>
16 : #include <ATen/ops/empty.h>
17 : #include <ATen/ops/scalar_tensor.h>
18 : #include <ATen/ops/zeros.h>
19 : #endif
20 :
21 : #include <ATen/core/List.h>
22 :
23 : #include <utility>
24 :
25 : namespace at {
26 : namespace indexing {
27 :
28 : const int64_t INDEX_MIN = c10::SymInt::min_representable_int();
29 : const int64_t INDEX_MAX = -(INDEX_MIN + 1);
30 :
31 : enum class TensorIndexType { None, Ellipsis, SymInt, Boolean, Slice, Tensor };
32 :
33 : constexpr c10::nullopt_t None = c10::nullopt;
34 :
35 : struct TORCH_API EllipsisIndexType final {
36 : EllipsisIndexType() = default;
37 : };
38 : TORCH_API extern const EllipsisIndexType Ellipsis;
39 :
40 : struct TORCH_API Slice final {
41 : public:
42 36975204 : Slice(
43 : c10::optional<c10::SymInt> start_index = c10::nullopt,
44 : c10::optional<c10::SymInt> stop_index = c10::nullopt,
45 36975204 : c10::optional<c10::SymInt> step_index = c10::nullopt) {
46 36975204 : if (!step_index.has_value()) {
47 36975204 : step_ = c10::SymInt(1);
48 : } else {
49 0 : step_ = std::move(step_index).value();
50 : }
51 :
52 36975204 : TORCH_CHECK_VALUE(step_ != 0, "slice step cannot be zero");
53 :
54 36975204 : if (!start_index.has_value()) {
55 36974694 : start_ = c10::SymInt(step_ < 0 ? INDEX_MAX : 0);
56 : } else {
57 510 : start_ = std::move(start_index).value();
58 : }
59 :
60 36975204 : if (!stop_index.has_value()) {
61 36974694 : stop_ = c10::SymInt(step_ < 0 ? INDEX_MIN : INDEX_MAX);
62 : } else {
63 510 : stop_ = std::move(stop_index).value();
64 : }
65 36975204 : }
66 :
67 : inline c10::SymInt start() const {
68 : return start_;
69 : }
70 :
71 : inline c10::SymInt stop() const {
72 : return stop_;
73 : }
74 :
75 : inline c10::SymInt step() const {
76 : return step_;
77 : }
78 :
79 : private:
80 : c10::SymInt start_;
81 : c10::SymInt stop_;
82 : c10::SymInt step_;
83 : };
84 :
85 : TORCH_API std::ostream& operator<<(std::ostream& stream, const Slice& slice);
86 :
87 : // `at::indexing::TensorIndex` is used for converting C++ tensor indices such as
88 : // `{None, "...", Ellipsis, 0, true, Slice(1, None, 2), torch::tensor({1, 2})}`
89 : // into its equivalent `std::vector<TensorIndex>`, so that further tensor
90 : // indexing operations can be performed using the supplied indices.
91 : //
92 : // There is one-to-one correspondence between Python and C++ tensor index types:
93 : // Python | C++
94 : // -----------------------------------------------------
95 : // `None` | `at::indexing::None`
96 : // `Ellipsis` | `at::indexing::Ellipsis`
97 : // `...` | `"..."`
98 : // `123` | `123`
99 : // `True` / `False` | `true` / `false`
100 : // `:` | `Slice()` / `Slice(None, None)`
101 : // `::` | `Slice()` / `Slice(None, None, None)`
102 : // `1:` | `Slice(1, None)`
103 : // `1::` | `Slice(1, None, None)`
104 : // `:3` | `Slice(None, 3)`
105 : // `:3:` | `Slice(None, 3, None)`
106 : // `::2` | `Slice(None, None, 2)`
107 : // `1:3` | `Slice(1, 3)`
108 : // `1::2` | `Slice(1, None, 2)`
109 : // `:3:2` | `Slice(None, 3, 2)`
110 : // `1:3:2` | `Slice(1, 3, 2)`
111 : // `torch.tensor([1, 2])`) | `torch::tensor({1, 2})`
112 : struct TORCH_API TensorIndex final {
113 : // Case 1: `at::indexing::None`
114 : TensorIndex(c10::nullopt_t) : type_(TensorIndexType::None) {}
115 :
116 : // Case 2: "..." / `at::indexing::Ellipsis`
117 776042 : TensorIndex(at::indexing::EllipsisIndexType)
118 776042 : : type_(TensorIndexType::Ellipsis) {}
119 776042 : TensorIndex(const char* str) : TensorIndex(at::indexing::Ellipsis) {
120 776042 : TORCH_CHECK_VALUE(
121 : strcmp(str, "...") == 0,
122 : "Expected \"...\" to represent an ellipsis index, but got \"",
123 : str,
124 : "\"");
125 776042 : }
126 :
127 : // Case 3: (Sym) Integer value
128 36171516 : TensorIndex(SymInt integer)
129 36171516 : : integer_(std::move(integer)), type_(TensorIndexType::SymInt) {}
130 : TensorIndex(int64_t integer) : TensorIndex(SymInt(integer)) {}
131 36171516 : TensorIndex(int integer) : TensorIndex(SymInt(integer)) {}
132 :
133 : // Case 4: Boolean value
134 : template <
135 : class T,
136 : class = typename std::enable_if<std::is_same<bool, T>::value>::type>
137 : TensorIndex(T boolean) : boolean_(boolean), type_(TensorIndexType::Boolean) {}
138 :
139 : // Case 5: Slice represented in `at::indexing::Slice` form
140 510 : TensorIndex(Slice slice)
141 510 : : slice_(std::move(slice)), type_(TensorIndexType::Slice) {}
142 :
143 : // Case 6: Tensor value
144 27136 : TensorIndex(Tensor tensor)
145 27136 : : tensor_(std::move(tensor)), type_(TensorIndexType::Tensor) {}
146 :
147 : inline bool is_none() const {
148 : return type_ == TensorIndexType::None;
149 : }
150 :
151 : inline bool is_ellipsis() const {
152 : return type_ == TensorIndexType::Ellipsis;
153 : }
154 :
155 : inline bool is_integer() const {
156 : return type_ == TensorIndexType::SymInt;
157 : }
158 :
159 : inline SymInt integer() const {
160 : return integer_;
161 : }
162 :
163 : inline bool is_boolean() const {
164 : return type_ == TensorIndexType::Boolean;
165 : }
166 :
167 : inline bool boolean() const {
168 : return boolean_;
169 : }
170 :
171 : inline bool is_slice() const {
172 : return type_ == TensorIndexType::Slice;
173 : }
174 :
175 : inline const Slice& slice() const {
176 : return slice_;
177 : }
178 :
179 : inline bool is_tensor() const {
180 : return type_ == TensorIndexType::Tensor;
181 : }
182 :
183 : inline const Tensor& tensor() const {
184 : return tensor_;
185 : }
186 :
187 : private:
188 : SymInt integer_ = 0;
189 : bool boolean_ = false;
190 : Slice slice_;
191 : Tensor tensor_;
192 : TensorIndexType type_;
193 : };
194 :
195 : TORCH_API std::ostream& operator<<(
196 : std::ostream& stream,
197 : const TensorIndex& tensor_index);
198 : TORCH_API std::ostream& operator<<(
199 : std::ostream& stream,
200 : const std::vector<TensorIndex>& tensor_indices);
201 :
202 : namespace impl {
203 : static inline Tensor applySlice(
204 : const Tensor& self,
205 : int64_t dim,
206 : c10::SymInt start,
207 : c10::SymInt stop,
208 : c10::SymInt step,
209 : bool disable_slice_optimization,
210 : const at::Device& self_device,
211 : const c10::optional<SymIntArrayRef>& self_sizes) {
212 : // TODO: implement negative step
213 : TORCH_CHECK_VALUE(step > 0, "step must be greater than zero");
214 :
215 : // See NOTE [nested tensor size for indexing]
216 : if (self_sizes.has_value()) {
217 : // Skip this optimization if we are tracing, as the trace may be polymorphic
218 : // over the shape of the `self` tensor, and we still want to record
219 : // the slice.
220 : SymInt length = (self_device == at::kCPU || self_device == at::kCUDA)
221 : ? (*self_sizes)[dim]
222 : : self.sym_size(dim);
223 : if (!disable_slice_optimization && start == 0 && length == stop &&
224 : step == 1) {
225 : return self;
226 : }
227 : }
228 : return self.slice_symint(dim, start, stop, std::move(step));
229 : }
230 :
231 : static inline Tensor applySelect(
232 : const Tensor& self,
233 : int64_t dim,
234 : SymInt index,
235 : int64_t real_dim,
236 : const at::Device& /*self_device*/,
237 : const c10::optional<SymIntArrayRef>& self_sizes) {
238 : // See NOTE [nested tensor size for indexing]
239 : if (self_sizes.has_value()) {
240 : auto maybe_index = index.maybe_as_int();
241 : if (maybe_index.has_value()) {
242 : TORCH_CHECK_INDEX(
243 : !(maybe_index.value() == 0 && dim == 0 && self_sizes->empty()),
244 : "invalid index of a 0-dim tensor. ",
245 : "Use `tensor.item()` in Python or `tensor.item<T>()` in C++ to convert a 0-dim tensor to a number");
246 : }
247 :
248 : auto size = (*self_sizes)[dim];
249 : TORCH_CHECK_INDEX(
250 : size >= -index && size > index,
251 : "index ",
252 : index,
253 : " is out of bounds for dimension ",
254 : real_dim,
255 : " with size ",
256 : size);
257 : }
258 :
259 : // if the index is negative, do not normalize it because that would fix the
260 : // index on the current tensor size in the tracer. aten::select also works on
261 : // negative indices
262 : return self.select_symint(dim, index);
263 : }
264 :
265 : static inline Tensor boolToIndexingTensorCPUOrCUDA(
266 : const Tensor& self,
267 : bool value) {
268 : // booleans add a dimension of size 1. true indexes this dimension as if 0:,
269 : // false as empty.
270 : if (value) {
271 : return at::empty({1}, {}, self.options().dtype(kLong)).fill_(0.);
272 : } else {
273 : return at::empty({0}, {}, self.options().dtype(kLong));
274 : }
275 : }
276 :
277 : static inline Tensor boolToIndexingTensorNonNativeDeviceType(
278 : const Tensor& self,
279 : bool value) {
280 : // booleans add a dimension of size 1. true indexes this dimension as if 0:,
281 : // false as empty.
282 : if (value) {
283 : return at::zeros({1}, {}, self.options().dtype(kLong));
284 : } else {
285 : return at::empty({0}, {}, self.options().dtype(kLong));
286 : }
287 : }
288 :
289 : static inline Tensor boolToIndexingTensor(
290 : const Tensor& self,
291 : bool value,
292 : const at::Device& self_device) {
293 : if (self_device == at::kCPU || self_device == at::kCUDA) {
294 : return boolToIndexingTensorCPUOrCUDA(self, value);
295 : } else {
296 : return boolToIndexingTensorNonNativeDeviceType(self, value);
297 : }
298 : }
299 :
300 : static inline Tensor scalarToTensorNonNativeDeviceType(
301 : const Scalar& v,
302 : const TensorOptions& options) {
303 : return at::scalar_tensor(v, options);
304 : }
305 :
306 : static inline void recordTensorIndex(
307 : const Tensor& tensor,
308 : std::vector<Tensor>& outIndices,
309 : int64_t* dim_ptr) {
310 : // TODO: check scalarType
311 : outIndices.resize(*dim_ptr + 1);
312 : outIndices[*dim_ptr] = tensor;
313 : (*dim_ptr)++;
314 : };
315 :
316 : static inline c10::List<c10::optional<Tensor>> typeConvertIndices(
317 : const Tensor& /*self*/,
318 : std::vector<Tensor>&& indices) {
319 : c10::List<c10::optional<Tensor>> converted_inds;
320 : converted_inds.reserve(indices.size());
321 : for (const auto& i : indices) {
322 : converted_inds.push_back(std::move(i));
323 : }
324 : return converted_inds;
325 : }
326 :
327 : // NOTE: Why do we mirror instead of replace the `count_specified_dimensions`
328 : // function in torch/csrc/autograd/python_variable_indexing.cpp? It's because
329 : // `count_specified_dimensions` is on the hot path of Python tensor multi-dim
330 : // indexing (i.e. it's called by `applySlicing` which is called by
331 : // `THPVariable_getitem` / `THPVariable_setitem` when handling indexing of more
332 : // than one dimension). If we were to merge the Python/C++
333 : // `count_specified_dimensions` function, on the Python side we would have to
334 : // construct a `std::vector` container to be consumed by the C++
335 : // `count_specified_dimensions` function, which adds 100s of nanoseconds
336 : // overhead and is undesirable.
337 : static inline int64_t count_specified_dimensions(
338 : const ArrayRef<TensorIndex>& indices) {
339 : // Count the number of indexed dimensions (everything but ellipsis and None)
340 : int64_t count = 0;
341 : for (auto& obj : indices) {
342 : if (obj.is_tensor()) {
343 : auto& tensor = obj.tensor();
344 : if (tensor.scalar_type() == kByte || tensor.scalar_type() == kBool) {
345 : count += tensor.dim();
346 : } else {
347 : count++;
348 : }
349 : } else if (!obj.is_none() && !obj.is_ellipsis() && !obj.is_boolean()) {
350 : count++;
351 : }
352 : }
353 : return count;
354 : }
355 : } // namespace impl
356 :
357 : // NOTE: Many functions below are only for consumption from Python indexing
358 : // implementation, they include:
359 : //
360 : // - `Tensor scalarToTensor(...)`
361 : // - `IntArrayRef slicePrefix1sSize(...)`
362 : // - `void copy_to(...)`
363 : // - `Tensor handleDimInMultiDimIndexing(...)`
364 : // - `Tensor dispatch_index(...)`
365 : // - `Tensor dispatch_index_put_(...)`
366 : // - `Tensor get_item(...)`
367 : // - `void set_item(...)`
368 : //
369 : // The rest of the functions are in `at::indexing::impl` namespace, signifying
370 : // that they shouldn't be used from Python indexing implementation.
371 : static inline Tensor scalarToTensor(
372 : const Scalar& v,
373 : const TensorOptions& options,
374 : const at::Device& self_device) {
375 : if (self_device == at::kCPU) {
376 : return at::detail::scalar_tensor_static(
377 : v, options.dtype_opt()->toScalarType(), self_device);
378 : } else {
379 : return impl::scalarToTensorNonNativeDeviceType(v, options);
380 : }
381 : }
382 :
383 : // To match numpy semantics:
384 : // As a special case for backwards compatibility,
385 : // strip away unit dimensions from the left of 'src'
386 : static inline SymIntArrayRef slicePrefix1sSize(const SymIntArrayRef& sizes) {
387 : size_t first_non1_src = sizes.size();
388 : for (const auto i : c10::irange(sizes.size())) {
389 : // Unbacked SymInt has different behavior, but this is sound because
390 : // failing to slice will only ever cause an error, not divergent
391 : // behavior
392 : if (!sizes[i].has_hint() || sizes[i] != 1) {
393 : first_non1_src = i;
394 : break;
395 : }
396 : }
397 :
398 : return sizes.slice(first_non1_src);
399 : }
400 :
401 : static inline void copy_to(const Tensor& dst, const Tensor& src) {
402 : if (dst.sym_sizes().equals(src.sym_sizes())) {
403 : // A shortcut to avoid generating hard-coded constant sizes during tracing.
404 : // This is not a perfect solution: when src & dst have different shapes,
405 : // constants will still appear. Users can workaround that case by
406 : // dst[index..] = src.reshape(..)
407 : dst.copy_(src);
408 : return;
409 : } else if (src.dim() == 0 && src.device().type() == at::kCPU) {
410 : dst.fill_(src);
411 : return;
412 : }
413 : auto src_view = src.view_symint(slicePrefix1sSize(src.sym_sizes()));
414 : c10::MaybeOwned<Tensor> b_src = expand_inplace(dst, src_view, "setitem");
415 : dst.copy_(*b_src);
416 : }
417 :
418 : // See NOTE [ Setting `disable_slice_optimization` when calling C++ tensor
419 : // indexing functions from Python ]
420 : static inline Tensor handleDimInMultiDimIndexing(
421 : const Tensor& prev_dim_result,
422 : const Tensor& original_tensor,
423 : const TensorIndex& index,
424 : int64_t* dim_ptr,
425 : int64_t* specified_dims_ptr,
426 : int64_t real_dim,
427 : std::vector<Tensor>& outIndices,
428 : bool disable_slice_optimization,
429 : const at::Device& original_tensor_device,
430 : const c10::optional<SymIntArrayRef>& prev_dim_result_sizes) {
431 : if (index.is_integer()) {
432 : return impl::applySelect(
433 : prev_dim_result,
434 : *dim_ptr,
435 : index.integer(),
436 : real_dim,
437 : original_tensor_device,
438 : prev_dim_result_sizes);
439 : } else if (index.is_slice()) {
440 : Tensor result = impl::applySlice(
441 : prev_dim_result,
442 : *dim_ptr,
443 : index.slice().start(),
444 : index.slice().stop(),
445 : index.slice().step(),
446 : /*disable_slice_optimization=*/disable_slice_optimization,
447 : original_tensor_device,
448 : prev_dim_result_sizes);
449 : (*dim_ptr)++;
450 : return result;
451 : } else if (index.is_ellipsis()) {
452 : (*dim_ptr) += original_tensor.dim() - (*specified_dims_ptr);
453 : return prev_dim_result;
454 : } else if (index.is_none()) {
455 : Tensor result = prev_dim_result.unsqueeze(*dim_ptr);
456 : (*dim_ptr)++;
457 : return result;
458 : } else if (index.is_boolean()) {
459 : Tensor result = prev_dim_result.unsqueeze(*dim_ptr);
460 : impl::recordTensorIndex(
461 : impl::boolToIndexingTensor(
462 : result, index.boolean(), original_tensor_device),
463 : outIndices,
464 : dim_ptr);
465 : return result;
466 : } else if (index.is_tensor()) {
467 : Tensor result = prev_dim_result;
468 : const Tensor& tensor = index.tensor();
469 : auto scalar_type = tensor.scalar_type();
470 : if (tensor.dim() == 0 &&
471 : at::isIntegralType(scalar_type, /*includeBool=*/true)) {
472 : if (scalar_type != at::kByte && scalar_type != at::kBool) {
473 : result = impl::applySelect(
474 : result,
475 : *dim_ptr,
476 : tensor.item<int64_t>(),
477 : real_dim,
478 : original_tensor_device,
479 : prev_dim_result_sizes);
480 : } else {
481 : result = result.unsqueeze(*dim_ptr);
482 : if (scalar_type == at::kBool) {
483 : impl::recordTensorIndex(
484 : impl::boolToIndexingTensor(
485 : result, tensor.item<bool>() != 0, original_tensor_device),
486 : outIndices,
487 : dim_ptr);
488 : } else {
489 : impl::recordTensorIndex(
490 : impl::boolToIndexingTensor(
491 : result, tensor.item<uint8_t>() != 0, original_tensor_device),
492 : outIndices,
493 : dim_ptr);
494 : }
495 : }
496 : } else {
497 : impl::recordTensorIndex(tensor, outIndices, dim_ptr);
498 : }
499 : return result;
500 : } else {
501 : TORCH_INTERNAL_ASSERT(false, "Invalid TensorIndex type");
502 : }
503 : }
504 :
505 : namespace impl {
506 : // This mirrors `applySlicing` in
507 : // torch/csrc/autograd/python_variable_indexing.cpp
508 : static inline Tensor applySlicing(
509 : const Tensor& self,
510 : const ArrayRef<TensorIndex>& indices,
511 : std::vector<Tensor>& outIndices,
512 : bool disable_slice_optimization,
513 : const at::Device& self_device,
514 : const c10::optional<SymIntArrayRef>& self_sizes) {
515 : int64_t dim = 0;
516 : int64_t specified_dims = impl::count_specified_dimensions(indices);
517 :
518 : // See NOTE [nested tensor size for indexing]
519 : if (self_sizes.has_value()) {
520 : TORCH_CHECK_INDEX(
521 : specified_dims <= (int64_t)self_sizes->size(),
522 : "too many indices for tensor of dimension ",
523 : (int)self_sizes->size());
524 : }
525 :
526 : Tensor result = self;
527 : for (const auto i : c10::irange(indices.size())) {
528 : auto& obj = indices[i];
529 : // See NOTE [nested tensor size for indexing]
530 : c10::optional<SymIntArrayRef> result_sizes = result.is_nested()
531 : ? c10::optional<SymIntArrayRef>(c10::nullopt)
532 : : c10::optional<SymIntArrayRef>(result.sym_sizes());
533 : result = handleDimInMultiDimIndexing(
534 : /*prev_dim_result=*/result,
535 : /*original_tensor=*/self,
536 : /*index=*/obj,
537 : /*dim=*/&dim,
538 : /*specified_dims=*/&specified_dims,
539 : /*real_dim=*/i,
540 : /*outIndices=*/outIndices,
541 : /*disable_slice_optimization=*/disable_slice_optimization,
542 : /*original_tensor_device=*/self_device,
543 : /*prev_dim_result_sizes=*/result_sizes);
544 : }
545 : return result;
546 : }
547 : } // namespace impl
548 :
549 : static inline Tensor dispatch_index(
550 : const Tensor& self,
551 : std::vector<Tensor>&& indices) {
552 : return self.index(impl::typeConvertIndices(self, std::move(indices)));
553 : }
554 :
555 : static inline Tensor dispatch_index_put_(
556 : Tensor& self,
557 : std::vector<Tensor>&& indices,
558 : const Tensor& value) {
559 : return self.index_put_(
560 : impl::typeConvertIndices(self, std::move(indices)), value);
561 : }
562 :
563 : // NOTE [ Setting `disable_slice_optimization` when calling C++ tensor indexing
564 : // functions from Python ]
565 : //
566 : // Question: When should we set `disable_slice_optimization` to `true` when
567 : // calling C++ tensor indexing functions from Python indexing code?
568 : //
569 : // Answer: What "slice optimization" means: when we have a slicing expression
570 : // like `x[0:5, 0]`, where the sliced tensor was of size 5 in dimension 0, we
571 : // would skip dispatching the actual slice call as an optimization. However,
572 : // here are the cases where we DON'T want this optimization:
573 : //
574 : // 1. When we are doing 1-D slicing (e.g. `tensor[:]`).
575 : // Reason: we always return a shallow copy for expressions such as
576 : // `tensor[:]` / `tensor[...]` / `tensor[:, :]`. (Note that for `tensor[:,
577 : // :]`, we return an alias of `tensor` by doing the following:
578 : // ```
579 : // Tensor sliced = impl::applySlicing(self, indices, tensorIndices,
580 : // disable_slice_optimization, self_device, self_sizes); if
581 : // (tensorIndices.empty()) {
582 : // if (sliced.is_same(self)) {
583 : // // ensure we return a shallow copy for things like x[...]
584 : // sliced = at::alias(sliced);
585 : // }
586 : // return sliced;
587 : // }
588 : // ```)
589 : // 2. When we are doing JIT tracing.
590 : // Reason: JIT tracing needs the `self.slice(...)` call to properly trace the
591 : // slice operation.
592 :
593 : // This mirrors `THPVariable_getitem` in
594 : // torch/csrc/autograd/python_variable_indexing.cpp See NOTE [ Setting
595 : // `disable_slice_optimization` when calling C++ tensor indexing functions from
596 : // Python ]
597 : static inline Tensor get_item(
598 : const Tensor& self,
599 : const ArrayRef<TensorIndex>& indices,
600 : bool disable_slice_optimization = false) {
601 : at::Device self_device = self.device();
602 : // NOTE [nested tensor size for indexing]
603 : // nested tensor does not have a size (yet) so for now we represent its size
604 : // as null may need to be changed after we reach a better solution for nested
605 : // tensor size
606 : c10::optional<SymIntArrayRef> self_sizes = self.is_nested()
607 : ? c10::optional<SymIntArrayRef>(c10::nullopt)
608 : : c10::optional<SymIntArrayRef>(self.sym_sizes());
609 :
610 : // handle simple types: integers, slices, none, ellipsis, bool
611 : if (indices.size() == 1) {
612 : const TensorIndex& index = indices[0];
613 : if (index.is_integer()) {
614 : return impl::applySelect(
615 : self, 0, index.integer(), 0, self_device, self_sizes);
616 : } else if (index.is_slice()) {
617 : return impl::applySlice(
618 : self,
619 : 0,
620 : index.slice().start(),
621 : index.slice().stop(),
622 : index.slice().step(),
623 : /*disable_slice_optimization=*/true,
624 : self_device,
625 : self_sizes);
626 : } else if (index.is_none()) {
627 : return self.unsqueeze(0);
628 : } else if (index.is_ellipsis()) {
629 : return at::alias(self);
630 : } else if (index.is_boolean()) {
631 : Tensor result = self.unsqueeze(0);
632 : return dispatch_index(
633 : result,
634 : std::vector<Tensor>{impl::boolToIndexingTensor(
635 : result, index.boolean(), self_device)});
636 : }
637 : }
638 :
639 : std::vector<Tensor> tensorIndices;
640 : Tensor sliced = impl::applySlicing(
641 : self,
642 : indices,
643 : tensorIndices,
644 : disable_slice_optimization,
645 : self_device,
646 : self_sizes);
647 : if (tensorIndices.empty()) {
648 : if (sliced.is_same(self)) {
649 : // ensure we return a shallow copy for things like x[...]
650 : sliced = at::alias(sliced);
651 : }
652 : return sliced;
653 : }
654 :
655 : // indexing by tensors ("advanced" indexing)
656 : return dispatch_index(sliced, std::move(tensorIndices));
657 : }
658 :
659 : // This mirrors `THPVariable_setitem` in
660 : // torch/csrc/autograd/python_variable_indexing.cpp for "the assigned value is a
661 : // Tensor" case See NOTE [ Setting `disable_slice_optimization` when calling C++
662 : // tensor indexing functions from Python ]
663 : static inline void set_item(
664 : const Tensor& self,
665 : const ArrayRef<TensorIndex>& indices,
666 : const Tensor& value,
667 : bool disable_slice_optimization = false) {
668 : at::Device self_device = self.device();
669 : SymIntArrayRef self_sizes = self.sym_sizes();
670 :
671 : // handle simple types: integers, slices, ellipsis, bool
672 : if (indices.size() == 1) {
673 : const TensorIndex& index = indices[0];
674 : if (index.is_boolean() && !index.boolean()) {
675 : // do nothing for false (technically we should check the size, but we
676 : // don't have real 0-sized shapes.
677 : return;
678 : } else if (index.is_ellipsis()) {
679 : copy_to(self, value);
680 : return;
681 : } else if (index.is_none() || (index.is_boolean() && index.boolean())) {
682 : copy_to(self.unsqueeze(0), value);
683 : return;
684 : } else if (index.is_integer()) {
685 : copy_to(
686 : impl::applySelect(
687 : self, 0, index.integer(), 0, self_device, self_sizes),
688 : value);
689 : return;
690 : } else if (index.is_slice()) {
691 : copy_to(
692 : impl::applySlice(
693 : self,
694 : 0,
695 : index.slice().start(),
696 : index.slice().stop(),
697 : index.slice().step(),
698 : /*disable_slice_optimization=*/disable_slice_optimization,
699 : self_device,
700 : self_sizes),
701 : value);
702 : return;
703 : }
704 : }
705 :
706 : std::vector<Tensor> tensorIndices;
707 : Tensor sliced = impl::applySlicing(
708 : self,
709 : indices,
710 : tensorIndices,
711 : disable_slice_optimization,
712 : self_device,
713 : self_sizes);
714 : if (tensorIndices.empty()) {
715 : copy_to(sliced, value);
716 : return;
717 : }
718 :
719 : SymIntArrayRef valueSizes = value.sym_sizes();
720 : SymIntArrayRef slicedValueSizes = slicePrefix1sSize(valueSizes);
721 : Tensor valuesSliced;
722 : if (!valueSizes.equals(slicedValueSizes)) {
723 : valuesSliced = value.view_symint(slicedValueSizes);
724 : } else {
725 : valuesSliced = value;
726 : }
727 : dispatch_index_put_(sliced, std::move(tensorIndices), valuesSliced);
728 : return;
729 : }
730 :
731 : } // namespace indexing
732 : } // namespace at
|