809 lines
82 KiB
HTML
809 lines
82 KiB
HTML
|
<!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.01 Transitional//EN">
|
||
|
|
||
|
<html lang="en">
|
||
|
|
||
|
<head>
|
||
|
<meta http-equiv="Content-Type" content="text/html; charset=UTF-8">
|
||
|
<title>LCOV - coverage.info - libtorch/include/ATen/TensorIndexing.h</title>
|
||
|
<link rel="stylesheet" type="text/css" href="../../../gcov.css">
|
||
|
</head>
|
||
|
|
||
|
<body>
|
||
|
|
||
|
<table width="100%" border=0 cellspacing=0 cellpadding=0>
|
||
|
<tr><td class="title">LCOV - code coverage report</td></tr>
|
||
|
<tr><td class="ruler"><img src="../../../glass.png" width=3 height=3 alt=""></td></tr>
|
||
|
|
||
|
<tr>
|
||
|
<td width="100%">
|
||
|
<table cellpadding=1 border=0 width="100%">
|
||
|
<tr>
|
||
|
<td width="10%" class="headerItem">Current view:</td>
|
||
|
<td width="10%" class="headerValue"><a href="../../../index.html">top level</a> - <a href="index.html">libtorch/include/ATen</a> - TensorIndexing.h<span style="font-size: 80%;"> (source / <a href="TensorIndexing.h.func-c.html">functions</a>)</span></td>
|
||
|
<td width="5%"></td>
|
||
|
<td width="5%"></td>
|
||
|
<td width="5%" class="headerCovTableHead">Coverage</td>
|
||
|
<td width="5%" class="headerCovTableHead" title="Covered + Uncovered code">Total</td>
|
||
|
<td width="5%" class="headerCovTableHead" title="Exercised code only">Hit</td>
|
||
|
</tr>
|
||
|
<tr>
|
||
|
<td class="headerItem">Test:</td>
|
||
|
<td class="headerValue">coverage.info</td>
|
||
|
<td></td>
|
||
|
<td class="headerItem">Lines:</td>
|
||
|
<td class="headerCovTableEntryHi">96.0 %</td>
|
||
|
<td class="headerCovTableEntry">25</td>
|
||
|
<td class="headerCovTableEntry">24</td>
|
||
|
</tr>
|
||
|
<tr>
|
||
|
<td class="headerItem">Test Date:</td>
|
||
|
<td class="headerValue">2024-04-30 13:17:26</td>
|
||
|
<td></td>
|
||
|
<td class="headerItem">Functions:</td>
|
||
|
<td class="headerCovTableEntryHi">100.0 %</td>
|
||
|
<td class="headerCovTableEntry">7</td>
|
||
|
<td class="headerCovTableEntry">7</td>
|
||
|
</tr>
|
||
|
<tr><td><img src="../../../glass.png" width=3 height=3 alt=""></td></tr>
|
||
|
</table>
|
||
|
</td>
|
||
|
</tr>
|
||
|
|
||
|
<tr><td class="ruler"><img src="../../../glass.png" width=3 height=3 alt=""></td></tr>
|
||
|
</table>
|
||
|
|
||
|
<table cellpadding=0 cellspacing=0 border=0>
|
||
|
<tr>
|
||
|
<td><br></td>
|
||
|
</tr>
|
||
|
<tr>
|
||
|
<td>
|
||
|
<pre class="sourceHeading"> Line data Source code</pre>
|
||
|
<pre class="source">
|
||
|
<span id="L1"><span class="lineNum"> 1</span> : #pragma once</span>
|
||
|
<span id="L2"><span class="lineNum"> 2</span> : </span>
|
||
|
<span id="L3"><span class="lineNum"> 3</span> : #include <ATen/ExpandUtils.h></span>
|
||
|
<span id="L4"><span class="lineNum"> 4</span> : #include <ATen/ScalarOps.h></span>
|
||
|
<span id="L5"><span class="lineNum"> 5</span> : #include <ATen/core/Tensor.h></span>
|
||
|
<span id="L6"><span class="lineNum"> 6</span> : #include <ATen/core/TensorBody.h></span>
|
||
|
<span id="L7"><span class="lineNum"> 7</span> : #include <c10/core/SymInt.h></span>
|
||
|
<span id="L8"><span class="lineNum"> 8</span> : #include <c10/util/Optional.h></span>
|
||
|
<span id="L9"><span class="lineNum"> 9</span> : #include <c10/util/irange.h></span>
|
||
|
<span id="L10"><span class="lineNum"> 10</span> : </span>
|
||
|
<span id="L11"><span class="lineNum"> 11</span> : #ifndef AT_PER_OPERATOR_HEADERS</span>
|
||
|
<span id="L12"><span class="lineNum"> 12</span> : #include <ATen/Functions.h></span>
|
||
|
<span id="L13"><span class="lineNum"> 13</span> : #include <ATen/NativeFunctions.h></span>
|
||
|
<span id="L14"><span class="lineNum"> 14</span> : #else</span>
|
||
|
<span id="L15"><span class="lineNum"> 15</span> : #include <ATen/ops/alias.h></span>
|
||
|
<span id="L16"><span class="lineNum"> 16</span> : #include <ATen/ops/empty.h></span>
|
||
|
<span id="L17"><span class="lineNum"> 17</span> : #include <ATen/ops/scalar_tensor.h></span>
|
||
|
<span id="L18"><span class="lineNum"> 18</span> : #include <ATen/ops/zeros.h></span>
|
||
|
<span id="L19"><span class="lineNum"> 19</span> : #endif</span>
|
||
|
<span id="L20"><span class="lineNum"> 20</span> : </span>
|
||
|
<span id="L21"><span class="lineNum"> 21</span> : #include <ATen/core/List.h></span>
|
||
|
<span id="L22"><span class="lineNum"> 22</span> : </span>
|
||
|
<span id="L23"><span class="lineNum"> 23</span> : #include <utility></span>
|
||
|
<span id="L24"><span class="lineNum"> 24</span> : </span>
|
||
|
<span id="L25"><span class="lineNum"> 25</span> : namespace at {</span>
|
||
|
<span id="L26"><span class="lineNum"> 26</span> : namespace indexing {</span>
|
||
|
<span id="L27"><span class="lineNum"> 27</span> : </span>
|
||
|
<span id="L28"><span class="lineNum"> 28</span> : const int64_t INDEX_MIN = c10::SymInt::min_representable_int();</span>
|
||
|
<span id="L29"><span class="lineNum"> 29</span> : const int64_t INDEX_MAX = -(INDEX_MIN + 1);</span>
|
||
|
<span id="L30"><span class="lineNum"> 30</span> : </span>
|
||
|
<span id="L31"><span class="lineNum"> 31</span> : enum class TensorIndexType { None, Ellipsis, SymInt, Boolean, Slice, Tensor };</span>
|
||
|
<span id="L32"><span class="lineNum"> 32</span> : </span>
|
||
|
<span id="L33"><span class="lineNum"> 33</span> : constexpr c10::nullopt_t None = c10::nullopt;</span>
|
||
|
<span id="L34"><span class="lineNum"> 34</span> : </span>
|
||
|
<span id="L35"><span class="lineNum"> 35</span> : struct TORCH_API EllipsisIndexType final {</span>
|
||
|
<span id="L36"><span class="lineNum"> 36</span> : EllipsisIndexType() = default;</span>
|
||
|
<span id="L37"><span class="lineNum"> 37</span> : };</span>
|
||
|
<span id="L38"><span class="lineNum"> 38</span> : TORCH_API extern const EllipsisIndexType Ellipsis;</span>
|
||
|
<span id="L39"><span class="lineNum"> 39</span> : </span>
|
||
|
<span id="L40"><span class="lineNum"> 40</span> : struct TORCH_API Slice final {</span>
|
||
|
<span id="L41"><span class="lineNum"> 41</span> : public:</span>
|
||
|
<span id="L42"><span class="lineNum"> 42</span> <span class="tlaGNC tlaBgGNC"> 36975204 : Slice(</span></span>
|
||
|
<span id="L43"><span class="lineNum"> 43</span> : c10::optional<c10::SymInt> start_index = c10::nullopt,</span>
|
||
|
<span id="L44"><span class="lineNum"> 44</span> : c10::optional<c10::SymInt> stop_index = c10::nullopt,</span>
|
||
|
<span id="L45"><span class="lineNum"> 45</span> <span class="tlaGNC"> 36975204 : c10::optional<c10::SymInt> step_index = c10::nullopt) {</span></span>
|
||
|
<span id="L46"><span class="lineNum"> 46</span> <span class="tlaGNC"> 36975204 : if (!step_index.has_value()) {</span></span>
|
||
|
<span id="L47"><span class="lineNum"> 47</span> <span class="tlaGNC"> 36975204 : step_ = c10::SymInt(1);</span></span>
|
||
|
<span id="L48"><span class="lineNum"> 48</span> : } else {</span>
|
||
|
<span id="L49"><span class="lineNum"> 49</span> <span class="tlaUNC tlaBgUNC"> 0 : step_ = std::move(step_index).value();</span></span>
|
||
|
<span id="L50"><span class="lineNum"> 50</span> : }</span>
|
||
|
<span id="L51"><span class="lineNum"> 51</span> : </span>
|
||
|
<span id="L52"><span class="lineNum"> 52</span> <span class="tlaGNC tlaBgGNC"> 36975204 : TORCH_CHECK_VALUE(step_ != 0, "slice step cannot be zero");</span></span>
|
||
|
<span id="L53"><span class="lineNum"> 53</span> : </span>
|
||
|
<span id="L54"><span class="lineNum"> 54</span> <span class="tlaGNC"> 36975204 : if (!start_index.has_value()) {</span></span>
|
||
|
<span id="L55"><span class="lineNum"> 55</span> <span class="tlaGNC"> 36974694 : start_ = c10::SymInt(step_ < 0 ? INDEX_MAX : 0);</span></span>
|
||
|
<span id="L56"><span class="lineNum"> 56</span> : } else {</span>
|
||
|
<span id="L57"><span class="lineNum"> 57</span> <span class="tlaGNC"> 510 : start_ = std::move(start_index).value();</span></span>
|
||
|
<span id="L58"><span class="lineNum"> 58</span> : }</span>
|
||
|
<span id="L59"><span class="lineNum"> 59</span> : </span>
|
||
|
<span id="L60"><span class="lineNum"> 60</span> <span class="tlaGNC"> 36975204 : if (!stop_index.has_value()) {</span></span>
|
||
|
<span id="L61"><span class="lineNum"> 61</span> <span class="tlaGNC"> 36974694 : stop_ = c10::SymInt(step_ < 0 ? INDEX_MIN : INDEX_MAX);</span></span>
|
||
|
<span id="L62"><span class="lineNum"> 62</span> : } else {</span>
|
||
|
<span id="L63"><span class="lineNum"> 63</span> <span class="tlaGNC"> 510 : stop_ = std::move(stop_index).value();</span></span>
|
||
|
<span id="L64"><span class="lineNum"> 64</span> : }</span>
|
||
|
<span id="L65"><span class="lineNum"> 65</span> <span class="tlaGNC"> 36975204 : }</span></span>
|
||
|
<span id="L66"><span class="lineNum"> 66</span> : </span>
|
||
|
<span id="L67"><span class="lineNum"> 67</span> : inline c10::SymInt start() const {</span>
|
||
|
<span id="L68"><span class="lineNum"> 68</span> : return start_;</span>
|
||
|
<span id="L69"><span class="lineNum"> 69</span> : }</span>
|
||
|
<span id="L70"><span class="lineNum"> 70</span> : </span>
|
||
|
<span id="L71"><span class="lineNum"> 71</span> : inline c10::SymInt stop() const {</span>
|
||
|
<span id="L72"><span class="lineNum"> 72</span> : return stop_;</span>
|
||
|
<span id="L73"><span class="lineNum"> 73</span> : }</span>
|
||
|
<span id="L74"><span class="lineNum"> 74</span> : </span>
|
||
|
<span id="L75"><span class="lineNum"> 75</span> : inline c10::SymInt step() const {</span>
|
||
|
<span id="L76"><span class="lineNum"> 76</span> : return step_;</span>
|
||
|
<span id="L77"><span class="lineNum"> 77</span> : }</span>
|
||
|
<span id="L78"><span class="lineNum"> 78</span> : </span>
|
||
|
<span id="L79"><span class="lineNum"> 79</span> : private:</span>
|
||
|
<span id="L80"><span class="lineNum"> 80</span> : c10::SymInt start_;</span>
|
||
|
<span id="L81"><span class="lineNum"> 81</span> : c10::SymInt stop_;</span>
|
||
|
<span id="L82"><span class="lineNum"> 82</span> : c10::SymInt step_;</span>
|
||
|
<span id="L83"><span class="lineNum"> 83</span> : };</span>
|
||
|
<span id="L84"><span class="lineNum"> 84</span> : </span>
|
||
|
<span id="L85"><span class="lineNum"> 85</span> : TORCH_API std::ostream& operator<<(std::ostream& stream, const Slice& slice);</span>
|
||
|
<span id="L86"><span class="lineNum"> 86</span> : </span>
|
||
|
<span id="L87"><span class="lineNum"> 87</span> : // `at::indexing::TensorIndex` is used for converting C++ tensor indices such as</span>
|
||
|
<span id="L88"><span class="lineNum"> 88</span> : // `{None, "...", Ellipsis, 0, true, Slice(1, None, 2), torch::tensor({1, 2})}`</span>
|
||
|
<span id="L89"><span class="lineNum"> 89</span> : // into its equivalent `std::vector<TensorIndex>`, so that further tensor</span>
|
||
|
<span id="L90"><span class="lineNum"> 90</span> : // indexing operations can be performed using the supplied indices.</span>
|
||
|
<span id="L91"><span class="lineNum"> 91</span> : //</span>
|
||
|
<span id="L92"><span class="lineNum"> 92</span> : // There is one-to-one correspondence between Python and C++ tensor index types:</span>
|
||
|
<span id="L93"><span class="lineNum"> 93</span> : // Python | C++</span>
|
||
|
<span id="L94"><span class="lineNum"> 94</span> : // -----------------------------------------------------</span>
|
||
|
<span id="L95"><span class="lineNum"> 95</span> : // `None` | `at::indexing::None`</span>
|
||
|
<span id="L96"><span class="lineNum"> 96</span> : // `Ellipsis` | `at::indexing::Ellipsis`</span>
|
||
|
<span id="L97"><span class="lineNum"> 97</span> : // `...` | `"..."`</span>
|
||
|
<span id="L98"><span class="lineNum"> 98</span> : // `123` | `123`</span>
|
||
|
<span id="L99"><span class="lineNum"> 99</span> : // `True` / `False` | `true` / `false`</span>
|
||
|
<span id="L100"><span class="lineNum"> 100</span> : // `:` | `Slice()` / `Slice(None, None)`</span>
|
||
|
<span id="L101"><span class="lineNum"> 101</span> : // `::` | `Slice()` / `Slice(None, None, None)`</span>
|
||
|
<span id="L102"><span class="lineNum"> 102</span> : // `1:` | `Slice(1, None)`</span>
|
||
|
<span id="L103"><span class="lineNum"> 103</span> : // `1::` | `Slice(1, None, None)`</span>
|
||
|
<span id="L104"><span class="lineNum"> 104</span> : // `:3` | `Slice(None, 3)`</span>
|
||
|
<span id="L105"><span class="lineNum"> 105</span> : // `:3:` | `Slice(None, 3, None)`</span>
|
||
|
<span id="L106"><span class="lineNum"> 106</span> : // `::2` | `Slice(None, None, 2)`</span>
|
||
|
<span id="L107"><span class="lineNum"> 107</span> : // `1:3` | `Slice(1, 3)`</span>
|
||
|
<span id="L108"><span class="lineNum"> 108</span> : // `1::2` | `Slice(1, None, 2)`</span>
|
||
|
<span id="L109"><span class="lineNum"> 109</span> : // `:3:2` | `Slice(None, 3, 2)`</span>
|
||
|
<span id="L110"><span class="lineNum"> 110</span> : // `1:3:2` | `Slice(1, 3, 2)`</span>
|
||
|
<span id="L111"><span class="lineNum"> 111</span> : // `torch.tensor([1, 2])`) | `torch::tensor({1, 2})`</span>
|
||
|
<span id="L112"><span class="lineNum"> 112</span> : struct TORCH_API TensorIndex final {</span>
|
||
|
<span id="L113"><span class="lineNum"> 113</span> : // Case 1: `at::indexing::None`</span>
|
||
|
<span id="L114"><span class="lineNum"> 114</span> : TensorIndex(c10::nullopt_t) : type_(TensorIndexType::None) {}</span>
|
||
|
<span id="L115"><span class="lineNum"> 115</span> : </span>
|
||
|
<span id="L116"><span class="lineNum"> 116</span> : // Case 2: "..." / `at::indexing::Ellipsis`</span>
|
||
|
<span id="L117"><span class="lineNum"> 117</span> <span class="tlaGNC"> 776042 : TensorIndex(at::indexing::EllipsisIndexType)</span></span>
|
||
|
<span id="L118"><span class="lineNum"> 118</span> <span class="tlaGNC"> 776042 : : type_(TensorIndexType::Ellipsis) {}</span></span>
|
||
|
<span id="L119"><span class="lineNum"> 119</span> <span class="tlaGNC"> 776042 : TensorIndex(const char* str) : TensorIndex(at::indexing::Ellipsis) {</span></span>
|
||
|
<span id="L120"><span class="lineNum"> 120</span> <span class="tlaGNC"> 776042 : TORCH_CHECK_VALUE(</span></span>
|
||
|
<span id="L121"><span class="lineNum"> 121</span> : strcmp(str, "...") == 0,</span>
|
||
|
<span id="L122"><span class="lineNum"> 122</span> : "Expected \"...\" to represent an ellipsis index, but got \"",</span>
|
||
|
<span id="L123"><span class="lineNum"> 123</span> : str,</span>
|
||
|
<span id="L124"><span class="lineNum"> 124</span> : "\"");</span>
|
||
|
<span id="L125"><span class="lineNum"> 125</span> <span class="tlaGNC"> 776042 : }</span></span>
|
||
|
<span id="L126"><span class="lineNum"> 126</span> : </span>
|
||
|
<span id="L127"><span class="lineNum"> 127</span> : // Case 3: (Sym) Integer value</span>
|
||
|
<span id="L128"><span class="lineNum"> 128</span> <span class="tlaGNC"> 36171516 : TensorIndex(SymInt integer)</span></span>
|
||
|
<span id="L129"><span class="lineNum"> 129</span> <span class="tlaGNC"> 36171516 : : integer_(std::move(integer)), type_(TensorIndexType::SymInt) {}</span></span>
|
||
|
<span id="L130"><span class="lineNum"> 130</span> : TensorIndex(int64_t integer) : TensorIndex(SymInt(integer)) {}</span>
|
||
|
<span id="L131"><span class="lineNum"> 131</span> <span class="tlaGNC"> 36171516 : TensorIndex(int integer) : TensorIndex(SymInt(integer)) {}</span></span>
|
||
|
<span id="L132"><span class="lineNum"> 132</span> : </span>
|
||
|
<span id="L133"><span class="lineNum"> 133</span> : // Case 4: Boolean value</span>
|
||
|
<span id="L134"><span class="lineNum"> 134</span> : template <</span>
|
||
|
<span id="L135"><span class="lineNum"> 135</span> : class T,</span>
|
||
|
<span id="L136"><span class="lineNum"> 136</span> : class = typename std::enable_if<std::is_same<bool, T>::value>::type></span>
|
||
|
<span id="L137"><span class="lineNum"> 137</span> : TensorIndex(T boolean) : boolean_(boolean), type_(TensorIndexType::Boolean) {}</span>
|
||
|
<span id="L138"><span class="lineNum"> 138</span> : </span>
|
||
|
<span id="L139"><span class="lineNum"> 139</span> : // Case 5: Slice represented in `at::indexing::Slice` form</span>
|
||
|
<span id="L140"><span class="lineNum"> 140</span> <span class="tlaGNC"> 510 : TensorIndex(Slice slice)</span></span>
|
||
|
<span id="L141"><span class="lineNum"> 141</span> <span class="tlaGNC"> 510 : : slice_(std::move(slice)), type_(TensorIndexType::Slice) {}</span></span>
|
||
|
<span id="L142"><span class="lineNum"> 142</span> : </span>
|
||
|
<span id="L143"><span class="lineNum"> 143</span> : // Case 6: Tensor value</span>
|
||
|
<span id="L144"><span class="lineNum"> 144</span> <span class="tlaGNC"> 27136 : TensorIndex(Tensor tensor)</span></span>
|
||
|
<span id="L145"><span class="lineNum"> 145</span> <span class="tlaGNC"> 27136 : : tensor_(std::move(tensor)), type_(TensorIndexType::Tensor) {}</span></span>
|
||
|
<span id="L146"><span class="lineNum"> 146</span> : </span>
|
||
|
<span id="L147"><span class="lineNum"> 147</span> : inline bool is_none() const {</span>
|
||
|
<span id="L148"><span class="lineNum"> 148</span> : return type_ == TensorIndexType::None;</span>
|
||
|
<span id="L149"><span class="lineNum"> 149</span> : }</span>
|
||
|
<span id="L150"><span class="lineNum"> 150</span> : </span>
|
||
|
<span id="L151"><span class="lineNum"> 151</span> : inline bool is_ellipsis() const {</span>
|
||
|
<span id="L152"><span class="lineNum"> 152</span> : return type_ == TensorIndexType::Ellipsis;</span>
|
||
|
<span id="L153"><span class="lineNum"> 153</span> : }</span>
|
||
|
<span id="L154"><span class="lineNum"> 154</span> : </span>
|
||
|
<span id="L155"><span class="lineNum"> 155</span> : inline bool is_integer() const {</span>
|
||
|
<span id="L156"><span class="lineNum"> 156</span> : return type_ == TensorIndexType::SymInt;</span>
|
||
|
<span id="L157"><span class="lineNum"> 157</span> : }</span>
|
||
|
<span id="L158"><span class="lineNum"> 158</span> : </span>
|
||
|
<span id="L159"><span class="lineNum"> 159</span> : inline SymInt integer() const {</span>
|
||
|
<span id="L160"><span class="lineNum"> 160</span> : return integer_;</span>
|
||
|
<span id="L161"><span class="lineNum"> 161</span> : }</span>
|
||
|
<span id="L162"><span class="lineNum"> 162</span> : </span>
|
||
|
<span id="L163"><span class="lineNum"> 163</span> : inline bool is_boolean() const {</span>
|
||
|
<span id="L164"><span class="lineNum"> 164</span> : return type_ == TensorIndexType::Boolean;</span>
|
||
|
<span id="L165"><span class="lineNum"> 165</span> : }</span>
|
||
|
<span id="L166"><span class="lineNum"> 166</span> : </span>
|
||
|
<span id="L167"><span class="lineNum"> 167</span> : inline bool boolean() const {</span>
|
||
|
<span id="L168"><span class="lineNum"> 168</span> : return boolean_;</span>
|
||
|
<span id="L169"><span class="lineNum"> 169</span> : }</span>
|
||
|
<span id="L170"><span class="lineNum"> 170</span> : </span>
|
||
|
<span id="L171"><span class="lineNum"> 171</span> : inline bool is_slice() const {</span>
|
||
|
<span id="L172"><span class="lineNum"> 172</span> : return type_ == TensorIndexType::Slice;</span>
|
||
|
<span id="L173"><span class="lineNum"> 173</span> : }</span>
|
||
|
<span id="L174"><span class="lineNum"> 174</span> : </span>
|
||
|
<span id="L175"><span class="lineNum"> 175</span> : inline const Slice& slice() const {</span>
|
||
|
<span id="L176"><span class="lineNum"> 176</span> : return slice_;</span>
|
||
|
<span id="L177"><span class="lineNum"> 177</span> : }</span>
|
||
|
<span id="L178"><span class="lineNum"> 178</span> : </span>
|
||
|
<span id="L179"><span class="lineNum"> 179</span> : inline bool is_tensor() const {</span>
|
||
|
<span id="L180"><span class="lineNum"> 180</span> : return type_ == TensorIndexType::Tensor;</span>
|
||
|
<span id="L181"><span class="lineNum"> 181</span> : }</span>
|
||
|
<span id="L182"><span class="lineNum"> 182</span> : </span>
|
||
|
<span id="L183"><span class="lineNum"> 183</span> : inline const Tensor& tensor() const {</span>
|
||
|
<span id="L184"><span class="lineNum"> 184</span> : return tensor_;</span>
|
||
|
<span id="L185"><span class="lineNum"> 185</span> : }</span>
|
||
|
<span id="L186"><span class="lineNum"> 186</span> : </span>
|
||
|
<span id="L187"><span class="lineNum"> 187</span> : private:</span>
|
||
|
<span id="L188"><span class="lineNum"> 188</span> : SymInt integer_ = 0;</span>
|
||
|
<span id="L189"><span class="lineNum"> 189</span> : bool boolean_ = false;</span>
|
||
|
<span id="L190"><span class="lineNum"> 190</span> : Slice slice_;</span>
|
||
|
<span id="L191"><span class="lineNum"> 191</span> : Tensor tensor_;</span>
|
||
|
<span id="L192"><span class="lineNum"> 192</span> : TensorIndexType type_;</span>
|
||
|
<span id="L193"><span class="lineNum"> 193</span> : };</span>
|
||
|
<span id="L194"><span class="lineNum"> 194</span> : </span>
|
||
|
<span id="L195"><span class="lineNum"> 195</span> : TORCH_API std::ostream& operator<<(</span>
|
||
|
<span id="L196"><span class="lineNum"> 196</span> : std::ostream& stream,</span>
|
||
|
<span id="L197"><span class="lineNum"> 197</span> : const TensorIndex& tensor_index);</span>
|
||
|
<span id="L198"><span class="lineNum"> 198</span> : TORCH_API std::ostream& operator<<(</span>
|
||
|
<span id="L199"><span class="lineNum"> 199</span> : std::ostream& stream,</span>
|
||
|
<span id="L200"><span class="lineNum"> 200</span> : const std::vector<TensorIndex>& tensor_indices);</span>
|
||
|
<span id="L201"><span class="lineNum"> 201</span> : </span>
|
||
|
<span id="L202"><span class="lineNum"> 202</span> : namespace impl {</span>
|
||
|
<span id="L203"><span class="lineNum"> 203</span> : static inline Tensor applySlice(</span>
|
||
|
<span id="L204"><span class="lineNum"> 204</span> : const Tensor& self,</span>
|
||
|
<span id="L205"><span class="lineNum"> 205</span> : int64_t dim,</span>
|
||
|
<span id="L206"><span class="lineNum"> 206</span> : c10::SymInt start,</span>
|
||
|
<span id="L207"><span class="lineNum"> 207</span> : c10::SymInt stop,</span>
|
||
|
<span id="L208"><span class="lineNum"> 208</span> : c10::SymInt step,</span>
|
||
|
<span id="L209"><span class="lineNum"> 209</span> : bool disable_slice_optimization,</span>
|
||
|
<span id="L210"><span class="lineNum"> 210</span> : const at::Device& self_device,</span>
|
||
|
<span id="L211"><span class="lineNum"> 211</span> : const c10::optional<SymIntArrayRef>& self_sizes) {</span>
|
||
|
<span id="L212"><span class="lineNum"> 212</span> : // TODO: implement negative step</span>
|
||
|
<span id="L213"><span class="lineNum"> 213</span> : TORCH_CHECK_VALUE(step > 0, "step must be greater than zero");</span>
|
||
|
<span id="L214"><span class="lineNum"> 214</span> : </span>
|
||
|
<span id="L215"><span class="lineNum"> 215</span> : // See NOTE [nested tensor size for indexing]</span>
|
||
|
<span id="L216"><span class="lineNum"> 216</span> : if (self_sizes.has_value()) {</span>
|
||
|
<span id="L217"><span class="lineNum"> 217</span> : // Skip this optimization if we are tracing, as the trace may be polymorphic</span>
|
||
|
<span id="L218"><span class="lineNum"> 218</span> : // over the shape of the `self` tensor, and we still want to record</span>
|
||
|
<span id="L219"><span class="lineNum"> 219</span> : // the slice.</span>
|
||
|
<span id="L220"><span class="lineNum"> 220</span> : SymInt length = (self_device == at::kCPU || self_device == at::kCUDA)</span>
|
||
|
<span id="L221"><span class="lineNum"> 221</span> : ? (*self_sizes)[dim]</span>
|
||
|
<span id="L222"><span class="lineNum"> 222</span> : : self.sym_size(dim);</span>
|
||
|
<span id="L223"><span class="lineNum"> 223</span> : if (!disable_slice_optimization && start == 0 && length == stop &&</span>
|
||
|
<span id="L224"><span class="lineNum"> 224</span> : step == 1) {</span>
|
||
|
<span id="L225"><span class="lineNum"> 225</span> : return self;</span>
|
||
|
<span id="L226"><span class="lineNum"> 226</span> : }</span>
|
||
|
<span id="L227"><span class="lineNum"> 227</span> : }</span>
|
||
|
<span id="L228"><span class="lineNum"> 228</span> : return self.slice_symint(dim, start, stop, std::move(step));</span>
|
||
|
<span id="L229"><span class="lineNum"> 229</span> : }</span>
|
||
|
<span id="L230"><span class="lineNum"> 230</span> : </span>
|
||
|
<span id="L231"><span class="lineNum"> 231</span> : static inline Tensor applySelect(</span>
|
||
|
<span id="L232"><span class="lineNum"> 232</span> : const Tensor& self,</span>
|
||
|
<span id="L233"><span class="lineNum"> 233</span> : int64_t dim,</span>
|
||
|
<span id="L234"><span class="lineNum"> 234</span> : SymInt index,</span>
|
||
|
<span id="L235"><span class="lineNum"> 235</span> : int64_t real_dim,</span>
|
||
|
<span id="L236"><span class="lineNum"> 236</span> : const at::Device& /*self_device*/,</span>
|
||
|
<span id="L237"><span class="lineNum"> 237</span> : const c10::optional<SymIntArrayRef>& self_sizes) {</span>
|
||
|
<span id="L238"><span class="lineNum"> 238</span> : // See NOTE [nested tensor size for indexing]</span>
|
||
|
<span id="L239"><span class="lineNum"> 239</span> : if (self_sizes.has_value()) {</span>
|
||
|
<span id="L240"><span class="lineNum"> 240</span> : auto maybe_index = index.maybe_as_int();</span>
|
||
|
<span id="L241"><span class="lineNum"> 241</span> : if (maybe_index.has_value()) {</span>
|
||
|
<span id="L242"><span class="lineNum"> 242</span> : TORCH_CHECK_INDEX(</span>
|
||
|
<span id="L243"><span class="lineNum"> 243</span> : !(maybe_index.value() == 0 && dim == 0 && self_sizes->empty()),</span>
|
||
|
<span id="L244"><span class="lineNum"> 244</span> : "invalid index of a 0-dim tensor. ",</span>
|
||
|
<span id="L245"><span class="lineNum"> 245</span> : "Use `tensor.item()` in Python or `tensor.item<T>()` in C++ to convert a 0-dim tensor to a number");</span>
|
||
|
<span id="L246"><span class="lineNum"> 246</span> : }</span>
|
||
|
<span id="L247"><span class="lineNum"> 247</span> : </span>
|
||
|
<span id="L248"><span class="lineNum"> 248</span> : auto size = (*self_sizes)[dim];</span>
|
||
|
<span id="L249"><span class="lineNum"> 249</span> : TORCH_CHECK_INDEX(</span>
|
||
|
<span id="L250"><span class="lineNum"> 250</span> : size >= -index && size > index,</span>
|
||
|
<span id="L251"><span class="lineNum"> 251</span> : "index ",</span>
|
||
|
<span id="L252"><span class="lineNum"> 252</span> : index,</span>
|
||
|
<span id="L253"><span class="lineNum"> 253</span> : " is out of bounds for dimension ",</span>
|
||
|
<span id="L254"><span class="lineNum"> 254</span> : real_dim,</span>
|
||
|
<span id="L255"><span class="lineNum"> 255</span> : " with size ",</span>
|
||
|
<span id="L256"><span class="lineNum"> 256</span> : size);</span>
|
||
|
<span id="L257"><span class="lineNum"> 257</span> : }</span>
|
||
|
<span id="L258"><span class="lineNum"> 258</span> : </span>
|
||
|
<span id="L259"><span class="lineNum"> 259</span> : // if the index is negative, do not normalize it because that would fix the</span>
|
||
|
<span id="L260"><span class="lineNum"> 260</span> : // index on the current tensor size in the tracer. aten::select also works on</span>
|
||
|
<span id="L261"><span class="lineNum"> 261</span> : // negative indices</span>
|
||
|
<span id="L262"><span class="lineNum"> 262</span> : return self.select_symint(dim, index);</span>
|
||
|
<span id="L263"><span class="lineNum"> 263</span> : }</span>
|
||
|
<span id="L264"><span class="lineNum"> 264</span> : </span>
|
||
|
<span id="L265"><span class="lineNum"> 265</span> : static inline Tensor boolToIndexingTensorCPUOrCUDA(</span>
|
||
|
<span id="L266"><span class="lineNum"> 266</span> : const Tensor& self,</span>
|
||
|
<span id="L267"><span class="lineNum"> 267</span> : bool value) {</span>
|
||
|
<span id="L268"><span class="lineNum"> 268</span> : // booleans add a dimension of size 1. true indexes this dimension as if 0:,</span>
|
||
|
<span id="L269"><span class="lineNum"> 269</span> : // false as empty.</span>
|
||
|
<span id="L270"><span class="lineNum"> 270</span> : if (value) {</span>
|
||
|
<span id="L271"><span class="lineNum"> 271</span> : return at::empty({1}, {}, self.options().dtype(kLong)).fill_(0.);</span>
|
||
|
<span id="L272"><span class="lineNum"> 272</span> : } else {</span>
|
||
|
<span id="L273"><span class="lineNum"> 273</span> : return at::empty({0}, {}, self.options().dtype(kLong));</span>
|
||
|
<span id="L274"><span class="lineNum"> 274</span> : }</span>
|
||
|
<span id="L275"><span class="lineNum"> 275</span> : }</span>
|
||
|
<span id="L276"><span class="lineNum"> 276</span> : </span>
|
||
|
<span id="L277"><span class="lineNum"> 277</span> : static inline Tensor boolToIndexingTensorNonNativeDeviceType(</span>
|
||
|
<span id="L278"><span class="lineNum"> 278</span> : const Tensor& self,</span>
|
||
|
<span id="L279"><span class="lineNum"> 279</span> : bool value) {</span>
|
||
|
<span id="L280"><span class="lineNum"> 280</span> : // booleans add a dimension of size 1. true indexes this dimension as if 0:,</span>
|
||
|
<span id="L281"><span class="lineNum"> 281</span> : // false as empty.</span>
|
||
|
<span id="L282"><span class="lineNum"> 282</span> : if (value) {</span>
|
||
|
<span id="L283"><span class="lineNum"> 283</span> : return at::zeros({1}, {}, self.options().dtype(kLong));</span>
|
||
|
<span id="L284"><span class="lineNum"> 284</span> : } else {</span>
|
||
|
<span id="L285"><span class="lineNum"> 285</span> : return at::empty({0}, {}, self.options().dtype(kLong));</span>
|
||
|
<span id="L286"><span class="lineNum"> 286</span> : }</span>
|
||
|
<span id="L287"><span class="lineNum"> 287</span> : }</span>
|
||
|
<span id="L288"><span class="lineNum"> 288</span> : </span>
|
||
|
<span id="L289"><span class="lineNum"> 289</span> : static inline Tensor boolToIndexingTensor(</span>
|
||
|
<span id="L290"><span class="lineNum"> 290</span> : const Tensor& self,</span>
|
||
|
<span id="L291"><span class="lineNum"> 291</span> : bool value,</span>
|
||
|
<span id="L292"><span class="lineNum"> 292</span> : const at::Device& self_device) {</span>
|
||
|
<span id="L293"><span class="lineNum"> 293</span> : if (self_device == at::kCPU || self_device == at::kCUDA) {</span>
|
||
|
<span id="L294"><span class="lineNum"> 294</span> : return boolToIndexingTensorCPUOrCUDA(self, value);</span>
|
||
|
<span id="L295"><span class="lineNum"> 295</span> : } else {</span>
|
||
|
<span id="L296"><span class="lineNum"> 296</span> : return boolToIndexingTensorNonNativeDeviceType(self, value);</span>
|
||
|
<span id="L297"><span class="lineNum"> 297</span> : }</span>
|
||
|
<span id="L298"><span class="lineNum"> 298</span> : }</span>
|
||
|
<span id="L299"><span class="lineNum"> 299</span> : </span>
|
||
|
<span id="L300"><span class="lineNum"> 300</span> : static inline Tensor scalarToTensorNonNativeDeviceType(</span>
|
||
|
<span id="L301"><span class="lineNum"> 301</span> : const Scalar& v,</span>
|
||
|
<span id="L302"><span class="lineNum"> 302</span> : const TensorOptions& options) {</span>
|
||
|
<span id="L303"><span class="lineNum"> 303</span> : return at::scalar_tensor(v, options);</span>
|
||
|
<span id="L304"><span class="lineNum"> 304</span> : }</span>
|
||
|
<span id="L305"><span class="lineNum"> 305</span> : </span>
|
||
|
<span id="L306"><span class="lineNum"> 306</span> : static inline void recordTensorIndex(</span>
|
||
|
<span id="L307"><span class="lineNum"> 307</span> : const Tensor& tensor,</span>
|
||
|
<span id="L308"><span class="lineNum"> 308</span> : std::vector<Tensor>& outIndices,</span>
|
||
|
<span id="L309"><span class="lineNum"> 309</span> : int64_t* dim_ptr) {</span>
|
||
|
<span id="L310"><span class="lineNum"> 310</span> : // TODO: check scalarType</span>
|
||
|
<span id="L311"><span class="lineNum"> 311</span> : outIndices.resize(*dim_ptr + 1);</span>
|
||
|
<span id="L312"><span class="lineNum"> 312</span> : outIndices[*dim_ptr] = tensor;</span>
|
||
|
<span id="L313"><span class="lineNum"> 313</span> : (*dim_ptr)++;</span>
|
||
|
<span id="L314"><span class="lineNum"> 314</span> : };</span>
|
||
|
<span id="L315"><span class="lineNum"> 315</span> : </span>
|
||
|
<span id="L316"><span class="lineNum"> 316</span> : static inline c10::List<c10::optional<Tensor>> typeConvertIndices(</span>
|
||
|
<span id="L317"><span class="lineNum"> 317</span> : const Tensor& /*self*/,</span>
|
||
|
<span id="L318"><span class="lineNum"> 318</span> : std::vector<Tensor>&& indices) {</span>
|
||
|
<span id="L319"><span class="lineNum"> 319</span> : c10::List<c10::optional<Tensor>> converted_inds;</span>
|
||
|
<span id="L320"><span class="lineNum"> 320</span> : converted_inds.reserve(indices.size());</span>
|
||
|
<span id="L321"><span class="lineNum"> 321</span> : for (const auto& i : indices) {</span>
|
||
|
<span id="L322"><span class="lineNum"> 322</span> : converted_inds.push_back(std::move(i));</span>
|
||
|
<span id="L323"><span class="lineNum"> 323</span> : }</span>
|
||
|
<span id="L324"><span class="lineNum"> 324</span> : return converted_inds;</span>
|
||
|
<span id="L325"><span class="lineNum"> 325</span> : }</span>
|
||
|
<span id="L326"><span class="lineNum"> 326</span> : </span>
|
||
|
<span id="L327"><span class="lineNum"> 327</span> : // NOTE: Why do we mirror instead of replace the `count_specified_dimensions`</span>
|
||
|
<span id="L328"><span class="lineNum"> 328</span> : // function in torch/csrc/autograd/python_variable_indexing.cpp? It's because</span>
|
||
|
<span id="L329"><span class="lineNum"> 329</span> : // `count_specified_dimensions` is on the hot path of Python tensor multi-dim</span>
|
||
|
<span id="L330"><span class="lineNum"> 330</span> : // indexing (i.e. it's called by `applySlicing` which is called by</span>
|
||
|
<span id="L331"><span class="lineNum"> 331</span> : // `THPVariable_getitem` / `THPVariable_setitem` when handling indexing of more</span>
|
||
|
<span id="L332"><span class="lineNum"> 332</span> : // than one dimension). If we were to merge the Python/C++</span>
|
||
|
<span id="L333"><span class="lineNum"> 333</span> : // `count_specified_dimensions` function, on the Python side we would have to</span>
|
||
|
<span id="L334"><span class="lineNum"> 334</span> : // construct a `std::vector` container to be consumed by the C++</span>
|
||
|
<span id="L335"><span class="lineNum"> 335</span> : // `count_specified_dimensions` function, which adds 100s of nanoseconds</span>
|
||
|
<span id="L336"><span class="lineNum"> 336</span> : // overhead and is undesirable.</span>
|
||
|
<span id="L337"><span class="lineNum"> 337</span> : static inline int64_t count_specified_dimensions(</span>
|
||
|
<span id="L338"><span class="lineNum"> 338</span> : const ArrayRef<TensorIndex>& indices) {</span>
|
||
|
<span id="L339"><span class="lineNum"> 339</span> : // Count the number of indexed dimensions (everything but ellipsis and None)</span>
|
||
|
<span id="L340"><span class="lineNum"> 340</span> : int64_t count = 0;</span>
|
||
|
<span id="L341"><span class="lineNum"> 341</span> : for (auto& obj : indices) {</span>
|
||
|
<span id="L342"><span class="lineNum"> 342</span> : if (obj.is_tensor()) {</span>
|
||
|
<span id="L343"><span class="lineNum"> 343</span> : auto& tensor = obj.tensor();</span>
|
||
|
<span id="L344"><span class="lineNum"> 344</span> : if (tensor.scalar_type() == kByte || tensor.scalar_type() == kBool) {</span>
|
||
|
<span id="L345"><span class="lineNum"> 345</span> : count += tensor.dim();</span>
|
||
|
<span id="L346"><span class="lineNum"> 346</span> : } else {</span>
|
||
|
<span id="L347"><span class="lineNum"> 347</span> : count++;</span>
|
||
|
<span id="L348"><span class="lineNum"> 348</span> : }</span>
|
||
|
<span id="L349"><span class="lineNum"> 349</span> : } else if (!obj.is_none() && !obj.is_ellipsis() && !obj.is_boolean()) {</span>
|
||
|
<span id="L350"><span class="lineNum"> 350</span> : count++;</span>
|
||
|
<span id="L351"><span class="lineNum"> 351</span> : }</span>
|
||
|
<span id="L352"><span class="lineNum"> 352</span> : }</span>
|
||
|
<span id="L353"><span class="lineNum"> 353</span> : return count;</span>
|
||
|
<span id="L354"><span class="lineNum"> 354</span> : }</span>
|
||
|
<span id="L355"><span class="lineNum"> 355</span> : } // namespace impl</span>
|
||
|
<span id="L356"><span class="lineNum"> 356</span> : </span>
|
||
|
<span id="L357"><span class="lineNum"> 357</span> : // NOTE: Many functions below are only for consumption from Python indexing</span>
|
||
|
<span id="L358"><span class="lineNum"> 358</span> : // implementation, they include:</span>
|
||
|
<span id="L359"><span class="lineNum"> 359</span> : //</span>
|
||
|
<span id="L360"><span class="lineNum"> 360</span> : // - `Tensor scalarToTensor(...)`</span>
|
||
|
<span id="L361"><span class="lineNum"> 361</span> : // - `IntArrayRef slicePrefix1sSize(...)`</span>
|
||
|
<span id="L362"><span class="lineNum"> 362</span> : // - `void copy_to(...)`</span>
|
||
|
<span id="L363"><span class="lineNum"> 363</span> : // - `Tensor handleDimInMultiDimIndexing(...)`</span>
|
||
|
<span id="L364"><span class="lineNum"> 364</span> : // - `Tensor dispatch_index(...)`</span>
|
||
|
<span id="L365"><span class="lineNum"> 365</span> : // - `Tensor dispatch_index_put_(...)`</span>
|
||
|
<span id="L366"><span class="lineNum"> 366</span> : // - `Tensor get_item(...)`</span>
|
||
|
<span id="L367"><span class="lineNum"> 367</span> : // - `void set_item(...)`</span>
|
||
|
<span id="L368"><span class="lineNum"> 368</span> : //</span>
|
||
|
<span id="L369"><span class="lineNum"> 369</span> : // The rest of the functions are in `at::indexing::impl` namespace, signifying</span>
|
||
|
<span id="L370"><span class="lineNum"> 370</span> : // that they shouldn't be used from Python indexing implementation.</span>
|
||
|
<span id="L371"><span class="lineNum"> 371</span> : static inline Tensor scalarToTensor(</span>
|
||
|
<span id="L372"><span class="lineNum"> 372</span> : const Scalar& v,</span>
|
||
|
<span id="L373"><span class="lineNum"> 373</span> : const TensorOptions& options,</span>
|
||
|
<span id="L374"><span class="lineNum"> 374</span> : const at::Device& self_device) {</span>
|
||
|
<span id="L375"><span class="lineNum"> 375</span> : if (self_device == at::kCPU) {</span>
|
||
|
<span id="L376"><span class="lineNum"> 376</span> : return at::detail::scalar_tensor_static(</span>
|
||
|
<span id="L377"><span class="lineNum"> 377</span> : v, options.dtype_opt()->toScalarType(), self_device);</span>
|
||
|
<span id="L378"><span class="lineNum"> 378</span> : } else {</span>
|
||
|
<span id="L379"><span class="lineNum"> 379</span> : return impl::scalarToTensorNonNativeDeviceType(v, options);</span>
|
||
|
<span id="L380"><span class="lineNum"> 380</span> : }</span>
|
||
|
<span id="L381"><span class="lineNum"> 381</span> : }</span>
|
||
|
<span id="L382"><span class="lineNum"> 382</span> : </span>
|
||
|
<span id="L383"><span class="lineNum"> 383</span> : // To match numpy semantics:</span>
|
||
|
<span id="L384"><span class="lineNum"> 384</span> : // As a special case for backwards compatibility,</span>
|
||
|
<span id="L385"><span class="lineNum"> 385</span> : // strip away unit dimensions from the left of 'src'</span>
|
||
|
<span id="L386"><span class="lineNum"> 386</span> : static inline SymIntArrayRef slicePrefix1sSize(const SymIntArrayRef& sizes) {</span>
|
||
|
<span id="L387"><span class="lineNum"> 387</span> : size_t first_non1_src = sizes.size();</span>
|
||
|
<span id="L388"><span class="lineNum"> 388</span> : for (const auto i : c10::irange(sizes.size())) {</span>
|
||
|
<span id="L389"><span class="lineNum"> 389</span> : // Unbacked SymInt has different behavior, but this is sound because</span>
|
||
|
<span id="L390"><span class="lineNum"> 390</span> : // failing to slice will only ever cause an error, not divergent</span>
|
||
|
<span id="L391"><span class="lineNum"> 391</span> : // behavior</span>
|
||
|
<span id="L392"><span class="lineNum"> 392</span> : if (!sizes[i].has_hint() || sizes[i] != 1) {</span>
|
||
|
<span id="L393"><span class="lineNum"> 393</span> : first_non1_src = i;</span>
|
||
|
<span id="L394"><span class="lineNum"> 394</span> : break;</span>
|
||
|
<span id="L395"><span class="lineNum"> 395</span> : }</span>
|
||
|
<span id="L396"><span class="lineNum"> 396</span> : }</span>
|
||
|
<span id="L397"><span class="lineNum"> 397</span> : </span>
|
||
|
<span id="L398"><span class="lineNum"> 398</span> : return sizes.slice(first_non1_src);</span>
|
||
|
<span id="L399"><span class="lineNum"> 399</span> : }</span>
|
||
|
<span id="L400"><span class="lineNum"> 400</span> : </span>
|
||
|
<span id="L401"><span class="lineNum"> 401</span> : static inline void copy_to(const Tensor& dst, const Tensor& src) {</span>
|
||
|
<span id="L402"><span class="lineNum"> 402</span> : if (dst.sym_sizes().equals(src.sym_sizes())) {</span>
|
||
|
<span id="L403"><span class="lineNum"> 403</span> : // A shortcut to avoid generating hard-coded constant sizes during tracing.</span>
|
||
|
<span id="L404"><span class="lineNum"> 404</span> : // This is not a perfect solution: when src & dst have different shapes,</span>
|
||
|
<span id="L405"><span class="lineNum"> 405</span> : // constants will still appear. Users can workaround that case by</span>
|
||
|
<span id="L406"><span class="lineNum"> 406</span> : // dst[index..] = src.reshape(..)</span>
|
||
|
<span id="L407"><span class="lineNum"> 407</span> : dst.copy_(src);</span>
|
||
|
<span id="L408"><span class="lineNum"> 408</span> : return;</span>
|
||
|
<span id="L409"><span class="lineNum"> 409</span> : } else if (src.dim() == 0 && src.device().type() == at::kCPU) {</span>
|
||
|
<span id="L410"><span class="lineNum"> 410</span> : dst.fill_(src);</span>
|
||
|
<span id="L411"><span class="lineNum"> 411</span> : return;</span>
|
||
|
<span id="L412"><span class="lineNum"> 412</span> : }</span>
|
||
|
<span id="L413"><span class="lineNum"> 413</span> : auto src_view = src.view_symint(slicePrefix1sSize(src.sym_sizes()));</span>
|
||
|
<span id="L414"><span class="lineNum"> 414</span> : c10::MaybeOwned<Tensor> b_src = expand_inplace(dst, src_view, "setitem");</span>
|
||
|
<span id="L415"><span class="lineNum"> 415</span> : dst.copy_(*b_src);</span>
|
||
|
<span id="L416"><span class="lineNum"> 416</span> : }</span>
|
||
|
<span id="L417"><span class="lineNum"> 417</span> : </span>
|
||
|
<span id="L418"><span class="lineNum"> 418</span> : // See NOTE [ Setting `disable_slice_optimization` when calling C++ tensor</span>
|
||
|
<span id="L419"><span class="lineNum"> 419</span> : // indexing functions from Python ]</span>
|
||
|
<span id="L420"><span class="lineNum"> 420</span> : static inline Tensor handleDimInMultiDimIndexing(</span>
|
||
|
<span id="L421"><span class="lineNum"> 421</span> : const Tensor& prev_dim_result,</span>
|
||
|
<span id="L422"><span class="lineNum"> 422</span> : const Tensor& original_tensor,</span>
|
||
|
<span id="L423"><span class="lineNum"> 423</span> : const TensorIndex& index,</span>
|
||
|
<span id="L424"><span class="lineNum"> 424</span> : int64_t* dim_ptr,</span>
|
||
|
<span id="L425"><span class="lineNum"> 425</span> : int64_t* specified_dims_ptr,</span>
|
||
|
<span id="L426"><span class="lineNum"> 426</span> : int64_t real_dim,</span>
|
||
|
<span id="L427"><span class="lineNum"> 427</span> : std::vector<Tensor>& outIndices,</span>
|
||
|
<span id="L428"><span class="lineNum"> 428</span> : bool disable_slice_optimization,</span>
|
||
|
<span id="L429"><span class="lineNum"> 429</span> : const at::Device& original_tensor_device,</span>
|
||
|
<span id="L430"><span class="lineNum"> 430</span> : const c10::optional<SymIntArrayRef>& prev_dim_result_sizes) {</span>
|
||
|
<span id="L431"><span class="lineNum"> 431</span> : if (index.is_integer()) {</span>
|
||
|
<span id="L432"><span class="lineNum"> 432</span> : return impl::applySelect(</span>
|
||
|
<span id="L433"><span class="lineNum"> 433</span> : prev_dim_result,</span>
|
||
|
<span id="L434"><span class="lineNum"> 434</span> : *dim_ptr,</span>
|
||
|
<span id="L435"><span class="lineNum"> 435</span> : index.integer(),</span>
|
||
|
<span id="L436"><span class="lineNum"> 436</span> : real_dim,</span>
|
||
|
<span id="L437"><span class="lineNum"> 437</span> : original_tensor_device,</span>
|
||
|
<span id="L438"><span class="lineNum"> 438</span> : prev_dim_result_sizes);</span>
|
||
|
<span id="L439"><span class="lineNum"> 439</span> : } else if (index.is_slice()) {</span>
|
||
|
<span id="L440"><span class="lineNum"> 440</span> : Tensor result = impl::applySlice(</span>
|
||
|
<span id="L441"><span class="lineNum"> 441</span> : prev_dim_result,</span>
|
||
|
<span id="L442"><span class="lineNum"> 442</span> : *dim_ptr,</span>
|
||
|
<span id="L443"><span class="lineNum"> 443</span> : index.slice().start(),</span>
|
||
|
<span id="L444"><span class="lineNum"> 444</span> : index.slice().stop(),</span>
|
||
|
<span id="L445"><span class="lineNum"> 445</span> : index.slice().step(),</span>
|
||
|
<span id="L446"><span class="lineNum"> 446</span> : /*disable_slice_optimization=*/disable_slice_optimization,</span>
|
||
|
<span id="L447"><span class="lineNum"> 447</span> : original_tensor_device,</span>
|
||
|
<span id="L448"><span class="lineNum"> 448</span> : prev_dim_result_sizes);</span>
|
||
|
<span id="L449"><span class="lineNum"> 449</span> : (*dim_ptr)++;</span>
|
||
|
<span id="L450"><span class="lineNum"> 450</span> : return result;</span>
|
||
|
<span id="L451"><span class="lineNum"> 451</span> : } else if (index.is_ellipsis()) {</span>
|
||
|
<span id="L452"><span class="lineNum"> 452</span> : (*dim_ptr) += original_tensor.dim() - (*specified_dims_ptr);</span>
|
||
|
<span id="L453"><span class="lineNum"> 453</span> : return prev_dim_result;</span>
|
||
|
<span id="L454"><span class="lineNum"> 454</span> : } else if (index.is_none()) {</span>
|
||
|
<span id="L455"><span class="lineNum"> 455</span> : Tensor result = prev_dim_result.unsqueeze(*dim_ptr);</span>
|
||
|
<span id="L456"><span class="lineNum"> 456</span> : (*dim_ptr)++;</span>
|
||
|
<span id="L457"><span class="lineNum"> 457</span> : return result;</span>
|
||
|
<span id="L458"><span class="lineNum"> 458</span> : } else if (index.is_boolean()) {</span>
|
||
|
<span id="L459"><span class="lineNum"> 459</span> : Tensor result = prev_dim_result.unsqueeze(*dim_ptr);</span>
|
||
|
<span id="L460"><span class="lineNum"> 460</span> : impl::recordTensorIndex(</span>
|
||
|
<span id="L461"><span class="lineNum"> 461</span> : impl::boolToIndexingTensor(</span>
|
||
|
<span id="L462"><span class="lineNum"> 462</span> : result, index.boolean(), original_tensor_device),</span>
|
||
|
<span id="L463"><span class="lineNum"> 463</span> : outIndices,</span>
|
||
|
<span id="L464"><span class="lineNum"> 464</span> : dim_ptr);</span>
|
||
|
<span id="L465"><span class="lineNum"> 465</span> : return result;</span>
|
||
|
<span id="L466"><span class="lineNum"> 466</span> : } else if (index.is_tensor()) {</span>
|
||
|
<span id="L467"><span class="lineNum"> 467</span> : Tensor result = prev_dim_result;</span>
|
||
|
<span id="L468"><span class="lineNum"> 468</span> : const Tensor& tensor = index.tensor();</span>
|
||
|
<span id="L469"><span class="lineNum"> 469</span> : auto scalar_type = tensor.scalar_type();</span>
|
||
|
<span id="L470"><span class="lineNum"> 470</span> : if (tensor.dim() == 0 &&</span>
|
||
|
<span id="L471"><span class="lineNum"> 471</span> : at::isIntegralType(scalar_type, /*includeBool=*/true)) {</span>
|
||
|
<span id="L472"><span class="lineNum"> 472</span> : if (scalar_type != at::kByte && scalar_type != at::kBool) {</span>
|
||
|
<span id="L473"><span class="lineNum"> 473</span> : result = impl::applySelect(</span>
|
||
|
<span id="L474"><span class="lineNum"> 474</span> : result,</span>
|
||
|
<span id="L475"><span class="lineNum"> 475</span> : *dim_ptr,</span>
|
||
|
<span id="L476"><span class="lineNum"> 476</span> : tensor.item<int64_t>(),</span>
|
||
|
<span id="L477"><span class="lineNum"> 477</span> : real_dim,</span>
|
||
|
<span id="L478"><span class="lineNum"> 478</span> : original_tensor_device,</span>
|
||
|
<span id="L479"><span class="lineNum"> 479</span> : prev_dim_result_sizes);</span>
|
||
|
<span id="L480"><span class="lineNum"> 480</span> : } else {</span>
|
||
|
<span id="L481"><span class="lineNum"> 481</span> : result = result.unsqueeze(*dim_ptr);</span>
|
||
|
<span id="L482"><span class="lineNum"> 482</span> : if (scalar_type == at::kBool) {</span>
|
||
|
<span id="L483"><span class="lineNum"> 483</span> : impl::recordTensorIndex(</span>
|
||
|
<span id="L484"><span class="lineNum"> 484</span> : impl::boolToIndexingTensor(</span>
|
||
|
<span id="L485"><span class="lineNum"> 485</span> : result, tensor.item<bool>() != 0, original_tensor_device),</span>
|
||
|
<span id="L486"><span class="lineNum"> 486</span> : outIndices,</span>
|
||
|
<span id="L487"><span class="lineNum"> 487</span> : dim_ptr);</span>
|
||
|
<span id="L488"><span class="lineNum"> 488</span> : } else {</span>
|
||
|
<span id="L489"><span class="lineNum"> 489</span> : impl::recordTensorIndex(</span>
|
||
|
<span id="L490"><span class="lineNum"> 490</span> : impl::boolToIndexingTensor(</span>
|
||
|
<span id="L491"><span class="lineNum"> 491</span> : result, tensor.item<uint8_t>() != 0, original_tensor_device),</span>
|
||
|
<span id="L492"><span class="lineNum"> 492</span> : outIndices,</span>
|
||
|
<span id="L493"><span class="lineNum"> 493</span> : dim_ptr);</span>
|
||
|
<span id="L494"><span class="lineNum"> 494</span> : }</span>
|
||
|
<span id="L495"><span class="lineNum"> 495</span> : }</span>
|
||
|
<span id="L496"><span class="lineNum"> 496</span> : } else {</span>
|
||
|
<span id="L497"><span class="lineNum"> 497</span> : impl::recordTensorIndex(tensor, outIndices, dim_ptr);</span>
|
||
|
<span id="L498"><span class="lineNum"> 498</span> : }</span>
|
||
|
<span id="L499"><span class="lineNum"> 499</span> : return result;</span>
|
||
|
<span id="L500"><span class="lineNum"> 500</span> : } else {</span>
|
||
|
<span id="L501"><span class="lineNum"> 501</span> : TORCH_INTERNAL_ASSERT(false, "Invalid TensorIndex type");</span>
|
||
|
<span id="L502"><span class="lineNum"> 502</span> : }</span>
|
||
|
<span id="L503"><span class="lineNum"> 503</span> : }</span>
|
||
|
<span id="L504"><span class="lineNum"> 504</span> : </span>
|
||
|
<span id="L505"><span class="lineNum"> 505</span> : namespace impl {</span>
|
||
|
<span id="L506"><span class="lineNum"> 506</span> : // This mirrors `applySlicing` in</span>
|
||
|
<span id="L507"><span class="lineNum"> 507</span> : // torch/csrc/autograd/python_variable_indexing.cpp</span>
|
||
|
<span id="L508"><span class="lineNum"> 508</span> : static inline Tensor applySlicing(</span>
|
||
|
<span id="L509"><span class="lineNum"> 509</span> : const Tensor& self,</span>
|
||
|
<span id="L510"><span class="lineNum"> 510</span> : const ArrayRef<TensorIndex>& indices,</span>
|
||
|
<span id="L511"><span class="lineNum"> 511</span> : std::vector<Tensor>& outIndices,</span>
|
||
|
<span id="L512"><span class="lineNum"> 512</span> : bool disable_slice_optimization,</span>
|
||
|
<span id="L513"><span class="lineNum"> 513</span> : const at::Device& self_device,</span>
|
||
|
<span id="L514"><span class="lineNum"> 514</span> : const c10::optional<SymIntArrayRef>& self_sizes) {</span>
|
||
|
<span id="L515"><span class="lineNum"> 515</span> : int64_t dim = 0;</span>
|
||
|
<span id="L516"><span class="lineNum"> 516</span> : int64_t specified_dims = impl::count_specified_dimensions(indices);</span>
|
||
|
<span id="L517"><span class="lineNum"> 517</span> : </span>
|
||
|
<span id="L518"><span class="lineNum"> 518</span> : // See NOTE [nested tensor size for indexing]</span>
|
||
|
<span id="L519"><span class="lineNum"> 519</span> : if (self_sizes.has_value()) {</span>
|
||
|
<span id="L520"><span class="lineNum"> 520</span> : TORCH_CHECK_INDEX(</span>
|
||
|
<span id="L521"><span class="lineNum"> 521</span> : specified_dims <= (int64_t)self_sizes->size(),</span>
|
||
|
<span id="L522"><span class="lineNum"> 522</span> : "too many indices for tensor of dimension ",</span>
|
||
|
<span id="L523"><span class="lineNum"> 523</span> : (int)self_sizes->size());</span>
|
||
|
<span id="L524"><span class="lineNum"> 524</span> : }</span>
|
||
|
<span id="L525"><span class="lineNum"> 525</span> : </span>
|
||
|
<span id="L526"><span class="lineNum"> 526</span> : Tensor result = self;</span>
|
||
|
<span id="L527"><span class="lineNum"> 527</span> : for (const auto i : c10::irange(indices.size())) {</span>
|
||
|
<span id="L528"><span class="lineNum"> 528</span> : auto& obj = indices[i];</span>
|
||
|
<span id="L529"><span class="lineNum"> 529</span> : // See NOTE [nested tensor size for indexing]</span>
|
||
|
<span id="L530"><span class="lineNum"> 530</span> : c10::optional<SymIntArrayRef> result_sizes = result.is_nested()</span>
|
||
|
<span id="L531"><span class="lineNum"> 531</span> : ? c10::optional<SymIntArrayRef>(c10::nullopt)</span>
|
||
|
<span id="L532"><span class="lineNum"> 532</span> : : c10::optional<SymIntArrayRef>(result.sym_sizes());</span>
|
||
|
<span id="L533"><span class="lineNum"> 533</span> : result = handleDimInMultiDimIndexing(</span>
|
||
|
<span id="L534"><span class="lineNum"> 534</span> : /*prev_dim_result=*/result,</span>
|
||
|
<span id="L535"><span class="lineNum"> 535</span> : /*original_tensor=*/self,</span>
|
||
|
<span id="L536"><span class="lineNum"> 536</span> : /*index=*/obj,</span>
|
||
|
<span id="L537"><span class="lineNum"> 537</span> : /*dim=*/&dim,</span>
|
||
|
<span id="L538"><span class="lineNum"> 538</span> : /*specified_dims=*/&specified_dims,</span>
|
||
|
<span id="L539"><span class="lineNum"> 539</span> : /*real_dim=*/i,</span>
|
||
|
<span id="L540"><span class="lineNum"> 540</span> : /*outIndices=*/outIndices,</span>
|
||
|
<span id="L541"><span class="lineNum"> 541</span> : /*disable_slice_optimization=*/disable_slice_optimization,</span>
|
||
|
<span id="L542"><span class="lineNum"> 542</span> : /*original_tensor_device=*/self_device,</span>
|
||
|
<span id="L543"><span class="lineNum"> 543</span> : /*prev_dim_result_sizes=*/result_sizes);</span>
|
||
|
<span id="L544"><span class="lineNum"> 544</span> : }</span>
|
||
|
<span id="L545"><span class="lineNum"> 545</span> : return result;</span>
|
||
|
<span id="L546"><span class="lineNum"> 546</span> : }</span>
|
||
|
<span id="L547"><span class="lineNum"> 547</span> : } // namespace impl</span>
|
||
|
<span id="L548"><span class="lineNum"> 548</span> : </span>
|
||
|
<span id="L549"><span class="lineNum"> 549</span> : static inline Tensor dispatch_index(</span>
|
||
|
<span id="L550"><span class="lineNum"> 550</span> : const Tensor& self,</span>
|
||
|
<span id="L551"><span class="lineNum"> 551</span> : std::vector<Tensor>&& indices) {</span>
|
||
|
<span id="L552"><span class="lineNum"> 552</span> : return self.index(impl::typeConvertIndices(self, std::move(indices)));</span>
|
||
|
<span id="L553"><span class="lineNum"> 553</span> : }</span>
|
||
|
<span id="L554"><span class="lineNum"> 554</span> : </span>
|
||
|
<span id="L555"><span class="lineNum"> 555</span> : static inline Tensor dispatch_index_put_(</span>
|
||
|
<span id="L556"><span class="lineNum"> 556</span> : Tensor& self,</span>
|
||
|
<span id="L557"><span class="lineNum"> 557</span> : std::vector<Tensor>&& indices,</span>
|
||
|
<span id="L558"><span class="lineNum"> 558</span> : const Tensor& value) {</span>
|
||
|
<span id="L559"><span class="lineNum"> 559</span> : return self.index_put_(</span>
|
||
|
<span id="L560"><span class="lineNum"> 560</span> : impl::typeConvertIndices(self, std::move(indices)), value);</span>
|
||
|
<span id="L561"><span class="lineNum"> 561</span> : }</span>
|
||
|
<span id="L562"><span class="lineNum"> 562</span> : </span>
|
||
|
<span id="L563"><span class="lineNum"> 563</span> : // NOTE [ Setting `disable_slice_optimization` when calling C++ tensor indexing</span>
|
||
|
<span id="L564"><span class="lineNum"> 564</span> : // functions from Python ]</span>
|
||
|
<span id="L565"><span class="lineNum"> 565</span> : //</span>
|
||
|
<span id="L566"><span class="lineNum"> 566</span> : // Question: When should we set `disable_slice_optimization` to `true` when</span>
|
||
|
<span id="L567"><span class="lineNum"> 567</span> : // calling C++ tensor indexing functions from Python indexing code?</span>
|
||
|
<span id="L568"><span class="lineNum"> 568</span> : //</span>
|
||
|
<span id="L569"><span class="lineNum"> 569</span> : // Answer: What "slice optimization" means: when we have a slicing expression</span>
|
||
|
<span id="L570"><span class="lineNum"> 570</span> : // like `x[0:5, 0]`, where the sliced tensor was of size 5 in dimension 0, we</span>
|
||
|
<span id="L571"><span class="lineNum"> 571</span> : // would skip dispatching the actual slice call as an optimization. However,</span>
|
||
|
<span id="L572"><span class="lineNum"> 572</span> : // here are the cases where we DON'T want this optimization:</span>
|
||
|
<span id="L573"><span class="lineNum"> 573</span> : //</span>
|
||
|
<span id="L574"><span class="lineNum"> 574</span> : // 1. When we are doing 1-D slicing (e.g. `tensor[:]`).</span>
|
||
|
<span id="L575"><span class="lineNum"> 575</span> : // Reason: we always return a shallow copy for expressions such as</span>
|
||
|
<span id="L576"><span class="lineNum"> 576</span> : // `tensor[:]` / `tensor[...]` / `tensor[:, :]`. (Note that for `tensor[:,</span>
|
||
|
<span id="L577"><span class="lineNum"> 577</span> : // :]`, we return an alias of `tensor` by doing the following:</span>
|
||
|
<span id="L578"><span class="lineNum"> 578</span> : // ```</span>
|
||
|
<span id="L579"><span class="lineNum"> 579</span> : // Tensor sliced = impl::applySlicing(self, indices, tensorIndices,</span>
|
||
|
<span id="L580"><span class="lineNum"> 580</span> : // disable_slice_optimization, self_device, self_sizes); if</span>
|
||
|
<span id="L581"><span class="lineNum"> 581</span> : // (tensorIndices.empty()) {</span>
|
||
|
<span id="L582"><span class="lineNum"> 582</span> : // if (sliced.is_same(self)) {</span>
|
||
|
<span id="L583"><span class="lineNum"> 583</span> : // // ensure we return a shallow copy for things like x[...]</span>
|
||
|
<span id="L584"><span class="lineNum"> 584</span> : // sliced = at::alias(sliced);</span>
|
||
|
<span id="L585"><span class="lineNum"> 585</span> : // }</span>
|
||
|
<span id="L586"><span class="lineNum"> 586</span> : // return sliced;</span>
|
||
|
<span id="L587"><span class="lineNum"> 587</span> : // }</span>
|
||
|
<span id="L588"><span class="lineNum"> 588</span> : // ```)</span>
|
||
|
<span id="L589"><span class="lineNum"> 589</span> : // 2. When we are doing JIT tracing.</span>
|
||
|
<span id="L590"><span class="lineNum"> 590</span> : // Reason: JIT tracing needs the `self.slice(...)` call to properly trace the</span>
|
||
|
<span id="L591"><span class="lineNum"> 591</span> : // slice operation.</span>
|
||
|
<span id="L592"><span class="lineNum"> 592</span> : </span>
|
||
|
<span id="L593"><span class="lineNum"> 593</span> : // This mirrors `THPVariable_getitem` in</span>
|
||
|
<span id="L594"><span class="lineNum"> 594</span> : // torch/csrc/autograd/python_variable_indexing.cpp See NOTE [ Setting</span>
|
||
|
<span id="L595"><span class="lineNum"> 595</span> : // `disable_slice_optimization` when calling C++ tensor indexing functions from</span>
|
||
|
<span id="L596"><span class="lineNum"> 596</span> : // Python ]</span>
|
||
|
<span id="L597"><span class="lineNum"> 597</span> : static inline Tensor get_item(</span>
|
||
|
<span id="L598"><span class="lineNum"> 598</span> : const Tensor& self,</span>
|
||
|
<span id="L599"><span class="lineNum"> 599</span> : const ArrayRef<TensorIndex>& indices,</span>
|
||
|
<span id="L600"><span class="lineNum"> 600</span> : bool disable_slice_optimization = false) {</span>
|
||
|
<span id="L601"><span class="lineNum"> 601</span> : at::Device self_device = self.device();</span>
|
||
|
<span id="L602"><span class="lineNum"> 602</span> : // NOTE [nested tensor size for indexing]</span>
|
||
|
<span id="L603"><span class="lineNum"> 603</span> : // nested tensor does not have a size (yet) so for now we represent its size</span>
|
||
|
<span id="L604"><span class="lineNum"> 604</span> : // as null may need to be changed after we reach a better solution for nested</span>
|
||
|
<span id="L605"><span class="lineNum"> 605</span> : // tensor size</span>
|
||
|
<span id="L606"><span class="lineNum"> 606</span> : c10::optional<SymIntArrayRef> self_sizes = self.is_nested()</span>
|
||
|
<span id="L607"><span class="lineNum"> 607</span> : ? c10::optional<SymIntArrayRef>(c10::nullopt)</span>
|
||
|
<span id="L608"><span class="lineNum"> 608</span> : : c10::optional<SymIntArrayRef>(self.sym_sizes());</span>
|
||
|
<span id="L609"><span class="lineNum"> 609</span> : </span>
|
||
|
<span id="L610"><span class="lineNum"> 610</span> : // handle simple types: integers, slices, none, ellipsis, bool</span>
|
||
|
<span id="L611"><span class="lineNum"> 611</span> : if (indices.size() == 1) {</span>
|
||
|
<span id="L612"><span class="lineNum"> 612</span> : const TensorIndex& index = indices[0];</span>
|
||
|
<span id="L613"><span class="lineNum"> 613</span> : if (index.is_integer()) {</span>
|
||
|
<span id="L614"><span class="lineNum"> 614</span> : return impl::applySelect(</span>
|
||
|
<span id="L615"><span class="lineNum"> 615</span> : self, 0, index.integer(), 0, self_device, self_sizes);</span>
|
||
|
<span id="L616"><span class="lineNum"> 616</span> : } else if (index.is_slice()) {</span>
|
||
|
<span id="L617"><span class="lineNum"> 617</span> : return impl::applySlice(</span>
|
||
|
<span id="L618"><span class="lineNum"> 618</span> : self,</span>
|
||
|
<span id="L619"><span class="lineNum"> 619</span> : 0,</span>
|
||
|
<span id="L620"><span class="lineNum"> 620</span> : index.slice().start(),</span>
|
||
|
<span id="L621"><span class="lineNum"> 621</span> : index.slice().stop(),</span>
|
||
|
<span id="L622"><span class="lineNum"> 622</span> : index.slice().step(),</span>
|
||
|
<span id="L623"><span class="lineNum"> 623</span> : /*disable_slice_optimization=*/true,</span>
|
||
|
<span id="L624"><span class="lineNum"> 624</span> : self_device,</span>
|
||
|
<span id="L625"><span class="lineNum"> 625</span> : self_sizes);</span>
|
||
|
<span id="L626"><span class="lineNum"> 626</span> : } else if (index.is_none()) {</span>
|
||
|
<span id="L627"><span class="lineNum"> 627</span> : return self.unsqueeze(0);</span>
|
||
|
<span id="L628"><span class="lineNum"> 628</span> : } else if (index.is_ellipsis()) {</span>
|
||
|
<span id="L629"><span class="lineNum"> 629</span> : return at::alias(self);</span>
|
||
|
<span id="L630"><span class="lineNum"> 630</span> : } else if (index.is_boolean()) {</span>
|
||
|
<span id="L631"><span class="lineNum"> 631</span> : Tensor result = self.unsqueeze(0);</span>
|
||
|
<span id="L632"><span class="lineNum"> 632</span> : return dispatch_index(</span>
|
||
|
<span id="L633"><span class="lineNum"> 633</span> : result,</span>
|
||
|
<span id="L634"><span class="lineNum"> 634</span> : std::vector<Tensor>{impl::boolToIndexingTensor(</span>
|
||
|
<span id="L635"><span class="lineNum"> 635</span> : result, index.boolean(), self_device)});</span>
|
||
|
<span id="L636"><span class="lineNum"> 636</span> : }</span>
|
||
|
<span id="L637"><span class="lineNum"> 637</span> : }</span>
|
||
|
<span id="L638"><span class="lineNum"> 638</span> : </span>
|
||
|
<span id="L639"><span class="lineNum"> 639</span> : std::vector<Tensor> tensorIndices;</span>
|
||
|
<span id="L640"><span class="lineNum"> 640</span> : Tensor sliced = impl::applySlicing(</span>
|
||
|
<span id="L641"><span class="lineNum"> 641</span> : self,</span>
|
||
|
<span id="L642"><span class="lineNum"> 642</span> : indices,</span>
|
||
|
<span id="L643"><span class="lineNum"> 643</span> : tensorIndices,</span>
|
||
|
<span id="L644"><span class="lineNum"> 644</span> : disable_slice_optimization,</span>
|
||
|
<span id="L645"><span class="lineNum"> 645</span> : self_device,</span>
|
||
|
<span id="L646"><span class="lineNum"> 646</span> : self_sizes);</span>
|
||
|
<span id="L647"><span class="lineNum"> 647</span> : if (tensorIndices.empty()) {</span>
|
||
|
<span id="L648"><span class="lineNum"> 648</span> : if (sliced.is_same(self)) {</span>
|
||
|
<span id="L649"><span class="lineNum"> 649</span> : // ensure we return a shallow copy for things like x[...]</span>
|
||
|
<span id="L650"><span class="lineNum"> 650</span> : sliced = at::alias(sliced);</span>
|
||
|
<span id="L651"><span class="lineNum"> 651</span> : }</span>
|
||
|
<span id="L652"><span class="lineNum"> 652</span> : return sliced;</span>
|
||
|
<span id="L653"><span class="lineNum"> 653</span> : }</span>
|
||
|
<span id="L654"><span class="lineNum"> 654</span> : </span>
|
||
|
<span id="L655"><span class="lineNum"> 655</span> : // indexing by tensors ("advanced" indexing)</span>
|
||
|
<span id="L656"><span class="lineNum"> 656</span> : return dispatch_index(sliced, std::move(tensorIndices));</span>
|
||
|
<span id="L657"><span class="lineNum"> 657</span> : }</span>
|
||
|
<span id="L658"><span class="lineNum"> 658</span> : </span>
|
||
|
<span id="L659"><span class="lineNum"> 659</span> : // This mirrors `THPVariable_setitem` in</span>
|
||
|
<span id="L660"><span class="lineNum"> 660</span> : // torch/csrc/autograd/python_variable_indexing.cpp for "the assigned value is a</span>
|
||
|
<span id="L661"><span class="lineNum"> 661</span> : // Tensor" case See NOTE [ Setting `disable_slice_optimization` when calling C++</span>
|
||
|
<span id="L662"><span class="lineNum"> 662</span> : // tensor indexing functions from Python ]</span>
|
||
|
<span id="L663"><span class="lineNum"> 663</span> : static inline void set_item(</span>
|
||
|
<span id="L664"><span class="lineNum"> 664</span> : const Tensor& self,</span>
|
||
|
<span id="L665"><span class="lineNum"> 665</span> : const ArrayRef<TensorIndex>& indices,</span>
|
||
|
<span id="L666"><span class="lineNum"> 666</span> : const Tensor& value,</span>
|
||
|
<span id="L667"><span class="lineNum"> 667</span> : bool disable_slice_optimization = false) {</span>
|
||
|
<span id="L668"><span class="lineNum"> 668</span> : at::Device self_device = self.device();</span>
|
||
|
<span id="L669"><span class="lineNum"> 669</span> : SymIntArrayRef self_sizes = self.sym_sizes();</span>
|
||
|
<span id="L670"><span class="lineNum"> 670</span> : </span>
|
||
|
<span id="L671"><span class="lineNum"> 671</span> : // handle simple types: integers, slices, ellipsis, bool</span>
|
||
|
<span id="L672"><span class="lineNum"> 672</span> : if (indices.size() == 1) {</span>
|
||
|
<span id="L673"><span class="lineNum"> 673</span> : const TensorIndex& index = indices[0];</span>
|
||
|
<span id="L674"><span class="lineNum"> 674</span> : if (index.is_boolean() && !index.boolean()) {</span>
|
||
|
<span id="L675"><span class="lineNum"> 675</span> : // do nothing for false (technically we should check the size, but we</span>
|
||
|
<span id="L676"><span class="lineNum"> 676</span> : // don't have real 0-sized shapes.</span>
|
||
|
<span id="L677"><span class="lineNum"> 677</span> : return;</span>
|
||
|
<span id="L678"><span class="lineNum"> 678</span> : } else if (index.is_ellipsis()) {</span>
|
||
|
<span id="L679"><span class="lineNum"> 679</span> : copy_to(self, value);</span>
|
||
|
<span id="L680"><span class="lineNum"> 680</span> : return;</span>
|
||
|
<span id="L681"><span class="lineNum"> 681</span> : } else if (index.is_none() || (index.is_boolean() && index.boolean())) {</span>
|
||
|
<span id="L682"><span class="lineNum"> 682</span> : copy_to(self.unsqueeze(0), value);</span>
|
||
|
<span id="L683"><span class="lineNum"> 683</span> : return;</span>
|
||
|
<span id="L684"><span class="lineNum"> 684</span> : } else if (index.is_integer()) {</span>
|
||
|
<span id="L685"><span class="lineNum"> 685</span> : copy_to(</span>
|
||
|
<span id="L686"><span class="lineNum"> 686</span> : impl::applySelect(</span>
|
||
|
<span id="L687"><span class="lineNum"> 687</span> : self, 0, index.integer(), 0, self_device, self_sizes),</span>
|
||
|
<span id="L688"><span class="lineNum"> 688</span> : value);</span>
|
||
|
<span id="L689"><span class="lineNum"> 689</span> : return;</span>
|
||
|
<span id="L690"><span class="lineNum"> 690</span> : } else if (index.is_slice()) {</span>
|
||
|
<span id="L691"><span class="lineNum"> 691</span> : copy_to(</span>
|
||
|
<span id="L692"><span class="lineNum"> 692</span> : impl::applySlice(</span>
|
||
|
<span id="L693"><span class="lineNum"> 693</span> : self,</span>
|
||
|
<span id="L694"><span class="lineNum"> 694</span> : 0,</span>
|
||
|
<span id="L695"><span class="lineNum"> 695</span> : index.slice().start(),</span>
|
||
|
<span id="L696"><span class="lineNum"> 696</span> : index.slice().stop(),</span>
|
||
|
<span id="L697"><span class="lineNum"> 697</span> : index.slice().step(),</span>
|
||
|
<span id="L698"><span class="lineNum"> 698</span> : /*disable_slice_optimization=*/disable_slice_optimization,</span>
|
||
|
<span id="L699"><span class="lineNum"> 699</span> : self_device,</span>
|
||
|
<span id="L700"><span class="lineNum"> 700</span> : self_sizes),</span>
|
||
|
<span id="L701"><span class="lineNum"> 701</span> : value);</span>
|
||
|
<span id="L702"><span class="lineNum"> 702</span> : return;</span>
|
||
|
<span id="L703"><span class="lineNum"> 703</span> : }</span>
|
||
|
<span id="L704"><span class="lineNum"> 704</span> : }</span>
|
||
|
<span id="L705"><span class="lineNum"> 705</span> : </span>
|
||
|
<span id="L706"><span class="lineNum"> 706</span> : std::vector<Tensor> tensorIndices;</span>
|
||
|
<span id="L707"><span class="lineNum"> 707</span> : Tensor sliced = impl::applySlicing(</span>
|
||
|
<span id="L708"><span class="lineNum"> 708</span> : self,</span>
|
||
|
<span id="L709"><span class="lineNum"> 709</span> : indices,</span>
|
||
|
<span id="L710"><span class="lineNum"> 710</span> : tensorIndices,</span>
|
||
|
<span id="L711"><span class="lineNum"> 711</span> : disable_slice_optimization,</span>
|
||
|
<span id="L712"><span class="lineNum"> 712</span> : self_device,</span>
|
||
|
<span id="L713"><span class="lineNum"> 713</span> : self_sizes);</span>
|
||
|
<span id="L714"><span class="lineNum"> 714</span> : if (tensorIndices.empty()) {</span>
|
||
|
<span id="L715"><span class="lineNum"> 715</span> : copy_to(sliced, value);</span>
|
||
|
<span id="L716"><span class="lineNum"> 716</span> : return;</span>
|
||
|
<span id="L717"><span class="lineNum"> 717</span> : }</span>
|
||
|
<span id="L718"><span class="lineNum"> 718</span> : </span>
|
||
|
<span id="L719"><span class="lineNum"> 719</span> : SymIntArrayRef valueSizes = value.sym_sizes();</span>
|
||
|
<span id="L720"><span class="lineNum"> 720</span> : SymIntArrayRef slicedValueSizes = slicePrefix1sSize(valueSizes);</span>
|
||
|
<span id="L721"><span class="lineNum"> 721</span> : Tensor valuesSliced;</span>
|
||
|
<span id="L722"><span class="lineNum"> 722</span> : if (!valueSizes.equals(slicedValueSizes)) {</span>
|
||
|
<span id="L723"><span class="lineNum"> 723</span> : valuesSliced = value.view_symint(slicedValueSizes);</span>
|
||
|
<span id="L724"><span class="lineNum"> 724</span> : } else {</span>
|
||
|
<span id="L725"><span class="lineNum"> 725</span> : valuesSliced = value;</span>
|
||
|
<span id="L726"><span class="lineNum"> 726</span> : }</span>
|
||
|
<span id="L727"><span class="lineNum"> 727</span> : dispatch_index_put_(sliced, std::move(tensorIndices), valuesSliced);</span>
|
||
|
<span id="L728"><span class="lineNum"> 728</span> : return;</span>
|
||
|
<span id="L729"><span class="lineNum"> 729</span> : }</span>
|
||
|
<span id="L730"><span class="lineNum"> 730</span> : </span>
|
||
|
<span id="L731"><span class="lineNum"> 731</span> : } // namespace indexing</span>
|
||
|
<span id="L732"><span class="lineNum"> 732</span> : } // namespace at</span>
|
||
|
</pre>
|
||
|
</td>
|
||
|
</tr>
|
||
|
</table>
|
||
|
<br>
|
||
|
|
||
|
<table width="100%" border=0 cellspacing=0 cellpadding=0>
|
||
|
<tr><td class="ruler"><img src="../../../glass.png" width=3 height=3 alt=""></td></tr>
|
||
|
<tr><td class="versionInfo">Generated by: <a href="https://github.com//linux-test-project/lcov" target="_parent">LCOV version 2.0-1</a></td></tr>
|
||
|
</table>
|
||
|
<br>
|
||
|
|
||
|
</body>
|
||
|
</html>
|