Enhance tests coverage and report output

This commit is contained in:
2024-04-30 14:00:24 +02:00
parent b4a222b100
commit 3c7382a93a
947 changed files with 376596 additions and 3921 deletions

View File

@@ -0,0 +1,930 @@
<!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.01 Transitional//EN">
<html lang="en">
<head>
<meta http-equiv="Content-Type" content="text/html; charset=UTF-8">
<title>LCOV - coverage.info - libtorch/include/torch/csrc/autograd/variable.h</title>
<link rel="stylesheet" type="text/css" href="../../../../../gcov.css">
</head>
<body>
<table width="100%" border=0 cellspacing=0 cellpadding=0>
<tr><td class="title">LCOV - code coverage report</td></tr>
<tr><td class="ruler"><img src="../../../../../glass.png" width=3 height=3 alt=""></td></tr>
<tr>
<td width="100%">
<table cellpadding=1 border=0 width="100%">
<tr>
<td width="10%" class="headerItem">Current view:</td>
<td width="10%" class="headerValue"><a href="../../../../../index.html">top level</a> - <a href="index.html">libtorch/include/torch/csrc/autograd</a> - variable.h<span style="font-size: 80%;"> (source / <a href="variable.h.func-c.html">functions</a>)</span></td>
<td width="5%"></td>
<td width="5%"></td>
<td width="5%" class="headerCovTableHead">Coverage</td>
<td width="5%" class="headerCovTableHead" title="Covered + Uncovered code">Total</td>
<td width="5%" class="headerCovTableHead" title="Exercised code only">Hit</td>
</tr>
<tr>
<td class="headerItem">Test:</td>
<td class="headerValue">coverage.info</td>
<td></td>
<td class="headerItem">Lines:</td>
<td class="headerCovTableEntryLo">45.7&nbsp;%</td>
<td class="headerCovTableEntry">35</td>
<td class="headerCovTableEntry">16</td>
</tr>
<tr>
<td class="headerItem">Test Date:</td>
<td class="headerValue">2024-04-30 13:17:26</td>
<td></td>
<td class="headerItem">Functions:</td>
<td class="headerCovTableEntryLo">25.0&nbsp;%</td>
<td class="headerCovTableEntry">4</td>
<td class="headerCovTableEntry">1</td>
</tr>
<tr><td><img src="../../../../../glass.png" width=3 height=3 alt=""></td></tr>
</table>
</td>
</tr>
<tr><td class="ruler"><img src="../../../../../glass.png" width=3 height=3 alt=""></td></tr>
</table>
<table cellpadding=0 cellspacing=0 border=0>
<tr>
<td><br></td>
</tr>
<tr>
<td>
<pre class="sourceHeading"> Line data Source code</pre>
<pre class="source">
<span id="L1"><span class="lineNum"> 1</span> : #pragma once</span>
<span id="L2"><span class="lineNum"> 2</span> : </span>
<span id="L3"><span class="lineNum"> 3</span> : #include &lt;torch/csrc/utils/python_stub.h&gt;</span>
<span id="L4"><span class="lineNum"> 4</span> : </span>
<span id="L5"><span class="lineNum"> 5</span> : #include &lt;torch/csrc/Export.h&gt;</span>
<span id="L6"><span class="lineNum"> 6</span> : #include &lt;torch/csrc/autograd/cpp_hook.h&gt;</span>
<span id="L7"><span class="lineNum"> 7</span> : #include &lt;torch/csrc/autograd/edge.h&gt;</span>
<span id="L8"><span class="lineNum"> 8</span> : #include &lt;torch/csrc/autograd/forward_grad.h&gt;</span>
<span id="L9"><span class="lineNum"> 9</span> : #include &lt;torch/csrc/autograd/function_hook.h&gt;</span>
<span id="L10"><span class="lineNum"> 10</span> : </span>
<span id="L11"><span class="lineNum"> 11</span> : #include &lt;ATen/NamedTensorUtils.h&gt;</span>
<span id="L12"><span class="lineNum"> 12</span> : #include &lt;ATen/core/Tensor.h&gt;</span>
<span id="L13"><span class="lineNum"> 13</span> : #include &lt;ATen/core/VariableHooksInterface.h&gt;</span>
<span id="L14"><span class="lineNum"> 14</span> : #include &lt;c10/util/Exception.h&gt;</span>
<span id="L15"><span class="lineNum"> 15</span> : </span>
<span id="L16"><span class="lineNum"> 16</span> : #include &lt;cstdint&gt;</span>
<span id="L17"><span class="lineNum"> 17</span> : #include &lt;memory&gt;</span>
<span id="L18"><span class="lineNum"> 18</span> : #include &lt;mutex&gt;</span>
<span id="L19"><span class="lineNum"> 19</span> : #include &lt;stdexcept&gt;</span>
<span id="L20"><span class="lineNum"> 20</span> : #include &lt;string&gt;</span>
<span id="L21"><span class="lineNum"> 21</span> : #include &lt;utility&gt;</span>
<span id="L22"><span class="lineNum"> 22</span> : #include &lt;vector&gt;</span>
<span id="L23"><span class="lineNum"> 23</span> : </span>
<span id="L24"><span class="lineNum"> 24</span> : namespace torch {</span>
<span id="L25"><span class="lineNum"> 25</span> : namespace autograd {</span>
<span id="L26"><span class="lineNum"> 26</span> : </span>
<span id="L27"><span class="lineNum"> 27</span> : /// `Variable` is exactly the same as `Tensor` (i.e. we have `using Variable =</span>
<span id="L28"><span class="lineNum"> 28</span> : /// at::Tensor`). This means you can perform all the usual mathematical and</span>
<span id="L29"><span class="lineNum"> 29</span> : /// other operations you can perform on `Tensor`s also on `Variable`s.</span>
<span id="L30"><span class="lineNum"> 30</span> : ///</span>
<span id="L31"><span class="lineNum"> 31</span> : /// The only reason we are keeping the `Variable` class is backward</span>
<span id="L32"><span class="lineNum"> 32</span> : /// compatibility with external user's legacy C++ frontend code. Our intention</span>
<span id="L33"><span class="lineNum"> 33</span> : /// is to eliminate the `Variable` class in the near future.</span>
<span id="L34"><span class="lineNum"> 34</span> : using Variable = at::Tensor;</span>
<span id="L35"><span class="lineNum"> 35</span> : </span>
<span id="L36"><span class="lineNum"> 36</span> : } // namespace autograd</span>
<span id="L37"><span class="lineNum"> 37</span> : } // namespace torch</span>
<span id="L38"><span class="lineNum"> 38</span> : </span>
<span id="L39"><span class="lineNum"> 39</span> : // The following are all internal APIs and should not be shown in libtorch docs.</span>
<span id="L40"><span class="lineNum"> 40</span> : // Therefore, we wrap the following code with `#ifndef DOXYGEN_SHOULD_SKIP_THIS</span>
<span id="L41"><span class="lineNum"> 41</span> : // ... #endif`</span>
<span id="L42"><span class="lineNum"> 42</span> : </span>
<span id="L43"><span class="lineNum"> 43</span> : #ifndef DOXYGEN_SHOULD_SKIP_THIS</span>
<span id="L44"><span class="lineNum"> 44</span> : </span>
<span id="L45"><span class="lineNum"> 45</span> : namespace torch {</span>
<span id="L46"><span class="lineNum"> 46</span> : namespace autograd {</span>
<span id="L47"><span class="lineNum"> 47</span> : </span>
<span id="L48"><span class="lineNum"> 48</span> : /// Check if this type is supported by the autograd engine.</span>
<span id="L49"><span class="lineNum"> 49</span> : /// If you change this, update the doc at the top of the</span>
<span id="L50"><span class="lineNum"> 50</span> : /// torch/autograd/__init__.py file and</span>
<span id="L51"><span class="lineNum"> 51</span> : /// &quot;test_set_requires_grad_only_for_continuous_types&quot; in test/test_autograd.py</span>
<span id="L52"><span class="lineNum"> 52</span> <span class="tlaUNC tlaBgUNC"> 0 : static inline bool isDifferentiableType(at::ScalarType t) {</span></span>
<span id="L53"><span class="lineNum"> 53</span> <span class="tlaUNC"> 0 : return isFloatingType(t) || isComplexType(t);</span></span>
<span id="L54"><span class="lineNum"> 54</span> : }</span>
<span id="L55"><span class="lineNum"> 55</span> : </span>
<span id="L56"><span class="lineNum"> 56</span> : struct Node;</span>
<span id="L57"><span class="lineNum"> 57</span> : </span>
<span id="L58"><span class="lineNum"> 58</span> : ///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~</span>
<span id="L59"><span class="lineNum"> 59</span> : /// Variable</span>
<span id="L60"><span class="lineNum"> 60</span> : ///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~</span>
<span id="L61"><span class="lineNum"> 61</span> : /// A `Variable` augments a `Tensor` with the ability to interact in our</span>
<span id="L62"><span class="lineNum"> 62</span> : /// autograd machinery. Conceptually, `Variable`s travel along `Edge`s between</span>
<span id="L63"><span class="lineNum"> 63</span> : /// `Node`s in the autograd graph. A `Variable` can either be a leaf, like a</span>
<span id="L64"><span class="lineNum"> 64</span> : /// weight in a neural network, or an interior variable, when it is the result</span>
<span id="L65"><span class="lineNum"> 65</span> : /// of an operation between variables. Every `Variable` also stores another</span>
<span id="L66"><span class="lineNum"> 66</span> : /// `Variable` called its `grad` (gradient). If the variable is a leaf, its</span>
<span id="L67"><span class="lineNum"> 67</span> : /// gradient will be accumulated into this variable.</span>
<span id="L68"><span class="lineNum"> 68</span> : ///</span>
<span id="L69"><span class="lineNum"> 69</span> : /// Every Tensor is a Variable, but sometimes we colloquially refer to Variables</span>
<span id="L70"><span class="lineNum"> 70</span> : /// that don't require gradients as Tensors (since none of the autograd</span>
<span id="L71"><span class="lineNum"> 71</span> : /// machinery for Variables applies). Historically, Variables and Tensors</span>
<span id="L72"><span class="lineNum"> 72</span> : /// were separate concepts, but now they are exactly the same (i.e. we have</span>
<span id="L73"><span class="lineNum"> 73</span> : /// `using Variable = at::Tensor`).</span>
<span id="L74"><span class="lineNum"> 74</span> : ///</span>
<span id="L75"><span class="lineNum"> 75</span> : /// Gradient Edges</span>
<span id="L76"><span class="lineNum"> 76</span> : ///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~</span>
<span id="L77"><span class="lineNum"> 77</span> : /// Furthermore, `Variable`s have the notion of a `gradient_edge`, which is the</span>
<span id="L78"><span class="lineNum"> 78</span> : /// edge in the autograd graph that connects the variable to a particular input</span>
<span id="L79"><span class="lineNum"> 79</span> : /// of the gradient function that will be invoked with the variable during the</span>
<span id="L80"><span class="lineNum"> 80</span> : /// backward pass. More precisely, this gradient function can be one of two</span>
<span id="L81"><span class="lineNum"> 81</span> : /// things:</span>
<span id="L82"><span class="lineNum"> 82</span> : /// 1. A `grad_fn`, if the variable is in the interior of the graph. This is the</span>
<span id="L83"><span class="lineNum"> 83</span> : /// gradient of the function that produced the variable.</span>
<span id="L84"><span class="lineNum"> 84</span> : /// 2. A `grad_accumulator`, if the variable is a leaf, which accumulates a</span>
<span id="L85"><span class="lineNum"> 85</span> : /// scalar gradient value into its `grad` variable.</span>
<span id="L86"><span class="lineNum"> 86</span> : ///</span>
<span id="L87"><span class="lineNum"> 87</span> : /// Versioning</span>
<span id="L88"><span class="lineNum"> 88</span> : ///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~</span>
<span id="L89"><span class="lineNum"> 89</span> : /// Another major feature of `Variable`s are *versions*. Versions are</span>
<span id="L90"><span class="lineNum"> 90</span> : /// incremented when an in-place mutation of a variable occurs. Versions are</span>
<span id="L91"><span class="lineNum"> 91</span> : /// useful when constructing `SavedVariable`s, which take a snapshot of a</span>
<span id="L92"><span class="lineNum"> 92</span> : /// `Variable` at a certain version. You can retrieve a `Variable`'s version</span>
<span id="L93"><span class="lineNum"> 93</span> : /// through its `current_version()` method.</span>
<span id="L94"><span class="lineNum"> 94</span> : ///</span>
<span id="L95"><span class="lineNum"> 95</span> : /// Views</span>
<span id="L96"><span class="lineNum"> 96</span> : ///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~</span>
<span id="L97"><span class="lineNum"> 97</span> : /// It is possible for a `Variable` to be a *view* of another `Variable`, in</span>
<span id="L98"><span class="lineNum"> 98</span> : /// which case it tracks that `Variable`'s data and autograd history. Beyond</span>
<span id="L99"><span class="lineNum"> 99</span> : /// construction, the interface of a view is identical to that of a regular</span>
<span id="L100"><span class="lineNum"> 100</span> : /// `Variable`. You can determine whether `Variable` is in fact a view by</span>
<span id="L101"><span class="lineNum"> 101</span> : /// probing its `is_view()` method. Note that the *view* semantics are only</span>
<span id="L102"><span class="lineNum"> 102</span> : /// meaningful for `Variable` relations that are relevant to autograd.</span>
<span id="L103"><span class="lineNum"> 103</span> : /// See NOTE [ Autograd View Variables ] for more details.</span>
<span id="L104"><span class="lineNum"> 104</span> : ///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~</span>
<span id="L105"><span class="lineNum"> 105</span> : </span>
<span id="L106"><span class="lineNum"> 106</span> : struct AutogradMeta;</span>
<span id="L107"><span class="lineNum"> 107</span> : struct DifferentiableViewMeta;</span>
<span id="L108"><span class="lineNum"> 108</span> : </span>
<span id="L109"><span class="lineNum"> 109</span> : // Private-ish functions for manipulating variables; we don't want to put them</span>
<span id="L110"><span class="lineNum"> 110</span> : // on Tensor proper</span>
<span id="L111"><span class="lineNum"> 111</span> : namespace impl {</span>
<span id="L112"><span class="lineNum"> 112</span> : </span>
<span id="L113"><span class="lineNum"> 113</span> : // WARNING: This may return a nullptr. If you require AutogradMeta to return</span>
<span id="L114"><span class="lineNum"> 114</span> : // a materialized structure, use materialize_autograd_meta instead.</span>
<span id="L115"><span class="lineNum"> 115</span> : TORCH_API AutogradMeta* get_autograd_meta(const at::TensorBase&amp;);</span>
<span id="L116"><span class="lineNum"> 116</span> : </span>
<span id="L117"><span class="lineNum"> 117</span> : // WARNING: This will return a nullptr if the Tensor is not a view.</span>
<span id="L118"><span class="lineNum"> 118</span> : TORCH_API DifferentiableViewMeta* get_view_autograd_meta(const at::TensorBase&amp;);</span>
<span id="L119"><span class="lineNum"> 119</span> : </span>
<span id="L120"><span class="lineNum"> 120</span> : // Returns the current autograd meta, materializing it if it was previously</span>
<span id="L121"><span class="lineNum"> 121</span> : // none. This counts as a *mutating* operation, so do not call it on</span>
<span id="L122"><span class="lineNum"> 122</span> : // &quot;read-only&quot; operators; in particular, this is NOT thread safe</span>
<span id="L123"><span class="lineNum"> 123</span> : TORCH_API AutogradMeta* materialize_autograd_meta(const at::TensorBase&amp;);</span>
<span id="L124"><span class="lineNum"> 124</span> : </span>
<span id="L125"><span class="lineNum"> 125</span> : /// Set the gradient accumulator of the `Variable`. This is only applicable to</span>
<span id="L126"><span class="lineNum"> 126</span> : /// leaf variables. Interior variables should call `set_gradient_edge()`.</span>
<span id="L127"><span class="lineNum"> 127</span> : TORCH_API void set_grad_accumulator(</span>
<span id="L128"><span class="lineNum"> 128</span> : const Variable&amp;,</span>
<span id="L129"><span class="lineNum"> 129</span> : std::weak_ptr&lt;Node&gt; grad_accumulator);</span>
<span id="L130"><span class="lineNum"> 130</span> : </span>
<span id="L131"><span class="lineNum"> 131</span> : /// Attempts to get a pointer to the gradient accumulator of the `Variable`,</span>
<span id="L132"><span class="lineNum"> 132</span> : /// if it still exists. If the gradient accumulator function has been</span>
<span id="L133"><span class="lineNum"> 133</span> : /// destroyed, returns a `nullptr`.</span>
<span id="L134"><span class="lineNum"> 134</span> : TORCH_API std::shared_ptr&lt;Node&gt; try_get_grad_accumulator(const Variable&amp;);</span>
<span id="L135"><span class="lineNum"> 135</span> : </span>
<span id="L136"><span class="lineNum"> 136</span> : /// Gets the gradient accumulator of the `Variable` if it has one, or else</span>
<span id="L137"><span class="lineNum"> 137</span> : /// create one on the fly and return it.</span>
<span id="L138"><span class="lineNum"> 138</span> : TORCH_API std::shared_ptr&lt;Node&gt; grad_accumulator(const Variable&amp;);</span>
<span id="L139"><span class="lineNum"> 139</span> : </span>
<span id="L140"><span class="lineNum"> 140</span> : /// Returns the &quot;canonical&quot; gradient edge of this `Variable`, i.e. either the</span>
<span id="L141"><span class="lineNum"> 141</span> : /// gradient function if this is an interior `Variable`, or the gradient</span>
<span id="L142"><span class="lineNum"> 142</span> : /// accumulator otherwise. If the `Variable` is interior, the returned `Edge`</span>
<span id="L143"><span class="lineNum"> 143</span> : /// will store the input index of the `Node` to which this variable is</span>
<span id="L144"><span class="lineNum"> 144</span> : /// connected in its `input_nr` field. For leaves, the `input_nr` is always</span>
<span id="L145"><span class="lineNum"> 145</span> : /// zero. Note that `set_gradient_edge` and `gradient_edge` are not</span>
<span id="L146"><span class="lineNum"> 146</span> : /// symmetric. You must use `set_gradient_edge` to set the `grad_fn` and</span>
<span id="L147"><span class="lineNum"> 147</span> : /// `set_grad_accumulator` to set the accumulator.</span>
<span id="L148"><span class="lineNum"> 148</span> : TORCH_API Edge gradient_edge(const Variable&amp;);</span>
<span id="L149"><span class="lineNum"> 149</span> : </span>
<span id="L150"><span class="lineNum"> 150</span> : /// Set the gradient edge -- i.e. `grad_fn` and `input_nr` -- of the</span>
<span id="L151"><span class="lineNum"> 151</span> : /// `Variable`.</span>
<span id="L152"><span class="lineNum"> 152</span> : /// NOTE: This will always set the `grad_fn`, even if this is a leaf variable,</span>
<span id="L153"><span class="lineNum"> 153</span> : /// and never the `grad_accumulator`. For the latter, use</span>
<span id="L154"><span class="lineNum"> 154</span> : /// `set_grad_accumulator`. This allows late construction of an interior</span>
<span id="L155"><span class="lineNum"> 155</span> : /// `Variable`.</span>
<span id="L156"><span class="lineNum"> 156</span> : TORCH_API void set_gradient_edge(const Variable&amp;, Edge edge);</span>
<span id="L157"><span class="lineNum"> 157</span> : </span>
<span id="L158"><span class="lineNum"> 158</span> : // Autograd Graph Interaction</span>
<span id="L159"><span class="lineNum"> 159</span> : //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~</span>
<span id="L160"><span class="lineNum"> 160</span> : </span>
<span id="L161"><span class="lineNum"> 161</span> : /// Update the `grad_fn` of an existing Variable. Called after in-place</span>
<span id="L162"><span class="lineNum"> 162</span> : /// modifications.</span>
<span id="L163"><span class="lineNum"> 163</span> : ///</span>
<span id="L164"><span class="lineNum"> 164</span> : /// For View Variables:</span>
<span id="L165"><span class="lineNum"> 165</span> : /// Called after in-place modifications. Modifies the grad_fn of the base</span>
<span id="L166"><span class="lineNum"> 166</span> : /// Variable.</span>
<span id="L167"><span class="lineNum"> 167</span> : TORCH_API void rebase_history(const Variable&amp;, Edge gradient_edge);</span>
<span id="L168"><span class="lineNum"> 168</span> : </span>
<span id="L169"><span class="lineNum"> 169</span> : /// Gets the raw gradient function pointer, whatever it currently is.</span>
<span id="L170"><span class="lineNum"> 170</span> : TORCH_API Node* grad_fn_unsafe(const Variable&amp;);</span>
<span id="L171"><span class="lineNum"> 171</span> : </span>
<span id="L172"><span class="lineNum"> 172</span> : /// Increments the version count of this `Variable`.</span>
<span id="L173"><span class="lineNum"> 173</span> : TORCH_API void bump_version(const Variable&amp;);</span>
<span id="L174"><span class="lineNum"> 174</span> : TORCH_API void set_version_counter(</span>
<span id="L175"><span class="lineNum"> 175</span> : const Variable&amp;,</span>
<span id="L176"><span class="lineNum"> 176</span> : const c10::VariableVersion&amp; version_counter);</span>
<span id="L177"><span class="lineNum"> 177</span> : </span>
<span id="L178"><span class="lineNum"> 178</span> : /// Retrieves this `Variable`s version counter.</span>
<span id="L179"><span class="lineNum"> 179</span> : TORCH_API const c10::VariableVersion&amp; version_counter(const Variable&amp;);</span>
<span id="L180"><span class="lineNum"> 180</span> : </span>
<span id="L181"><span class="lineNum"> 181</span> : TORCH_API void set_name(const Variable&amp;, const std::string&amp; name);</span>
<span id="L182"><span class="lineNum"> 182</span> : </span>
<span id="L183"><span class="lineNum"> 183</span> : TORCH_API void add_hook(</span>
<span id="L184"><span class="lineNum"> 184</span> : const at::TensorBase&amp;,</span>
<span id="L185"><span class="lineNum"> 185</span> : std::unique_ptr&lt;FunctionPreHook&gt; hook);</span>
<span id="L186"><span class="lineNum"> 186</span> : TORCH_API std::vector&lt;std::unique_ptr&lt;FunctionPreHook&gt;&gt;&amp; hooks(const Variable&amp;);</span>
<span id="L187"><span class="lineNum"> 187</span> : TORCH_API void clear_hooks(const at::TensorBase&amp;);</span>
<span id="L188"><span class="lineNum"> 188</span> : </span>
<span id="L189"><span class="lineNum"> 189</span> : TORCH_API void set_post_acc_grad_hooks(</span>
<span id="L190"><span class="lineNum"> 190</span> : const at::TensorBase&amp;,</span>
<span id="L191"><span class="lineNum"> 191</span> : std::unique_ptr&lt;PostAccumulateGradHook&gt; dict);</span>
<span id="L192"><span class="lineNum"> 192</span> : TORCH_API std::unique_ptr&lt;PostAccumulateGradHook&gt;&amp; post_acc_grad_hooks(</span>
<span id="L193"><span class="lineNum"> 193</span> : const Variable&amp;);</span>
<span id="L194"><span class="lineNum"> 194</span> : </span>
<span id="L195"><span class="lineNum"> 195</span> : TORCH_API void create_cpp_hook(</span>
<span id="L196"><span class="lineNum"> 196</span> : const at::TensorBase&amp;,</span>
<span id="L197"><span class="lineNum"> 197</span> : bool is_retains_grad_hooks = false);</span>
<span id="L198"><span class="lineNum"> 198</span> : } // namespace impl</span>
<span id="L199"><span class="lineNum"> 199</span> : </span>
<span id="L200"><span class="lineNum"> 200</span> : //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~</span>
<span id="L201"><span class="lineNum"> 201</span> : // AutogradMeta</span>
<span id="L202"><span class="lineNum"> 202</span> : //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~</span>
<span id="L203"><span class="lineNum"> 203</span> : </span>
<span id="L204"><span class="lineNum"> 204</span> : /// Each `Variable` has one unique `AutogradMeta` struct, which stores autograd</span>
<span id="L205"><span class="lineNum"> 205</span> : /// metadata fields that are necessary for tracking the Variable's autograd</span>
<span id="L206"><span class="lineNum"> 206</span> : /// history. As an optimization, a Variable may store a nullptr, in lieu of a</span>
<span id="L207"><span class="lineNum"> 207</span> : /// default constructed AutogradMeta.</span>
<span id="L208"><span class="lineNum"> 208</span> : </span>
<span id="L209"><span class="lineNum"> 209</span> : struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {</span>
<span id="L210"><span class="lineNum"> 210</span> : std::string name_;</span>
<span id="L211"><span class="lineNum"> 211</span> : </span>
<span id="L212"><span class="lineNum"> 212</span> : Variable grad_;</span>
<span id="L213"><span class="lineNum"> 213</span> : std::shared_ptr&lt;Node&gt; grad_fn_;</span>
<span id="L214"><span class="lineNum"> 214</span> : std::weak_ptr&lt;Node&gt; grad_accumulator_;</span>
<span id="L215"><span class="lineNum"> 215</span> : </span>
<span id="L216"><span class="lineNum"> 216</span> : // This field is used to store all the forward AD gradients</span>
<span id="L217"><span class="lineNum"> 217</span> : // associated with this AutogradMeta (and the Tensor it corresponds to)</span>
<span id="L218"><span class="lineNum"> 218</span> : // There is a semantic 1:1 correspondence between AutogradMeta and</span>
<span id="L219"><span class="lineNum"> 219</span> : // ForwardGrad but:</span>
<span id="L220"><span class="lineNum"> 220</span> : // - This field is lazily populated.</span>
<span id="L221"><span class="lineNum"> 221</span> : // - This field is a shared_ptr but it must never be</span>
<span id="L222"><span class="lineNum"> 222</span> : // shared by multiple Tensors. See Note [ Using ForwardGrad ]</span>
<span id="L223"><span class="lineNum"> 223</span> : // Any transition from not_initialized to initialized</span>
<span id="L224"><span class="lineNum"> 224</span> : // must be protected by mutex_</span>
<span id="L225"><span class="lineNum"> 225</span> : std::shared_ptr&lt;ForwardGrad&gt; fw_grad_;</span>
<span id="L226"><span class="lineNum"> 226</span> : </span>
<span id="L227"><span class="lineNum"> 227</span> : // The hooks_ field is actually reused by both python and cpp logic</span>
<span id="L228"><span class="lineNum"> 228</span> : // For both cases, we have a data structure, cpp_hooks_list_ (cpp)</span>
<span id="L229"><span class="lineNum"> 229</span> : // or dict (python) which is the canonical copy.</span>
<span id="L230"><span class="lineNum"> 230</span> : // Then, for both cases, we always register a single hook to</span>
<span id="L231"><span class="lineNum"> 231</span> : // hooks_ which wraps all the hooks in the list/dict.</span>
<span id="L232"><span class="lineNum"> 232</span> : // And, again in both cases, if the grad_fn exists on that tensor</span>
<span id="L233"><span class="lineNum"> 233</span> : // we will additionally register a single hook to the grad_fn.</span>
<span id="L234"><span class="lineNum"> 234</span> : //</span>
<span id="L235"><span class="lineNum"> 235</span> : // Note that the cpp and python use cases aren't actually aware of</span>
<span id="L236"><span class="lineNum"> 236</span> : // each other, so using both is not defined behavior.</span>
<span id="L237"><span class="lineNum"> 237</span> : std::vector&lt;std::unique_ptr&lt;FunctionPreHook&gt;&gt; hooks_;</span>
<span id="L238"><span class="lineNum"> 238</span> : std::shared_ptr&lt;hooks_list&gt; cpp_hooks_list_;</span>
<span id="L239"><span class="lineNum"> 239</span> : </span>
<span id="L240"><span class="lineNum"> 240</span> : // The post_acc_grad_hooks_ field stores only Python hooks</span>
<span id="L241"><span class="lineNum"> 241</span> : // (PyFunctionTensorPostAccGradHooks) that are called after the</span>
<span id="L242"><span class="lineNum"> 242</span> : // .grad field has been accumulated into. This is less complicated</span>
<span id="L243"><span class="lineNum"> 243</span> : // than the hooks_ field, which encapsulates a lot more.</span>
<span id="L244"><span class="lineNum"> 244</span> : std::unique_ptr&lt;PostAccumulateGradHook&gt; post_acc_grad_hooks_ = nullptr;</span>
<span id="L245"><span class="lineNum"> 245</span> : </span>
<span id="L246"><span class="lineNum"> 246</span> : // Only meaningful on leaf variables (must be false otherwise)</span>
<span id="L247"><span class="lineNum"> 247</span> : bool requires_grad_{false};</span>
<span id="L248"><span class="lineNum"> 248</span> : </span>
<span id="L249"><span class="lineNum"> 249</span> : // Only meaningful on non-leaf variables (must be false otherwise)</span>
<span id="L250"><span class="lineNum"> 250</span> : bool retains_grad_{false};</span>
<span id="L251"><span class="lineNum"> 251</span> : </span>
<span id="L252"><span class="lineNum"> 252</span> : bool is_view_{false};</span>
<span id="L253"><span class="lineNum"> 253</span> : </span>
<span id="L254"><span class="lineNum"> 254</span> : // The &quot;output number&quot; of this variable; e.g., if this variable</span>
<span id="L255"><span class="lineNum"> 255</span> : // was the second output of a function, then output_nr == 1.</span>
<span id="L256"><span class="lineNum"> 256</span> : // We use this to make sure we can setup the backwards trace</span>
<span id="L257"><span class="lineNum"> 257</span> : // correctly when this variable is passed to another function.</span>
<span id="L258"><span class="lineNum"> 258</span> : uint32_t output_nr_;</span>
<span id="L259"><span class="lineNum"> 259</span> : </span>
<span id="L260"><span class="lineNum"> 260</span> : // Mutex to ensure that concurrent read operations that modify internal</span>
<span id="L261"><span class="lineNum"> 261</span> : // state are still thread-safe. Used by grad_fn(), grad_accumulator(),</span>
<span id="L262"><span class="lineNum"> 262</span> : // fw_grad() and set_fw_grad()</span>
<span id="L263"><span class="lineNum"> 263</span> : // This is mutable because we need to be able to acquire this from const</span>
<span id="L264"><span class="lineNum"> 264</span> : // version of this class for the functions above</span>
<span id="L265"><span class="lineNum"> 265</span> : mutable std::mutex mutex_;</span>
<span id="L266"><span class="lineNum"> 266</span> : </span>
<span id="L267"><span class="lineNum"> 267</span> : /// Sets the `requires_grad` property of `Variable`. This should be true for</span>
<span id="L268"><span class="lineNum"> 268</span> : /// leaf variables that want to accumulate gradients, and false for all other</span>
<span id="L269"><span class="lineNum"> 269</span> : /// variables.</span>
<span id="L270"><span class="lineNum"> 270</span> <span class="tlaUNC"> 0 : void set_requires_grad(bool requires_grad, at::TensorImpl* self_impl)</span></span>
<span id="L271"><span class="lineNum"> 271</span> : override {</span>
<span id="L272"><span class="lineNum"> 272</span> <span class="tlaUNC"> 0 : TORCH_CHECK(</span></span>
<span id="L273"><span class="lineNum"> 273</span> : !requires_grad ||</span>
<span id="L274"><span class="lineNum"> 274</span> : isDifferentiableType(at::typeMetaToScalarType(self_impl-&gt;dtype())),</span>
<span id="L275"><span class="lineNum"> 275</span> : &quot;Only Tensors of floating point and complex dtype can require gradients&quot;);</span>
<span id="L276"><span class="lineNum"> 276</span> <span class="tlaUNC"> 0 : requires_grad_ = requires_grad;</span></span>
<span id="L277"><span class="lineNum"> 277</span> <span class="tlaUNC"> 0 : }</span></span>
<span id="L278"><span class="lineNum"> 278</span> : </span>
<span id="L279"><span class="lineNum"> 279</span> : bool requires_grad() const override {</span>
<span id="L280"><span class="lineNum"> 280</span> : return requires_grad_ || grad_fn_;</span>
<span id="L281"><span class="lineNum"> 281</span> : }</span>
<span id="L282"><span class="lineNum"> 282</span> : </span>
<span id="L283"><span class="lineNum"> 283</span> : /// Accesses the gradient `Variable` of this `Variable`.</span>
<span id="L284"><span class="lineNum"> 284</span> : Variable&amp; mutable_grad() override {</span>
<span id="L285"><span class="lineNum"> 285</span> : return grad_;</span>
<span id="L286"><span class="lineNum"> 286</span> : }</span>
<span id="L287"><span class="lineNum"> 287</span> : </span>
<span id="L288"><span class="lineNum"> 288</span> : const Variable&amp; grad() const override {</span>
<span id="L289"><span class="lineNum"> 289</span> : return grad_;</span>
<span id="L290"><span class="lineNum"> 290</span> : }</span>
<span id="L291"><span class="lineNum"> 291</span> : </span>
<span id="L292"><span class="lineNum"> 292</span> : const Variable&amp; fw_grad(uint64_t level, const at::TensorBase&amp; self)</span>
<span id="L293"><span class="lineNum"> 293</span> : const override;</span>
<span id="L294"><span class="lineNum"> 294</span> : </span>
<span id="L295"><span class="lineNum"> 295</span> : void set_fw_grad(</span>
<span id="L296"><span class="lineNum"> 296</span> : const at::TensorBase&amp; new_grad,</span>
<span id="L297"><span class="lineNum"> 297</span> : const at::TensorBase&amp; self,</span>
<span id="L298"><span class="lineNum"> 298</span> : uint64_t level,</span>
<span id="L299"><span class="lineNum"> 299</span> : bool is_inplace_op) override;</span>
<span id="L300"><span class="lineNum"> 300</span> : </span>
<span id="L301"><span class="lineNum"> 301</span> <span class="tlaUNC"> 0 : AutogradMeta(</span></span>
<span id="L302"><span class="lineNum"> 302</span> : at::TensorImpl* self_impl = nullptr,</span>
<span id="L303"><span class="lineNum"> 303</span> : bool requires_grad = false,</span>
<span id="L304"><span class="lineNum"> 304</span> : Edge gradient_edge = Edge())</span>
<span id="L305"><span class="lineNum"> 305</span> <span class="tlaUNC"> 0 : : grad_fn_(std::move(gradient_edge.function)),</span></span>
<span id="L306"><span class="lineNum"> 306</span> : </span>
<span id="L307"><span class="lineNum"> 307</span> <span class="tlaUNC"> 0 : output_nr_(gradient_edge.input_nr) {</span></span>
<span id="L308"><span class="lineNum"> 308</span> : // set_requires_grad also checks error conditions.</span>
<span id="L309"><span class="lineNum"> 309</span> <span class="tlaUNC"> 0 : if (requires_grad) {</span></span>
<span id="L310"><span class="lineNum"> 310</span> <span class="tlaUNC"> 0 : TORCH_INTERNAL_ASSERT(self_impl);</span></span>
<span id="L311"><span class="lineNum"> 311</span> : // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)</span>
<span id="L312"><span class="lineNum"> 312</span> <span class="tlaUNC"> 0 : set_requires_grad(requires_grad, self_impl);</span></span>
<span id="L313"><span class="lineNum"> 313</span> : }</span>
<span id="L314"><span class="lineNum"> 314</span> <span class="tlaUNC"> 0 : TORCH_CHECK(</span></span>
<span id="L315"><span class="lineNum"> 315</span> : !grad_fn_ || !requires_grad_,</span>
<span id="L316"><span class="lineNum"> 316</span> : &quot;requires_grad should be false if grad_fn is set&quot;);</span>
<span id="L317"><span class="lineNum"> 317</span> <span class="tlaUNC"> 0 : }</span></span>
<span id="L318"><span class="lineNum"> 318</span> : </span>
<span id="L319"><span class="lineNum"> 319</span> : ~AutogradMeta() override {</span>
<span id="L320"><span class="lineNum"> 320</span> : // If AutogradMeta is being destroyed, it means that there is no other</span>
<span id="L321"><span class="lineNum"> 321</span> : // reference to its corresponding Tensor. It implies that no other thread</span>
<span id="L322"><span class="lineNum"> 322</span> : // can be using this object and so there is no need to lock mutex_ here to</span>
<span id="L323"><span class="lineNum"> 323</span> : // guard the check if fw_grad_ is populated.</span>
<span id="L324"><span class="lineNum"> 324</span> : if (fw_grad_) {</span>
<span id="L325"><span class="lineNum"> 325</span> : // See note [ Using ForwardGrad ]</span>
<span id="L326"><span class="lineNum"> 326</span> : fw_grad_-&gt;clear();</span>
<span id="L327"><span class="lineNum"> 327</span> : }</span>
<span id="L328"><span class="lineNum"> 328</span> : }</span>
<span id="L329"><span class="lineNum"> 329</span> : };</span>
<span id="L330"><span class="lineNum"> 330</span> : </span>
<span id="L331"><span class="lineNum"> 331</span> : struct TORCH_API ViewInfo {</span>
<span id="L332"><span class="lineNum"> 332</span> : /// The base `Variable`</span>
<span id="L333"><span class="lineNum"> 333</span> : /// If this ViewInfo represents a forward (respectively backward) AD gradient,</span>
<span id="L334"><span class="lineNum"> 334</span> : /// then this Tensor cannot be a forward (respectively backward) view.</span>
<span id="L335"><span class="lineNum"> 335</span> : Variable base_;</span>
<span id="L336"><span class="lineNum"> 336</span> : </span>
<span id="L337"><span class="lineNum"> 337</span> : /// By default we use as_strided to recover views which is more efficient.</span>
<span id="L338"><span class="lineNum"> 338</span> : /// view_fn is only saved when as_strided is not supported.</span>
<span id="L339"><span class="lineNum"> 339</span> : /// If view_fn has value, we use it to recover views in backward.</span>
<span id="L340"><span class="lineNum"> 340</span> : std::function&lt;Variable(const Variable&amp;)&gt; view_fn_;</span>
<span id="L341"><span class="lineNum"> 341</span> : </span>
<span id="L342"><span class="lineNum"> 342</span> : /// Accessors for the view function</span>
<span id="L343"><span class="lineNum"> 343</span> : bool has_view_fn() const {</span>
<span id="L344"><span class="lineNum"> 344</span> : return view_fn_ != nullptr;</span>
<span id="L345"><span class="lineNum"> 345</span> : }</span>
<span id="L346"><span class="lineNum"> 346</span> : </span>
<span id="L347"><span class="lineNum"> 347</span> : std::function&lt;Variable(const Variable&amp;)&gt; view_fn() const {</span>
<span id="L348"><span class="lineNum"> 348</span> : TORCH_CHECK(</span>
<span id="L349"><span class="lineNum"> 349</span> : has_view_fn(), &quot;Can only access the view function if it exists.&quot;);</span>
<span id="L350"><span class="lineNum"> 350</span> : return view_fn_;</span>
<span id="L351"><span class="lineNum"> 351</span> : }</span>
<span id="L352"><span class="lineNum"> 352</span> : </span>
<span id="L353"><span class="lineNum"> 353</span> : /// The chain function can be used to build a new ViewInfo for a</span>
<span id="L354"><span class="lineNum"> 354</span> : /// differentiable view function. It will return a new view info that</span>
<span id="L355"><span class="lineNum"> 355</span> : /// accurately represents how &quot;tensor&quot; is a view of this instance's &quot;base_&quot;.</span>
<span id="L356"><span class="lineNum"> 356</span> : /// The &quot;base&quot; and &quot;tensor&quot; are respectively the input and output of the</span>
<span id="L357"><span class="lineNum"> 357</span> : /// differentiable view function that happened. They are required to properly</span>
<span id="L358"><span class="lineNum"> 358</span> : /// set the optional view_fn_ when it is not provided. The &quot;view_func&quot;, if</span>
<span id="L359"><span class="lineNum"> 359</span> : /// provided, should be a function that allows to re-do the view between</span>
<span id="L360"><span class="lineNum"> 360</span> : /// &quot;base&quot; and &quot;tensor&quot;.</span>
<span id="L361"><span class="lineNum"> 361</span> : ViewInfo chain(</span>
<span id="L362"><span class="lineNum"> 362</span> : const Variable&amp; base,</span>
<span id="L363"><span class="lineNum"> 363</span> : const Variable&amp; tensor,</span>
<span id="L364"><span class="lineNum"> 364</span> : std::function&lt;Variable(const Variable&amp;)&gt; view_func = nullptr) const;</span>
<span id="L365"><span class="lineNum"> 365</span> : </span>
<span id="L366"><span class="lineNum"> 366</span> : ViewInfo(Variable base, std::function&lt;Variable(const Variable&amp;)&gt; view_fn)</span>
<span id="L367"><span class="lineNum"> 367</span> : : base_(std::move(base)), view_fn_(std::move(view_fn)) {</span>
<span id="L368"><span class="lineNum"> 368</span> : TORCH_CHECK(base_.defined(), &quot;base is undefined&quot;);</span>
<span id="L369"><span class="lineNum"> 369</span> : }</span>
<span id="L370"><span class="lineNum"> 370</span> : };</span>
<span id="L371"><span class="lineNum"> 371</span> : </span>
<span id="L372"><span class="lineNum"> 372</span> : //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~</span>
<span id="L373"><span class="lineNum"> 373</span> : // DifferentiableViewMeta</span>
<span id="L374"><span class="lineNum"> 374</span> : //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~</span>
<span id="L375"><span class="lineNum"> 375</span> : </span>
<span id="L376"><span class="lineNum"> 376</span> : /// NOTE [ Autograd View Variables ]</span>
<span id="L377"><span class="lineNum"> 377</span> : ///</span>
<span id="L378"><span class="lineNum"> 378</span> : /// Many operations return Variable that shares storage with an input Variable.</span>
<span id="L379"><span class="lineNum"> 379</span> : /// The returned Variable is called a **view** Variable on the input **base**</span>
<span id="L380"><span class="lineNum"> 380</span> : /// Variable.</span>
<span id="L381"><span class="lineNum"> 381</span> : ///</span>
<span id="L382"><span class="lineNum"> 382</span> : /// In PyTorch, we have two types of views: differentiable views, and</span>
<span id="L383"><span class="lineNum"> 383</span> : /// non-differentiable views. In either type, to support proper version</span>
<span id="L384"><span class="lineNum"> 384</span> : /// checking, the base and view Variables must always share the same</span>
<span id="L385"><span class="lineNum"> 385</span> : /// version_counter.</span>
<span id="L386"><span class="lineNum"> 386</span> : ///</span>
<span id="L387"><span class="lineNum"> 387</span> : ///</span>
<span id="L388"><span class="lineNum"> 388</span> : /// Differentiable Views</span>
<span id="L389"><span class="lineNum"> 389</span> : /// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~</span>
<span id="L390"><span class="lineNum"> 390</span> : /// This class allows to track both forward and backward AD differentiable</span>
<span id="L391"><span class="lineNum"> 391</span> : /// views. These views can have different base as non-differentiable view for</span>
<span id="L392"><span class="lineNum"> 392</span> : /// forward and backward mode AD are not the same.</span>
<span id="L393"><span class="lineNum"> 393</span> : ///</span>
<span id="L394"><span class="lineNum"> 394</span> : /// Most function are either both forward and backward differentiable views (for</span>
<span id="L395"><span class="lineNum"> 395</span> : /// example: view, select, narrow, transpose, etc) or both not forward and not</span>
<span id="L396"><span class="lineNum"> 396</span> : /// backward differentiable views (for example: indices, values, eq, lt, etc).</span>
<span id="L397"><span class="lineNum"> 397</span> : /// But there are also functions that are forward but not backward</span>
<span id="L398"><span class="lineNum"> 398</span> : /// differentiable views (only detach for now) or functions that are backward</span>
<span id="L399"><span class="lineNum"> 399</span> : /// but not forward differentiable view (only make_dual and unpack dual for</span>
<span id="L400"><span class="lineNum"> 400</span> : /// now).</span>
<span id="L401"><span class="lineNum"> 401</span> : ///</span>
<span id="L402"><span class="lineNum"> 402</span> : /// A concrete example of two views with different bases is as follow:</span>
<span id="L403"><span class="lineNum"> 403</span> : ///</span>
<span id="L404"><span class="lineNum"> 404</span> : /// # Have:</span>
<span id="L405"><span class="lineNum"> 405</span> : /// # dual is a dual Tensor that is neither a forward or backward view</span>
<span id="L406"><span class="lineNum"> 406</span> : /// detached_dual = dual.detach()</span>
<span id="L407"><span class="lineNum"> 407</span> : /// view = detached_dual.view_as(dual)</span>
<span id="L408"><span class="lineNum"> 408</span> : /// # The forward base of view is dual</span>
<span id="L409"><span class="lineNum"> 409</span> : /// # The backward base of view is detached_dual</span>
<span id="L410"><span class="lineNum"> 410</span> : ///</span>
<span id="L411"><span class="lineNum"> 411</span> : /// - Backward Mode View</span>
<span id="L412"><span class="lineNum"> 412</span> : /// Differentiable views are the view variables where you want gradients to flow</span>
<span id="L413"><span class="lineNum"> 413</span> : /// back to the base variables. Out-of-place operations on views are quite</span>
<span id="L414"><span class="lineNum"> 414</span> : /// straightforward, but in-place ones are very tricky. Even if the base</span>
<span id="L415"><span class="lineNum"> 415</span> : /// variable may not require grad when we create the view, we still need to</span>
<span id="L416"><span class="lineNum"> 416</span> : /// track the view relation because future in-place ops may require back-proping</span>
<span id="L417"><span class="lineNum"> 417</span> : /// through it. For example, we need to support</span>
<span id="L418"><span class="lineNum"> 418</span> : ///</span>
<span id="L419"><span class="lineNum"> 419</span> : /// (1) in-place operation on view, e.g.,</span>
<span id="L420"><span class="lineNum"> 420</span> : ///</span>
<span id="L421"><span class="lineNum"> 421</span> : /// # Have:</span>
<span id="L422"><span class="lineNum"> 422</span> : /// # base.requires_grad = False</span>
<span id="L423"><span class="lineNum"> 423</span> : /// # var.requires_grad = True</span>
<span id="L424"><span class="lineNum"> 424</span> : /// base[1] = var # i.e., base[1].copy_(var)</span>
<span id="L425"><span class="lineNum"> 425</span> : /// torch.autograd.grad(base.sum(), var) &lt;- should return an all ones</span>
<span id="L426"><span class="lineNum"> 426</span> : /// tensor</span>
<span id="L427"><span class="lineNum"> 427</span> : ///</span>
<span id="L428"><span class="lineNum"> 428</span> : /// (2) in-place operation on base after view is created, e.g.,</span>
<span id="L429"><span class="lineNum"> 429</span> : ///</span>
<span id="L430"><span class="lineNum"> 430</span> : /// # Have:</span>
<span id="L431"><span class="lineNum"> 431</span> : /// # base.requires_grad = False</span>
<span id="L432"><span class="lineNum"> 432</span> : /// # var.requires_grad = True</span>
<span id="L433"><span class="lineNum"> 433</span> : /// view = base[1]</span>
<span id="L434"><span class="lineNum"> 434</span> : /// base.copy_(var)</span>
<span id="L435"><span class="lineNum"> 435</span> : /// torch.autograd.grad(view.sum(), var) &lt;- should return a tensor with</span>
<span id="L436"><span class="lineNum"> 436</span> : /// var[1] filled with all ones and</span>
<span id="L437"><span class="lineNum"> 437</span> : /// zeros everywhere else</span>
<span id="L438"><span class="lineNum"> 438</span> : ///</span>
<span id="L439"><span class="lineNum"> 439</span> : /// - Forward Mode View</span>
<span id="L440"><span class="lineNum"> 440</span> : /// Forward differentiable views follow the same semantic as backward ones but</span>
<span id="L441"><span class="lineNum"> 441</span> : /// show up differently as they are computed along with the forward evaluation.</span>
<span id="L442"><span class="lineNum"> 442</span> : /// The hard examples above are thus very similar</span>
<span id="L443"><span class="lineNum"> 443</span> : ///</span>
<span id="L444"><span class="lineNum"> 444</span> : /// (1) in-place operation on view, e.g.,</span>
<span id="L445"><span class="lineNum"> 445</span> : ///</span>
<span id="L446"><span class="lineNum"> 446</span> : /// # Have:</span>
<span id="L447"><span class="lineNum"> 447</span> : /// # base is a regular Tensor</span>
<span id="L448"><span class="lineNum"> 448</span> : /// # var is a dual Tensor whose tangent is all ones</span>
<span id="L449"><span class="lineNum"> 449</span> : /// base[1] = var # i.e., base[1].copy_(var)</span>
<span id="L450"><span class="lineNum"> 450</span> : /// # Now, base is a dual Tensor</span>
<span id="L451"><span class="lineNum"> 451</span> : /// _, fw_grad = fwAD.unpack_dual(base) &lt;- fw_grad should be a tensor with</span>
<span id="L452"><span class="lineNum"> 452</span> : /// fw_grad[1] filled with all ones</span>
<span id="L453"><span class="lineNum"> 453</span> : /// and zeros everywhere else</span>
<span id="L454"><span class="lineNum"> 454</span> : ///</span>
<span id="L455"><span class="lineNum"> 455</span> : /// (2) in-place operation on base after view is created, e.g.,</span>
<span id="L456"><span class="lineNum"> 456</span> : ///</span>
<span id="L457"><span class="lineNum"> 457</span> : /// # Have:</span>
<span id="L458"><span class="lineNum"> 458</span> : /// # base is a regular Tensor</span>
<span id="L459"><span class="lineNum"> 459</span> : /// # var is a dual Tensor whose tangent is all ones</span>
<span id="L460"><span class="lineNum"> 460</span> : /// view = base[1]</span>
<span id="L461"><span class="lineNum"> 461</span> : /// base.copy_(var)</span>
<span id="L462"><span class="lineNum"> 462</span> : /// _, fw_grad = fwAD.unpack_dual(view) &lt;- fw_grad should be an all ones</span>
<span id="L463"><span class="lineNum"> 463</span> : /// tensor</span>
<span id="L464"><span class="lineNum"> 464</span> : ///</span>
<span id="L465"><span class="lineNum"> 465</span> : /// See Note [Forward Grad View/inplace] for more details on how we handle these</span>
<span id="L466"><span class="lineNum"> 466</span> : /// hard cases.</span>
<span id="L467"><span class="lineNum"> 467</span> : ///</span>
<span id="L468"><span class="lineNum"> 468</span> : ///</span>
<span id="L469"><span class="lineNum"> 469</span> : /// DifferentiableViewMeta is created to support gradient tracking of</span>
<span id="L470"><span class="lineNum"> 470</span> : /// such **in-place** operations. In particular,</span>
<span id="L471"><span class="lineNum"> 471</span> : /// + if an in-place op is done on base, the grad_fn field of the view may</span>
<span id="L472"><span class="lineNum"> 472</span> : /// become stale. So accesses should always go through grad_fn(), which</span>
<span id="L473"><span class="lineNum"> 473</span> : /// reconstructs an updated grad_fn if the version_counter has incremented.</span>
<span id="L474"><span class="lineNum"> 474</span> : /// All other fields are always valid.</span>
<span id="L475"><span class="lineNum"> 475</span> : /// + if an in-place op is done on view, in rebase_history() of view, which is</span>
<span id="L476"><span class="lineNum"> 476</span> : /// called after every in-place op in VariableType.cpp, the grad_fn of base</span>
<span id="L477"><span class="lineNum"> 477</span> : /// is updated.</span>
<span id="L478"><span class="lineNum"> 478</span> : /// + if a single autograd Node returns multiple differentiable views, if any</span>
<span id="L479"><span class="lineNum"> 479</span> : /// output is modified by an inplace operation, the autograd engine will</span>
<span id="L480"><span class="lineNum"> 480</span> : /// make an equivalent graph (corresponding to the view operations) without</span>
<span id="L481"><span class="lineNum"> 481</span> : /// using equivalent graph, where each output is treated as if it were</span>
<span id="L482"><span class="lineNum"> 482</span> : /// produced by a distinct view operation. This discards the original (e.g.,</span>
<span id="L483"><span class="lineNum"> 483</span> : /// user provided) grad_fn. If the provided grad_fn does more than the</span>
<span id="L484"><span class="lineNum"> 484</span> : /// backward of the view, then the DifferentiableViewMeta must be created</span>
<span id="L485"><span class="lineNum"> 485</span> : /// with creation_meta= CreationMeta::MULTI_OUTPUT_NODE to prevent the</span>
<span id="L486"><span class="lineNum"> 486</span> : /// engine from ignoring the provided grad_fn.</span>
<span id="L487"><span class="lineNum"> 487</span> : ///</span>
<span id="L488"><span class="lineNum"> 488</span> : /// Interaction with GradMode:</span>
<span id="L489"><span class="lineNum"> 489</span> : /// The particular case that we consider here is:</span>
<span id="L490"><span class="lineNum"> 490</span> : ///</span>
<span id="L491"><span class="lineNum"> 491</span> : /// # Have:</span>
<span id="L492"><span class="lineNum"> 492</span> : /// # base.requires_grad = True or False</span>
<span id="L493"><span class="lineNum"> 493</span> : /// with torch.no_grad():</span>
<span id="L494"><span class="lineNum"> 494</span> : /// view = base[1]</span>
<span id="L495"><span class="lineNum"> 495</span> : /// base.requires_grad_()</span>
<span id="L496"><span class="lineNum"> 496</span> : /// view.copy_(var)</span>
<span id="L497"><span class="lineNum"> 497</span> : /// torch.autograd.grad(base.sum(), var) &lt;- what should it return?</span>
<span id="L498"><span class="lineNum"> 498</span> : ///</span>
<span id="L499"><span class="lineNum"> 499</span> : /// Given that this particular code example is ambiguous and can easily be</span>
<span id="L500"><span class="lineNum"> 500</span> : /// replace by either moving both inside the no_grad block or both outside, we</span>
<span id="L501"><span class="lineNum"> 501</span> : /// explicitly forbid it. For now, it is deprecated by a warning. This is</span>
<span id="L502"><span class="lineNum"> 502</span> : /// achieved by setting creation_meta=CreationMeta::NO_GRAD_MODE for all</span>
<span id="L503"><span class="lineNum"> 503</span> : /// differentiable views created in no_grad mode.</span>
<span id="L504"><span class="lineNum"> 504</span> : ///</span>
<span id="L505"><span class="lineNum"> 505</span> : /// See Note [View + Inplace update for base tensor]</span>
<span id="L506"><span class="lineNum"> 506</span> : /// and Note [View + Inplace update for view tensor] for the details how</span>
<span id="L507"><span class="lineNum"> 507</span> : /// autograd handles inplace update with view ops.</span>
<span id="L508"><span class="lineNum"> 508</span> : ///</span>
<span id="L509"><span class="lineNum"> 509</span> : /// Non-Differentiable Views</span>
<span id="L510"><span class="lineNum"> 510</span> : /// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~</span>
<span id="L511"><span class="lineNum"> 511</span> : /// In certain cases, although function outputs share storage with inputs, they</span>
<span id="L512"><span class="lineNum"> 512</span> : /// will **never** require gradient history tracking. Instead of registering the</span>
<span id="L513"><span class="lineNum"> 513</span> : /// view relation via DifferentiableViewMeta in autograd, the views will be</span>
<span id="L514"><span class="lineNum"> 514</span> : /// using usual AutogradMeta and just share the version counters with the base</span>
<span id="L515"><span class="lineNum"> 515</span> : /// Variables.</span>
<span id="L516"><span class="lineNum"> 516</span> : /// Such views include:</span>
<span id="L517"><span class="lineNum"> 517</span> : /// 1. Views created from .detach()</span>
<span id="L518"><span class="lineNum"> 518</span> : /// 2. Views that are non-differentiable by its nature.</span>
<span id="L519"><span class="lineNum"> 519</span> : /// E.g., `sparse_tensor.indices()` is a integral view on a (possibly)</span>
<span id="L520"><span class="lineNum"> 520</span> : /// floating point tensor.</span>
<span id="L521"><span class="lineNum"> 521</span> : /// See top of `derivatives.yaml` on how to specify that outputs of a</span>
<span id="L522"><span class="lineNum"> 522</span> : /// function are non-differentiable.</span>
<span id="L523"><span class="lineNum"> 523</span> : /// These are called non-differentiable views as the gradients do not flow</span>
<span id="L524"><span class="lineNum"> 524</span> : /// through the view relation.</span>
<span id="L525"><span class="lineNum"> 525</span> : ///</span>
<span id="L526"><span class="lineNum"> 526</span> : /// Relevant logic for both differentiable and non-differentiable views is</span>
<span id="L527"><span class="lineNum"> 527</span> : /// implemented in make_variable_(non_)differentiable_view below, and</span>
<span id="L528"><span class="lineNum"> 528</span> : /// wrap_output of gen_variable_type.py.</span>
<span id="L529"><span class="lineNum"> 529</span> : </span>
<span id="L530"><span class="lineNum"> 530</span> : /// NOTE [ View + Inplace detection ]</span>
<span id="L531"><span class="lineNum"> 531</span> : ///</span>
<span id="L532"><span class="lineNum"> 532</span> : /// We want to detect views followed by inplace as they are often forbidden to</span>
<span id="L533"><span class="lineNum"> 533</span> : /// ensure correctness of the computed gradients. But since we want to only</span>
<span id="L534"><span class="lineNum"> 534</span> : /// notify the user when both happen, we tag the DifferentiableViewMeta when the</span>
<span id="L535"><span class="lineNum"> 535</span> : /// view is created via the `make_variable_*_view()` functions. This tag is then</span>
<span id="L536"><span class="lineNum"> 536</span> : /// checked by the `check_inplace()` function from `VariableTypeUtils.h` that</span>
<span id="L537"><span class="lineNum"> 537</span> : /// should be called before every inplace operation and to detect cases where</span>
<span id="L538"><span class="lineNum"> 538</span> : /// other views are modified and this one is rebased by side effect, we also</span>
<span id="L539"><span class="lineNum"> 539</span> : /// check in the `VariableHooks::grad_fn()`.</span>
<span id="L540"><span class="lineNum"> 540</span> : </span>
<span id="L541"><span class="lineNum"> 541</span> : /// Flag that gives more information about when this view was created:</span>
<span id="L542"><span class="lineNum"> 542</span> : /// - IN_CUSTOM_FUNCTION should be set when the view is created inside a custom</span>
<span id="L543"><span class="lineNum"> 543</span> : /// autograd Function is returned.</span>
<span id="L544"><span class="lineNum"> 544</span> : /// - NO_GRAD_MODE should be set when a view in created when GradMode is</span>
<span id="L545"><span class="lineNum"> 545</span> : /// disabled</span>
<span id="L546"><span class="lineNum"> 546</span> : /// - MULTI_OUTPUT_NODE should be set when a Node created by codegen code</span>
<span id="L547"><span class="lineNum"> 547</span> : /// returns</span>
<span id="L548"><span class="lineNum"> 548</span> : /// multiple differentiable views</span>
<span id="L549"><span class="lineNum"> 549</span> : /// - Inference_MODE should be set when a view of normal tensor is created in</span>
<span id="L550"><span class="lineNum"> 550</span> : /// InferenceMode.</span>
<span id="L551"><span class="lineNum"> 551</span> : /// - DEFAULT is for all other cases</span>
<span id="L552"><span class="lineNum"> 552</span> : enum class CreationMeta : uint8_t {</span>
<span id="L553"><span class="lineNum"> 553</span> : DEFAULT,</span>
<span id="L554"><span class="lineNum"> 554</span> : IN_CUSTOM_FUNCTION,</span>
<span id="L555"><span class="lineNum"> 555</span> : MULTI_OUTPUT_NODE,</span>
<span id="L556"><span class="lineNum"> 556</span> : NO_GRAD_MODE,</span>
<span id="L557"><span class="lineNum"> 557</span> : INFERENCE_MODE</span>
<span id="L558"><span class="lineNum"> 558</span> : };</span>
<span id="L559"><span class="lineNum"> 559</span> : </span>
<span id="L560"><span class="lineNum"> 560</span> : /// Handles correctly propagating CreationMeta when a new view is created from a</span>
<span id="L561"><span class="lineNum"> 561</span> : /// previous view. In general, we don't want the new view to be _less_</span>
<span id="L562"><span class="lineNum"> 562</span> : /// restrictive than the previous view (it's okay to be _more_ restrictive). A</span>
<span id="L563"><span class="lineNum"> 563</span> : /// CreationMeta value of DEFAULT is currently the least restrictive, as the</span>
<span id="L564"><span class="lineNum"> 564</span> : /// behavior for all other CreationMeta values is to error out for in-place ops.</span>
<span id="L565"><span class="lineNum"> 565</span> : /// A CreationMeta value of INFERENCE_MODE is currently the most restrictive, so</span>
<span id="L566"><span class="lineNum"> 566</span> : /// it takes precedence in propagation. If this changes, the logic here will</span>
<span id="L567"><span class="lineNum"> 567</span> : /// need to be updated to properly handle the new semantics.</span>
<span id="L568"><span class="lineNum"> 568</span> : inline CreationMeta propagate_creation_meta(</span>
<span id="L569"><span class="lineNum"> 569</span> : CreationMeta prev_view_creation_meta,</span>
<span id="L570"><span class="lineNum"> 570</span> : CreationMeta new_view_creation_meta) {</span>
<span id="L571"><span class="lineNum"> 571</span> : return (new_view_creation_meta == CreationMeta::DEFAULT)</span>
<span id="L572"><span class="lineNum"> 572</span> : ? prev_view_creation_meta</span>
<span id="L573"><span class="lineNum"> 573</span> : : (prev_view_creation_meta == CreationMeta::INFERENCE_MODE</span>
<span id="L574"><span class="lineNum"> 574</span> : ? prev_view_creation_meta</span>
<span id="L575"><span class="lineNum"> 575</span> : : new_view_creation_meta);</span>
<span id="L576"><span class="lineNum"> 576</span> : }</span>
<span id="L577"><span class="lineNum"> 577</span> : </span>
<span id="L578"><span class="lineNum"> 578</span> : /// Unified function to handle error checking when rebase happens</span>
<span id="L579"><span class="lineNum"> 579</span> : /// indirect=true means that the caller is not doing the inplace, but the</span>
<span id="L580"><span class="lineNum"> 580</span> : /// inplace happened somewhere else.</span>
<span id="L581"><span class="lineNum"> 581</span> : TORCH_API void handle_view_on_rebase(</span>
<span id="L582"><span class="lineNum"> 582</span> : DifferentiableViewMeta* diff_view_meta,</span>
<span id="L583"><span class="lineNum"> 583</span> : bool indirect = false);</span>
<span id="L584"><span class="lineNum"> 584</span> : </span>
<span id="L585"><span class="lineNum"> 585</span> : struct TORCH_API DifferentiableViewMeta : public AutogradMeta {</span>
<span id="L586"><span class="lineNum"> 586</span> : private:</span>
<span id="L587"><span class="lineNum"> 587</span> : /// Informations about the views</span>
<span id="L588"><span class="lineNum"> 588</span> : c10::optional&lt;ViewInfo&gt; backward_info_;</span>
<span id="L589"><span class="lineNum"> 589</span> : c10::optional&lt;ViewInfo&gt; forward_info_;</span>
<span id="L590"><span class="lineNum"> 590</span> : </span>
<span id="L591"><span class="lineNum"> 591</span> : // Optimization to reduce the number of ViewInfo we create.</span>
<span id="L592"><span class="lineNum"> 592</span> : // In the (very common) case where backward_info_ == forward_info_, we only</span>
<span id="L593"><span class="lineNum"> 593</span> : // populate backward_info_ (that should be used as both the forward and</span>
<span id="L594"><span class="lineNum"> 594</span> : // backward view information) and set shared_view_info_ = true. Invariants:</span>
<span id="L595"><span class="lineNum"> 595</span> : // - If shared_view_info_ is false, there is no special constraints on</span>
<span id="L596"><span class="lineNum"> 596</span> : // backward_info_ and forward_info_</span>
<span id="L597"><span class="lineNum"> 597</span> : // - If shared_view_info_ is true, we must have:</span>
<span id="L598"><span class="lineNum"> 598</span> : // - backward_info_.has_value() == true</span>
<span id="L599"><span class="lineNum"> 599</span> : // - forward_info_.has_value() == false</span>
<span id="L600"><span class="lineNum"> 600</span> : bool shared_view_info_;</span>
<span id="L601"><span class="lineNum"> 601</span> : </span>
<span id="L602"><span class="lineNum"> 602</span> : /// The two following fields are extra information that we track to ensure</span>
<span id="L603"><span class="lineNum"> 603</span> : /// that any operation on this backward view is valid.</span>
<span id="L604"><span class="lineNum"> 604</span> : </span>
<span id="L605"><span class="lineNum"> 605</span> : /// The value of the version_counter at the time grad_fn was created. The</span>
<span id="L606"><span class="lineNum"> 606</span> : /// grad_fn field is stale if attr_version_ !=</span>
<span id="L607"><span class="lineNum"> 607</span> : /// version_counter.current_version().</span>
<span id="L608"><span class="lineNum"> 608</span> : uint32_t attr_version_;</span>
<span id="L609"><span class="lineNum"> 609</span> : CreationMeta creation_meta_;</span>
<span id="L610"><span class="lineNum"> 610</span> : </span>
<span id="L611"><span class="lineNum"> 611</span> : public:</span>
<span id="L612"><span class="lineNum"> 612</span> : /// requires_grad is a backward AD field so we only use the view specific</span>
<span id="L613"><span class="lineNum"> 613</span> : /// logic for backward differentiable views</span>
<span id="L614"><span class="lineNum"> 614</span> : bool requires_grad() const override {</span>
<span id="L615"><span class="lineNum"> 615</span> : return requires_grad_ || grad_fn_ ||</span>
<span id="L616"><span class="lineNum"> 616</span> : (has_bw_view() &amp;&amp; get_backward_view().base_.requires_grad());</span>
<span id="L617"><span class="lineNum"> 617</span> : }</span>
<span id="L618"><span class="lineNum"> 618</span> : </span>
<span id="L619"><span class="lineNum"> 619</span> : bool shared_view_info() const {</span>
<span id="L620"><span class="lineNum"> 620</span> : return shared_view_info_;</span>
<span id="L621"><span class="lineNum"> 621</span> : }</span>
<span id="L622"><span class="lineNum"> 622</span> : </span>
<span id="L623"><span class="lineNum"> 623</span> : bool has_bw_view() const {</span>
<span id="L624"><span class="lineNum"> 624</span> : return backward_info_.has_value();</span>
<span id="L625"><span class="lineNum"> 625</span> : }</span>
<span id="L626"><span class="lineNum"> 626</span> : </span>
<span id="L627"><span class="lineNum"> 627</span> : const ViewInfo&amp; get_backward_view() const {</span>
<span id="L628"><span class="lineNum"> 628</span> : TORCH_CHECK(</span>
<span id="L629"><span class="lineNum"> 629</span> : has_bw_view(), &quot;backward view info can only exist for backward views.&quot;);</span>
<span id="L630"><span class="lineNum"> 630</span> : return backward_info_.value();</span>
<span id="L631"><span class="lineNum"> 631</span> : }</span>
<span id="L632"><span class="lineNum"> 632</span> : </span>
<span id="L633"><span class="lineNum"> 633</span> : uint32_t get_attr_version() const {</span>
<span id="L634"><span class="lineNum"> 634</span> : TORCH_CHECK(</span>
<span id="L635"><span class="lineNum"> 635</span> : has_bw_view(), &quot;attr_version can only exist for backward views.&quot;);</span>
<span id="L636"><span class="lineNum"> 636</span> : return attr_version_;</span>
<span id="L637"><span class="lineNum"> 637</span> : }</span>
<span id="L638"><span class="lineNum"> 638</span> : </span>
<span id="L639"><span class="lineNum"> 639</span> : void set_attr_version(uint32_t new_attr_version) {</span>
<span id="L640"><span class="lineNum"> 640</span> : TORCH_CHECK(</span>
<span id="L641"><span class="lineNum"> 641</span> : has_bw_view(), &quot;attr_version can only exist for backward views.&quot;);</span>
<span id="L642"><span class="lineNum"> 642</span> : attr_version_ = new_attr_version;</span>
<span id="L643"><span class="lineNum"> 643</span> : }</span>
<span id="L644"><span class="lineNum"> 644</span> : </span>
<span id="L645"><span class="lineNum"> 645</span> : CreationMeta get_creation_meta() const {</span>
<span id="L646"><span class="lineNum"> 646</span> : TORCH_CHECK(</span>
<span id="L647"><span class="lineNum"> 647</span> : has_bw_view(), &quot;creation_meta can only exist for backward views.&quot;);</span>
<span id="L648"><span class="lineNum"> 648</span> : return creation_meta_;</span>
<span id="L649"><span class="lineNum"> 649</span> : }</span>
<span id="L650"><span class="lineNum"> 650</span> : </span>
<span id="L651"><span class="lineNum"> 651</span> : void set_creation_meta(CreationMeta new_creation_meta) {</span>
<span id="L652"><span class="lineNum"> 652</span> : TORCH_CHECK(</span>
<span id="L653"><span class="lineNum"> 653</span> : has_bw_view(), &quot;creation_meta can only exist for backward views.&quot;);</span>
<span id="L654"><span class="lineNum"> 654</span> : creation_meta_ = new_creation_meta;</span>
<span id="L655"><span class="lineNum"> 655</span> : }</span>
<span id="L656"><span class="lineNum"> 656</span> : </span>
<span id="L657"><span class="lineNum"> 657</span> : bool has_fw_view() const {</span>
<span id="L658"><span class="lineNum"> 658</span> : return shared_view_info_ || forward_info_.has_value();</span>
<span id="L659"><span class="lineNum"> 659</span> : }</span>
<span id="L660"><span class="lineNum"> 660</span> : </span>
<span id="L661"><span class="lineNum"> 661</span> : const ViewInfo&amp; get_forward_view() const {</span>
<span id="L662"><span class="lineNum"> 662</span> : TORCH_CHECK(</span>
<span id="L663"><span class="lineNum"> 663</span> : has_fw_view(), &quot;forward view info can only exist for forward views.&quot;);</span>
<span id="L664"><span class="lineNum"> 664</span> : TORCH_CHECK(</span>
<span id="L665"><span class="lineNum"> 665</span> : !shared_view_info_ || has_bw_view(),</span>
<span id="L666"><span class="lineNum"> 666</span> : &quot;forward view info can only exist for forward views.&quot;);</span>
<span id="L667"><span class="lineNum"> 667</span> : return shared_view_info_ ? backward_info_.value() : forward_info_.value();</span>
<span id="L668"><span class="lineNum"> 668</span> : }</span>
<span id="L669"><span class="lineNum"> 669</span> : </span>
<span id="L670"><span class="lineNum"> 670</span> : DifferentiableViewMeta(</span>
<span id="L671"><span class="lineNum"> 671</span> : at::TensorImpl* self_impl,</span>
<span id="L672"><span class="lineNum"> 672</span> : c10::optional&lt;ViewInfo&gt; backward_info,</span>
<span id="L673"><span class="lineNum"> 673</span> : c10::optional&lt;ViewInfo&gt; forward_info,</span>
<span id="L674"><span class="lineNum"> 674</span> : bool shared_view_info,</span>
<span id="L675"><span class="lineNum"> 675</span> : CreationMeta creation_meta = CreationMeta::DEFAULT);</span>
<span id="L676"><span class="lineNum"> 676</span> : };</span>
<span id="L677"><span class="lineNum"> 677</span> : </span>
<span id="L678"><span class="lineNum"> 678</span> : //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~</span>
<span id="L679"><span class="lineNum"> 679</span> : // Variable Implementation</span>
<span id="L680"><span class="lineNum"> 680</span> : //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~</span>
<span id="L681"><span class="lineNum"> 681</span> : </span>
<span id="L682"><span class="lineNum"> 682</span> : // Factory Functions</span>
<span id="L683"><span class="lineNum"> 683</span> : //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~</span>
<span id="L684"><span class="lineNum"> 684</span> : </span>
<span id="L685"><span class="lineNum"> 685</span> : /// Creates a `Variable` that is a *view* of another (*base*) variable.</span>
<span id="L686"><span class="lineNum"> 686</span> : /// The `gradient_edge` is an optional (gradient_function, input_number) pair.</span>
<span id="L687"><span class="lineNum"> 687</span> : /// `is_differentiable` is a bool that specifies whether this view is</span>
<span id="L688"><span class="lineNum"> 688</span> : /// differentiable, i.e., whether the relation should be tracked by autograd.</span>
<span id="L689"><span class="lineNum"> 689</span> : /// See NOTE [ Autograd View Variables ] for details.</span>
<span id="L690"><span class="lineNum"> 690</span> : </span>
<span id="L691"><span class="lineNum"> 691</span> : /// NOTE: `allow_tensor_metadata_change` is set to true by default, because</span>
<span id="L692"><span class="lineNum"> 692</span> : /// there are a lot of call sites to these factory functions that need to change</span>
<span id="L693"><span class="lineNum"> 693</span> : /// the variable's size or storage afterwards, and they don't expect the</span>
<span id="L694"><span class="lineNum"> 694</span> : /// original tensor (where the variable is created from) to be updated. Setting</span>
<span id="L695"><span class="lineNum"> 695</span> : /// `allow_tensor_metadata_change_` to false by default would unnecessarily</span>
<span id="L696"><span class="lineNum"> 696</span> : /// prevent those changes from happening and is undesirable.</span>
<span id="L697"><span class="lineNum"> 697</span> : </span>
<span id="L698"><span class="lineNum"> 698</span> : // See NOTE [ Autograd View Variables ] for details.</span>
<span id="L699"><span class="lineNum"> 699</span> : // Differentiable view. Track history with DifferentiableViewMeta.</span>
<span id="L700"><span class="lineNum"> 700</span> : inline Variable make_variable_differentiable_view(</span>
<span id="L701"><span class="lineNum"> 701</span> : const at::Tensor&amp; data,</span>
<span id="L702"><span class="lineNum"> 702</span> : c10::optional&lt;ViewInfo&gt; backward_info,</span>
<span id="L703"><span class="lineNum"> 703</span> : c10::optional&lt;ViewInfo&gt; forward_info,</span>
<span id="L704"><span class="lineNum"> 704</span> : bool shared_view_info,</span>
<span id="L705"><span class="lineNum"> 705</span> : CreationMeta creation_meta,</span>
<span id="L706"><span class="lineNum"> 706</span> : bool allow_tensor_metadata_change = true) {</span>
<span id="L707"><span class="lineNum"> 707</span> : if (data.defined()) {</span>
<span id="L708"><span class="lineNum"> 708</span> : TORCH_CHECK(</span>
<span id="L709"><span class="lineNum"> 709</span> : data.getIntrusivePtr()-&gt;autograd_meta() == nullptr,</span>
<span id="L710"><span class="lineNum"> 710</span> : &quot;Attempted to make a tensor into a differentiable view, but the &quot;</span>
<span id="L711"><span class="lineNum"> 711</span> : &quot;tensor already had autograd metadata associated with it. If you are &quot;</span>
<span id="L712"><span class="lineNum"> 712</span> : &quot;using a __torch_dispatch__ mode, the most common cause for this &quot;</span>
<span id="L713"><span class="lineNum"> 713</span> : &quot;problem is that you used torch.overrides.enable_reentrant_dispatch() &quot;</span>
<span id="L714"><span class="lineNum"> 714</span> : &quot;improperly; tensors created within the extent of reentrant dispatch &quot;</span>
<span id="L715"><span class="lineNum"> 715</span> : &quot;MUST NOT be directly returned from __torch_dispatch__; instead, they &quot;</span>
<span id="L716"><span class="lineNum"> 716</span> : &quot;must be wrapped into fresh tensors that serve as the output. If you &quot;</span>
<span id="L717"><span class="lineNum"> 717</span> : &quot;are not using wrappers, you probably don't need reentrant dispatch. &quot;</span>
<span id="L718"><span class="lineNum"> 718</span> : &quot;If this doesn't seem applicable, please file a bug to PyTorch.&quot;);</span>
<span id="L719"><span class="lineNum"> 719</span> : at::TensorImpl* data_impl = data.unsafeGetTensorImpl();</span>
<span id="L720"><span class="lineNum"> 720</span> : data_impl-&gt;set_allow_tensor_metadata_change(allow_tensor_metadata_change);</span>
<span id="L721"><span class="lineNum"> 721</span> : data_impl-&gt;set_autograd_meta(std::make_unique&lt;DifferentiableViewMeta&gt;(</span>
<span id="L722"><span class="lineNum"> 722</span> : data_impl,</span>
<span id="L723"><span class="lineNum"> 723</span> : std::move(backward_info),</span>
<span id="L724"><span class="lineNum"> 724</span> : std::move(forward_info),</span>
<span id="L725"><span class="lineNum"> 725</span> : shared_view_info,</span>
<span id="L726"><span class="lineNum"> 726</span> : creation_meta));</span>
<span id="L727"><span class="lineNum"> 727</span> : return data;</span>
<span id="L728"><span class="lineNum"> 728</span> : }</span>
<span id="L729"><span class="lineNum"> 729</span> : return Variable();</span>
<span id="L730"><span class="lineNum"> 730</span> : }</span>
<span id="L731"><span class="lineNum"> 731</span> : </span>
<span id="L732"><span class="lineNum"> 732</span> : // See NOTE [ Autograd View Variables ] for details.</span>
<span id="L733"><span class="lineNum"> 733</span> : // Non-differentiable view. Just share version counter.</span>
<span id="L734"><span class="lineNum"> 734</span> : inline Variable make_variable_non_differentiable_view(</span>
<span id="L735"><span class="lineNum"> 735</span> : Variable base,</span>
<span id="L736"><span class="lineNum"> 736</span> : const at::Tensor&amp; data,</span>
<span id="L737"><span class="lineNum"> 737</span> : bool allow_tensor_metadata_change = true) {</span>
<span id="L738"><span class="lineNum"> 738</span> : if (data.defined()) {</span>
<span id="L739"><span class="lineNum"> 739</span> : // Currently all of non-differentiable view ops(detach/_indices/_values)</span>
<span id="L740"><span class="lineNum"> 740</span> : // share the same TensorImpl as their base Tensor. Thus a new TensorImpl</span>
<span id="L741"><span class="lineNum"> 741</span> : // allocation here is required.</span>
<span id="L742"><span class="lineNum"> 742</span> : auto data_impl_copy = data.getIntrusivePtr()-&gt;shallow_copy_and_detach(</span>
<span id="L743"><span class="lineNum"> 743</span> : /*version_counter=*/impl::version_counter(base),</span>
<span id="L744"><span class="lineNum"> 744</span> : /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);</span>
<span id="L745"><span class="lineNum"> 745</span> : data_impl_copy-&gt;set_autograd_meta(nullptr);</span>
<span id="L746"><span class="lineNum"> 746</span> : return Variable(data_impl_copy);</span>
<span id="L747"><span class="lineNum"> 747</span> : }</span>
<span id="L748"><span class="lineNum"> 748</span> : return Variable();</span>
<span id="L749"><span class="lineNum"> 749</span> : }</span>
<span id="L750"><span class="lineNum"> 750</span> : </span>
<span id="L751"><span class="lineNum"> 751</span> : /// Creates a `Variable` from the given `Tensor`, copying its underlying</span>
<span id="L752"><span class="lineNum"> 752</span> : /// `TensorImpl`. `requires_grad` should be set only for leaves, and determines</span>
<span id="L753"><span class="lineNum"> 753</span> : /// whether the `Variable` will accumulate gradients. NOTE: `data` must *not* be</span>
<span id="L754"><span class="lineNum"> 754</span> : /// a `Variable` already. Its dynamic type *must* be `Tensor`.</span>
<span id="L755"><span class="lineNum"> 755</span> : ///</span>
<span id="L756"><span class="lineNum"> 756</span> : /// TODO: Eliminate this function as much as possible, as it can be expressed</span>
<span id="L757"><span class="lineNum"> 757</span> : /// more clearly as detach() or a no-op in most call sites (especially when</span>
<span id="L758"><span class="lineNum"> 758</span> : /// there is only one use of the variable).</span>
<span id="L759"><span class="lineNum"> 759</span> <span class="tlaGNC tlaBgGNC"> 418844 : inline Variable make_variable(</span></span>
<span id="L760"><span class="lineNum"> 760</span> : at::Tensor data,</span>
<span id="L761"><span class="lineNum"> 761</span> : bool requires_grad = false,</span>
<span id="L762"><span class="lineNum"> 762</span> : bool allow_tensor_metadata_change = true) {</span>
<span id="L763"><span class="lineNum"> 763</span> <span class="tlaGNC"> 418844 : if (data.defined()) {</span></span>
<span id="L764"><span class="lineNum"> 764</span> <span class="tlaGNC"> 458920 : if (data.getIntrusivePtr().use_count() == 1 &amp;&amp;</span></span>
<span id="L765"><span class="lineNum"> 765</span> <span class="tlaGNC"> 40076 : data.getIntrusivePtr()-&gt;unique_version()) {</span></span>
<span id="L766"><span class="lineNum"> 766</span> <span class="tlaGNC"> 40076 : auto data_impl = data.unsafeReleaseIntrusivePtr();</span></span>
<span id="L767"><span class="lineNum"> 767</span> <span class="tlaGNC"> 40076 : data_impl-&gt;set_allow_tensor_metadata_change(allow_tensor_metadata_change);</span></span>
<span id="L768"><span class="lineNum"> 768</span> : // NOLINTNEXTLINE(bugprone-branch-clone)</span>
<span id="L769"><span class="lineNum"> 769</span> <span class="tlaGNC"> 40076 : if (requires_grad) {</span></span>
<span id="L770"><span class="lineNum"> 770</span> <span class="tlaUNC tlaBgUNC"> 0 : data_impl-&gt;set_autograd_meta(</span></span>
<span id="L771"><span class="lineNum"> 771</span> <span class="tlaUNC"> 0 : std::make_unique&lt;AutogradMeta&gt;(data_impl.get(), requires_grad));</span></span>
<span id="L772"><span class="lineNum"> 772</span> : } else {</span>
<span id="L773"><span class="lineNum"> 773</span> <span class="tlaGNC tlaBgGNC"> 40076 : data_impl-&gt;set_autograd_meta(nullptr);</span></span>
<span id="L774"><span class="lineNum"> 774</span> : }</span>
<span id="L775"><span class="lineNum"> 775</span> <span class="tlaGNC"> 40076 : return Variable(std::move(data_impl));</span></span>
<span id="L776"><span class="lineNum"> 776</span> <span class="tlaGNC"> 40076 : } else {</span></span>
<span id="L777"><span class="lineNum"> 777</span> <span class="tlaGNC"> 378768 : auto data_impl_copy = data.getIntrusivePtr()-&gt;shallow_copy_and_detach(</span></span>
<span id="L778"><span class="lineNum"> 778</span> : /*version_counter=*/0,</span>
<span id="L779"><span class="lineNum"> 779</span> <span class="tlaGNC"> 378768 : /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);</span></span>
<span id="L780"><span class="lineNum"> 780</span> : // NOLINTNEXTLINE(bugprone-branch-clone)</span>
<span id="L781"><span class="lineNum"> 781</span> <span class="tlaGNC"> 378768 : if (requires_grad) {</span></span>
<span id="L782"><span class="lineNum"> 782</span> <span class="tlaUNC tlaBgUNC"> 0 : data_impl_copy-&gt;set_autograd_meta(std::make_unique&lt;AutogradMeta&gt;(</span></span>
<span id="L783"><span class="lineNum"> 783</span> <span class="tlaUNC"> 0 : data_impl_copy.get(), requires_grad));</span></span>
<span id="L784"><span class="lineNum"> 784</span> : } else {</span>
<span id="L785"><span class="lineNum"> 785</span> <span class="tlaGNC tlaBgGNC"> 378768 : data_impl_copy-&gt;set_autograd_meta(nullptr);</span></span>
<span id="L786"><span class="lineNum"> 786</span> : }</span>
<span id="L787"><span class="lineNum"> 787</span> <span class="tlaGNC"> 378768 : return Variable(data_impl_copy);</span></span>
<span id="L788"><span class="lineNum"> 788</span> <span class="tlaGNC"> 378768 : }</span></span>
<span id="L789"><span class="lineNum"> 789</span> : }</span>
<span id="L790"><span class="lineNum"> 790</span> <span class="tlaUNC tlaBgUNC"> 0 : return Variable();</span></span>
<span id="L791"><span class="lineNum"> 791</span> : }</span>
<span id="L792"><span class="lineNum"> 792</span> : </span>
<span id="L793"><span class="lineNum"> 793</span> : /// Creates a `Variable` from the given `Tensor`, copying its underlying</span>
<span id="L794"><span class="lineNum"> 794</span> : /// `TensorImpl`. `gradient_edge` should be a (function, input_nr) pair</span>
<span id="L795"><span class="lineNum"> 795</span> : /// specifying the function in the autograd graph, and what particular input of</span>
<span id="L796"><span class="lineNum"> 796</span> : /// that function, this variable is connected to.</span>
<span id="L797"><span class="lineNum"> 797</span> : inline Variable make_variable(</span>
<span id="L798"><span class="lineNum"> 798</span> : at::Tensor data,</span>
<span id="L799"><span class="lineNum"> 799</span> : Edge gradient_edge,</span>
<span id="L800"><span class="lineNum"> 800</span> : bool allow_tensor_metadata_change = true) {</span>
<span id="L801"><span class="lineNum"> 801</span> : if (data.defined()) {</span>
<span id="L802"><span class="lineNum"> 802</span> : auto data_impl_copy = data.getIntrusivePtr()-&gt;shallow_copy_and_detach(</span>
<span id="L803"><span class="lineNum"> 803</span> : /*version_counter=*/0,</span>
<span id="L804"><span class="lineNum"> 804</span> : /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);</span>
<span id="L805"><span class="lineNum"> 805</span> : data_impl_copy-&gt;set_autograd_meta(std::make_unique&lt;AutogradMeta&gt;(</span>
<span id="L806"><span class="lineNum"> 806</span> : data_impl_copy.get(), false, std::move(gradient_edge)));</span>
<span id="L807"><span class="lineNum"> 807</span> : return Variable(data_impl_copy);</span>
<span id="L808"><span class="lineNum"> 808</span> : }</span>
<span id="L809"><span class="lineNum"> 809</span> : return Variable();</span>
<span id="L810"><span class="lineNum"> 810</span> : }</span>
<span id="L811"><span class="lineNum"> 811</span> : </span>
<span id="L812"><span class="lineNum"> 812</span> : struct VariableHooks final : at::impl::VariableHooksInterface {</span>
<span id="L813"><span class="lineNum"> 813</span> : at::TensorBase tensor_data(const at::TensorBase&amp;) const override;</span>
<span id="L814"><span class="lineNum"> 814</span> : at::TensorBase variable_data(const at::TensorBase&amp;) const override;</span>
<span id="L815"><span class="lineNum"> 815</span> : const std::shared_ptr&lt;torch::autograd::Node&gt;&amp; grad_fn(</span>
<span id="L816"><span class="lineNum"> 816</span> : const at::TensorBase&amp;) const override;</span>
<span id="L817"><span class="lineNum"> 817</span> : unsigned _register_hook(</span>
<span id="L818"><span class="lineNum"> 818</span> : const at::TensorBase&amp;,</span>
<span id="L819"><span class="lineNum"> 819</span> : std::function&lt;at::TensorBase(const at::TensorBase&amp;)&gt; hook) const override;</span>
<span id="L820"><span class="lineNum"> 820</span> : void remove_hook(const at::TensorBase&amp;, unsigned pos) const override;</span>
<span id="L821"><span class="lineNum"> 821</span> : bool is_view(const at::TensorBase&amp;) const override;</span>
<span id="L822"><span class="lineNum"> 822</span> : const at::TensorBase&amp; base(const at::TensorBase&amp;) const override;</span>
<span id="L823"><span class="lineNum"> 823</span> : const std::string&amp; name(const at::TensorBase&amp;) const override;</span>
<span id="L824"><span class="lineNum"> 824</span> : bool is_leaf(const at::TensorBase&amp;) const override;</span>
<span id="L825"><span class="lineNum"> 825</span> : int64_t output_nr(const at::TensorBase&amp;) const override;</span>
<span id="L826"><span class="lineNum"> 826</span> : void set_data(const at::TensorBase&amp; self, const at::TensorBase&amp; new_data)</span>
<span id="L827"><span class="lineNum"> 827</span> : const override;</span>
<span id="L828"><span class="lineNum"> 828</span> : at::TensorBase data(const at::TensorBase&amp; self) const override;</span>
<span id="L829"><span class="lineNum"> 829</span> : int64_t _version(const at::TensorBase&amp; self) const override;</span>
<span id="L830"><span class="lineNum"> 830</span> : void retain_grad(const at::TensorBase&amp; self) const override;</span>
<span id="L831"><span class="lineNum"> 831</span> : bool retains_grad(const at::TensorBase&amp; self) const override;</span>
<span id="L832"><span class="lineNum"> 832</span> : void _backward(</span>
<span id="L833"><span class="lineNum"> 833</span> : const at::Tensor&amp; self,</span>
<span id="L834"><span class="lineNum"> 834</span> : at::TensorList inputs,</span>
<span id="L835"><span class="lineNum"> 835</span> : const c10::optional&lt;at::Tensor&gt;&amp; gradient,</span>
<span id="L836"><span class="lineNum"> 836</span> : c10::optional&lt;bool&gt; keep_graph,</span>
<span id="L837"><span class="lineNum"> 837</span> : bool create_graph) const override;</span>
<span id="L838"><span class="lineNum"> 838</span> : void requires_grad_(const at::TensorBase&amp; self, bool _requires_grad)</span>
<span id="L839"><span class="lineNum"> 839</span> : const override;</span>
<span id="L840"><span class="lineNum"> 840</span> : void basic_autograd_not_implemented_fallback(</span>
<span id="L841"><span class="lineNum"> 841</span> : const c10::OperatorHandle&amp; op,</span>
<span id="L842"><span class="lineNum"> 842</span> : c10::DispatchKeySet dispatch_keys,</span>
<span id="L843"><span class="lineNum"> 843</span> : torch::jit::Stack* stack) const override;</span>
<span id="L844"><span class="lineNum"> 844</span> : };</span>
<span id="L845"><span class="lineNum"> 845</span> : </span>
<span id="L846"><span class="lineNum"> 846</span> : namespace utils {</span>
<span id="L847"><span class="lineNum"> 847</span> : </span>
<span id="L848"><span class="lineNum"> 848</span> : TORCH_API bool has_same_meta(const Variable&amp; base, const Variable&amp; other);</span>
<span id="L849"><span class="lineNum"> 849</span> : </span>
<span id="L850"><span class="lineNum"> 850</span> : } // namespace utils</span>
<span id="L851"><span class="lineNum"> 851</span> : } // namespace autograd</span>
<span id="L852"><span class="lineNum"> 852</span> : } // namespace torch</span>
<span id="L853"><span class="lineNum"> 853</span> : </span>
<span id="L854"><span class="lineNum"> 854</span> : #endif /* DOXYGEN_SHOULD_SKIP_THIS */</span>
</pre>
</td>
</tr>
</table>
<br>
<table width="100%" border=0 cellspacing=0 cellpadding=0>
<tr><td class="ruler"><img src="../../../../../glass.png" width=3 height=3 alt=""></td></tr>
<tr><td class="versionInfo">Generated by: <a href="https://github.com//linux-test-project/lcov" target="_parent">LCOV version 2.0-1</a></td></tr>
</table>
<br>
</body>
</html>