82 KiB
82 KiB
<html lang="en">
<head>
</head>
</html>
LCOV - code coverage report | ||||||||||||||||||||||
![]() | ||||||||||||||||||||||
|
||||||||||||||||||||||
![]() |
Line data Source code 1 : #pragma once 2 : 3 : #include <ATen/ExpandUtils.h> 4 : #include <ATen/ScalarOps.h> 5 : #include <ATen/core/Tensor.h> 6 : #include <ATen/core/TensorBody.h> 7 : #include <c10/core/SymInt.h> 8 : #include <c10/util/Optional.h> 9 : #include <c10/util/irange.h> 10 : 11 : #ifndef AT_PER_OPERATOR_HEADERS 12 : #include <ATen/Functions.h> 13 : #include <ATen/NativeFunctions.h> 14 : #else 15 : #include <ATen/ops/alias.h> 16 : #include <ATen/ops/empty.h> 17 : #include <ATen/ops/scalar_tensor.h> 18 : #include <ATen/ops/zeros.h> 19 : #endif 20 : 21 : #include <ATen/core/List.h> 22 : 23 : #include <utility> 24 : 25 : namespace at { 26 : namespace indexing { 27 : 28 : const int64_t INDEX_MIN = c10::SymInt::min_representable_int(); 29 : const int64_t INDEX_MAX = -(INDEX_MIN + 1); 30 : 31 : enum class TensorIndexType { None, Ellipsis, SymInt, Boolean, Slice, Tensor }; 32 : 33 : constexpr c10::nullopt_t None = c10::nullopt; 34 : 35 : struct TORCH_API EllipsisIndexType final { 36 : EllipsisIndexType() = default; 37 : }; 38 : TORCH_API extern const EllipsisIndexType Ellipsis; 39 : 40 : struct TORCH_API Slice final { 41 : public: 42 36975204 : Slice( 43 : c10::optional<c10::SymInt> start_index = c10::nullopt, 44 : c10::optional<c10::SymInt> stop_index = c10::nullopt, 45 36975204 : c10::optional<c10::SymInt> step_index = c10::nullopt) { 46 36975204 : if (!step_index.has_value()) { 47 36975204 : step_ = c10::SymInt(1); 48 : } else { 49 0 : step_ = std::move(step_index).value(); 50 : } 51 : 52 36975204 : TORCH_CHECK_VALUE(step_ != 0, "slice step cannot be zero"); 53 : 54 36975204 : if (!start_index.has_value()) { 55 36974694 : start_ = c10::SymInt(step_ < 0 ? INDEX_MAX : 0); 56 : } else { 57 510 : start_ = std::move(start_index).value(); 58 : } 59 : 60 36975204 : if (!stop_index.has_value()) { 61 36974694 : stop_ = c10::SymInt(step_ < 0 ? INDEX_MIN : INDEX_MAX); 62 : } else { 63 510 : stop_ = std::move(stop_index).value(); 64 : } 65 36975204 : } 66 : 67 : inline c10::SymInt start() const { 68 : return start_; 69 : } 70 : 71 : inline c10::SymInt stop() const { 72 : return stop_; 73 : } 74 : 75 : inline c10::SymInt step() const { 76 : return step_; 77 : } 78 : 79 : private: 80 : c10::SymInt start_; 81 : c10::SymInt stop_; 82 : c10::SymInt step_; 83 : }; 84 : 85 : TORCH_API std::ostream& operator<<(std::ostream& stream, const Slice& slice); 86 : 87 : // `at::indexing::TensorIndex` is used for converting C++ tensor indices such as 88 : // `{None, "...", Ellipsis, 0, true, Slice(1, None, 2), torch::tensor({1, 2})}` 89 : // into its equivalent `std::vector<TensorIndex>`, so that further tensor 90 : // indexing operations can be performed using the supplied indices. 91 : // 92 : // There is one-to-one correspondence between Python and C++ tensor index types: 93 : // Python | C++ 94 : // ----------------------------------------------------- 95 : // `None` | `at::indexing::None` 96 : // `Ellipsis` | `at::indexing::Ellipsis` 97 : // `...` | `"..."` 98 : // `123` | `123` 99 : // `True` / `False` | `true` / `false` 100 : // `:` | `Slice()` / `Slice(None, None)` 101 : // `::` | `Slice()` / `Slice(None, None, None)` 102 : // `1:` | `Slice(1, None)` 103 : // `1::` | `Slice(1, None, None)` 104 : // `:3` | `Slice(None, 3)` 105 : // `:3:` | `Slice(None, 3, None)` 106 : // `::2` | `Slice(None, None, 2)` 107 : // `1:3` | `Slice(1, 3)` 108 : // `1::2` | `Slice(1, None, 2)` 109 : // `:3:2` | `Slice(None, 3, 2)` 110 : // `1:3:2` | `Slice(1, 3, 2)` 111 : // `torch.tensor([1, 2])`) | `torch::tensor({1, 2})` 112 : struct TORCH_API TensorIndex final { 113 : // Case 1: `at::indexing::None` 114 : TensorIndex(c10::nullopt_t) : type_(TensorIndexType::None) {} 115 : 116 : // Case 2: "..." / `at::indexing::Ellipsis` 117 776042 : TensorIndex(at::indexing::EllipsisIndexType) 118 776042 : : type_(TensorIndexType::Ellipsis) {} 119 776042 : TensorIndex(const char* str) : TensorIndex(at::indexing::Ellipsis) { 120 776042 : TORCH_CHECK_VALUE( 121 : strcmp(str, "...") == 0, 122 : "Expected \"...\" to represent an ellipsis index, but got \"", 123 : str, 124 : "\""); 125 776042 : } 126 : 127 : // Case 3: (Sym) Integer value 128 36171516 : TensorIndex(SymInt integer) 129 36171516 : : integer_(std::move(integer)), type_(TensorIndexType::SymInt) {} 130 : TensorIndex(int64_t integer) : TensorIndex(SymInt(integer)) {} 131 36171516 : TensorIndex(int integer) : TensorIndex(SymInt(integer)) {} 132 : 133 : // Case 4: Boolean value 134 : template < 135 : class T, 136 : class = typename std::enable_if<std::is_same<bool, T>::value>::type> 137 : TensorIndex(T boolean) : boolean_(boolean), type_(TensorIndexType::Boolean) {} 138 : 139 : // Case 5: Slice represented in `at::indexing::Slice` form 140 510 : TensorIndex(Slice slice) 141 510 : : slice_(std::move(slice)), type_(TensorIndexType::Slice) {} 142 : 143 : // Case 6: Tensor value 144 27136 : TensorIndex(Tensor tensor) 145 27136 : : tensor_(std::move(tensor)), type_(TensorIndexType::Tensor) {} 146 : 147 : inline bool is_none() const { 148 : return type_ == TensorIndexType::None; 149 : } 150 : 151 : inline bool is_ellipsis() const { 152 : return type_ == TensorIndexType::Ellipsis; 153 : } 154 : 155 : inline bool is_integer() const { 156 : return type_ == TensorIndexType::SymInt; 157 : } 158 : 159 : inline SymInt integer() const { 160 : return integer_; 161 : } 162 : 163 : inline bool is_boolean() const { 164 : return type_ == TensorIndexType::Boolean; 165 : } 166 : 167 : inline bool boolean() const { 168 : return boolean_; 169 : } 170 : 171 : inline bool is_slice() const { 172 : return type_ == TensorIndexType::Slice; 173 : } 174 : 175 : inline const Slice& slice() const { 176 : return slice_; 177 : } 178 : 179 : inline bool is_tensor() const { 180 : return type_ == TensorIndexType::Tensor; 181 : } 182 : 183 : inline const Tensor& tensor() const { 184 : return tensor_; 185 : } 186 : 187 : private: 188 : SymInt integer_ = 0; 189 : bool boolean_ = false; 190 : Slice slice_; 191 : Tensor tensor_; 192 : TensorIndexType type_; 193 : }; 194 : 195 : TORCH_API std::ostream& operator<<( 196 : std::ostream& stream, 197 : const TensorIndex& tensor_index); 198 : TORCH_API std::ostream& operator<<( 199 : std::ostream& stream, 200 : const std::vector<TensorIndex>& tensor_indices); 201 : 202 : namespace impl { 203 : static inline Tensor applySlice( 204 : const Tensor& self, 205 : int64_t dim, 206 : c10::SymInt start, 207 : c10::SymInt stop, 208 : c10::SymInt step, 209 : bool disable_slice_optimization, 210 : const at::Device& self_device, 211 : const c10::optional<SymIntArrayRef>& self_sizes) { 212 : // TODO: implement negative step 213 : TORCH_CHECK_VALUE(step > 0, "step must be greater than zero"); 214 : 215 : // See NOTE [nested tensor size for indexing] 216 : if (self_sizes.has_value()) { 217 : // Skip this optimization if we are tracing, as the trace may be polymorphic 218 : // over the shape of the `self` tensor, and we still want to record 219 : // the slice. 220 : SymInt length = (self_device == at::kCPU || self_device == at::kCUDA) 221 : ? (*self_sizes)[dim] 222 : : self.sym_size(dim); 223 : if (!disable_slice_optimization && start == 0 && length == stop && 224 : step == 1) { 225 : return self; 226 : } 227 : } 228 : return self.slice_symint(dim, start, stop, std::move(step)); 229 : } 230 : 231 : static inline Tensor applySelect( 232 : const Tensor& self, 233 : int64_t dim, 234 : SymInt index, 235 : int64_t real_dim, 236 : const at::Device& /*self_device*/, 237 : const c10::optional<SymIntArrayRef>& self_sizes) { 238 : // See NOTE [nested tensor size for indexing] 239 : if (self_sizes.has_value()) { 240 : auto maybe_index = index.maybe_as_int(); 241 : if (maybe_index.has_value()) { 242 : TORCH_CHECK_INDEX( 243 : !(maybe_index.value() == 0 && dim == 0 && self_sizes->empty()), 244 : "invalid index of a 0-dim tensor. ", 245 : "Use `tensor.item()` in Python or `tensor.item<T>()` in C++ to convert a 0-dim tensor to a number"); 246 : } 247 : 248 : auto size = (*self_sizes)[dim]; 249 : TORCH_CHECK_INDEX( 250 : size >= -index && size > index, 251 : "index ", 252 : index, 253 : " is out of bounds for dimension ", 254 : real_dim, 255 : " with size ", 256 : size); 257 : } 258 : 259 : // if the index is negative, do not normalize it because that would fix the 260 : // index on the current tensor size in the tracer. aten::select also works on 261 : // negative indices 262 : return self.select_symint(dim, index); 263 : } 264 : 265 : static inline Tensor boolToIndexingTensorCPUOrCUDA( 266 : const Tensor& self, 267 : bool value) { 268 : // booleans add a dimension of size 1. true indexes this dimension as if 0:, 269 : // false as empty. 270 : if (value) { 271 : return at::empty({1}, {}, self.options().dtype(kLong)).fill_(0.); 272 : } else { 273 : return at::empty({0}, {}, self.options().dtype(kLong)); 274 : } 275 : } 276 : 277 : static inline Tensor boolToIndexingTensorNonNativeDeviceType( 278 : const Tensor& self, 279 : bool value) { 280 : // booleans add a dimension of size 1. true indexes this dimension as if 0:, 281 : // false as empty. 282 : if (value) { 283 : return at::zeros({1}, {}, self.options().dtype(kLong)); 284 : } else { 285 : return at::empty({0}, {}, self.options().dtype(kLong)); 286 : } 287 : } 288 : 289 : static inline Tensor boolToIndexingTensor( 290 : const Tensor& self, 291 : bool value, 292 : const at::Device& self_device) { 293 : if (self_device == at::kCPU || self_device == at::kCUDA) { 294 : return boolToIndexingTensorCPUOrCUDA(self, value); 295 : } else { 296 : return boolToIndexingTensorNonNativeDeviceType(self, value); 297 : } 298 : } 299 : 300 : static inline Tensor scalarToTensorNonNativeDeviceType( 301 : const Scalar& v, 302 : const TensorOptions& options) { 303 : return at::scalar_tensor(v, options); 304 : } 305 : 306 : static inline void recordTensorIndex( 307 : const Tensor& tensor, 308 : std::vector<Tensor>& outIndices, 309 : int64_t* dim_ptr) { 310 : // TODO: check scalarType 311 : outIndices.resize(*dim_ptr + 1); 312 : outIndices[*dim_ptr] = tensor; 313 : (*dim_ptr)++; 314 : }; 315 : 316 : static inline c10::List<c10::optional<Tensor>> typeConvertIndices( 317 : const Tensor& /*self*/, 318 : std::vector<Tensor>&& indices) { 319 : c10::List<c10::optional<Tensor>> converted_inds; 320 : converted_inds.reserve(indices.size()); 321 : for (const auto& i : indices) { 322 : converted_inds.push_back(std::move(i)); 323 : } 324 : return converted_inds; 325 : } 326 : 327 : // NOTE: Why do we mirror instead of replace the `count_specified_dimensions` 328 : // function in torch/csrc/autograd/python_variable_indexing.cpp? It's because 329 : // `count_specified_dimensions` is on the hot path of Python tensor multi-dim 330 : // indexing (i.e. it's called by `applySlicing` which is called by 331 : // `THPVariable_getitem` / `THPVariable_setitem` when handling indexing of more 332 : // than one dimension). If we were to merge the Python/C++ 333 : // `count_specified_dimensions` function, on the Python side we would have to 334 : // construct a `std::vector` container to be consumed by the C++ 335 : // `count_specified_dimensions` function, which adds 100s of nanoseconds 336 : // overhead and is undesirable. 337 : static inline int64_t count_specified_dimensions( 338 : const ArrayRef<TensorIndex>& indices) { 339 : // Count the number of indexed dimensions (everything but ellipsis and None) 340 : int64_t count = 0; 341 : for (auto& obj : indices) { 342 : if (obj.is_tensor()) { 343 : auto& tensor = obj.tensor(); 344 : if (tensor.scalar_type() == kByte || tensor.scalar_type() == kBool) { 345 : count += tensor.dim(); 346 : } else { 347 : count++; 348 : } 349 : } else if (!obj.is_none() && !obj.is_ellipsis() && !obj.is_boolean()) { 350 : count++; 351 : } 352 : } 353 : return count; 354 : } 355 : } // namespace impl 356 : 357 : // NOTE: Many functions below are only for consumption from Python indexing 358 : // implementation, they include: 359 : // 360 : // - `Tensor scalarToTensor(...)` 361 : // - `IntArrayRef slicePrefix1sSize(...)` 362 : // - `void copy_to(...)` 363 : // - `Tensor handleDimInMultiDimIndexing(...)` 364 : // - `Tensor dispatch_index(...)` 365 : // - `Tensor dispatch_index_put_(...)` 366 : // - `Tensor get_item(...)` 367 : // - `void set_item(...)` 368 : // 369 : // The rest of the functions are in `at::indexing::impl` namespace, signifying 370 : // that they shouldn't be used from Python indexing implementation. 371 : static inline Tensor scalarToTensor( 372 : const Scalar& v, 373 : const TensorOptions& options, 374 : const at::Device& self_device) { 375 : if (self_device == at::kCPU) { 376 : return at::detail::scalar_tensor_static( 377 : v, options.dtype_opt()->toScalarType(), self_device); 378 : } else { 379 : return impl::scalarToTensorNonNativeDeviceType(v, options); 380 : } 381 : } 382 : 383 : // To match numpy semantics: 384 : // As a special case for backwards compatibility, 385 : // strip away unit dimensions from the left of 'src' 386 : static inline SymIntArrayRef slicePrefix1sSize(const SymIntArrayRef& sizes) { 387 : size_t first_non1_src = sizes.size(); 388 : for (const auto i : c10::irange(sizes.size())) { 389 : // Unbacked SymInt has different behavior, but this is sound because 390 : // failing to slice will only ever cause an error, not divergent 391 : // behavior 392 : if (!sizes[i].has_hint() || sizes[i] != 1) { 393 : first_non1_src = i; 394 : break; 395 : } 396 : } 397 : 398 : return sizes.slice(first_non1_src); 399 : } 400 : 401 : static inline void copy_to(const Tensor& dst, const Tensor& src) { 402 : if (dst.sym_sizes().equals(src.sym_sizes())) { 403 : // A shortcut to avoid generating hard-coded constant sizes during tracing. 404 : // This is not a perfect solution: when src & dst have different shapes, 405 : // constants will still appear. Users can workaround that case by 406 : // dst[index..] = src.reshape(..) 407 : dst.copy_(src); 408 : return; 409 : } else if (src.dim() == 0 && src.device().type() == at::kCPU) { 410 : dst.fill_(src); 411 : return; 412 : } 413 : auto src_view = src.view_symint(slicePrefix1sSize(src.sym_sizes())); 414 : c10::MaybeOwned<Tensor> b_src = expand_inplace(dst, src_view, "setitem"); 415 : dst.copy_(*b_src); 416 : } 417 : 418 : // See NOTE [ Setting `disable_slice_optimization` when calling C++ tensor 419 : // indexing functions from Python ] 420 : static inline Tensor handleDimInMultiDimIndexing( 421 : const Tensor& prev_dim_result, 422 : const Tensor& original_tensor, 423 : const TensorIndex& index, 424 : int64_t* dim_ptr, 425 : int64_t* specified_dims_ptr, 426 : int64_t real_dim, 427 : std::vector<Tensor>& outIndices, 428 : bool disable_slice_optimization, 429 : const at::Device& original_tensor_device, 430 : const c10::optional<SymIntArrayRef>& prev_dim_result_sizes) { 431 : if (index.is_integer()) { 432 : return impl::applySelect( 433 : prev_dim_result, 434 : *dim_ptr, 435 : index.integer(), 436 : real_dim, 437 : original_tensor_device, 438 : prev_dim_result_sizes); 439 : } else if (index.is_slice()) { 440 : Tensor result = impl::applySlice( 441 : prev_dim_result, 442 : *dim_ptr, 443 : index.slice().start(), 444 : index.slice().stop(), 445 : index.slice().step(), 446 : /*disable_slice_optimization=*/disable_slice_optimization, 447 : original_tensor_device, 448 : prev_dim_result_sizes); 449 : (*dim_ptr)++; 450 : return result; 451 : } else if (index.is_ellipsis()) { 452 : (*dim_ptr) += original_tensor.dim() - (*specified_dims_ptr); 453 : return prev_dim_result; 454 : } else if (index.is_none()) { 455 : Tensor result = prev_dim_result.unsqueeze(*dim_ptr); 456 : (*dim_ptr)++; 457 : return result; 458 : } else if (index.is_boolean()) { 459 : Tensor result = prev_dim_result.unsqueeze(*dim_ptr); 460 : impl::recordTensorIndex( 461 : impl::boolToIndexingTensor( 462 : result, index.boolean(), original_tensor_device), 463 : outIndices, 464 : dim_ptr); 465 : return result; 466 : } else if (index.is_tensor()) { 467 : Tensor result = prev_dim_result; 468 : const Tensor& tensor = index.tensor(); 469 : auto scalar_type = tensor.scalar_type(); 470 : if (tensor.dim() == 0 && 471 : at::isIntegralType(scalar_type, /*includeBool=*/true)) { 472 : if (scalar_type != at::kByte && scalar_type != at::kBool) { 473 : result = impl::applySelect( 474 : result, 475 : *dim_ptr, 476 : tensor.item<int64_t>(), 477 : real_dim, 478 : original_tensor_device, 479 : prev_dim_result_sizes); 480 : } else { 481 : result = result.unsqueeze(*dim_ptr); 482 : if (scalar_type == at::kBool) { 483 : impl::recordTensorIndex( 484 : impl::boolToIndexingTensor( 485 : result, tensor.item<bool>() != 0, original_tensor_device), 486 : outIndices, 487 : dim_ptr); 488 : } else { 489 : impl::recordTensorIndex( 490 : impl::boolToIndexingTensor( 491 : result, tensor.item<uint8_t>() != 0, original_tensor_device), 492 : outIndices, 493 : dim_ptr); 494 : } 495 : } 496 : } else { 497 : impl::recordTensorIndex(tensor, outIndices, dim_ptr); 498 : } 499 : return result; 500 : } else { 501 : TORCH_INTERNAL_ASSERT(false, "Invalid TensorIndex type"); 502 : } 503 : } 504 : 505 : namespace impl { 506 : // This mirrors `applySlicing` in 507 : // torch/csrc/autograd/python_variable_indexing.cpp 508 : static inline Tensor applySlicing( 509 : const Tensor& self, 510 : const ArrayRef<TensorIndex>& indices, 511 : std::vector<Tensor>& outIndices, 512 : bool disable_slice_optimization, 513 : const at::Device& self_device, 514 : const c10::optional<SymIntArrayRef>& self_sizes) { 515 : int64_t dim = 0; 516 : int64_t specified_dims = impl::count_specified_dimensions(indices); 517 : 518 : // See NOTE [nested tensor size for indexing] 519 : if (self_sizes.has_value()) { 520 : TORCH_CHECK_INDEX( 521 : specified_dims <= (int64_t)self_sizes->size(), 522 : "too many indices for tensor of dimension ", 523 : (int)self_sizes->size()); 524 : } 525 : 526 : Tensor result = self; 527 : for (const auto i : c10::irange(indices.size())) { 528 : auto& obj = indices[i]; 529 : // See NOTE [nested tensor size for indexing] 530 : c10::optional<SymIntArrayRef> result_sizes = result.is_nested() 531 : ? c10::optional<SymIntArrayRef>(c10::nullopt) 532 : : c10::optional<SymIntArrayRef>(result.sym_sizes()); 533 : result = handleDimInMultiDimIndexing( 534 : /*prev_dim_result=*/result, 535 : /*original_tensor=*/self, 536 : /*index=*/obj, 537 : /*dim=*/&dim, 538 : /*specified_dims=*/&specified_dims, 539 : /*real_dim=*/i, 540 : /*outIndices=*/outIndices, 541 : /*disable_slice_optimization=*/disable_slice_optimization, 542 : /*original_tensor_device=*/self_device, 543 : /*prev_dim_result_sizes=*/result_sizes); 544 : } 545 : return result; 546 : } 547 : } // namespace impl 548 : 549 : static inline Tensor dispatch_index( 550 : const Tensor& self, 551 : std::vector<Tensor>&& indices) { 552 : return self.index(impl::typeConvertIndices(self, std::move(indices))); 553 : } 554 : 555 : static inline Tensor dispatch_index_put_( 556 : Tensor& self, 557 : std::vector<Tensor>&& indices, 558 : const Tensor& value) { 559 : return self.index_put_( 560 : impl::typeConvertIndices(self, std::move(indices)), value); 561 : } 562 : 563 : // NOTE [ Setting `disable_slice_optimization` when calling C++ tensor indexing 564 : // functions from Python ] 565 : // 566 : // Question: When should we set `disable_slice_optimization` to `true` when 567 : // calling C++ tensor indexing functions from Python indexing code? 568 : // 569 : // Answer: What "slice optimization" means: when we have a slicing expression 570 : // like `x[0:5, 0]`, where the sliced tensor was of size 5 in dimension 0, we 571 : // would skip dispatching the actual slice call as an optimization. However, 572 : // here are the cases where we DON'T want this optimization: 573 : // 574 : // 1. When we are doing 1-D slicing (e.g. `tensor[:]`). 575 : // Reason: we always return a shallow copy for expressions such as 576 : // `tensor[:]` / `tensor[...]` / `tensor[:, :]`. (Note that for `tensor[:, 577 : // :]`, we return an alias of `tensor` by doing the following: 578 : // ``` 579 : // Tensor sliced = impl::applySlicing(self, indices, tensorIndices, 580 : // disable_slice_optimization, self_device, self_sizes); if 581 : // (tensorIndices.empty()) { 582 : // if (sliced.is_same(self)) { 583 : // // ensure we return a shallow copy for things like x[...] 584 : // sliced = at::alias(sliced); 585 : // } 586 : // return sliced; 587 : // } 588 : // ```) 589 : // 2. When we are doing JIT tracing. 590 : // Reason: JIT tracing needs the `self.slice(...)` call to properly trace the 591 : // slice operation. 592 : 593 : // This mirrors `THPVariable_getitem` in 594 : // torch/csrc/autograd/python_variable_indexing.cpp See NOTE [ Setting 595 : // `disable_slice_optimization` when calling C++ tensor indexing functions from 596 : // Python ] 597 : static inline Tensor get_item( 598 : const Tensor& self, 599 : const ArrayRef<TensorIndex>& indices, 600 : bool disable_slice_optimization = false) { 601 : at::Device self_device = self.device(); 602 : // NOTE [nested tensor size for indexing] 603 : // nested tensor does not have a size (yet) so for now we represent its size 604 : // as null may need to be changed after we reach a better solution for nested 605 : // tensor size 606 : c10::optional<SymIntArrayRef> self_sizes = self.is_nested() 607 : ? c10::optional<SymIntArrayRef>(c10::nullopt) 608 : : c10::optional<SymIntArrayRef>(self.sym_sizes()); 609 : 610 : // handle simple types: integers, slices, none, ellipsis, bool 611 : if (indices.size() == 1) { 612 : const TensorIndex& index = indices[0]; 613 : if (index.is_integer()) { 614 : return impl::applySelect( 615 : self, 0, index.integer(), 0, self_device, self_sizes); 616 : } else if (index.is_slice()) { 617 : return impl::applySlice( 618 : self, 619 : 0, 620 : index.slice().start(), 621 : index.slice().stop(), 622 : index.slice().step(), 623 : /*disable_slice_optimization=*/true, 624 : self_device, 625 : self_sizes); 626 : } else if (index.is_none()) { 627 : return self.unsqueeze(0); 628 : } else if (index.is_ellipsis()) { 629 : return at::alias(self); 630 : } else if (index.is_boolean()) { 631 : Tensor result = self.unsqueeze(0); 632 : return dispatch_index( 633 : result, 634 : std::vector<Tensor>{impl::boolToIndexingTensor( 635 : result, index.boolean(), self_device)}); 636 : } 637 : } 638 : 639 : std::vector<Tensor> tensorIndices; 640 : Tensor sliced = impl::applySlicing( 641 : self, 642 : indices, 643 : tensorIndices, 644 : disable_slice_optimization, 645 : self_device, 646 : self_sizes); 647 : if (tensorIndices.empty()) { 648 : if (sliced.is_same(self)) { 649 : // ensure we return a shallow copy for things like x[...] 650 : sliced = at::alias(sliced); 651 : } 652 : return sliced; 653 : } 654 : 655 : // indexing by tensors ("advanced" indexing) 656 : return dispatch_index(sliced, std::move(tensorIndices)); 657 : } 658 : 659 : // This mirrors `THPVariable_setitem` in 660 : // torch/csrc/autograd/python_variable_indexing.cpp for "the assigned value is a 661 : // Tensor" case See NOTE [ Setting `disable_slice_optimization` when calling C++ 662 : // tensor indexing functions from Python ] 663 : static inline void set_item( 664 : const Tensor& self, 665 : const ArrayRef<TensorIndex>& indices, 666 : const Tensor& value, 667 : bool disable_slice_optimization = false) { 668 : at::Device self_device = self.device(); 669 : SymIntArrayRef self_sizes = self.sym_sizes(); 670 : 671 : // handle simple types: integers, slices, ellipsis, bool 672 : if (indices.size() == 1) { 673 : const TensorIndex& index = indices[0]; 674 : if (index.is_boolean() && !index.boolean()) { 675 : // do nothing for false (technically we should check the size, but we 676 : // don't have real 0-sized shapes. 677 : return; 678 : } else if (index.is_ellipsis()) { 679 : copy_to(self, value); 680 : return; 681 : } else if (index.is_none() || (index.is_boolean() && index.boolean())) { 682 : copy_to(self.unsqueeze(0), value); 683 : return; 684 : } else if (index.is_integer()) { 685 : copy_to( 686 : impl::applySelect( 687 : self, 0, index.integer(), 0, self_device, self_sizes), 688 : value); 689 : return; 690 : } else if (index.is_slice()) { 691 : copy_to( 692 : impl::applySlice( 693 : self, 694 : 0, 695 : index.slice().start(), 696 : index.slice().stop(), 697 : index.slice().step(), 698 : /*disable_slice_optimization=*/disable_slice_optimization, 699 : self_device, 700 : self_sizes), 701 : value); 702 : return; 703 : } 704 : } 705 : 706 : std::vector<Tensor> tensorIndices; 707 : Tensor sliced = impl::applySlicing( 708 : self, 709 : indices, 710 : tensorIndices, 711 : disable_slice_optimization, 712 : self_device, 713 : self_sizes); 714 : if (tensorIndices.empty()) { 715 : copy_to(sliced, value); 716 : return; 717 : } 718 : 719 : SymIntArrayRef valueSizes = value.sym_sizes(); 720 : SymIntArrayRef slicedValueSizes = slicePrefix1sSize(valueSizes); 721 : Tensor valuesSliced; 722 : if (!valueSizes.equals(slicedValueSizes)) { 723 : valuesSliced = value.view_symint(slicedValueSizes); 724 : } else { 725 : valuesSliced = value; 726 : } 727 : dispatch_index_put_(sliced, std::move(tensorIndices), valuesSliced); 728 : return; 729 : } 730 : 731 : } // namespace indexing 732 : } // namespace at |
![]() |
Generated by: LCOV version 2.0-1 |
</html>