583 lines
61 KiB
HTML
583 lines
61 KiB
HTML
<!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.01 Transitional//EN">
|
|
|
|
<html lang="en">
|
|
|
|
<head>
|
|
<meta http-equiv="Content-Type" content="text/html; charset=UTF-8">
|
|
<title>LCOV - coverage.info - libtorch/include/c10/util/Half.h</title>
|
|
<link rel="stylesheet" type="text/css" href="../../../../gcov.css">
|
|
</head>
|
|
|
|
<body>
|
|
|
|
<table width="100%" border=0 cellspacing=0 cellpadding=0>
|
|
<tr><td class="title">LCOV - code coverage report</td></tr>
|
|
<tr><td class="ruler"><img src="../../../../glass.png" width=3 height=3 alt=""></td></tr>
|
|
|
|
<tr>
|
|
<td width="100%">
|
|
<table cellpadding=1 border=0 width="100%">
|
|
<tr>
|
|
<td width="10%" class="headerItem">Current view:</td>
|
|
<td width="10%" class="headerValue"><a href="../../../../index.html">top level</a> - <a href="index.html">libtorch/include/c10/util</a> - Half.h<span style="font-size: 80%;"> (source / <a href="Half.h.func-c.html">functions</a>)</span></td>
|
|
<td width="5%"></td>
|
|
<td width="5%"></td>
|
|
<td width="5%" class="headerCovTableHead">Coverage</td>
|
|
<td width="5%" class="headerCovTableHead" title="Covered + Uncovered code">Total</td>
|
|
<td width="5%" class="headerCovTableHead" title="Exercised code only">Hit</td>
|
|
</tr>
|
|
<tr>
|
|
<td class="headerItem">Test:</td>
|
|
<td class="headerValue">coverage.info</td>
|
|
<td></td>
|
|
<td class="headerItem">Lines:</td>
|
|
<td class="headerCovTableEntryLo">15.4 %</td>
|
|
<td class="headerCovTableEntry">13</td>
|
|
<td class="headerCovTableEntry">2</td>
|
|
</tr>
|
|
<tr>
|
|
<td class="headerItem">Test Date:</td>
|
|
<td class="headerValue">2024-04-30 13:17:26</td>
|
|
<td></td>
|
|
<td class="headerItem">Functions:</td>
|
|
<td class="headerCovTableEntryLo">25.0 %</td>
|
|
<td class="headerCovTableEntry">4</td>
|
|
<td class="headerCovTableEntry">1</td>
|
|
</tr>
|
|
<tr><td><img src="../../../../glass.png" width=3 height=3 alt=""></td></tr>
|
|
</table>
|
|
</td>
|
|
</tr>
|
|
|
|
<tr><td class="ruler"><img src="../../../../glass.png" width=3 height=3 alt=""></td></tr>
|
|
</table>
|
|
|
|
<table cellpadding=0 cellspacing=0 border=0>
|
|
<tr>
|
|
<td><br></td>
|
|
</tr>
|
|
<tr>
|
|
<td>
|
|
<pre class="sourceHeading"> Line data Source code</pre>
|
|
<pre class="source">
|
|
<span id="L1"><span class="lineNum"> 1</span> : #pragma once</span>
|
|
<span id="L2"><span class="lineNum"> 2</span> : </span>
|
|
<span id="L3"><span class="lineNum"> 3</span> : /// Defines the Half type (half-precision floating-point) including conversions</span>
|
|
<span id="L4"><span class="lineNum"> 4</span> : /// to standard C types and basic arithmetic operations. Note that arithmetic</span>
|
|
<span id="L5"><span class="lineNum"> 5</span> : /// operations are implemented by converting to floating point and</span>
|
|
<span id="L6"><span class="lineNum"> 6</span> : /// performing the operation in float32, instead of using CUDA half intrinsics.</span>
|
|
<span id="L7"><span class="lineNum"> 7</span> : /// Most uses of this type within ATen are memory bound, including the</span>
|
|
<span id="L8"><span class="lineNum"> 8</span> : /// element-wise kernels, and the half intrinsics aren't efficient on all GPUs.</span>
|
|
<span id="L9"><span class="lineNum"> 9</span> : /// If you are writing a compute bound kernel, you can use the CUDA half</span>
|
|
<span id="L10"><span class="lineNum"> 10</span> : /// intrinsics directly on the Half type from device code.</span>
|
|
<span id="L11"><span class="lineNum"> 11</span> : </span>
|
|
<span id="L12"><span class="lineNum"> 12</span> : #include <c10/macros/Macros.h></span>
|
|
<span id="L13"><span class="lineNum"> 13</span> : #include <c10/util/C++17.h></span>
|
|
<span id="L14"><span class="lineNum"> 14</span> : #include <c10/util/TypeSafeSignMath.h></span>
|
|
<span id="L15"><span class="lineNum"> 15</span> : #include <c10/util/complex.h></span>
|
|
<span id="L16"><span class="lineNum"> 16</span> : #include <c10/util/floating_point_utils.h></span>
|
|
<span id="L17"><span class="lineNum"> 17</span> : #include <type_traits></span>
|
|
<span id="L18"><span class="lineNum"> 18</span> : </span>
|
|
<span id="L19"><span class="lineNum"> 19</span> : #if defined(__cplusplus) && (__cplusplus >= 201103L)</span>
|
|
<span id="L20"><span class="lineNum"> 20</span> : #include <cmath></span>
|
|
<span id="L21"><span class="lineNum"> 21</span> : #include <cstdint></span>
|
|
<span id="L22"><span class="lineNum"> 22</span> : #elif !defined(__OPENCL_VERSION__)</span>
|
|
<span id="L23"><span class="lineNum"> 23</span> : #include <math.h></span>
|
|
<span id="L24"><span class="lineNum"> 24</span> : #include <stdint.h></span>
|
|
<span id="L25"><span class="lineNum"> 25</span> : #endif</span>
|
|
<span id="L26"><span class="lineNum"> 26</span> : </span>
|
|
<span id="L27"><span class="lineNum"> 27</span> : #ifdef _MSC_VER</span>
|
|
<span id="L28"><span class="lineNum"> 28</span> : #include <intrin.h></span>
|
|
<span id="L29"><span class="lineNum"> 29</span> : #endif</span>
|
|
<span id="L30"><span class="lineNum"> 30</span> : </span>
|
|
<span id="L31"><span class="lineNum"> 31</span> : #include <complex></span>
|
|
<span id="L32"><span class="lineNum"> 32</span> : #include <cstdint></span>
|
|
<span id="L33"><span class="lineNum"> 33</span> : #include <cstring></span>
|
|
<span id="L34"><span class="lineNum"> 34</span> : #include <iosfwd></span>
|
|
<span id="L35"><span class="lineNum"> 35</span> : #include <limits></span>
|
|
<span id="L36"><span class="lineNum"> 36</span> : #include <sstream></span>
|
|
<span id="L37"><span class="lineNum"> 37</span> : #include <stdexcept></span>
|
|
<span id="L38"><span class="lineNum"> 38</span> : #include <string></span>
|
|
<span id="L39"><span class="lineNum"> 39</span> : #include <utility></span>
|
|
<span id="L40"><span class="lineNum"> 40</span> : </span>
|
|
<span id="L41"><span class="lineNum"> 41</span> : #ifdef __CUDACC__</span>
|
|
<span id="L42"><span class="lineNum"> 42</span> : #include <cuda_fp16.h></span>
|
|
<span id="L43"><span class="lineNum"> 43</span> : #endif</span>
|
|
<span id="L44"><span class="lineNum"> 44</span> : </span>
|
|
<span id="L45"><span class="lineNum"> 45</span> : #ifdef __HIPCC__</span>
|
|
<span id="L46"><span class="lineNum"> 46</span> : #include <hip/hip_fp16.h></span>
|
|
<span id="L47"><span class="lineNum"> 47</span> : #endif</span>
|
|
<span id="L48"><span class="lineNum"> 48</span> : </span>
|
|
<span id="L49"><span class="lineNum"> 49</span> : #if defined(CL_SYCL_LANGUAGE_VERSION)</span>
|
|
<span id="L50"><span class="lineNum"> 50</span> : #include <CL/sycl.hpp> // for SYCL 1.2.1</span>
|
|
<span id="L51"><span class="lineNum"> 51</span> : #elif defined(SYCL_LANGUAGE_VERSION)</span>
|
|
<span id="L52"><span class="lineNum"> 52</span> : #include <sycl/sycl.hpp> // for SYCL 2020</span>
|
|
<span id="L53"><span class="lineNum"> 53</span> : #endif</span>
|
|
<span id="L54"><span class="lineNum"> 54</span> : </span>
|
|
<span id="L55"><span class="lineNum"> 55</span> : #include <typeinfo> // operator typeid</span>
|
|
<span id="L56"><span class="lineNum"> 56</span> : </span>
|
|
<span id="L57"><span class="lineNum"> 57</span> : namespace c10 {</span>
|
|
<span id="L58"><span class="lineNum"> 58</span> : </span>
|
|
<span id="L59"><span class="lineNum"> 59</span> : namespace detail {</span>
|
|
<span id="L60"><span class="lineNum"> 60</span> : </span>
|
|
<span id="L61"><span class="lineNum"> 61</span> : /*</span>
|
|
<span id="L62"><span class="lineNum"> 62</span> : * Convert a 16-bit floating-point number in IEEE half-precision format, in bit</span>
|
|
<span id="L63"><span class="lineNum"> 63</span> : * representation, to a 32-bit floating-point number in IEEE single-precision</span>
|
|
<span id="L64"><span class="lineNum"> 64</span> : * format, in bit representation.</span>
|
|
<span id="L65"><span class="lineNum"> 65</span> : *</span>
|
|
<span id="L66"><span class="lineNum"> 66</span> : * @note The implementation doesn't use any floating-point operations.</span>
|
|
<span id="L67"><span class="lineNum"> 67</span> : */</span>
|
|
<span id="L68"><span class="lineNum"> 68</span> : inline uint32_t fp16_ieee_to_fp32_bits(uint16_t h) {</span>
|
|
<span id="L69"><span class="lineNum"> 69</span> : /*</span>
|
|
<span id="L70"><span class="lineNum"> 70</span> : * Extend the half-precision floating-point number to 32 bits and shift to the</span>
|
|
<span id="L71"><span class="lineNum"> 71</span> : * upper part of the 32-bit word:</span>
|
|
<span id="L72"><span class="lineNum"> 72</span> : * +---+-----+------------+-------------------+</span>
|
|
<span id="L73"><span class="lineNum"> 73</span> : * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000|</span>
|
|
<span id="L74"><span class="lineNum"> 74</span> : * +---+-----+------------+-------------------+</span>
|
|
<span id="L75"><span class="lineNum"> 75</span> : * Bits 31 26-30 16-25 0-15</span>
|
|
<span id="L76"><span class="lineNum"> 76</span> : *</span>
|
|
<span id="L77"><span class="lineNum"> 77</span> : * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0</span>
|
|
<span id="L78"><span class="lineNum"> 78</span> : * - zero bits.</span>
|
|
<span id="L79"><span class="lineNum"> 79</span> : */</span>
|
|
<span id="L80"><span class="lineNum"> 80</span> : const uint32_t w = (uint32_t)h << 16;</span>
|
|
<span id="L81"><span class="lineNum"> 81</span> : /*</span>
|
|
<span id="L82"><span class="lineNum"> 82</span> : * Extract the sign of the input number into the high bit of the 32-bit word:</span>
|
|
<span id="L83"><span class="lineNum"> 83</span> : *</span>
|
|
<span id="L84"><span class="lineNum"> 84</span> : * +---+----------------------------------+</span>
|
|
<span id="L85"><span class="lineNum"> 85</span> : * | S |0000000 00000000 00000000 00000000|</span>
|
|
<span id="L86"><span class="lineNum"> 86</span> : * +---+----------------------------------+</span>
|
|
<span id="L87"><span class="lineNum"> 87</span> : * Bits 31 0-31</span>
|
|
<span id="L88"><span class="lineNum"> 88</span> : */</span>
|
|
<span id="L89"><span class="lineNum"> 89</span> : const uint32_t sign = w & UINT32_C(0x80000000);</span>
|
|
<span id="L90"><span class="lineNum"> 90</span> : /*</span>
|
|
<span id="L91"><span class="lineNum"> 91</span> : * Extract mantissa and biased exponent of the input number into the bits 0-30</span>
|
|
<span id="L92"><span class="lineNum"> 92</span> : * of the 32-bit word:</span>
|
|
<span id="L93"><span class="lineNum"> 93</span> : *</span>
|
|
<span id="L94"><span class="lineNum"> 94</span> : * +---+-----+------------+-------------------+</span>
|
|
<span id="L95"><span class="lineNum"> 95</span> : * | 0 |EEEEE|MM MMMM MMMM|0000 0000 0000 0000|</span>
|
|
<span id="L96"><span class="lineNum"> 96</span> : * +---+-----+------------+-------------------+</span>
|
|
<span id="L97"><span class="lineNum"> 97</span> : * Bits 30 27-31 17-26 0-16</span>
|
|
<span id="L98"><span class="lineNum"> 98</span> : */</span>
|
|
<span id="L99"><span class="lineNum"> 99</span> : const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF);</span>
|
|
<span id="L100"><span class="lineNum"> 100</span> : /*</span>
|
|
<span id="L101"><span class="lineNum"> 101</span> : * Renorm shift is the number of bits to shift mantissa left to make the</span>
|
|
<span id="L102"><span class="lineNum"> 102</span> : * half-precision number normalized. If the initial number is normalized, some</span>
|
|
<span id="L103"><span class="lineNum"> 103</span> : * of its high 6 bits (sign == 0 and 5-bit exponent) equals one. In this case</span>
|
|
<span id="L104"><span class="lineNum"> 104</span> : * renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note</span>
|
|
<span id="L105"><span class="lineNum"> 105</span> : * that if we shift denormalized nonsign by renorm_shift, the unit bit of</span>
|
|
<span id="L106"><span class="lineNum"> 106</span> : * mantissa will shift into exponent, turning the biased exponent into 1, and</span>
|
|
<span id="L107"><span class="lineNum"> 107</span> : * making mantissa normalized (i.e. without leading 1).</span>
|
|
<span id="L108"><span class="lineNum"> 108</span> : */</span>
|
|
<span id="L109"><span class="lineNum"> 109</span> : #ifdef _MSC_VER</span>
|
|
<span id="L110"><span class="lineNum"> 110</span> : unsigned long nonsign_bsr;</span>
|
|
<span id="L111"><span class="lineNum"> 111</span> : _BitScanReverse(&nonsign_bsr, (unsigned long)nonsign);</span>
|
|
<span id="L112"><span class="lineNum"> 112</span> : uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31;</span>
|
|
<span id="L113"><span class="lineNum"> 113</span> : #else</span>
|
|
<span id="L114"><span class="lineNum"> 114</span> : uint32_t renorm_shift = __builtin_clz(nonsign);</span>
|
|
<span id="L115"><span class="lineNum"> 115</span> : #endif</span>
|
|
<span id="L116"><span class="lineNum"> 116</span> : renorm_shift = renorm_shift > 5 ? renorm_shift - 5 : 0;</span>
|
|
<span id="L117"><span class="lineNum"> 117</span> : /*</span>
|
|
<span id="L118"><span class="lineNum"> 118</span> : * Iff half-precision number has exponent of 15, the addition overflows</span>
|
|
<span id="L119"><span class="lineNum"> 119</span> : * it into bit 31, and the subsequent shift turns the high 9 bits</span>
|
|
<span id="L120"><span class="lineNum"> 120</span> : * into 1. Thus inf_nan_mask == 0x7F800000 if the half-precision number</span>
|
|
<span id="L121"><span class="lineNum"> 121</span> : * had exponent of 15 (i.e. was NaN or infinity) 0x00000000 otherwise</span>
|
|
<span id="L122"><span class="lineNum"> 122</span> : */</span>
|
|
<span id="L123"><span class="lineNum"> 123</span> : const int32_t inf_nan_mask =</span>
|
|
<span id="L124"><span class="lineNum"> 124</span> : ((int32_t)(nonsign + 0x04000000) >> 8) & INT32_C(0x7F800000);</span>
|
|
<span id="L125"><span class="lineNum"> 125</span> : /*</span>
|
|
<span id="L126"><span class="lineNum"> 126</span> : * Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31</span>
|
|
<span id="L127"><span class="lineNum"> 127</span> : * into 1. Otherwise, bit 31 remains 0. The signed shift right by 31</span>
|
|
<span id="L128"><span class="lineNum"> 128</span> : * broadcasts bit 31 into all bits of the zero_mask. Thus zero_mask ==</span>
|
|
<span id="L129"><span class="lineNum"> 129</span> : * 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h)</span>
|
|
<span id="L130"><span class="lineNum"> 130</span> : * 0x00000000 otherwise</span>
|
|
<span id="L131"><span class="lineNum"> 131</span> : */</span>
|
|
<span id="L132"><span class="lineNum"> 132</span> : const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31;</span>
|
|
<span id="L133"><span class="lineNum"> 133</span> : /*</span>
|
|
<span id="L134"><span class="lineNum"> 134</span> : * 1. Shift nonsign left by renorm_shift to normalize it (if the input</span>
|
|
<span id="L135"><span class="lineNum"> 135</span> : * was denormal)</span>
|
|
<span id="L136"><span class="lineNum"> 136</span> : * 2. Shift nonsign right by 3 so the exponent (5 bits originally)</span>
|
|
<span id="L137"><span class="lineNum"> 137</span> : * becomes an 8-bit field and 10-bit mantissa shifts into the 10 high</span>
|
|
<span id="L138"><span class="lineNum"> 138</span> : * bits of the 23-bit mantissa of IEEE single-precision number.</span>
|
|
<span id="L139"><span class="lineNum"> 139</span> : * 3. Add 0x70 to the exponent (starting at bit 23) to compensate the</span>
|
|
<span id="L140"><span class="lineNum"> 140</span> : * different in exponent bias (0x7F for single-precision number less 0xF</span>
|
|
<span id="L141"><span class="lineNum"> 141</span> : * for half-precision number).</span>
|
|
<span id="L142"><span class="lineNum"> 142</span> : * 4. Subtract renorm_shift from the exponent (starting at bit 23) to</span>
|
|
<span id="L143"><span class="lineNum"> 143</span> : * account for renormalization. As renorm_shift is less than 0x70, this</span>
|
|
<span id="L144"><span class="lineNum"> 144</span> : * can be combined with step 3.</span>
|
|
<span id="L145"><span class="lineNum"> 145</span> : * 5. Binary OR with inf_nan_mask to turn the exponent into 0xFF if the</span>
|
|
<span id="L146"><span class="lineNum"> 146</span> : * input was NaN or infinity.</span>
|
|
<span id="L147"><span class="lineNum"> 147</span> : * 6. Binary ANDNOT with zero_mask to turn the mantissa and exponent</span>
|
|
<span id="L148"><span class="lineNum"> 148</span> : * into zero if the input was zero.</span>
|
|
<span id="L149"><span class="lineNum"> 149</span> : * 7. Combine with the sign of the input number.</span>
|
|
<span id="L150"><span class="lineNum"> 150</span> : */</span>
|
|
<span id="L151"><span class="lineNum"> 151</span> : return sign |</span>
|
|
<span id="L152"><span class="lineNum"> 152</span> : ((((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) |</span>
|
|
<span id="L153"><span class="lineNum"> 153</span> : inf_nan_mask) &</span>
|
|
<span id="L154"><span class="lineNum"> 154</span> : ~zero_mask);</span>
|
|
<span id="L155"><span class="lineNum"> 155</span> : }</span>
|
|
<span id="L156"><span class="lineNum"> 156</span> : </span>
|
|
<span id="L157"><span class="lineNum"> 157</span> : /*</span>
|
|
<span id="L158"><span class="lineNum"> 158</span> : * Convert a 16-bit floating-point number in IEEE half-precision format, in bit</span>
|
|
<span id="L159"><span class="lineNum"> 159</span> : * representation, to a 32-bit floating-point number in IEEE single-precision</span>
|
|
<span id="L160"><span class="lineNum"> 160</span> : * format.</span>
|
|
<span id="L161"><span class="lineNum"> 161</span> : *</span>
|
|
<span id="L162"><span class="lineNum"> 162</span> : * @note The implementation relies on IEEE-like (no assumption about rounding</span>
|
|
<span id="L163"><span class="lineNum"> 163</span> : * mode and no operations on denormals) floating-point operations and bitcasts</span>
|
|
<span id="L164"><span class="lineNum"> 164</span> : * between integer and floating-point variables.</span>
|
|
<span id="L165"><span class="lineNum"> 165</span> : */</span>
|
|
<span id="L166"><span class="lineNum"> 166</span> : C10_HOST_DEVICE inline float fp16_ieee_to_fp32_value(uint16_t h) {</span>
|
|
<span id="L167"><span class="lineNum"> 167</span> : /*</span>
|
|
<span id="L168"><span class="lineNum"> 168</span> : * Extend the half-precision floating-point number to 32 bits and shift to the</span>
|
|
<span id="L169"><span class="lineNum"> 169</span> : * upper part of the 32-bit word:</span>
|
|
<span id="L170"><span class="lineNum"> 170</span> : * +---+-----+------------+-------------------+</span>
|
|
<span id="L171"><span class="lineNum"> 171</span> : * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000|</span>
|
|
<span id="L172"><span class="lineNum"> 172</span> : * +---+-----+------------+-------------------+</span>
|
|
<span id="L173"><span class="lineNum"> 173</span> : * Bits 31 26-30 16-25 0-15</span>
|
|
<span id="L174"><span class="lineNum"> 174</span> : *</span>
|
|
<span id="L175"><span class="lineNum"> 175</span> : * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0</span>
|
|
<span id="L176"><span class="lineNum"> 176</span> : * - zero bits.</span>
|
|
<span id="L177"><span class="lineNum"> 177</span> : */</span>
|
|
<span id="L178"><span class="lineNum"> 178</span> : const uint32_t w = (uint32_t)h << 16;</span>
|
|
<span id="L179"><span class="lineNum"> 179</span> : /*</span>
|
|
<span id="L180"><span class="lineNum"> 180</span> : * Extract the sign of the input number into the high bit of the 32-bit word:</span>
|
|
<span id="L181"><span class="lineNum"> 181</span> : *</span>
|
|
<span id="L182"><span class="lineNum"> 182</span> : * +---+----------------------------------+</span>
|
|
<span id="L183"><span class="lineNum"> 183</span> : * | S |0000000 00000000 00000000 00000000|</span>
|
|
<span id="L184"><span class="lineNum"> 184</span> : * +---+----------------------------------+</span>
|
|
<span id="L185"><span class="lineNum"> 185</span> : * Bits 31 0-31</span>
|
|
<span id="L186"><span class="lineNum"> 186</span> : */</span>
|
|
<span id="L187"><span class="lineNum"> 187</span> : const uint32_t sign = w & UINT32_C(0x80000000);</span>
|
|
<span id="L188"><span class="lineNum"> 188</span> : /*</span>
|
|
<span id="L189"><span class="lineNum"> 189</span> : * Extract mantissa and biased exponent of the input number into the high bits</span>
|
|
<span id="L190"><span class="lineNum"> 190</span> : * of the 32-bit word:</span>
|
|
<span id="L191"><span class="lineNum"> 191</span> : *</span>
|
|
<span id="L192"><span class="lineNum"> 192</span> : * +-----+------------+---------------------+</span>
|
|
<span id="L193"><span class="lineNum"> 193</span> : * |EEEEE|MM MMMM MMMM|0 0000 0000 0000 0000|</span>
|
|
<span id="L194"><span class="lineNum"> 194</span> : * +-----+------------+---------------------+</span>
|
|
<span id="L195"><span class="lineNum"> 195</span> : * Bits 27-31 17-26 0-16</span>
|
|
<span id="L196"><span class="lineNum"> 196</span> : */</span>
|
|
<span id="L197"><span class="lineNum"> 197</span> : const uint32_t two_w = w + w;</span>
|
|
<span id="L198"><span class="lineNum"> 198</span> : </span>
|
|
<span id="L199"><span class="lineNum"> 199</span> : /*</span>
|
|
<span id="L200"><span class="lineNum"> 200</span> : * Shift mantissa and exponent into bits 23-28 and bits 13-22 so they become</span>
|
|
<span id="L201"><span class="lineNum"> 201</span> : * mantissa and exponent of a single-precision floating-point number:</span>
|
|
<span id="L202"><span class="lineNum"> 202</span> : *</span>
|
|
<span id="L203"><span class="lineNum"> 203</span> : * S|Exponent | Mantissa</span>
|
|
<span id="L204"><span class="lineNum"> 204</span> : * +-+---+-----+------------+----------------+</span>
|
|
<span id="L205"><span class="lineNum"> 205</span> : * |0|000|EEEEE|MM MMMM MMMM|0 0000 0000 0000|</span>
|
|
<span id="L206"><span class="lineNum"> 206</span> : * +-+---+-----+------------+----------------+</span>
|
|
<span id="L207"><span class="lineNum"> 207</span> : * Bits | 23-31 | 0-22</span>
|
|
<span id="L208"><span class="lineNum"> 208</span> : *</span>
|
|
<span id="L209"><span class="lineNum"> 209</span> : * Next, there are some adjustments to the exponent:</span>
|
|
<span id="L210"><span class="lineNum"> 210</span> : * - The exponent needs to be corrected by the difference in exponent bias</span>
|
|
<span id="L211"><span class="lineNum"> 211</span> : * between single-precision and half-precision formats (0x7F - 0xF = 0x70)</span>
|
|
<span id="L212"><span class="lineNum"> 212</span> : * - Inf and NaN values in the inputs should become Inf and NaN values after</span>
|
|
<span id="L213"><span class="lineNum"> 213</span> : * conversion to the single-precision number. Therefore, if the biased</span>
|
|
<span id="L214"><span class="lineNum"> 214</span> : * exponent of the half-precision input was 0x1F (max possible value), the</span>
|
|
<span id="L215"><span class="lineNum"> 215</span> : * biased exponent of the single-precision output must be 0xFF (max possible</span>
|
|
<span id="L216"><span class="lineNum"> 216</span> : * value). We do this correction in two steps:</span>
|
|
<span id="L217"><span class="lineNum"> 217</span> : * - First, we adjust the exponent by (0xFF - 0x1F) = 0xE0 (see exp_offset</span>
|
|
<span id="L218"><span class="lineNum"> 218</span> : * below) rather than by 0x70 suggested by the difference in the exponent bias</span>
|
|
<span id="L219"><span class="lineNum"> 219</span> : * (see above).</span>
|
|
<span id="L220"><span class="lineNum"> 220</span> : * - Then we multiply the single-precision result of exponent adjustment by</span>
|
|
<span id="L221"><span class="lineNum"> 221</span> : * 2**(-112) to reverse the effect of exponent adjustment by 0xE0 less the</span>
|
|
<span id="L222"><span class="lineNum"> 222</span> : * necessary exponent adjustment by 0x70 due to difference in exponent bias.</span>
|
|
<span id="L223"><span class="lineNum"> 223</span> : * The floating-point multiplication hardware would ensure than Inf and</span>
|
|
<span id="L224"><span class="lineNum"> 224</span> : * NaN would retain their value on at least partially IEEE754-compliant</span>
|
|
<span id="L225"><span class="lineNum"> 225</span> : * implementations.</span>
|
|
<span id="L226"><span class="lineNum"> 226</span> : *</span>
|
|
<span id="L227"><span class="lineNum"> 227</span> : * Note that the above operations do not handle denormal inputs (where biased</span>
|
|
<span id="L228"><span class="lineNum"> 228</span> : * exponent == 0). However, they also do not operate on denormal inputs, and</span>
|
|
<span id="L229"><span class="lineNum"> 229</span> : * do not produce denormal results.</span>
|
|
<span id="L230"><span class="lineNum"> 230</span> : */</span>
|
|
<span id="L231"><span class="lineNum"> 231</span> : constexpr uint32_t exp_offset = UINT32_C(0xE0) << 23;</span>
|
|
<span id="L232"><span class="lineNum"> 232</span> : // const float exp_scale = 0x1.0p-112f;</span>
|
|
<span id="L233"><span class="lineNum"> 233</span> : constexpr uint32_t scale_bits = (uint32_t)15 << 23;</span>
|
|
<span id="L234"><span class="lineNum"> 234</span> : float exp_scale_val;</span>
|
|
<span id="L235"><span class="lineNum"> 235</span> : std::memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val));</span>
|
|
<span id="L236"><span class="lineNum"> 236</span> : const float exp_scale = exp_scale_val;</span>
|
|
<span id="L237"><span class="lineNum"> 237</span> : const float normalized_value =</span>
|
|
<span id="L238"><span class="lineNum"> 238</span> : fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;</span>
|
|
<span id="L239"><span class="lineNum"> 239</span> : </span>
|
|
<span id="L240"><span class="lineNum"> 240</span> : /*</span>
|
|
<span id="L241"><span class="lineNum"> 241</span> : * Convert denormalized half-precision inputs into single-precision results</span>
|
|
<span id="L242"><span class="lineNum"> 242</span> : * (always normalized). Zero inputs are also handled here.</span>
|
|
<span id="L243"><span class="lineNum"> 243</span> : *</span>
|
|
<span id="L244"><span class="lineNum"> 244</span> : * In a denormalized number the biased exponent is zero, and mantissa has</span>
|
|
<span id="L245"><span class="lineNum"> 245</span> : * on-zero bits. First, we shift mantissa into bits 0-9 of the 32-bit word.</span>
|
|
<span id="L246"><span class="lineNum"> 246</span> : *</span>
|
|
<span id="L247"><span class="lineNum"> 247</span> : * zeros | mantissa</span>
|
|
<span id="L248"><span class="lineNum"> 248</span> : * +---------------------------+------------+</span>
|
|
<span id="L249"><span class="lineNum"> 249</span> : * |0000 0000 0000 0000 0000 00|MM MMMM MMMM|</span>
|
|
<span id="L250"><span class="lineNum"> 250</span> : * +---------------------------+------------+</span>
|
|
<span id="L251"><span class="lineNum"> 251</span> : * Bits 10-31 0-9</span>
|
|
<span id="L252"><span class="lineNum"> 252</span> : *</span>
|
|
<span id="L253"><span class="lineNum"> 253</span> : * Now, remember that denormalized half-precision numbers are represented as:</span>
|
|
<span id="L254"><span class="lineNum"> 254</span> : * FP16 = mantissa * 2**(-24).</span>
|
|
<span id="L255"><span class="lineNum"> 255</span> : * The trick is to construct a normalized single-precision number with the</span>
|
|
<span id="L256"><span class="lineNum"> 256</span> : * same mantissa and thehalf-precision input and with an exponent which would</span>
|
|
<span id="L257"><span class="lineNum"> 257</span> : * scale the corresponding mantissa bits to 2**(-24). A normalized</span>
|
|
<span id="L258"><span class="lineNum"> 258</span> : * single-precision floating-point number is represented as: FP32 = (1 +</span>
|
|
<span id="L259"><span class="lineNum"> 259</span> : * mantissa * 2**(-23)) * 2**(exponent - 127) Therefore, when the biased</span>
|
|
<span id="L260"><span class="lineNum"> 260</span> : * exponent is 126, a unit change in the mantissa of the input denormalized</span>
|
|
<span id="L261"><span class="lineNum"> 261</span> : * half-precision number causes a change of the constructed single-precision</span>
|
|
<span id="L262"><span class="lineNum"> 262</span> : * number by 2**(-24), i.e. the same amount.</span>
|
|
<span id="L263"><span class="lineNum"> 263</span> : *</span>
|
|
<span id="L264"><span class="lineNum"> 264</span> : * The last step is to adjust the bias of the constructed single-precision</span>
|
|
<span id="L265"><span class="lineNum"> 265</span> : * number. When the input half-precision number is zero, the constructed</span>
|
|
<span id="L266"><span class="lineNum"> 266</span> : * single-precision number has the value of FP32 = 1 * 2**(126 - 127) =</span>
|
|
<span id="L267"><span class="lineNum"> 267</span> : * 2**(-1) = 0.5 Therefore, we need to subtract 0.5 from the constructed</span>
|
|
<span id="L268"><span class="lineNum"> 268</span> : * single-precision number to get the numerical equivalent of the input</span>
|
|
<span id="L269"><span class="lineNum"> 269</span> : * half-precision number.</span>
|
|
<span id="L270"><span class="lineNum"> 270</span> : */</span>
|
|
<span id="L271"><span class="lineNum"> 271</span> : constexpr uint32_t magic_mask = UINT32_C(126) << 23;</span>
|
|
<span id="L272"><span class="lineNum"> 272</span> : constexpr float magic_bias = 0.5f;</span>
|
|
<span id="L273"><span class="lineNum"> 273</span> : const float denormalized_value =</span>
|
|
<span id="L274"><span class="lineNum"> 274</span> : fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;</span>
|
|
<span id="L275"><span class="lineNum"> 275</span> : </span>
|
|
<span id="L276"><span class="lineNum"> 276</span> : /*</span>
|
|
<span id="L277"><span class="lineNum"> 277</span> : * - Choose either results of conversion of input as a normalized number, or</span>
|
|
<span id="L278"><span class="lineNum"> 278</span> : * as a denormalized number, depending on the input exponent. The variable</span>
|
|
<span id="L279"><span class="lineNum"> 279</span> : * two_w contains input exponent in bits 27-31, therefore if its smaller than</span>
|
|
<span id="L280"><span class="lineNum"> 280</span> : * 2**27, the input is either a denormal number, or zero.</span>
|
|
<span id="L281"><span class="lineNum"> 281</span> : * - Combine the result of conversion of exponent and mantissa with the sign</span>
|
|
<span id="L282"><span class="lineNum"> 282</span> : * of the input number.</span>
|
|
<span id="L283"><span class="lineNum"> 283</span> : */</span>
|
|
<span id="L284"><span class="lineNum"> 284</span> : constexpr uint32_t denormalized_cutoff = UINT32_C(1) << 27;</span>
|
|
<span id="L285"><span class="lineNum"> 285</span> : const uint32_t result = sign |</span>
|
|
<span id="L286"><span class="lineNum"> 286</span> : (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value)</span>
|
|
<span id="L287"><span class="lineNum"> 287</span> : : fp32_to_bits(normalized_value));</span>
|
|
<span id="L288"><span class="lineNum"> 288</span> : return fp32_from_bits(result);</span>
|
|
<span id="L289"><span class="lineNum"> 289</span> : }</span>
|
|
<span id="L290"><span class="lineNum"> 290</span> : </span>
|
|
<span id="L291"><span class="lineNum"> 291</span> : /*</span>
|
|
<span id="L292"><span class="lineNum"> 292</span> : * Convert a 32-bit floating-point number in IEEE single-precision format to a</span>
|
|
<span id="L293"><span class="lineNum"> 293</span> : * 16-bit floating-point number in IEEE half-precision format, in bit</span>
|
|
<span id="L294"><span class="lineNum"> 294</span> : * representation.</span>
|
|
<span id="L295"><span class="lineNum"> 295</span> : *</span>
|
|
<span id="L296"><span class="lineNum"> 296</span> : * @note The implementation relies on IEEE-like (no assumption about rounding</span>
|
|
<span id="L297"><span class="lineNum"> 297</span> : * mode and no operations on denormals) floating-point operations and bitcasts</span>
|
|
<span id="L298"><span class="lineNum"> 298</span> : * between integer and floating-point variables.</span>
|
|
<span id="L299"><span class="lineNum"> 299</span> : */</span>
|
|
<span id="L300"><span class="lineNum"> 300</span> : inline uint16_t fp16_ieee_from_fp32_value(float f) {</span>
|
|
<span id="L301"><span class="lineNum"> 301</span> : // const float scale_to_inf = 0x1.0p+112f;</span>
|
|
<span id="L302"><span class="lineNum"> 302</span> : // const float scale_to_zero = 0x1.0p-110f;</span>
|
|
<span id="L303"><span class="lineNum"> 303</span> : constexpr uint32_t scale_to_inf_bits = (uint32_t)239 << 23;</span>
|
|
<span id="L304"><span class="lineNum"> 304</span> : constexpr uint32_t scale_to_zero_bits = (uint32_t)17 << 23;</span>
|
|
<span id="L305"><span class="lineNum"> 305</span> : float scale_to_inf_val, scale_to_zero_val;</span>
|
|
<span id="L306"><span class="lineNum"> 306</span> : std::memcpy(&scale_to_inf_val, &scale_to_inf_bits, sizeof(scale_to_inf_val));</span>
|
|
<span id="L307"><span class="lineNum"> 307</span> : std::memcpy(</span>
|
|
<span id="L308"><span class="lineNum"> 308</span> : &scale_to_zero_val, &scale_to_zero_bits, sizeof(scale_to_zero_val));</span>
|
|
<span id="L309"><span class="lineNum"> 309</span> : const float scale_to_inf = scale_to_inf_val;</span>
|
|
<span id="L310"><span class="lineNum"> 310</span> : const float scale_to_zero = scale_to_zero_val;</span>
|
|
<span id="L311"><span class="lineNum"> 311</span> : </span>
|
|
<span id="L312"><span class="lineNum"> 312</span> : #if defined(_MSC_VER) && _MSC_VER == 1916</span>
|
|
<span id="L313"><span class="lineNum"> 313</span> : float base = ((signbit(f) != 0 ? -f : f) * scale_to_inf) * scale_to_zero;</span>
|
|
<span id="L314"><span class="lineNum"> 314</span> : #else</span>
|
|
<span id="L315"><span class="lineNum"> 315</span> : float base = (fabsf(f) * scale_to_inf) * scale_to_zero;</span>
|
|
<span id="L316"><span class="lineNum"> 316</span> : #endif</span>
|
|
<span id="L317"><span class="lineNum"> 317</span> : </span>
|
|
<span id="L318"><span class="lineNum"> 318</span> : const uint32_t w = fp32_to_bits(f);</span>
|
|
<span id="L319"><span class="lineNum"> 319</span> : const uint32_t shl1_w = w + w;</span>
|
|
<span id="L320"><span class="lineNum"> 320</span> : const uint32_t sign = w & UINT32_C(0x80000000);</span>
|
|
<span id="L321"><span class="lineNum"> 321</span> : uint32_t bias = shl1_w & UINT32_C(0xFF000000);</span>
|
|
<span id="L322"><span class="lineNum"> 322</span> : if (bias < UINT32_C(0x71000000)) {</span>
|
|
<span id="L323"><span class="lineNum"> 323</span> : bias = UINT32_C(0x71000000);</span>
|
|
<span id="L324"><span class="lineNum"> 324</span> : }</span>
|
|
<span id="L325"><span class="lineNum"> 325</span> : </span>
|
|
<span id="L326"><span class="lineNum"> 326</span> : base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;</span>
|
|
<span id="L327"><span class="lineNum"> 327</span> : const uint32_t bits = fp32_to_bits(base);</span>
|
|
<span id="L328"><span class="lineNum"> 328</span> : const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);</span>
|
|
<span id="L329"><span class="lineNum"> 329</span> : const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);</span>
|
|
<span id="L330"><span class="lineNum"> 330</span> : const uint32_t nonsign = exp_bits + mantissa_bits;</span>
|
|
<span id="L331"><span class="lineNum"> 331</span> : return static_cast<uint16_t>(</span>
|
|
<span id="L332"><span class="lineNum"> 332</span> : (sign >> 16) |</span>
|
|
<span id="L333"><span class="lineNum"> 333</span> : (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign));</span>
|
|
<span id="L334"><span class="lineNum"> 334</span> : }</span>
|
|
<span id="L335"><span class="lineNum"> 335</span> : </span>
|
|
<span id="L336"><span class="lineNum"> 336</span> : } // namespace detail</span>
|
|
<span id="L337"><span class="lineNum"> 337</span> : </span>
|
|
<span id="L338"><span class="lineNum"> 338</span> : struct alignas(2) Half {</span>
|
|
<span id="L339"><span class="lineNum"> 339</span> : unsigned short x;</span>
|
|
<span id="L340"><span class="lineNum"> 340</span> : </span>
|
|
<span id="L341"><span class="lineNum"> 341</span> : struct from_bits_t {};</span>
|
|
<span id="L342"><span class="lineNum"> 342</span> : C10_HOST_DEVICE static constexpr from_bits_t from_bits() {</span>
|
|
<span id="L343"><span class="lineNum"> 343</span> : return from_bits_t();</span>
|
|
<span id="L344"><span class="lineNum"> 344</span> : }</span>
|
|
<span id="L345"><span class="lineNum"> 345</span> : </span>
|
|
<span id="L346"><span class="lineNum"> 346</span> : // HIP wants __host__ __device__ tag, CUDA does not</span>
|
|
<span id="L347"><span class="lineNum"> 347</span> : #if defined(USE_ROCM)</span>
|
|
<span id="L348"><span class="lineNum"> 348</span> : C10_HOST_DEVICE Half() = default;</span>
|
|
<span id="L349"><span class="lineNum"> 349</span> : #else</span>
|
|
<span id="L350"><span class="lineNum"> 350</span> : Half() = default;</span>
|
|
<span id="L351"><span class="lineNum"> 351</span> : #endif</span>
|
|
<span id="L352"><span class="lineNum"> 352</span> : </span>
|
|
<span id="L353"><span class="lineNum"> 353</span> : constexpr C10_HOST_DEVICE Half(unsigned short bits, from_bits_t) : x(bits){};</span>
|
|
<span id="L354"><span class="lineNum"> 354</span> : inline C10_HOST_DEVICE Half(float value);</span>
|
|
<span id="L355"><span class="lineNum"> 355</span> : inline C10_HOST_DEVICE operator float() const;</span>
|
|
<span id="L356"><span class="lineNum"> 356</span> : </span>
|
|
<span id="L357"><span class="lineNum"> 357</span> : #if defined(__CUDACC__) || defined(__HIPCC__)</span>
|
|
<span id="L358"><span class="lineNum"> 358</span> : inline C10_HOST_DEVICE Half(const __half& value);</span>
|
|
<span id="L359"><span class="lineNum"> 359</span> : inline C10_HOST_DEVICE operator __half() const;</span>
|
|
<span id="L360"><span class="lineNum"> 360</span> : #endif</span>
|
|
<span id="L361"><span class="lineNum"> 361</span> : #ifdef SYCL_LANGUAGE_VERSION</span>
|
|
<span id="L362"><span class="lineNum"> 362</span> : inline C10_HOST_DEVICE Half(const sycl::half& value);</span>
|
|
<span id="L363"><span class="lineNum"> 363</span> : inline C10_HOST_DEVICE operator sycl::half() const;</span>
|
|
<span id="L364"><span class="lineNum"> 364</span> : #endif</span>
|
|
<span id="L365"><span class="lineNum"> 365</span> : };</span>
|
|
<span id="L366"><span class="lineNum"> 366</span> : </span>
|
|
<span id="L367"><span class="lineNum"> 367</span> : // TODO : move to complex.h</span>
|
|
<span id="L368"><span class="lineNum"> 368</span> : template <></span>
|
|
<span id="L369"><span class="lineNum"> 369</span> : struct alignas(4) complex<Half> {</span>
|
|
<span id="L370"><span class="lineNum"> 370</span> : Half real_;</span>
|
|
<span id="L371"><span class="lineNum"> 371</span> : Half imag_;</span>
|
|
<span id="L372"><span class="lineNum"> 372</span> : </span>
|
|
<span id="L373"><span class="lineNum"> 373</span> : // Constructors</span>
|
|
<span id="L374"><span class="lineNum"> 374</span> : complex() = default;</span>
|
|
<span id="L375"><span class="lineNum"> 375</span> : // Half constructor is not constexpr so the following constructor can't</span>
|
|
<span id="L376"><span class="lineNum"> 376</span> : // be constexpr</span>
|
|
<span id="L377"><span class="lineNum"> 377</span> : C10_HOST_DEVICE explicit inline complex(const Half& real, const Half& imag)</span>
|
|
<span id="L378"><span class="lineNum"> 378</span> : : real_(real), imag_(imag) {}</span>
|
|
<span id="L379"><span class="lineNum"> 379</span> : C10_HOST_DEVICE inline complex(const c10::complex<float>& value)</span>
|
|
<span id="L380"><span class="lineNum"> 380</span> : : real_(value.real()), imag_(value.imag()) {}</span>
|
|
<span id="L381"><span class="lineNum"> 381</span> : </span>
|
|
<span id="L382"><span class="lineNum"> 382</span> : // Conversion operator</span>
|
|
<span id="L383"><span class="lineNum"> 383</span> : inline C10_HOST_DEVICE operator c10::complex<float>() const {</span>
|
|
<span id="L384"><span class="lineNum"> 384</span> : return {real_, imag_};</span>
|
|
<span id="L385"><span class="lineNum"> 385</span> : }</span>
|
|
<span id="L386"><span class="lineNum"> 386</span> : </span>
|
|
<span id="L387"><span class="lineNum"> 387</span> : constexpr C10_HOST_DEVICE Half real() const {</span>
|
|
<span id="L388"><span class="lineNum"> 388</span> : return real_;</span>
|
|
<span id="L389"><span class="lineNum"> 389</span> : }</span>
|
|
<span id="L390"><span class="lineNum"> 390</span> : constexpr C10_HOST_DEVICE Half imag() const {</span>
|
|
<span id="L391"><span class="lineNum"> 391</span> : return imag_;</span>
|
|
<span id="L392"><span class="lineNum"> 392</span> : }</span>
|
|
<span id="L393"><span class="lineNum"> 393</span> : </span>
|
|
<span id="L394"><span class="lineNum"> 394</span> : C10_HOST_DEVICE complex<Half>& operator+=(const complex<Half>& other) {</span>
|
|
<span id="L395"><span class="lineNum"> 395</span> : real_ = static_cast<float>(real_) + static_cast<float>(other.real_);</span>
|
|
<span id="L396"><span class="lineNum"> 396</span> : imag_ = static_cast<float>(imag_) + static_cast<float>(other.imag_);</span>
|
|
<span id="L397"><span class="lineNum"> 397</span> : return *this;</span>
|
|
<span id="L398"><span class="lineNum"> 398</span> : }</span>
|
|
<span id="L399"><span class="lineNum"> 399</span> : </span>
|
|
<span id="L400"><span class="lineNum"> 400</span> : C10_HOST_DEVICE complex<Half>& operator-=(const complex<Half>& other) {</span>
|
|
<span id="L401"><span class="lineNum"> 401</span> : real_ = static_cast<float>(real_) - static_cast<float>(other.real_);</span>
|
|
<span id="L402"><span class="lineNum"> 402</span> : imag_ = static_cast<float>(imag_) - static_cast<float>(other.imag_);</span>
|
|
<span id="L403"><span class="lineNum"> 403</span> : return *this;</span>
|
|
<span id="L404"><span class="lineNum"> 404</span> : }</span>
|
|
<span id="L405"><span class="lineNum"> 405</span> : </span>
|
|
<span id="L406"><span class="lineNum"> 406</span> : C10_HOST_DEVICE complex<Half>& operator*=(const complex<Half>& other) {</span>
|
|
<span id="L407"><span class="lineNum"> 407</span> : auto a = static_cast<float>(real_);</span>
|
|
<span id="L408"><span class="lineNum"> 408</span> : auto b = static_cast<float>(imag_);</span>
|
|
<span id="L409"><span class="lineNum"> 409</span> : auto c = static_cast<float>(other.real());</span>
|
|
<span id="L410"><span class="lineNum"> 410</span> : auto d = static_cast<float>(other.imag());</span>
|
|
<span id="L411"><span class="lineNum"> 411</span> : real_ = a * c - b * d;</span>
|
|
<span id="L412"><span class="lineNum"> 412</span> : imag_ = a * d + b * c;</span>
|
|
<span id="L413"><span class="lineNum"> 413</span> : return *this;</span>
|
|
<span id="L414"><span class="lineNum"> 414</span> : }</span>
|
|
<span id="L415"><span class="lineNum"> 415</span> : };</span>
|
|
<span id="L416"><span class="lineNum"> 416</span> : </span>
|
|
<span id="L417"><span class="lineNum"> 417</span> : // In some versions of MSVC, there will be a compiler error when building.</span>
|
|
<span id="L418"><span class="lineNum"> 418</span> : // C4146: unary minus operator applied to unsigned type, result still unsigned</span>
|
|
<span id="L419"><span class="lineNum"> 419</span> : // C4804: unsafe use of type 'bool' in operation</span>
|
|
<span id="L420"><span class="lineNum"> 420</span> : // It can be addressed by disabling the following warning.</span>
|
|
<span id="L421"><span class="lineNum"> 421</span> : #ifdef _MSC_VER</span>
|
|
<span id="L422"><span class="lineNum"> 422</span> : #pragma warning(push)</span>
|
|
<span id="L423"><span class="lineNum"> 423</span> : #pragma warning(disable : 4146)</span>
|
|
<span id="L424"><span class="lineNum"> 424</span> : #pragma warning(disable : 4804)</span>
|
|
<span id="L425"><span class="lineNum"> 425</span> : #pragma warning(disable : 4018)</span>
|
|
<span id="L426"><span class="lineNum"> 426</span> : #endif</span>
|
|
<span id="L427"><span class="lineNum"> 427</span> : </span>
|
|
<span id="L428"><span class="lineNum"> 428</span> : // The overflow checks may involve float to int conversion which may</span>
|
|
<span id="L429"><span class="lineNum"> 429</span> : // trigger precision loss warning. Re-enable the warning once the code</span>
|
|
<span id="L430"><span class="lineNum"> 430</span> : // is fixed. See T58053069.</span>
|
|
<span id="L431"><span class="lineNum"> 431</span> : C10_CLANG_DIAGNOSTIC_PUSH()</span>
|
|
<span id="L432"><span class="lineNum"> 432</span> : #if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion")</span>
|
|
<span id="L433"><span class="lineNum"> 433</span> : C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion")</span>
|
|
<span id="L434"><span class="lineNum"> 434</span> : #endif</span>
|
|
<span id="L435"><span class="lineNum"> 435</span> : </span>
|
|
<span id="L436"><span class="lineNum"> 436</span> : // bool can be converted to any type.</span>
|
|
<span id="L437"><span class="lineNum"> 437</span> : // Without specializing on bool, in pytorch_linux_trusty_py2_7_9_build:</span>
|
|
<span id="L438"><span class="lineNum"> 438</span> : // `error: comparison of constant '255' with boolean expression is always false`</span>
|
|
<span id="L439"><span class="lineNum"> 439</span> : // for `f > limit::max()` below</span>
|
|
<span id="L440"><span class="lineNum"> 440</span> : template <typename To, typename From></span>
|
|
<span id="L441"><span class="lineNum"> 441</span> <span class="tlaUNC tlaBgUNC"> 0 : typename std::enable_if<std::is_same<From, bool>::value, bool>::type overflows(</span></span>
|
|
<span id="L442"><span class="lineNum"> 442</span> : From /*f*/) {</span>
|
|
<span id="L443"><span class="lineNum"> 443</span> <span class="tlaUNC"> 0 : return false;</span></span>
|
|
<span id="L444"><span class="lineNum"> 444</span> : }</span>
|
|
<span id="L445"><span class="lineNum"> 445</span> : </span>
|
|
<span id="L446"><span class="lineNum"> 446</span> : // skip isnan and isinf check for integral types</span>
|
|
<span id="L447"><span class="lineNum"> 447</span> : template <typename To, typename From></span>
|
|
<span id="L448"><span class="lineNum"> 448</span> : typename std::enable_if<</span>
|
|
<span id="L449"><span class="lineNum"> 449</span> : std::is_integral<From>::value && !std::is_same<From, bool>::value,</span>
|
|
<span id="L450"><span class="lineNum"> 450</span> : bool>::type</span>
|
|
<span id="L451"><span class="lineNum"> 451</span> <span class="tlaGNC tlaBgGNC"> 7072 : overflows(From f) {</span></span>
|
|
<span id="L452"><span class="lineNum"> 452</span> : using limit = std::numeric_limits<typename scalar_value_type<To>::type>;</span>
|
|
<span id="L453"><span class="lineNum"> 453</span> : if (!limit::is_signed && std::numeric_limits<From>::is_signed) {</span>
|
|
<span id="L454"><span class="lineNum"> 454</span> : // allow for negative numbers to wrap using two's complement arithmetic.</span>
|
|
<span id="L455"><span class="lineNum"> 455</span> : // For example, with uint8, this allows for `a - b` to be treated as</span>
|
|
<span id="L456"><span class="lineNum"> 456</span> : // `a + 255 * b`.</span>
|
|
<span id="L457"><span class="lineNum"> 457</span> : return greater_than_max<To>(f) ||</span>
|
|
<span id="L458"><span class="lineNum"> 458</span> : (c10::is_negative(f) && -static_cast<uint64_t>(f) > limit::max());</span>
|
|
<span id="L459"><span class="lineNum"> 459</span> : } else {</span>
|
|
<span id="L460"><span class="lineNum"> 460</span> <span class="tlaGNC"> 7072 : return c10::less_than_lowest<To>(f) || greater_than_max<To>(f);</span></span>
|
|
<span id="L461"><span class="lineNum"> 461</span> : }</span>
|
|
<span id="L462"><span class="lineNum"> 462</span> : }</span>
|
|
<span id="L463"><span class="lineNum"> 463</span> : </span>
|
|
<span id="L464"><span class="lineNum"> 464</span> : template <typename To, typename From></span>
|
|
<span id="L465"><span class="lineNum"> 465</span> : typename std::enable_if<std::is_floating_point<From>::value, bool>::type</span>
|
|
<span id="L466"><span class="lineNum"> 466</span> <span class="tlaUNC tlaBgUNC"> 0 : overflows(From f) {</span></span>
|
|
<span id="L467"><span class="lineNum"> 467</span> : using limit = std::numeric_limits<typename scalar_value_type<To>::type>;</span>
|
|
<span id="L468"><span class="lineNum"> 468</span> : if (limit::has_infinity && std::isinf(static_cast<double>(f))) {</span>
|
|
<span id="L469"><span class="lineNum"> 469</span> : return false;</span>
|
|
<span id="L470"><span class="lineNum"> 470</span> : }</span>
|
|
<span id="L471"><span class="lineNum"> 471</span> <span class="tlaUNC"> 0 : if (!limit::has_quiet_NaN && (f != f)) {</span></span>
|
|
<span id="L472"><span class="lineNum"> 472</span> <span class="tlaUNC"> 0 : return true;</span></span>
|
|
<span id="L473"><span class="lineNum"> 473</span> : }</span>
|
|
<span id="L474"><span class="lineNum"> 474</span> <span class="tlaUNC"> 0 : return f < limit::lowest() || f > limit::max();</span></span>
|
|
<span id="L475"><span class="lineNum"> 475</span> : }</span>
|
|
<span id="L476"><span class="lineNum"> 476</span> : </span>
|
|
<span id="L477"><span class="lineNum"> 477</span> : C10_CLANG_DIAGNOSTIC_POP()</span>
|
|
<span id="L478"><span class="lineNum"> 478</span> : </span>
|
|
<span id="L479"><span class="lineNum"> 479</span> : #ifdef _MSC_VER</span>
|
|
<span id="L480"><span class="lineNum"> 480</span> : #pragma warning(pop)</span>
|
|
<span id="L481"><span class="lineNum"> 481</span> : #endif</span>
|
|
<span id="L482"><span class="lineNum"> 482</span> : </span>
|
|
<span id="L483"><span class="lineNum"> 483</span> : template <typename To, typename From></span>
|
|
<span id="L484"><span class="lineNum"> 484</span> <span class="tlaUNC"> 0 : typename std::enable_if<is_complex<From>::value, bool>::type overflows(From f) {</span></span>
|
|
<span id="L485"><span class="lineNum"> 485</span> : // casts from complex to real are considered to overflow if the</span>
|
|
<span id="L486"><span class="lineNum"> 486</span> : // imaginary component is non-zero</span>
|
|
<span id="L487"><span class="lineNum"> 487</span> <span class="tlaUNC"> 0 : if (!is_complex<To>::value && f.imag() != 0) {</span></span>
|
|
<span id="L488"><span class="lineNum"> 488</span> <span class="tlaUNC"> 0 : return true;</span></span>
|
|
<span id="L489"><span class="lineNum"> 489</span> : }</span>
|
|
<span id="L490"><span class="lineNum"> 490</span> : // Check for overflow componentwise</span>
|
|
<span id="L491"><span class="lineNum"> 491</span> : // (Technically, the imag overflow check is guaranteed to be false</span>
|
|
<span id="L492"><span class="lineNum"> 492</span> : // when !is_complex<To>, but any optimizer worth its salt will be</span>
|
|
<span id="L493"><span class="lineNum"> 493</span> : // able to figure it out.)</span>
|
|
<span id="L494"><span class="lineNum"> 494</span> : return overflows<</span>
|
|
<span id="L495"><span class="lineNum"> 495</span> : typename scalar_value_type<To>::type,</span>
|
|
<span id="L496"><span class="lineNum"> 496</span> <span class="tlaUNC"> 0 : typename From::value_type>(f.real()) ||</span></span>
|
|
<span id="L497"><span class="lineNum"> 497</span> : overflows<</span>
|
|
<span id="L498"><span class="lineNum"> 498</span> : typename scalar_value_type<To>::type,</span>
|
|
<span id="L499"><span class="lineNum"> 499</span> <span class="tlaUNC"> 0 : typename From::value_type>(f.imag());</span></span>
|
|
<span id="L500"><span class="lineNum"> 500</span> : }</span>
|
|
<span id="L501"><span class="lineNum"> 501</span> : </span>
|
|
<span id="L502"><span class="lineNum"> 502</span> : C10_API std::ostream& operator<<(std::ostream& out, const Half& value);</span>
|
|
<span id="L503"><span class="lineNum"> 503</span> : </span>
|
|
<span id="L504"><span class="lineNum"> 504</span> : } // namespace c10</span>
|
|
<span id="L505"><span class="lineNum"> 505</span> : </span>
|
|
<span id="L506"><span class="lineNum"> 506</span> : #include <c10/util/Half-inl.h> // IWYU pragma: keep</span>
|
|
</pre>
|
|
</td>
|
|
</tr>
|
|
</table>
|
|
<br>
|
|
|
|
<table width="100%" border=0 cellspacing=0 cellpadding=0>
|
|
<tr><td class="ruler"><img src="../../../../glass.png" width=3 height=3 alt=""></td></tr>
|
|
<tr><td class="versionInfo">Generated by: <a href="https://github.com//linux-test-project/lcov" target="_parent">LCOV version 2.0-1</a></td></tr>
|
|
</table>
|
|
<br>
|
|
|
|
</body>
|
|
</html>
|