Line data Source code
1 : #pragma once
2 :
3 : #include <torch/csrc/utils/python_stub.h>
4 :
5 : #include <torch/csrc/Export.h>
6 : #include <torch/csrc/autograd/cpp_hook.h>
7 : #include <torch/csrc/autograd/edge.h>
8 : #include <torch/csrc/autograd/forward_grad.h>
9 : #include <torch/csrc/autograd/function_hook.h>
10 :
11 : #include <ATen/NamedTensorUtils.h>
12 : #include <ATen/core/Tensor.h>
13 : #include <ATen/core/VariableHooksInterface.h>
14 : #include <c10/util/Exception.h>
15 :
16 : #include <cstdint>
17 : #include <memory>
18 : #include <mutex>
19 : #include <stdexcept>
20 : #include <string>
21 : #include <utility>
22 : #include <vector>
23 :
24 : namespace torch {
25 : namespace autograd {
26 :
27 : /// `Variable` is exactly the same as `Tensor` (i.e. we have `using Variable =
28 : /// at::Tensor`). This means you can perform all the usual mathematical and
29 : /// other operations you can perform on `Tensor`s also on `Variable`s.
30 : ///
31 : /// The only reason we are keeping the `Variable` class is backward
32 : /// compatibility with external user's legacy C++ frontend code. Our intention
33 : /// is to eliminate the `Variable` class in the near future.
34 : using Variable = at::Tensor;
35 :
36 : } // namespace autograd
37 : } // namespace torch
38 :
39 : // The following are all internal APIs and should not be shown in libtorch docs.
40 : // Therefore, we wrap the following code with `#ifndef DOXYGEN_SHOULD_SKIP_THIS
41 : // ... #endif`
42 :
43 : #ifndef DOXYGEN_SHOULD_SKIP_THIS
44 :
45 : namespace torch {
46 : namespace autograd {
47 :
48 : /// Check if this type is supported by the autograd engine.
49 : /// If you change this, update the doc at the top of the
50 : /// torch/autograd/__init__.py file and
51 : /// "test_set_requires_grad_only_for_continuous_types" in test/test_autograd.py
52 0 : static inline bool isDifferentiableType(at::ScalarType t) {
53 0 : return isFloatingType(t) || isComplexType(t);
54 : }
55 :
56 : struct Node;
57 :
58 : ///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
59 : /// Variable
60 : ///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
61 : /// A `Variable` augments a `Tensor` with the ability to interact in our
62 : /// autograd machinery. Conceptually, `Variable`s travel along `Edge`s between
63 : /// `Node`s in the autograd graph. A `Variable` can either be a leaf, like a
64 : /// weight in a neural network, or an interior variable, when it is the result
65 : /// of an operation between variables. Every `Variable` also stores another
66 : /// `Variable` called its `grad` (gradient). If the variable is a leaf, its
67 : /// gradient will be accumulated into this variable.
68 : ///
69 : /// Every Tensor is a Variable, but sometimes we colloquially refer to Variables
70 : /// that don't require gradients as Tensors (since none of the autograd
71 : /// machinery for Variables applies). Historically, Variables and Tensors
72 : /// were separate concepts, but now they are exactly the same (i.e. we have
73 : /// `using Variable = at::Tensor`).
74 : ///
75 : /// Gradient Edges
76 : ///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
77 : /// Furthermore, `Variable`s have the notion of a `gradient_edge`, which is the
78 : /// edge in the autograd graph that connects the variable to a particular input
79 : /// of the gradient function that will be invoked with the variable during the
80 : /// backward pass. More precisely, this gradient function can be one of two
81 : /// things:
82 : /// 1. A `grad_fn`, if the variable is in the interior of the graph. This is the
83 : /// gradient of the function that produced the variable.
84 : /// 2. A `grad_accumulator`, if the variable is a leaf, which accumulates a
85 : /// scalar gradient value into its `grad` variable.
86 : ///
87 : /// Versioning
88 : ///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
89 : /// Another major feature of `Variable`s are *versions*. Versions are
90 : /// incremented when an in-place mutation of a variable occurs. Versions are
91 : /// useful when constructing `SavedVariable`s, which take a snapshot of a
92 : /// `Variable` at a certain version. You can retrieve a `Variable`'s version
93 : /// through its `current_version()` method.
94 : ///
95 : /// Views
96 : ///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
97 : /// It is possible for a `Variable` to be a *view* of another `Variable`, in
98 : /// which case it tracks that `Variable`'s data and autograd history. Beyond
99 : /// construction, the interface of a view is identical to that of a regular
100 : /// `Variable`. You can determine whether `Variable` is in fact a view by
101 : /// probing its `is_view()` method. Note that the *view* semantics are only
102 : /// meaningful for `Variable` relations that are relevant to autograd.
103 : /// See NOTE [ Autograd View Variables ] for more details.
104 : ///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
105 :
106 : struct AutogradMeta;
107 : struct DifferentiableViewMeta;
108 :
109 : // Private-ish functions for manipulating variables; we don't want to put them
110 : // on Tensor proper
111 : namespace impl {
112 :
113 : // WARNING: This may return a nullptr. If you require AutogradMeta to return
114 : // a materialized structure, use materialize_autograd_meta instead.
115 : TORCH_API AutogradMeta* get_autograd_meta(const at::TensorBase&);
116 :
117 : // WARNING: This will return a nullptr if the Tensor is not a view.
118 : TORCH_API DifferentiableViewMeta* get_view_autograd_meta(const at::TensorBase&);
119 :
120 : // Returns the current autograd meta, materializing it if it was previously
121 : // none. This counts as a *mutating* operation, so do not call it on
122 : // "read-only" operators; in particular, this is NOT thread safe
123 : TORCH_API AutogradMeta* materialize_autograd_meta(const at::TensorBase&);
124 :
125 : /// Set the gradient accumulator of the `Variable`. This is only applicable to
126 : /// leaf variables. Interior variables should call `set_gradient_edge()`.
127 : TORCH_API void set_grad_accumulator(
128 : const Variable&,
129 : std::weak_ptr<Node> grad_accumulator);
130 :
131 : /// Attempts to get a pointer to the gradient accumulator of the `Variable`,
132 : /// if it still exists. If the gradient accumulator function has been
133 : /// destroyed, returns a `nullptr`.
134 : TORCH_API std::shared_ptr<Node> try_get_grad_accumulator(const Variable&);
135 :
136 : /// Gets the gradient accumulator of the `Variable` if it has one, or else
137 : /// create one on the fly and return it.
138 : TORCH_API std::shared_ptr<Node> grad_accumulator(const Variable&);
139 :
140 : /// Returns the "canonical" gradient edge of this `Variable`, i.e. either the
141 : /// gradient function if this is an interior `Variable`, or the gradient
142 : /// accumulator otherwise. If the `Variable` is interior, the returned `Edge`
143 : /// will store the input index of the `Node` to which this variable is
144 : /// connected in its `input_nr` field. For leaves, the `input_nr` is always
145 : /// zero. Note that `set_gradient_edge` and `gradient_edge` are not
146 : /// symmetric. You must use `set_gradient_edge` to set the `grad_fn` and
147 : /// `set_grad_accumulator` to set the accumulator.
148 : TORCH_API Edge gradient_edge(const Variable&);
149 :
150 : /// Set the gradient edge -- i.e. `grad_fn` and `input_nr` -- of the
151 : /// `Variable`.
152 : /// NOTE: This will always set the `grad_fn`, even if this is a leaf variable,
153 : /// and never the `grad_accumulator`. For the latter, use
154 : /// `set_grad_accumulator`. This allows late construction of an interior
155 : /// `Variable`.
156 : TORCH_API void set_gradient_edge(const Variable&, Edge edge);
157 :
158 : // Autograd Graph Interaction
159 : //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
160 :
161 : /// Update the `grad_fn` of an existing Variable. Called after in-place
162 : /// modifications.
163 : ///
164 : /// For View Variables:
165 : /// Called after in-place modifications. Modifies the grad_fn of the base
166 : /// Variable.
167 : TORCH_API void rebase_history(const Variable&, Edge gradient_edge);
168 :
169 : /// Gets the raw gradient function pointer, whatever it currently is.
170 : TORCH_API Node* grad_fn_unsafe(const Variable&);
171 :
172 : /// Increments the version count of this `Variable`.
173 : TORCH_API void bump_version(const Variable&);
174 : TORCH_API void set_version_counter(
175 : const Variable&,
176 : const c10::VariableVersion& version_counter);
177 :
178 : /// Retrieves this `Variable`s version counter.
179 : TORCH_API const c10::VariableVersion& version_counter(const Variable&);
180 :
181 : TORCH_API void set_name(const Variable&, const std::string& name);
182 :
183 : TORCH_API void add_hook(
184 : const at::TensorBase&,
185 : std::unique_ptr<FunctionPreHook> hook);
186 : TORCH_API std::vector<std::unique_ptr<FunctionPreHook>>& hooks(const Variable&);
187 : TORCH_API void clear_hooks(const at::TensorBase&);
188 :
189 : TORCH_API void set_post_acc_grad_hooks(
190 : const at::TensorBase&,
191 : std::unique_ptr<PostAccumulateGradHook> dict);
192 : TORCH_API std::unique_ptr<PostAccumulateGradHook>& post_acc_grad_hooks(
193 : const Variable&);
194 :
195 : TORCH_API void create_cpp_hook(
196 : const at::TensorBase&,
197 : bool is_retains_grad_hooks = false);
198 : } // namespace impl
199 :
200 : //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
201 : // AutogradMeta
202 : //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
203 :
204 : /// Each `Variable` has one unique `AutogradMeta` struct, which stores autograd
205 : /// metadata fields that are necessary for tracking the Variable's autograd
206 : /// history. As an optimization, a Variable may store a nullptr, in lieu of a
207 : /// default constructed AutogradMeta.
208 :
209 : struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
210 : std::string name_;
211 :
212 : Variable grad_;
213 : std::shared_ptr<Node> grad_fn_;
214 : std::weak_ptr<Node> grad_accumulator_;
215 :
216 : // This field is used to store all the forward AD gradients
217 : // associated with this AutogradMeta (and the Tensor it corresponds to)
218 : // There is a semantic 1:1 correspondence between AutogradMeta and
219 : // ForwardGrad but:
220 : // - This field is lazily populated.
221 : // - This field is a shared_ptr but it must never be
222 : // shared by multiple Tensors. See Note [ Using ForwardGrad ]
223 : // Any transition from not_initialized to initialized
224 : // must be protected by mutex_
225 : std::shared_ptr<ForwardGrad> fw_grad_;
226 :
227 : // The hooks_ field is actually reused by both python and cpp logic
228 : // For both cases, we have a data structure, cpp_hooks_list_ (cpp)
229 : // or dict (python) which is the canonical copy.
230 : // Then, for both cases, we always register a single hook to
231 : // hooks_ which wraps all the hooks in the list/dict.
232 : // And, again in both cases, if the grad_fn exists on that tensor
233 : // we will additionally register a single hook to the grad_fn.
234 : //
235 : // Note that the cpp and python use cases aren't actually aware of
236 : // each other, so using both is not defined behavior.
237 : std::vector<std::unique_ptr<FunctionPreHook>> hooks_;
238 : std::shared_ptr<hooks_list> cpp_hooks_list_;
239 :
240 : // The post_acc_grad_hooks_ field stores only Python hooks
241 : // (PyFunctionTensorPostAccGradHooks) that are called after the
242 : // .grad field has been accumulated into. This is less complicated
243 : // than the hooks_ field, which encapsulates a lot more.
244 : std::unique_ptr<PostAccumulateGradHook> post_acc_grad_hooks_ = nullptr;
245 :
246 : // Only meaningful on leaf variables (must be false otherwise)
247 : bool requires_grad_{false};
248 :
249 : // Only meaningful on non-leaf variables (must be false otherwise)
250 : bool retains_grad_{false};
251 :
252 : bool is_view_{false};
253 :
254 : // The "output number" of this variable; e.g., if this variable
255 : // was the second output of a function, then output_nr == 1.
256 : // We use this to make sure we can setup the backwards trace
257 : // correctly when this variable is passed to another function.
258 : uint32_t output_nr_;
259 :
260 : // Mutex to ensure that concurrent read operations that modify internal
261 : // state are still thread-safe. Used by grad_fn(), grad_accumulator(),
262 : // fw_grad() and set_fw_grad()
263 : // This is mutable because we need to be able to acquire this from const
264 : // version of this class for the functions above
265 : mutable std::mutex mutex_;
266 :
267 : /// Sets the `requires_grad` property of `Variable`. This should be true for
268 : /// leaf variables that want to accumulate gradients, and false for all other
269 : /// variables.
270 0 : void set_requires_grad(bool requires_grad, at::TensorImpl* self_impl)
271 : override {
272 0 : TORCH_CHECK(
273 : !requires_grad ||
274 : isDifferentiableType(at::typeMetaToScalarType(self_impl->dtype())),
275 : "Only Tensors of floating point and complex dtype can require gradients");
276 0 : requires_grad_ = requires_grad;
277 0 : }
278 :
279 : bool requires_grad() const override {
280 : return requires_grad_ || grad_fn_;
281 : }
282 :
283 : /// Accesses the gradient `Variable` of this `Variable`.
284 : Variable& mutable_grad() override {
285 : return grad_;
286 : }
287 :
288 : const Variable& grad() const override {
289 : return grad_;
290 : }
291 :
292 : const Variable& fw_grad(uint64_t level, const at::TensorBase& self)
293 : const override;
294 :
295 : void set_fw_grad(
296 : const at::TensorBase& new_grad,
297 : const at::TensorBase& self,
298 : uint64_t level,
299 : bool is_inplace_op) override;
300 :
301 0 : AutogradMeta(
302 : at::TensorImpl* self_impl = nullptr,
303 : bool requires_grad = false,
304 : Edge gradient_edge = Edge())
305 0 : : grad_fn_(std::move(gradient_edge.function)),
306 :
307 0 : output_nr_(gradient_edge.input_nr) {
308 : // set_requires_grad also checks error conditions.
309 0 : if (requires_grad) {
310 0 : TORCH_INTERNAL_ASSERT(self_impl);
311 : // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
312 0 : set_requires_grad(requires_grad, self_impl);
313 : }
314 0 : TORCH_CHECK(
315 : !grad_fn_ || !requires_grad_,
316 : "requires_grad should be false if grad_fn is set");
317 0 : }
318 :
319 : ~AutogradMeta() override {
320 : // If AutogradMeta is being destroyed, it means that there is no other
321 : // reference to its corresponding Tensor. It implies that no other thread
322 : // can be using this object and so there is no need to lock mutex_ here to
323 : // guard the check if fw_grad_ is populated.
324 : if (fw_grad_) {
325 : // See note [ Using ForwardGrad ]
326 : fw_grad_->clear();
327 : }
328 : }
329 : };
330 :
331 : struct TORCH_API ViewInfo {
332 : /// The base `Variable`
333 : /// If this ViewInfo represents a forward (respectively backward) AD gradient,
334 : /// then this Tensor cannot be a forward (respectively backward) view.
335 : Variable base_;
336 :
337 : /// By default we use as_strided to recover views which is more efficient.
338 : /// view_fn is only saved when as_strided is not supported.
339 : /// If view_fn has value, we use it to recover views in backward.
340 : std::function<Variable(const Variable&)> view_fn_;
341 :
342 : /// Accessors for the view function
343 : bool has_view_fn() const {
344 : return view_fn_ != nullptr;
345 : }
346 :
347 : std::function<Variable(const Variable&)> view_fn() const {
348 : TORCH_CHECK(
349 : has_view_fn(), "Can only access the view function if it exists.");
350 : return view_fn_;
351 : }
352 :
353 : /// The chain function can be used to build a new ViewInfo for a
354 : /// differentiable view function. It will return a new view info that
355 : /// accurately represents how "tensor" is a view of this instance's "base_".
356 : /// The "base" and "tensor" are respectively the input and output of the
357 : /// differentiable view function that happened. They are required to properly
358 : /// set the optional view_fn_ when it is not provided. The "view_func", if
359 : /// provided, should be a function that allows to re-do the view between
360 : /// "base" and "tensor".
361 : ViewInfo chain(
362 : const Variable& base,
363 : const Variable& tensor,
364 : std::function<Variable(const Variable&)> view_func = nullptr) const;
365 :
366 : ViewInfo(Variable base, std::function<Variable(const Variable&)> view_fn)
367 : : base_(std::move(base)), view_fn_(std::move(view_fn)) {
368 : TORCH_CHECK(base_.defined(), "base is undefined");
369 : }
370 : };
371 :
372 : //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
373 : // DifferentiableViewMeta
374 : //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
375 :
376 : /// NOTE [ Autograd View Variables ]
377 : ///
378 : /// Many operations return Variable that shares storage with an input Variable.
379 : /// The returned Variable is called a **view** Variable on the input **base**
380 : /// Variable.
381 : ///
382 : /// In PyTorch, we have two types of views: differentiable views, and
383 : /// non-differentiable views. In either type, to support proper version
384 : /// checking, the base and view Variables must always share the same
385 : /// version_counter.
386 : ///
387 : ///
388 : /// Differentiable Views
389 : /// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
390 : /// This class allows to track both forward and backward AD differentiable
391 : /// views. These views can have different base as non-differentiable view for
392 : /// forward and backward mode AD are not the same.
393 : ///
394 : /// Most function are either both forward and backward differentiable views (for
395 : /// example: view, select, narrow, transpose, etc) or both not forward and not
396 : /// backward differentiable views (for example: indices, values, eq, lt, etc).
397 : /// But there are also functions that are forward but not backward
398 : /// differentiable views (only detach for now) or functions that are backward
399 : /// but not forward differentiable view (only make_dual and unpack dual for
400 : /// now).
401 : ///
402 : /// A concrete example of two views with different bases is as follow:
403 : ///
404 : /// # Have:
405 : /// # dual is a dual Tensor that is neither a forward or backward view
406 : /// detached_dual = dual.detach()
407 : /// view = detached_dual.view_as(dual)
408 : /// # The forward base of view is dual
409 : /// # The backward base of view is detached_dual
410 : ///
411 : /// - Backward Mode View
412 : /// Differentiable views are the view variables where you want gradients to flow
413 : /// back to the base variables. Out-of-place operations on views are quite
414 : /// straightforward, but in-place ones are very tricky. Even if the base
415 : /// variable may not require grad when we create the view, we still need to
416 : /// track the view relation because future in-place ops may require back-proping
417 : /// through it. For example, we need to support
418 : ///
419 : /// (1) in-place operation on view, e.g.,
420 : ///
421 : /// # Have:
422 : /// # base.requires_grad = False
423 : /// # var.requires_grad = True
424 : /// base[1] = var # i.e., base[1].copy_(var)
425 : /// torch.autograd.grad(base.sum(), var) <- should return an all ones
426 : /// tensor
427 : ///
428 : /// (2) in-place operation on base after view is created, e.g.,
429 : ///
430 : /// # Have:
431 : /// # base.requires_grad = False
432 : /// # var.requires_grad = True
433 : /// view = base[1]
434 : /// base.copy_(var)
435 : /// torch.autograd.grad(view.sum(), var) <- should return a tensor with
436 : /// var[1] filled with all ones and
437 : /// zeros everywhere else
438 : ///
439 : /// - Forward Mode View
440 : /// Forward differentiable views follow the same semantic as backward ones but
441 : /// show up differently as they are computed along with the forward evaluation.
442 : /// The hard examples above are thus very similar
443 : ///
444 : /// (1) in-place operation on view, e.g.,
445 : ///
446 : /// # Have:
447 : /// # base is a regular Tensor
448 : /// # var is a dual Tensor whose tangent is all ones
449 : /// base[1] = var # i.e., base[1].copy_(var)
450 : /// # Now, base is a dual Tensor
451 : /// _, fw_grad = fwAD.unpack_dual(base) <- fw_grad should be a tensor with
452 : /// fw_grad[1] filled with all ones
453 : /// and zeros everywhere else
454 : ///
455 : /// (2) in-place operation on base after view is created, e.g.,
456 : ///
457 : /// # Have:
458 : /// # base is a regular Tensor
459 : /// # var is a dual Tensor whose tangent is all ones
460 : /// view = base[1]
461 : /// base.copy_(var)
462 : /// _, fw_grad = fwAD.unpack_dual(view) <- fw_grad should be an all ones
463 : /// tensor
464 : ///
465 : /// See Note [Forward Grad View/inplace] for more details on how we handle these
466 : /// hard cases.
467 : ///
468 : ///
469 : /// DifferentiableViewMeta is created to support gradient tracking of
470 : /// such **in-place** operations. In particular,
471 : /// + if an in-place op is done on base, the grad_fn field of the view may
472 : /// become stale. So accesses should always go through grad_fn(), which
473 : /// reconstructs an updated grad_fn if the version_counter has incremented.
474 : /// All other fields are always valid.
475 : /// + if an in-place op is done on view, in rebase_history() of view, which is
476 : /// called after every in-place op in VariableType.cpp, the grad_fn of base
477 : /// is updated.
478 : /// + if a single autograd Node returns multiple differentiable views, if any
479 : /// output is modified by an inplace operation, the autograd engine will
480 : /// make an equivalent graph (corresponding to the view operations) without
481 : /// using equivalent graph, where each output is treated as if it were
482 : /// produced by a distinct view operation. This discards the original (e.g.,
483 : /// user provided) grad_fn. If the provided grad_fn does more than the
484 : /// backward of the view, then the DifferentiableViewMeta must be created
485 : /// with creation_meta= CreationMeta::MULTI_OUTPUT_NODE to prevent the
486 : /// engine from ignoring the provided grad_fn.
487 : ///
488 : /// Interaction with GradMode:
489 : /// The particular case that we consider here is:
490 : ///
491 : /// # Have:
492 : /// # base.requires_grad = True or False
493 : /// with torch.no_grad():
494 : /// view = base[1]
495 : /// base.requires_grad_()
496 : /// view.copy_(var)
497 : /// torch.autograd.grad(base.sum(), var) <- what should it return?
498 : ///
499 : /// Given that this particular code example is ambiguous and can easily be
500 : /// replace by either moving both inside the no_grad block or both outside, we
501 : /// explicitly forbid it. For now, it is deprecated by a warning. This is
502 : /// achieved by setting creation_meta=CreationMeta::NO_GRAD_MODE for all
503 : /// differentiable views created in no_grad mode.
504 : ///
505 : /// See Note [View + Inplace update for base tensor]
506 : /// and Note [View + Inplace update for view tensor] for the details how
507 : /// autograd handles inplace update with view ops.
508 : ///
509 : /// Non-Differentiable Views
510 : /// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
511 : /// In certain cases, although function outputs share storage with inputs, they
512 : /// will **never** require gradient history tracking. Instead of registering the
513 : /// view relation via DifferentiableViewMeta in autograd, the views will be
514 : /// using usual AutogradMeta and just share the version counters with the base
515 : /// Variables.
516 : /// Such views include:
517 : /// 1. Views created from .detach()
518 : /// 2. Views that are non-differentiable by its nature.
519 : /// E.g., `sparse_tensor.indices()` is a integral view on a (possibly)
520 : /// floating point tensor.
521 : /// See top of `derivatives.yaml` on how to specify that outputs of a
522 : /// function are non-differentiable.
523 : /// These are called non-differentiable views as the gradients do not flow
524 : /// through the view relation.
525 : ///
526 : /// Relevant logic for both differentiable and non-differentiable views is
527 : /// implemented in make_variable_(non_)differentiable_view below, and
528 : /// wrap_output of gen_variable_type.py.
529 :
530 : /// NOTE [ View + Inplace detection ]
531 : ///
532 : /// We want to detect views followed by inplace as they are often forbidden to
533 : /// ensure correctness of the computed gradients. But since we want to only
534 : /// notify the user when both happen, we tag the DifferentiableViewMeta when the
535 : /// view is created via the `make_variable_*_view()` functions. This tag is then
536 : /// checked by the `check_inplace()` function from `VariableTypeUtils.h` that
537 : /// should be called before every inplace operation and to detect cases where
538 : /// other views are modified and this one is rebased by side effect, we also
539 : /// check in the `VariableHooks::grad_fn()`.
540 :
541 : /// Flag that gives more information about when this view was created:
542 : /// - IN_CUSTOM_FUNCTION should be set when the view is created inside a custom
543 : /// autograd Function is returned.
544 : /// - NO_GRAD_MODE should be set when a view in created when GradMode is
545 : /// disabled
546 : /// - MULTI_OUTPUT_NODE should be set when a Node created by codegen code
547 : /// returns
548 : /// multiple differentiable views
549 : /// - Inference_MODE should be set when a view of normal tensor is created in
550 : /// InferenceMode.
551 : /// - DEFAULT is for all other cases
552 : enum class CreationMeta : uint8_t {
553 : DEFAULT,
554 : IN_CUSTOM_FUNCTION,
555 : MULTI_OUTPUT_NODE,
556 : NO_GRAD_MODE,
557 : INFERENCE_MODE
558 : };
559 :
560 : /// Handles correctly propagating CreationMeta when a new view is created from a
561 : /// previous view. In general, we don't want the new view to be _less_
562 : /// restrictive than the previous view (it's okay to be _more_ restrictive). A
563 : /// CreationMeta value of DEFAULT is currently the least restrictive, as the
564 : /// behavior for all other CreationMeta values is to error out for in-place ops.
565 : /// A CreationMeta value of INFERENCE_MODE is currently the most restrictive, so
566 : /// it takes precedence in propagation. If this changes, the logic here will
567 : /// need to be updated to properly handle the new semantics.
568 : inline CreationMeta propagate_creation_meta(
569 : CreationMeta prev_view_creation_meta,
570 : CreationMeta new_view_creation_meta) {
571 : return (new_view_creation_meta == CreationMeta::DEFAULT)
572 : ? prev_view_creation_meta
573 : : (prev_view_creation_meta == CreationMeta::INFERENCE_MODE
574 : ? prev_view_creation_meta
575 : : new_view_creation_meta);
576 : }
577 :
578 : /// Unified function to handle error checking when rebase happens
579 : /// indirect=true means that the caller is not doing the inplace, but the
580 : /// inplace happened somewhere else.
581 : TORCH_API void handle_view_on_rebase(
582 : DifferentiableViewMeta* diff_view_meta,
583 : bool indirect = false);
584 :
585 : struct TORCH_API DifferentiableViewMeta : public AutogradMeta {
586 : private:
587 : /// Informations about the views
588 : c10::optional<ViewInfo> backward_info_;
589 : c10::optional<ViewInfo> forward_info_;
590 :
591 : // Optimization to reduce the number of ViewInfo we create.
592 : // In the (very common) case where backward_info_ == forward_info_, we only
593 : // populate backward_info_ (that should be used as both the forward and
594 : // backward view information) and set shared_view_info_ = true. Invariants:
595 : // - If shared_view_info_ is false, there is no special constraints on
596 : // backward_info_ and forward_info_
597 : // - If shared_view_info_ is true, we must have:
598 : // - backward_info_.has_value() == true
599 : // - forward_info_.has_value() == false
600 : bool shared_view_info_;
601 :
602 : /// The two following fields are extra information that we track to ensure
603 : /// that any operation on this backward view is valid.
604 :
605 : /// The value of the version_counter at the time grad_fn was created. The
606 : /// grad_fn field is stale if attr_version_ !=
607 : /// version_counter.current_version().
608 : uint32_t attr_version_;
609 : CreationMeta creation_meta_;
610 :
611 : public:
612 : /// requires_grad is a backward AD field so we only use the view specific
613 : /// logic for backward differentiable views
614 : bool requires_grad() const override {
615 : return requires_grad_ || grad_fn_ ||
616 : (has_bw_view() && get_backward_view().base_.requires_grad());
617 : }
618 :
619 : bool shared_view_info() const {
620 : return shared_view_info_;
621 : }
622 :
623 : bool has_bw_view() const {
624 : return backward_info_.has_value();
625 : }
626 :
627 : const ViewInfo& get_backward_view() const {
628 : TORCH_CHECK(
629 : has_bw_view(), "backward view info can only exist for backward views.");
630 : return backward_info_.value();
631 : }
632 :
633 : uint32_t get_attr_version() const {
634 : TORCH_CHECK(
635 : has_bw_view(), "attr_version can only exist for backward views.");
636 : return attr_version_;
637 : }
638 :
639 : void set_attr_version(uint32_t new_attr_version) {
640 : TORCH_CHECK(
641 : has_bw_view(), "attr_version can only exist for backward views.");
642 : attr_version_ = new_attr_version;
643 : }
644 :
645 : CreationMeta get_creation_meta() const {
646 : TORCH_CHECK(
647 : has_bw_view(), "creation_meta can only exist for backward views.");
648 : return creation_meta_;
649 : }
650 :
651 : void set_creation_meta(CreationMeta new_creation_meta) {
652 : TORCH_CHECK(
653 : has_bw_view(), "creation_meta can only exist for backward views.");
654 : creation_meta_ = new_creation_meta;
655 : }
656 :
657 : bool has_fw_view() const {
658 : return shared_view_info_ || forward_info_.has_value();
659 : }
660 :
661 : const ViewInfo& get_forward_view() const {
662 : TORCH_CHECK(
663 : has_fw_view(), "forward view info can only exist for forward views.");
664 : TORCH_CHECK(
665 : !shared_view_info_ || has_bw_view(),
666 : "forward view info can only exist for forward views.");
667 : return shared_view_info_ ? backward_info_.value() : forward_info_.value();
668 : }
669 :
670 : DifferentiableViewMeta(
671 : at::TensorImpl* self_impl,
672 : c10::optional<ViewInfo> backward_info,
673 : c10::optional<ViewInfo> forward_info,
674 : bool shared_view_info,
675 : CreationMeta creation_meta = CreationMeta::DEFAULT);
676 : };
677 :
678 : //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
679 : // Variable Implementation
680 : //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
681 :
682 : // Factory Functions
683 : //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
684 :
685 : /// Creates a `Variable` that is a *view* of another (*base*) variable.
686 : /// The `gradient_edge` is an optional (gradient_function, input_number) pair.
687 : /// `is_differentiable` is a bool that specifies whether this view is
688 : /// differentiable, i.e., whether the relation should be tracked by autograd.
689 : /// See NOTE [ Autograd View Variables ] for details.
690 :
691 : /// NOTE: `allow_tensor_metadata_change` is set to true by default, because
692 : /// there are a lot of call sites to these factory functions that need to change
693 : /// the variable's size or storage afterwards, and they don't expect the
694 : /// original tensor (where the variable is created from) to be updated. Setting
695 : /// `allow_tensor_metadata_change_` to false by default would unnecessarily
696 : /// prevent those changes from happening and is undesirable.
697 :
698 : // See NOTE [ Autograd View Variables ] for details.
699 : // Differentiable view. Track history with DifferentiableViewMeta.
700 : inline Variable make_variable_differentiable_view(
701 : const at::Tensor& data,
702 : c10::optional<ViewInfo> backward_info,
703 : c10::optional<ViewInfo> forward_info,
704 : bool shared_view_info,
705 : CreationMeta creation_meta,
706 : bool allow_tensor_metadata_change = true) {
707 : if (data.defined()) {
708 : TORCH_CHECK(
709 : data.getIntrusivePtr()->autograd_meta() == nullptr,
710 : "Attempted to make a tensor into a differentiable view, but the "
711 : "tensor already had autograd metadata associated with it. If you are "
712 : "using a __torch_dispatch__ mode, the most common cause for this "
713 : "problem is that you used torch.overrides.enable_reentrant_dispatch() "
714 : "improperly; tensors created within the extent of reentrant dispatch "
715 : "MUST NOT be directly returned from __torch_dispatch__; instead, they "
716 : "must be wrapped into fresh tensors that serve as the output. If you "
717 : "are not using wrappers, you probably don't need reentrant dispatch. "
718 : "If this doesn't seem applicable, please file a bug to PyTorch.");
719 : at::TensorImpl* data_impl = data.unsafeGetTensorImpl();
720 : data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
721 : data_impl->set_autograd_meta(std::make_unique<DifferentiableViewMeta>(
722 : data_impl,
723 : std::move(backward_info),
724 : std::move(forward_info),
725 : shared_view_info,
726 : creation_meta));
727 : return data;
728 : }
729 : return Variable();
730 : }
731 :
732 : // See NOTE [ Autograd View Variables ] for details.
733 : // Non-differentiable view. Just share version counter.
734 : inline Variable make_variable_non_differentiable_view(
735 : Variable base,
736 : const at::Tensor& data,
737 : bool allow_tensor_metadata_change = true) {
738 : if (data.defined()) {
739 : // Currently all of non-differentiable view ops(detach/_indices/_values)
740 : // share the same TensorImpl as their base Tensor. Thus a new TensorImpl
741 : // allocation here is required.
742 : auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach(
743 : /*version_counter=*/impl::version_counter(base),
744 : /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
745 : data_impl_copy->set_autograd_meta(nullptr);
746 : return Variable(data_impl_copy);
747 : }
748 : return Variable();
749 : }
750 :
751 : /// Creates a `Variable` from the given `Tensor`, copying its underlying
752 : /// `TensorImpl`. `requires_grad` should be set only for leaves, and determines
753 : /// whether the `Variable` will accumulate gradients. NOTE: `data` must *not* be
754 : /// a `Variable` already. Its dynamic type *must* be `Tensor`.
755 : ///
756 : /// TODO: Eliminate this function as much as possible, as it can be expressed
757 : /// more clearly as detach() or a no-op in most call sites (especially when
758 : /// there is only one use of the variable).
759 418844 : inline Variable make_variable(
760 : at::Tensor data,
761 : bool requires_grad = false,
762 : bool allow_tensor_metadata_change = true) {
763 418844 : if (data.defined()) {
764 458920 : if (data.getIntrusivePtr().use_count() == 1 &&
765 40076 : data.getIntrusivePtr()->unique_version()) {
766 40076 : auto data_impl = data.unsafeReleaseIntrusivePtr();
767 40076 : data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
768 : // NOLINTNEXTLINE(bugprone-branch-clone)
769 40076 : if (requires_grad) {
770 0 : data_impl->set_autograd_meta(
771 0 : std::make_unique<AutogradMeta>(data_impl.get(), requires_grad));
772 : } else {
773 40076 : data_impl->set_autograd_meta(nullptr);
774 : }
775 40076 : return Variable(std::move(data_impl));
776 40076 : } else {
777 378768 : auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach(
778 : /*version_counter=*/0,
779 378768 : /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
780 : // NOLINTNEXTLINE(bugprone-branch-clone)
781 378768 : if (requires_grad) {
782 0 : data_impl_copy->set_autograd_meta(std::make_unique<AutogradMeta>(
783 0 : data_impl_copy.get(), requires_grad));
784 : } else {
785 378768 : data_impl_copy->set_autograd_meta(nullptr);
786 : }
787 378768 : return Variable(data_impl_copy);
788 378768 : }
789 : }
790 0 : return Variable();
791 : }
792 :
793 : /// Creates a `Variable` from the given `Tensor`, copying its underlying
794 : /// `TensorImpl`. `gradient_edge` should be a (function, input_nr) pair
795 : /// specifying the function in the autograd graph, and what particular input of
796 : /// that function, this variable is connected to.
797 : inline Variable make_variable(
798 : at::Tensor data,
799 : Edge gradient_edge,
800 : bool allow_tensor_metadata_change = true) {
801 : if (data.defined()) {
802 : auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach(
803 : /*version_counter=*/0,
804 : /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
805 : data_impl_copy->set_autograd_meta(std::make_unique<AutogradMeta>(
806 : data_impl_copy.get(), false, std::move(gradient_edge)));
807 : return Variable(data_impl_copy);
808 : }
809 : return Variable();
810 : }
811 :
812 : struct VariableHooks final : at::impl::VariableHooksInterface {
813 : at::TensorBase tensor_data(const at::TensorBase&) const override;
814 : at::TensorBase variable_data(const at::TensorBase&) const override;
815 : const std::shared_ptr<torch::autograd::Node>& grad_fn(
816 : const at::TensorBase&) const override;
817 : unsigned _register_hook(
818 : const at::TensorBase&,
819 : std::function<at::TensorBase(const at::TensorBase&)> hook) const override;
820 : void remove_hook(const at::TensorBase&, unsigned pos) const override;
821 : bool is_view(const at::TensorBase&) const override;
822 : const at::TensorBase& base(const at::TensorBase&) const override;
823 : const std::string& name(const at::TensorBase&) const override;
824 : bool is_leaf(const at::TensorBase&) const override;
825 : int64_t output_nr(const at::TensorBase&) const override;
826 : void set_data(const at::TensorBase& self, const at::TensorBase& new_data)
827 : const override;
828 : at::TensorBase data(const at::TensorBase& self) const override;
829 : int64_t _version(const at::TensorBase& self) const override;
830 : void retain_grad(const at::TensorBase& self) const override;
831 : bool retains_grad(const at::TensorBase& self) const override;
832 : void _backward(
833 : const at::Tensor& self,
834 : at::TensorList inputs,
835 : const c10::optional<at::Tensor>& gradient,
836 : c10::optional<bool> keep_graph,
837 : bool create_graph) const override;
838 : void requires_grad_(const at::TensorBase& self, bool _requires_grad)
839 : const override;
840 : void basic_autograd_not_implemented_fallback(
841 : const c10::OperatorHandle& op,
842 : c10::DispatchKeySet dispatch_keys,
843 : torch::jit::Stack* stack) const override;
844 : };
845 :
846 : namespace utils {
847 :
848 : TORCH_API bool has_same_meta(const Variable& base, const Variable& other);
849 :
850 : } // namespace utils
851 : } // namespace autograd
852 : } // namespace torch
853 :
854 : #endif /* DOXYGEN_SHOULD_SKIP_THIS */
|