104 KiB
104 KiB
<html lang="en">
<head>
</head>
</html>
LCOV - code coverage report | ||||||||||||||||||||||
![]() | ||||||||||||||||||||||
|
||||||||||||||||||||||
![]() |
Line data Source code 1 : #pragma once 2 : 3 : #include <torch/csrc/utils/python_stub.h> 4 : 5 : #include <torch/csrc/Export.h> 6 : #include <torch/csrc/autograd/cpp_hook.h> 7 : #include <torch/csrc/autograd/edge.h> 8 : #include <torch/csrc/autograd/forward_grad.h> 9 : #include <torch/csrc/autograd/function_hook.h> 10 : 11 : #include <ATen/NamedTensorUtils.h> 12 : #include <ATen/core/Tensor.h> 13 : #include <ATen/core/VariableHooksInterface.h> 14 : #include <c10/util/Exception.h> 15 : 16 : #include <cstdint> 17 : #include <memory> 18 : #include <mutex> 19 : #include <stdexcept> 20 : #include <string> 21 : #include <utility> 22 : #include <vector> 23 : 24 : namespace torch { 25 : namespace autograd { 26 : 27 : /// `Variable` is exactly the same as `Tensor` (i.e. we have `using Variable = 28 : /// at::Tensor`). This means you can perform all the usual mathematical and 29 : /// other operations you can perform on `Tensor`s also on `Variable`s. 30 : /// 31 : /// The only reason we are keeping the `Variable` class is backward 32 : /// compatibility with external user's legacy C++ frontend code. Our intention 33 : /// is to eliminate the `Variable` class in the near future. 34 : using Variable = at::Tensor; 35 : 36 : } // namespace autograd 37 : } // namespace torch 38 : 39 : // The following are all internal APIs and should not be shown in libtorch docs. 40 : // Therefore, we wrap the following code with `#ifndef DOXYGEN_SHOULD_SKIP_THIS 41 : // ... #endif` 42 : 43 : #ifndef DOXYGEN_SHOULD_SKIP_THIS 44 : 45 : namespace torch { 46 : namespace autograd { 47 : 48 : /// Check if this type is supported by the autograd engine. 49 : /// If you change this, update the doc at the top of the 50 : /// torch/autograd/__init__.py file and 51 : /// "test_set_requires_grad_only_for_continuous_types" in test/test_autograd.py 52 0 : static inline bool isDifferentiableType(at::ScalarType t) { 53 0 : return isFloatingType(t) || isComplexType(t); 54 : } 55 : 56 : struct Node; 57 : 58 : ///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 59 : /// Variable 60 : ///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 61 : /// A `Variable` augments a `Tensor` with the ability to interact in our 62 : /// autograd machinery. Conceptually, `Variable`s travel along `Edge`s between 63 : /// `Node`s in the autograd graph. A `Variable` can either be a leaf, like a 64 : /// weight in a neural network, or an interior variable, when it is the result 65 : /// of an operation between variables. Every `Variable` also stores another 66 : /// `Variable` called its `grad` (gradient). If the variable is a leaf, its 67 : /// gradient will be accumulated into this variable. 68 : /// 69 : /// Every Tensor is a Variable, but sometimes we colloquially refer to Variables 70 : /// that don't require gradients as Tensors (since none of the autograd 71 : /// machinery for Variables applies). Historically, Variables and Tensors 72 : /// were separate concepts, but now they are exactly the same (i.e. we have 73 : /// `using Variable = at::Tensor`). 74 : /// 75 : /// Gradient Edges 76 : ///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 77 : /// Furthermore, `Variable`s have the notion of a `gradient_edge`, which is the 78 : /// edge in the autograd graph that connects the variable to a particular input 79 : /// of the gradient function that will be invoked with the variable during the 80 : /// backward pass. More precisely, this gradient function can be one of two 81 : /// things: 82 : /// 1. A `grad_fn`, if the variable is in the interior of the graph. This is the 83 : /// gradient of the function that produced the variable. 84 : /// 2. A `grad_accumulator`, if the variable is a leaf, which accumulates a 85 : /// scalar gradient value into its `grad` variable. 86 : /// 87 : /// Versioning 88 : ///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 89 : /// Another major feature of `Variable`s are *versions*. Versions are 90 : /// incremented when an in-place mutation of a variable occurs. Versions are 91 : /// useful when constructing `SavedVariable`s, which take a snapshot of a 92 : /// `Variable` at a certain version. You can retrieve a `Variable`'s version 93 : /// through its `current_version()` method. 94 : /// 95 : /// Views 96 : ///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 97 : /// It is possible for a `Variable` to be a *view* of another `Variable`, in 98 : /// which case it tracks that `Variable`'s data and autograd history. Beyond 99 : /// construction, the interface of a view is identical to that of a regular 100 : /// `Variable`. You can determine whether `Variable` is in fact a view by 101 : /// probing its `is_view()` method. Note that the *view* semantics are only 102 : /// meaningful for `Variable` relations that are relevant to autograd. 103 : /// See NOTE [ Autograd View Variables ] for more details. 104 : ///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 105 : 106 : struct AutogradMeta; 107 : struct DifferentiableViewMeta; 108 : 109 : // Private-ish functions for manipulating variables; we don't want to put them 110 : // on Tensor proper 111 : namespace impl { 112 : 113 : // WARNING: This may return a nullptr. If you require AutogradMeta to return 114 : // a materialized structure, use materialize_autograd_meta instead. 115 : TORCH_API AutogradMeta* get_autograd_meta(const at::TensorBase&); 116 : 117 : // WARNING: This will return a nullptr if the Tensor is not a view. 118 : TORCH_API DifferentiableViewMeta* get_view_autograd_meta(const at::TensorBase&); 119 : 120 : // Returns the current autograd meta, materializing it if it was previously 121 : // none. This counts as a *mutating* operation, so do not call it on 122 : // "read-only" operators; in particular, this is NOT thread safe 123 : TORCH_API AutogradMeta* materialize_autograd_meta(const at::TensorBase&); 124 : 125 : /// Set the gradient accumulator of the `Variable`. This is only applicable to 126 : /// leaf variables. Interior variables should call `set_gradient_edge()`. 127 : TORCH_API void set_grad_accumulator( 128 : const Variable&, 129 : std::weak_ptr<Node> grad_accumulator); 130 : 131 : /// Attempts to get a pointer to the gradient accumulator of the `Variable`, 132 : /// if it still exists. If the gradient accumulator function has been 133 : /// destroyed, returns a `nullptr`. 134 : TORCH_API std::shared_ptr<Node> try_get_grad_accumulator(const Variable&); 135 : 136 : /// Gets the gradient accumulator of the `Variable` if it has one, or else 137 : /// create one on the fly and return it. 138 : TORCH_API std::shared_ptr<Node> grad_accumulator(const Variable&); 139 : 140 : /// Returns the "canonical" gradient edge of this `Variable`, i.e. either the 141 : /// gradient function if this is an interior `Variable`, or the gradient 142 : /// accumulator otherwise. If the `Variable` is interior, the returned `Edge` 143 : /// will store the input index of the `Node` to which this variable is 144 : /// connected in its `input_nr` field. For leaves, the `input_nr` is always 145 : /// zero. Note that `set_gradient_edge` and `gradient_edge` are not 146 : /// symmetric. You must use `set_gradient_edge` to set the `grad_fn` and 147 : /// `set_grad_accumulator` to set the accumulator. 148 : TORCH_API Edge gradient_edge(const Variable&); 149 : 150 : /// Set the gradient edge -- i.e. `grad_fn` and `input_nr` -- of the 151 : /// `Variable`. 152 : /// NOTE: This will always set the `grad_fn`, even if this is a leaf variable, 153 : /// and never the `grad_accumulator`. For the latter, use 154 : /// `set_grad_accumulator`. This allows late construction of an interior 155 : /// `Variable`. 156 : TORCH_API void set_gradient_edge(const Variable&, Edge edge); 157 : 158 : // Autograd Graph Interaction 159 : //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 160 : 161 : /// Update the `grad_fn` of an existing Variable. Called after in-place 162 : /// modifications. 163 : /// 164 : /// For View Variables: 165 : /// Called after in-place modifications. Modifies the grad_fn of the base 166 : /// Variable. 167 : TORCH_API void rebase_history(const Variable&, Edge gradient_edge); 168 : 169 : /// Gets the raw gradient function pointer, whatever it currently is. 170 : TORCH_API Node* grad_fn_unsafe(const Variable&); 171 : 172 : /// Increments the version count of this `Variable`. 173 : TORCH_API void bump_version(const Variable&); 174 : TORCH_API void set_version_counter( 175 : const Variable&, 176 : const c10::VariableVersion& version_counter); 177 : 178 : /// Retrieves this `Variable`s version counter. 179 : TORCH_API const c10::VariableVersion& version_counter(const Variable&); 180 : 181 : TORCH_API void set_name(const Variable&, const std::string& name); 182 : 183 : TORCH_API void add_hook( 184 : const at::TensorBase&, 185 : std::unique_ptr<FunctionPreHook> hook); 186 : TORCH_API std::vector<std::unique_ptr<FunctionPreHook>>& hooks(const Variable&); 187 : TORCH_API void clear_hooks(const at::TensorBase&); 188 : 189 : TORCH_API void set_post_acc_grad_hooks( 190 : const at::TensorBase&, 191 : std::unique_ptr<PostAccumulateGradHook> dict); 192 : TORCH_API std::unique_ptr<PostAccumulateGradHook>& post_acc_grad_hooks( 193 : const Variable&); 194 : 195 : TORCH_API void create_cpp_hook( 196 : const at::TensorBase&, 197 : bool is_retains_grad_hooks = false); 198 : } // namespace impl 199 : 200 : //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 201 : // AutogradMeta 202 : //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 203 : 204 : /// Each `Variable` has one unique `AutogradMeta` struct, which stores autograd 205 : /// metadata fields that are necessary for tracking the Variable's autograd 206 : /// history. As an optimization, a Variable may store a nullptr, in lieu of a 207 : /// default constructed AutogradMeta. 208 : 209 : struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface { 210 : std::string name_; 211 : 212 : Variable grad_; 213 : std::shared_ptr<Node> grad_fn_; 214 : std::weak_ptr<Node> grad_accumulator_; 215 : 216 : // This field is used to store all the forward AD gradients 217 : // associated with this AutogradMeta (and the Tensor it corresponds to) 218 : // There is a semantic 1:1 correspondence between AutogradMeta and 219 : // ForwardGrad but: 220 : // - This field is lazily populated. 221 : // - This field is a shared_ptr but it must never be 222 : // shared by multiple Tensors. See Note [ Using ForwardGrad ] 223 : // Any transition from not_initialized to initialized 224 : // must be protected by mutex_ 225 : std::shared_ptr<ForwardGrad> fw_grad_; 226 : 227 : // The hooks_ field is actually reused by both python and cpp logic 228 : // For both cases, we have a data structure, cpp_hooks_list_ (cpp) 229 : // or dict (python) which is the canonical copy. 230 : // Then, for both cases, we always register a single hook to 231 : // hooks_ which wraps all the hooks in the list/dict. 232 : // And, again in both cases, if the grad_fn exists on that tensor 233 : // we will additionally register a single hook to the grad_fn. 234 : // 235 : // Note that the cpp and python use cases aren't actually aware of 236 : // each other, so using both is not defined behavior. 237 : std::vector<std::unique_ptr<FunctionPreHook>> hooks_; 238 : std::shared_ptr<hooks_list> cpp_hooks_list_; 239 : 240 : // The post_acc_grad_hooks_ field stores only Python hooks 241 : // (PyFunctionTensorPostAccGradHooks) that are called after the 242 : // .grad field has been accumulated into. This is less complicated 243 : // than the hooks_ field, which encapsulates a lot more. 244 : std::unique_ptr<PostAccumulateGradHook> post_acc_grad_hooks_ = nullptr; 245 : 246 : // Only meaningful on leaf variables (must be false otherwise) 247 : bool requires_grad_{false}; 248 : 249 : // Only meaningful on non-leaf variables (must be false otherwise) 250 : bool retains_grad_{false}; 251 : 252 : bool is_view_{false}; 253 : 254 : // The "output number" of this variable; e.g., if this variable 255 : // was the second output of a function, then output_nr == 1. 256 : // We use this to make sure we can setup the backwards trace 257 : // correctly when this variable is passed to another function. 258 : uint32_t output_nr_; 259 : 260 : // Mutex to ensure that concurrent read operations that modify internal 261 : // state are still thread-safe. Used by grad_fn(), grad_accumulator(), 262 : // fw_grad() and set_fw_grad() 263 : // This is mutable because we need to be able to acquire this from const 264 : // version of this class for the functions above 265 : mutable std::mutex mutex_; 266 : 267 : /// Sets the `requires_grad` property of `Variable`. This should be true for 268 : /// leaf variables that want to accumulate gradients, and false for all other 269 : /// variables. 270 0 : void set_requires_grad(bool requires_grad, at::TensorImpl* self_impl) 271 : override { 272 0 : TORCH_CHECK( 273 : !requires_grad || 274 : isDifferentiableType(at::typeMetaToScalarType(self_impl->dtype())), 275 : "Only Tensors of floating point and complex dtype can require gradients"); 276 0 : requires_grad_ = requires_grad; 277 0 : } 278 : 279 : bool requires_grad() const override { 280 : return requires_grad_ || grad_fn_; 281 : } 282 : 283 : /// Accesses the gradient `Variable` of this `Variable`. 284 : Variable& mutable_grad() override { 285 : return grad_; 286 : } 287 : 288 : const Variable& grad() const override { 289 : return grad_; 290 : } 291 : 292 : const Variable& fw_grad(uint64_t level, const at::TensorBase& self) 293 : const override; 294 : 295 : void set_fw_grad( 296 : const at::TensorBase& new_grad, 297 : const at::TensorBase& self, 298 : uint64_t level, 299 : bool is_inplace_op) override; 300 : 301 0 : AutogradMeta( 302 : at::TensorImpl* self_impl = nullptr, 303 : bool requires_grad = false, 304 : Edge gradient_edge = Edge()) 305 0 : : grad_fn_(std::move(gradient_edge.function)), 306 : 307 0 : output_nr_(gradient_edge.input_nr) { 308 : // set_requires_grad also checks error conditions. 309 0 : if (requires_grad) { 310 0 : TORCH_INTERNAL_ASSERT(self_impl); 311 : // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) 312 0 : set_requires_grad(requires_grad, self_impl); 313 : } 314 0 : TORCH_CHECK( 315 : !grad_fn_ || !requires_grad_, 316 : "requires_grad should be false if grad_fn is set"); 317 0 : } 318 : 319 : ~AutogradMeta() override { 320 : // If AutogradMeta is being destroyed, it means that there is no other 321 : // reference to its corresponding Tensor. It implies that no other thread 322 : // can be using this object and so there is no need to lock mutex_ here to 323 : // guard the check if fw_grad_ is populated. 324 : if (fw_grad_) { 325 : // See note [ Using ForwardGrad ] 326 : fw_grad_->clear(); 327 : } 328 : } 329 : }; 330 : 331 : struct TORCH_API ViewInfo { 332 : /// The base `Variable` 333 : /// If this ViewInfo represents a forward (respectively backward) AD gradient, 334 : /// then this Tensor cannot be a forward (respectively backward) view. 335 : Variable base_; 336 : 337 : /// By default we use as_strided to recover views which is more efficient. 338 : /// view_fn is only saved when as_strided is not supported. 339 : /// If view_fn has value, we use it to recover views in backward. 340 : std::function<Variable(const Variable&)> view_fn_; 341 : 342 : /// Accessors for the view function 343 : bool has_view_fn() const { 344 : return view_fn_ != nullptr; 345 : } 346 : 347 : std::function<Variable(const Variable&)> view_fn() const { 348 : TORCH_CHECK( 349 : has_view_fn(), "Can only access the view function if it exists."); 350 : return view_fn_; 351 : } 352 : 353 : /// The chain function can be used to build a new ViewInfo for a 354 : /// differentiable view function. It will return a new view info that 355 : /// accurately represents how "tensor" is a view of this instance's "base_". 356 : /// The "base" and "tensor" are respectively the input and output of the 357 : /// differentiable view function that happened. They are required to properly 358 : /// set the optional view_fn_ when it is not provided. The "view_func", if 359 : /// provided, should be a function that allows to re-do the view between 360 : /// "base" and "tensor". 361 : ViewInfo chain( 362 : const Variable& base, 363 : const Variable& tensor, 364 : std::function<Variable(const Variable&)> view_func = nullptr) const; 365 : 366 : ViewInfo(Variable base, std::function<Variable(const Variable&)> view_fn) 367 : : base_(std::move(base)), view_fn_(std::move(view_fn)) { 368 : TORCH_CHECK(base_.defined(), "base is undefined"); 369 : } 370 : }; 371 : 372 : //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 373 : // DifferentiableViewMeta 374 : //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 375 : 376 : /// NOTE [ Autograd View Variables ] 377 : /// 378 : /// Many operations return Variable that shares storage with an input Variable. 379 : /// The returned Variable is called a **view** Variable on the input **base** 380 : /// Variable. 381 : /// 382 : /// In PyTorch, we have two types of views: differentiable views, and 383 : /// non-differentiable views. In either type, to support proper version 384 : /// checking, the base and view Variables must always share the same 385 : /// version_counter. 386 : /// 387 : /// 388 : /// Differentiable Views 389 : /// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 390 : /// This class allows to track both forward and backward AD differentiable 391 : /// views. These views can have different base as non-differentiable view for 392 : /// forward and backward mode AD are not the same. 393 : /// 394 : /// Most function are either both forward and backward differentiable views (for 395 : /// example: view, select, narrow, transpose, etc) or both not forward and not 396 : /// backward differentiable views (for example: indices, values, eq, lt, etc). 397 : /// But there are also functions that are forward but not backward 398 : /// differentiable views (only detach for now) or functions that are backward 399 : /// but not forward differentiable view (only make_dual and unpack dual for 400 : /// now). 401 : /// 402 : /// A concrete example of two views with different bases is as follow: 403 : /// 404 : /// # Have: 405 : /// # dual is a dual Tensor that is neither a forward or backward view 406 : /// detached_dual = dual.detach() 407 : /// view = detached_dual.view_as(dual) 408 : /// # The forward base of view is dual 409 : /// # The backward base of view is detached_dual 410 : /// 411 : /// - Backward Mode View 412 : /// Differentiable views are the view variables where you want gradients to flow 413 : /// back to the base variables. Out-of-place operations on views are quite 414 : /// straightforward, but in-place ones are very tricky. Even if the base 415 : /// variable may not require grad when we create the view, we still need to 416 : /// track the view relation because future in-place ops may require back-proping 417 : /// through it. For example, we need to support 418 : /// 419 : /// (1) in-place operation on view, e.g., 420 : /// 421 : /// # Have: 422 : /// # base.requires_grad = False 423 : /// # var.requires_grad = True 424 : /// base[1] = var # i.e., base[1].copy_(var) 425 : /// torch.autograd.grad(base.sum(), var) <- should return an all ones 426 : /// tensor 427 : /// 428 : /// (2) in-place operation on base after view is created, e.g., 429 : /// 430 : /// # Have: 431 : /// # base.requires_grad = False 432 : /// # var.requires_grad = True 433 : /// view = base[1] 434 : /// base.copy_(var) 435 : /// torch.autograd.grad(view.sum(), var) <- should return a tensor with 436 : /// var[1] filled with all ones and 437 : /// zeros everywhere else 438 : /// 439 : /// - Forward Mode View 440 : /// Forward differentiable views follow the same semantic as backward ones but 441 : /// show up differently as they are computed along with the forward evaluation. 442 : /// The hard examples above are thus very similar 443 : /// 444 : /// (1) in-place operation on view, e.g., 445 : /// 446 : /// # Have: 447 : /// # base is a regular Tensor 448 : /// # var is a dual Tensor whose tangent is all ones 449 : /// base[1] = var # i.e., base[1].copy_(var) 450 : /// # Now, base is a dual Tensor 451 : /// _, fw_grad = fwAD.unpack_dual(base) <- fw_grad should be a tensor with 452 : /// fw_grad[1] filled with all ones 453 : /// and zeros everywhere else 454 : /// 455 : /// (2) in-place operation on base after view is created, e.g., 456 : /// 457 : /// # Have: 458 : /// # base is a regular Tensor 459 : /// # var is a dual Tensor whose tangent is all ones 460 : /// view = base[1] 461 : /// base.copy_(var) 462 : /// _, fw_grad = fwAD.unpack_dual(view) <- fw_grad should be an all ones 463 : /// tensor 464 : /// 465 : /// See Note [Forward Grad View/inplace] for more details on how we handle these 466 : /// hard cases. 467 : /// 468 : /// 469 : /// DifferentiableViewMeta is created to support gradient tracking of 470 : /// such **in-place** operations. In particular, 471 : /// + if an in-place op is done on base, the grad_fn field of the view may 472 : /// become stale. So accesses should always go through grad_fn(), which 473 : /// reconstructs an updated grad_fn if the version_counter has incremented. 474 : /// All other fields are always valid. 475 : /// + if an in-place op is done on view, in rebase_history() of view, which is 476 : /// called after every in-place op in VariableType.cpp, the grad_fn of base 477 : /// is updated. 478 : /// + if a single autograd Node returns multiple differentiable views, if any 479 : /// output is modified by an inplace operation, the autograd engine will 480 : /// make an equivalent graph (corresponding to the view operations) without 481 : /// using equivalent graph, where each output is treated as if it were 482 : /// produced by a distinct view operation. This discards the original (e.g., 483 : /// user provided) grad_fn. If the provided grad_fn does more than the 484 : /// backward of the view, then the DifferentiableViewMeta must be created 485 : /// with creation_meta= CreationMeta::MULTI_OUTPUT_NODE to prevent the 486 : /// engine from ignoring the provided grad_fn. 487 : /// 488 : /// Interaction with GradMode: 489 : /// The particular case that we consider here is: 490 : /// 491 : /// # Have: 492 : /// # base.requires_grad = True or False 493 : /// with torch.no_grad(): 494 : /// view = base[1] 495 : /// base.requires_grad_() 496 : /// view.copy_(var) 497 : /// torch.autograd.grad(base.sum(), var) <- what should it return? 498 : /// 499 : /// Given that this particular code example is ambiguous and can easily be 500 : /// replace by either moving both inside the no_grad block or both outside, we 501 : /// explicitly forbid it. For now, it is deprecated by a warning. This is 502 : /// achieved by setting creation_meta=CreationMeta::NO_GRAD_MODE for all 503 : /// differentiable views created in no_grad mode. 504 : /// 505 : /// See Note [View + Inplace update for base tensor] 506 : /// and Note [View + Inplace update for view tensor] for the details how 507 : /// autograd handles inplace update with view ops. 508 : /// 509 : /// Non-Differentiable Views 510 : /// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 511 : /// In certain cases, although function outputs share storage with inputs, they 512 : /// will **never** require gradient history tracking. Instead of registering the 513 : /// view relation via DifferentiableViewMeta in autograd, the views will be 514 : /// using usual AutogradMeta and just share the version counters with the base 515 : /// Variables. 516 : /// Such views include: 517 : /// 1. Views created from .detach() 518 : /// 2. Views that are non-differentiable by its nature. 519 : /// E.g., `sparse_tensor.indices()` is a integral view on a (possibly) 520 : /// floating point tensor. 521 : /// See top of `derivatives.yaml` on how to specify that outputs of a 522 : /// function are non-differentiable. 523 : /// These are called non-differentiable views as the gradients do not flow 524 : /// through the view relation. 525 : /// 526 : /// Relevant logic for both differentiable and non-differentiable views is 527 : /// implemented in make_variable_(non_)differentiable_view below, and 528 : /// wrap_output of gen_variable_type.py. 529 : 530 : /// NOTE [ View + Inplace detection ] 531 : /// 532 : /// We want to detect views followed by inplace as they are often forbidden to 533 : /// ensure correctness of the computed gradients. But since we want to only 534 : /// notify the user when both happen, we tag the DifferentiableViewMeta when the 535 : /// view is created via the `make_variable_*_view()` functions. This tag is then 536 : /// checked by the `check_inplace()` function from `VariableTypeUtils.h` that 537 : /// should be called before every inplace operation and to detect cases where 538 : /// other views are modified and this one is rebased by side effect, we also 539 : /// check in the `VariableHooks::grad_fn()`. 540 : 541 : /// Flag that gives more information about when this view was created: 542 : /// - IN_CUSTOM_FUNCTION should be set when the view is created inside a custom 543 : /// autograd Function is returned. 544 : /// - NO_GRAD_MODE should be set when a view in created when GradMode is 545 : /// disabled 546 : /// - MULTI_OUTPUT_NODE should be set when a Node created by codegen code 547 : /// returns 548 : /// multiple differentiable views 549 : /// - Inference_MODE should be set when a view of normal tensor is created in 550 : /// InferenceMode. 551 : /// - DEFAULT is for all other cases 552 : enum class CreationMeta : uint8_t { 553 : DEFAULT, 554 : IN_CUSTOM_FUNCTION, 555 : MULTI_OUTPUT_NODE, 556 : NO_GRAD_MODE, 557 : INFERENCE_MODE 558 : }; 559 : 560 : /// Handles correctly propagating CreationMeta when a new view is created from a 561 : /// previous view. In general, we don't want the new view to be _less_ 562 : /// restrictive than the previous view (it's okay to be _more_ restrictive). A 563 : /// CreationMeta value of DEFAULT is currently the least restrictive, as the 564 : /// behavior for all other CreationMeta values is to error out for in-place ops. 565 : /// A CreationMeta value of INFERENCE_MODE is currently the most restrictive, so 566 : /// it takes precedence in propagation. If this changes, the logic here will 567 : /// need to be updated to properly handle the new semantics. 568 : inline CreationMeta propagate_creation_meta( 569 : CreationMeta prev_view_creation_meta, 570 : CreationMeta new_view_creation_meta) { 571 : return (new_view_creation_meta == CreationMeta::DEFAULT) 572 : ? prev_view_creation_meta 573 : : (prev_view_creation_meta == CreationMeta::INFERENCE_MODE 574 : ? prev_view_creation_meta 575 : : new_view_creation_meta); 576 : } 577 : 578 : /// Unified function to handle error checking when rebase happens 579 : /// indirect=true means that the caller is not doing the inplace, but the 580 : /// inplace happened somewhere else. 581 : TORCH_API void handle_view_on_rebase( 582 : DifferentiableViewMeta* diff_view_meta, 583 : bool indirect = false); 584 : 585 : struct TORCH_API DifferentiableViewMeta : public AutogradMeta { 586 : private: 587 : /// Informations about the views 588 : c10::optional<ViewInfo> backward_info_; 589 : c10::optional<ViewInfo> forward_info_; 590 : 591 : // Optimization to reduce the number of ViewInfo we create. 592 : // In the (very common) case where backward_info_ == forward_info_, we only 593 : // populate backward_info_ (that should be used as both the forward and 594 : // backward view information) and set shared_view_info_ = true. Invariants: 595 : // - If shared_view_info_ is false, there is no special constraints on 596 : // backward_info_ and forward_info_ 597 : // - If shared_view_info_ is true, we must have: 598 : // - backward_info_.has_value() == true 599 : // - forward_info_.has_value() == false 600 : bool shared_view_info_; 601 : 602 : /// The two following fields are extra information that we track to ensure 603 : /// that any operation on this backward view is valid. 604 : 605 : /// The value of the version_counter at the time grad_fn was created. The 606 : /// grad_fn field is stale if attr_version_ != 607 : /// version_counter.current_version(). 608 : uint32_t attr_version_; 609 : CreationMeta creation_meta_; 610 : 611 : public: 612 : /// requires_grad is a backward AD field so we only use the view specific 613 : /// logic for backward differentiable views 614 : bool requires_grad() const override { 615 : return requires_grad_ || grad_fn_ || 616 : (has_bw_view() && get_backward_view().base_.requires_grad()); 617 : } 618 : 619 : bool shared_view_info() const { 620 : return shared_view_info_; 621 : } 622 : 623 : bool has_bw_view() const { 624 : return backward_info_.has_value(); 625 : } 626 : 627 : const ViewInfo& get_backward_view() const { 628 : TORCH_CHECK( 629 : has_bw_view(), "backward view info can only exist for backward views."); 630 : return backward_info_.value(); 631 : } 632 : 633 : uint32_t get_attr_version() const { 634 : TORCH_CHECK( 635 : has_bw_view(), "attr_version can only exist for backward views."); 636 : return attr_version_; 637 : } 638 : 639 : void set_attr_version(uint32_t new_attr_version) { 640 : TORCH_CHECK( 641 : has_bw_view(), "attr_version can only exist for backward views."); 642 : attr_version_ = new_attr_version; 643 : } 644 : 645 : CreationMeta get_creation_meta() const { 646 : TORCH_CHECK( 647 : has_bw_view(), "creation_meta can only exist for backward views."); 648 : return creation_meta_; 649 : } 650 : 651 : void set_creation_meta(CreationMeta new_creation_meta) { 652 : TORCH_CHECK( 653 : has_bw_view(), "creation_meta can only exist for backward views."); 654 : creation_meta_ = new_creation_meta; 655 : } 656 : 657 : bool has_fw_view() const { 658 : return shared_view_info_ || forward_info_.has_value(); 659 : } 660 : 661 : const ViewInfo& get_forward_view() const { 662 : TORCH_CHECK( 663 : has_fw_view(), "forward view info can only exist for forward views."); 664 : TORCH_CHECK( 665 : !shared_view_info_ || has_bw_view(), 666 : "forward view info can only exist for forward views."); 667 : return shared_view_info_ ? backward_info_.value() : forward_info_.value(); 668 : } 669 : 670 : DifferentiableViewMeta( 671 : at::TensorImpl* self_impl, 672 : c10::optional<ViewInfo> backward_info, 673 : c10::optional<ViewInfo> forward_info, 674 : bool shared_view_info, 675 : CreationMeta creation_meta = CreationMeta::DEFAULT); 676 : }; 677 : 678 : //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 679 : // Variable Implementation 680 : //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 681 : 682 : // Factory Functions 683 : //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 684 : 685 : /// Creates a `Variable` that is a *view* of another (*base*) variable. 686 : /// The `gradient_edge` is an optional (gradient_function, input_number) pair. 687 : /// `is_differentiable` is a bool that specifies whether this view is 688 : /// differentiable, i.e., whether the relation should be tracked by autograd. 689 : /// See NOTE [ Autograd View Variables ] for details. 690 : 691 : /// NOTE: `allow_tensor_metadata_change` is set to true by default, because 692 : /// there are a lot of call sites to these factory functions that need to change 693 : /// the variable's size or storage afterwards, and they don't expect the 694 : /// original tensor (where the variable is created from) to be updated. Setting 695 : /// `allow_tensor_metadata_change_` to false by default would unnecessarily 696 : /// prevent those changes from happening and is undesirable. 697 : 698 : // See NOTE [ Autograd View Variables ] for details. 699 : // Differentiable view. Track history with DifferentiableViewMeta. 700 : inline Variable make_variable_differentiable_view( 701 : const at::Tensor& data, 702 : c10::optional<ViewInfo> backward_info, 703 : c10::optional<ViewInfo> forward_info, 704 : bool shared_view_info, 705 : CreationMeta creation_meta, 706 : bool allow_tensor_metadata_change = true) { 707 : if (data.defined()) { 708 : TORCH_CHECK( 709 : data.getIntrusivePtr()->autograd_meta() == nullptr, 710 : "Attempted to make a tensor into a differentiable view, but the " 711 : "tensor already had autograd metadata associated with it. If you are " 712 : "using a __torch_dispatch__ mode, the most common cause for this " 713 : "problem is that you used torch.overrides.enable_reentrant_dispatch() " 714 : "improperly; tensors created within the extent of reentrant dispatch " 715 : "MUST NOT be directly returned from __torch_dispatch__; instead, they " 716 : "must be wrapped into fresh tensors that serve as the output. If you " 717 : "are not using wrappers, you probably don't need reentrant dispatch. " 718 : "If this doesn't seem applicable, please file a bug to PyTorch."); 719 : at::TensorImpl* data_impl = data.unsafeGetTensorImpl(); 720 : data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change); 721 : data_impl->set_autograd_meta(std::make_unique<DifferentiableViewMeta>( 722 : data_impl, 723 : std::move(backward_info), 724 : std::move(forward_info), 725 : shared_view_info, 726 : creation_meta)); 727 : return data; 728 : } 729 : return Variable(); 730 : } 731 : 732 : // See NOTE [ Autograd View Variables ] for details. 733 : // Non-differentiable view. Just share version counter. 734 : inline Variable make_variable_non_differentiable_view( 735 : Variable base, 736 : const at::Tensor& data, 737 : bool allow_tensor_metadata_change = true) { 738 : if (data.defined()) { 739 : // Currently all of non-differentiable view ops(detach/_indices/_values) 740 : // share the same TensorImpl as their base Tensor. Thus a new TensorImpl 741 : // allocation here is required. 742 : auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach( 743 : /*version_counter=*/impl::version_counter(base), 744 : /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); 745 : data_impl_copy->set_autograd_meta(nullptr); 746 : return Variable(data_impl_copy); 747 : } 748 : return Variable(); 749 : } 750 : 751 : /// Creates a `Variable` from the given `Tensor`, copying its underlying 752 : /// `TensorImpl`. `requires_grad` should be set only for leaves, and determines 753 : /// whether the `Variable` will accumulate gradients. NOTE: `data` must *not* be 754 : /// a `Variable` already. Its dynamic type *must* be `Tensor`. 755 : /// 756 : /// TODO: Eliminate this function as much as possible, as it can be expressed 757 : /// more clearly as detach() or a no-op in most call sites (especially when 758 : /// there is only one use of the variable). 759 418844 : inline Variable make_variable( 760 : at::Tensor data, 761 : bool requires_grad = false, 762 : bool allow_tensor_metadata_change = true) { 763 418844 : if (data.defined()) { 764 458920 : if (data.getIntrusivePtr().use_count() == 1 && 765 40076 : data.getIntrusivePtr()->unique_version()) { 766 40076 : auto data_impl = data.unsafeReleaseIntrusivePtr(); 767 40076 : data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change); 768 : // NOLINTNEXTLINE(bugprone-branch-clone) 769 40076 : if (requires_grad) { 770 0 : data_impl->set_autograd_meta( 771 0 : std::make_unique<AutogradMeta>(data_impl.get(), requires_grad)); 772 : } else { 773 40076 : data_impl->set_autograd_meta(nullptr); 774 : } 775 40076 : return Variable(std::move(data_impl)); 776 40076 : } else { 777 378768 : auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach( 778 : /*version_counter=*/0, 779 378768 : /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); 780 : // NOLINTNEXTLINE(bugprone-branch-clone) 781 378768 : if (requires_grad) { 782 0 : data_impl_copy->set_autograd_meta(std::make_unique<AutogradMeta>( 783 0 : data_impl_copy.get(), requires_grad)); 784 : } else { 785 378768 : data_impl_copy->set_autograd_meta(nullptr); 786 : } 787 378768 : return Variable(data_impl_copy); 788 378768 : } 789 : } 790 0 : return Variable(); 791 : } 792 : 793 : /// Creates a `Variable` from the given `Tensor`, copying its underlying 794 : /// `TensorImpl`. `gradient_edge` should be a (function, input_nr) pair 795 : /// specifying the function in the autograd graph, and what particular input of 796 : /// that function, this variable is connected to. 797 : inline Variable make_variable( 798 : at::Tensor data, 799 : Edge gradient_edge, 800 : bool allow_tensor_metadata_change = true) { 801 : if (data.defined()) { 802 : auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach( 803 : /*version_counter=*/0, 804 : /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); 805 : data_impl_copy->set_autograd_meta(std::make_unique<AutogradMeta>( 806 : data_impl_copy.get(), false, std::move(gradient_edge))); 807 : return Variable(data_impl_copy); 808 : } 809 : return Variable(); 810 : } 811 : 812 : struct VariableHooks final : at::impl::VariableHooksInterface { 813 : at::TensorBase tensor_data(const at::TensorBase&) const override; 814 : at::TensorBase variable_data(const at::TensorBase&) const override; 815 : const std::shared_ptr<torch::autograd::Node>& grad_fn( 816 : const at::TensorBase&) const override; 817 : unsigned _register_hook( 818 : const at::TensorBase&, 819 : std::function<at::TensorBase(const at::TensorBase&)> hook) const override; 820 : void remove_hook(const at::TensorBase&, unsigned pos) const override; 821 : bool is_view(const at::TensorBase&) const override; 822 : const at::TensorBase& base(const at::TensorBase&) const override; 823 : const std::string& name(const at::TensorBase&) const override; 824 : bool is_leaf(const at::TensorBase&) const override; 825 : int64_t output_nr(const at::TensorBase&) const override; 826 : void set_data(const at::TensorBase& self, const at::TensorBase& new_data) 827 : const override; 828 : at::TensorBase data(const at::TensorBase& self) const override; 829 : int64_t _version(const at::TensorBase& self) const override; 830 : void retain_grad(const at::TensorBase& self) const override; 831 : bool retains_grad(const at::TensorBase& self) const override; 832 : void _backward( 833 : const at::Tensor& self, 834 : at::TensorList inputs, 835 : const c10::optional<at::Tensor>& gradient, 836 : c10::optional<bool> keep_graph, 837 : bool create_graph) const override; 838 : void requires_grad_(const at::TensorBase& self, bool _requires_grad) 839 : const override; 840 : void basic_autograd_not_implemented_fallback( 841 : const c10::OperatorHandle& op, 842 : c10::DispatchKeySet dispatch_keys, 843 : torch::jit::Stack* stack) const override; 844 : }; 845 : 846 : namespace utils { 847 : 848 : TORCH_API bool has_same_meta(const Variable& base, const Variable& other); 849 : 850 : } // namespace utils 851 : } // namespace autograd 852 : } // namespace torch 853 : 854 : #endif /* DOXYGEN_SHOULD_SKIP_THIS */ |
![]() |
Generated by: LCOV version 2.0-1 |
</html>