61 KiB
61 KiB
<html lang="en">
<head>
</head>
</html>
LCOV - code coverage report | ||||||||||||||||||||||
![]() | ||||||||||||||||||||||
|
||||||||||||||||||||||
![]() |
Line data Source code 1 : #pragma once 2 : 3 : /// Defines the Half type (half-precision floating-point) including conversions 4 : /// to standard C types and basic arithmetic operations. Note that arithmetic 5 : /// operations are implemented by converting to floating point and 6 : /// performing the operation in float32, instead of using CUDA half intrinsics. 7 : /// Most uses of this type within ATen are memory bound, including the 8 : /// element-wise kernels, and the half intrinsics aren't efficient on all GPUs. 9 : /// If you are writing a compute bound kernel, you can use the CUDA half 10 : /// intrinsics directly on the Half type from device code. 11 : 12 : #include <c10/macros/Macros.h> 13 : #include <c10/util/C++17.h> 14 : #include <c10/util/TypeSafeSignMath.h> 15 : #include <c10/util/complex.h> 16 : #include <c10/util/floating_point_utils.h> 17 : #include <type_traits> 18 : 19 : #if defined(__cplusplus) && (__cplusplus >= 201103L) 20 : #include <cmath> 21 : #include <cstdint> 22 : #elif !defined(__OPENCL_VERSION__) 23 : #include <math.h> 24 : #include <stdint.h> 25 : #endif 26 : 27 : #ifdef _MSC_VER 28 : #include <intrin.h> 29 : #endif 30 : 31 : #include <complex> 32 : #include <cstdint> 33 : #include <cstring> 34 : #include <iosfwd> 35 : #include <limits> 36 : #include <sstream> 37 : #include <stdexcept> 38 : #include <string> 39 : #include <utility> 40 : 41 : #ifdef __CUDACC__ 42 : #include <cuda_fp16.h> 43 : #endif 44 : 45 : #ifdef __HIPCC__ 46 : #include <hip/hip_fp16.h> 47 : #endif 48 : 49 : #if defined(CL_SYCL_LANGUAGE_VERSION) 50 : #include <CL/sycl.hpp> // for SYCL 1.2.1 51 : #elif defined(SYCL_LANGUAGE_VERSION) 52 : #include <sycl/sycl.hpp> // for SYCL 2020 53 : #endif 54 : 55 : #include <typeinfo> // operator typeid 56 : 57 : namespace c10 { 58 : 59 : namespace detail { 60 : 61 : /* 62 : * Convert a 16-bit floating-point number in IEEE half-precision format, in bit 63 : * representation, to a 32-bit floating-point number in IEEE single-precision 64 : * format, in bit representation. 65 : * 66 : * @note The implementation doesn't use any floating-point operations. 67 : */ 68 : inline uint32_t fp16_ieee_to_fp32_bits(uint16_t h) { 69 : /* 70 : * Extend the half-precision floating-point number to 32 bits and shift to the 71 : * upper part of the 32-bit word: 72 : * +---+-----+------------+-------------------+ 73 : * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| 74 : * +---+-----+------------+-------------------+ 75 : * Bits 31 26-30 16-25 0-15 76 : * 77 : * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 78 : * - zero bits. 79 : */ 80 : const uint32_t w = (uint32_t)h << 16; 81 : /* 82 : * Extract the sign of the input number into the high bit of the 32-bit word: 83 : * 84 : * +---+----------------------------------+ 85 : * | S |0000000 00000000 00000000 00000000| 86 : * +---+----------------------------------+ 87 : * Bits 31 0-31 88 : */ 89 : const uint32_t sign = w & UINT32_C(0x80000000); 90 : /* 91 : * Extract mantissa and biased exponent of the input number into the bits 0-30 92 : * of the 32-bit word: 93 : * 94 : * +---+-----+------------+-------------------+ 95 : * | 0 |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| 96 : * +---+-----+------------+-------------------+ 97 : * Bits 30 27-31 17-26 0-16 98 : */ 99 : const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF); 100 : /* 101 : * Renorm shift is the number of bits to shift mantissa left to make the 102 : * half-precision number normalized. If the initial number is normalized, some 103 : * of its high 6 bits (sign == 0 and 5-bit exponent) equals one. In this case 104 : * renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note 105 : * that if we shift denormalized nonsign by renorm_shift, the unit bit of 106 : * mantissa will shift into exponent, turning the biased exponent into 1, and 107 : * making mantissa normalized (i.e. without leading 1). 108 : */ 109 : #ifdef _MSC_VER 110 : unsigned long nonsign_bsr; 111 : _BitScanReverse(&nonsign_bsr, (unsigned long)nonsign); 112 : uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31; 113 : #else 114 : uint32_t renorm_shift = __builtin_clz(nonsign); 115 : #endif 116 : renorm_shift = renorm_shift > 5 ? renorm_shift - 5 : 0; 117 : /* 118 : * Iff half-precision number has exponent of 15, the addition overflows 119 : * it into bit 31, and the subsequent shift turns the high 9 bits 120 : * into 1. Thus inf_nan_mask == 0x7F800000 if the half-precision number 121 : * had exponent of 15 (i.e. was NaN or infinity) 0x00000000 otherwise 122 : */ 123 : const int32_t inf_nan_mask = 124 : ((int32_t)(nonsign + 0x04000000) >> 8) & INT32_C(0x7F800000); 125 : /* 126 : * Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31 127 : * into 1. Otherwise, bit 31 remains 0. The signed shift right by 31 128 : * broadcasts bit 31 into all bits of the zero_mask. Thus zero_mask == 129 : * 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h) 130 : * 0x00000000 otherwise 131 : */ 132 : const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31; 133 : /* 134 : * 1. Shift nonsign left by renorm_shift to normalize it (if the input 135 : * was denormal) 136 : * 2. Shift nonsign right by 3 so the exponent (5 bits originally) 137 : * becomes an 8-bit field and 10-bit mantissa shifts into the 10 high 138 : * bits of the 23-bit mantissa of IEEE single-precision number. 139 : * 3. Add 0x70 to the exponent (starting at bit 23) to compensate the 140 : * different in exponent bias (0x7F for single-precision number less 0xF 141 : * for half-precision number). 142 : * 4. Subtract renorm_shift from the exponent (starting at bit 23) to 143 : * account for renormalization. As renorm_shift is less than 0x70, this 144 : * can be combined with step 3. 145 : * 5. Binary OR with inf_nan_mask to turn the exponent into 0xFF if the 146 : * input was NaN or infinity. 147 : * 6. Binary ANDNOT with zero_mask to turn the mantissa and exponent 148 : * into zero if the input was zero. 149 : * 7. Combine with the sign of the input number. 150 : */ 151 : return sign | 152 : ((((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) | 153 : inf_nan_mask) & 154 : ~zero_mask); 155 : } 156 : 157 : /* 158 : * Convert a 16-bit floating-point number in IEEE half-precision format, in bit 159 : * representation, to a 32-bit floating-point number in IEEE single-precision 160 : * format. 161 : * 162 : * @note The implementation relies on IEEE-like (no assumption about rounding 163 : * mode and no operations on denormals) floating-point operations and bitcasts 164 : * between integer and floating-point variables. 165 : */ 166 : C10_HOST_DEVICE inline float fp16_ieee_to_fp32_value(uint16_t h) { 167 : /* 168 : * Extend the half-precision floating-point number to 32 bits and shift to the 169 : * upper part of the 32-bit word: 170 : * +---+-----+------------+-------------------+ 171 : * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| 172 : * +---+-----+------------+-------------------+ 173 : * Bits 31 26-30 16-25 0-15 174 : * 175 : * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 176 : * - zero bits. 177 : */ 178 : const uint32_t w = (uint32_t)h << 16; 179 : /* 180 : * Extract the sign of the input number into the high bit of the 32-bit word: 181 : * 182 : * +---+----------------------------------+ 183 : * | S |0000000 00000000 00000000 00000000| 184 : * +---+----------------------------------+ 185 : * Bits 31 0-31 186 : */ 187 : const uint32_t sign = w & UINT32_C(0x80000000); 188 : /* 189 : * Extract mantissa and biased exponent of the input number into the high bits 190 : * of the 32-bit word: 191 : * 192 : * +-----+------------+---------------------+ 193 : * |EEEEE|MM MMMM MMMM|0 0000 0000 0000 0000| 194 : * +-----+------------+---------------------+ 195 : * Bits 27-31 17-26 0-16 196 : */ 197 : const uint32_t two_w = w + w; 198 : 199 : /* 200 : * Shift mantissa and exponent into bits 23-28 and bits 13-22 so they become 201 : * mantissa and exponent of a single-precision floating-point number: 202 : * 203 : * S|Exponent | Mantissa 204 : * +-+---+-----+------------+----------------+ 205 : * |0|000|EEEEE|MM MMMM MMMM|0 0000 0000 0000| 206 : * +-+---+-----+------------+----------------+ 207 : * Bits | 23-31 | 0-22 208 : * 209 : * Next, there are some adjustments to the exponent: 210 : * - The exponent needs to be corrected by the difference in exponent bias 211 : * between single-precision and half-precision formats (0x7F - 0xF = 0x70) 212 : * - Inf and NaN values in the inputs should become Inf and NaN values after 213 : * conversion to the single-precision number. Therefore, if the biased 214 : * exponent of the half-precision input was 0x1F (max possible value), the 215 : * biased exponent of the single-precision output must be 0xFF (max possible 216 : * value). We do this correction in two steps: 217 : * - First, we adjust the exponent by (0xFF - 0x1F) = 0xE0 (see exp_offset 218 : * below) rather than by 0x70 suggested by the difference in the exponent bias 219 : * (see above). 220 : * - Then we multiply the single-precision result of exponent adjustment by 221 : * 2**(-112) to reverse the effect of exponent adjustment by 0xE0 less the 222 : * necessary exponent adjustment by 0x70 due to difference in exponent bias. 223 : * The floating-point multiplication hardware would ensure than Inf and 224 : * NaN would retain their value on at least partially IEEE754-compliant 225 : * implementations. 226 : * 227 : * Note that the above operations do not handle denormal inputs (where biased 228 : * exponent == 0). However, they also do not operate on denormal inputs, and 229 : * do not produce denormal results. 230 : */ 231 : constexpr uint32_t exp_offset = UINT32_C(0xE0) << 23; 232 : // const float exp_scale = 0x1.0p-112f; 233 : constexpr uint32_t scale_bits = (uint32_t)15 << 23; 234 : float exp_scale_val; 235 : std::memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val)); 236 : const float exp_scale = exp_scale_val; 237 : const float normalized_value = 238 : fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; 239 : 240 : /* 241 : * Convert denormalized half-precision inputs into single-precision results 242 : * (always normalized). Zero inputs are also handled here. 243 : * 244 : * In a denormalized number the biased exponent is zero, and mantissa has 245 : * on-zero bits. First, we shift mantissa into bits 0-9 of the 32-bit word. 246 : * 247 : * zeros | mantissa 248 : * +---------------------------+------------+ 249 : * |0000 0000 0000 0000 0000 00|MM MMMM MMMM| 250 : * +---------------------------+------------+ 251 : * Bits 10-31 0-9 252 : * 253 : * Now, remember that denormalized half-precision numbers are represented as: 254 : * FP16 = mantissa * 2**(-24). 255 : * The trick is to construct a normalized single-precision number with the 256 : * same mantissa and thehalf-precision input and with an exponent which would 257 : * scale the corresponding mantissa bits to 2**(-24). A normalized 258 : * single-precision floating-point number is represented as: FP32 = (1 + 259 : * mantissa * 2**(-23)) * 2**(exponent - 127) Therefore, when the biased 260 : * exponent is 126, a unit change in the mantissa of the input denormalized 261 : * half-precision number causes a change of the constructed single-precision 262 : * number by 2**(-24), i.e. the same amount. 263 : * 264 : * The last step is to adjust the bias of the constructed single-precision 265 : * number. When the input half-precision number is zero, the constructed 266 : * single-precision number has the value of FP32 = 1 * 2**(126 - 127) = 267 : * 2**(-1) = 0.5 Therefore, we need to subtract 0.5 from the constructed 268 : * single-precision number to get the numerical equivalent of the input 269 : * half-precision number. 270 : */ 271 : constexpr uint32_t magic_mask = UINT32_C(126) << 23; 272 : constexpr float magic_bias = 0.5f; 273 : const float denormalized_value = 274 : fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; 275 : 276 : /* 277 : * - Choose either results of conversion of input as a normalized number, or 278 : * as a denormalized number, depending on the input exponent. The variable 279 : * two_w contains input exponent in bits 27-31, therefore if its smaller than 280 : * 2**27, the input is either a denormal number, or zero. 281 : * - Combine the result of conversion of exponent and mantissa with the sign 282 : * of the input number. 283 : */ 284 : constexpr uint32_t denormalized_cutoff = UINT32_C(1) << 27; 285 : const uint32_t result = sign | 286 : (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) 287 : : fp32_to_bits(normalized_value)); 288 : return fp32_from_bits(result); 289 : } 290 : 291 : /* 292 : * Convert a 32-bit floating-point number in IEEE single-precision format to a 293 : * 16-bit floating-point number in IEEE half-precision format, in bit 294 : * representation. 295 : * 296 : * @note The implementation relies on IEEE-like (no assumption about rounding 297 : * mode and no operations on denormals) floating-point operations and bitcasts 298 : * between integer and floating-point variables. 299 : */ 300 : inline uint16_t fp16_ieee_from_fp32_value(float f) { 301 : // const float scale_to_inf = 0x1.0p+112f; 302 : // const float scale_to_zero = 0x1.0p-110f; 303 : constexpr uint32_t scale_to_inf_bits = (uint32_t)239 << 23; 304 : constexpr uint32_t scale_to_zero_bits = (uint32_t)17 << 23; 305 : float scale_to_inf_val, scale_to_zero_val; 306 : std::memcpy(&scale_to_inf_val, &scale_to_inf_bits, sizeof(scale_to_inf_val)); 307 : std::memcpy( 308 : &scale_to_zero_val, &scale_to_zero_bits, sizeof(scale_to_zero_val)); 309 : const float scale_to_inf = scale_to_inf_val; 310 : const float scale_to_zero = scale_to_zero_val; 311 : 312 : #if defined(_MSC_VER) && _MSC_VER == 1916 313 : float base = ((signbit(f) != 0 ? -f : f) * scale_to_inf) * scale_to_zero; 314 : #else 315 : float base = (fabsf(f) * scale_to_inf) * scale_to_zero; 316 : #endif 317 : 318 : const uint32_t w = fp32_to_bits(f); 319 : const uint32_t shl1_w = w + w; 320 : const uint32_t sign = w & UINT32_C(0x80000000); 321 : uint32_t bias = shl1_w & UINT32_C(0xFF000000); 322 : if (bias < UINT32_C(0x71000000)) { 323 : bias = UINT32_C(0x71000000); 324 : } 325 : 326 : base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; 327 : const uint32_t bits = fp32_to_bits(base); 328 : const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); 329 : const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); 330 : const uint32_t nonsign = exp_bits + mantissa_bits; 331 : return static_cast<uint16_t>( 332 : (sign >> 16) | 333 : (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign)); 334 : } 335 : 336 : } // namespace detail 337 : 338 : struct alignas(2) Half { 339 : unsigned short x; 340 : 341 : struct from_bits_t {}; 342 : C10_HOST_DEVICE static constexpr from_bits_t from_bits() { 343 : return from_bits_t(); 344 : } 345 : 346 : // HIP wants __host__ __device__ tag, CUDA does not 347 : #if defined(USE_ROCM) 348 : C10_HOST_DEVICE Half() = default; 349 : #else 350 : Half() = default; 351 : #endif 352 : 353 : constexpr C10_HOST_DEVICE Half(unsigned short bits, from_bits_t) : x(bits){}; 354 : inline C10_HOST_DEVICE Half(float value); 355 : inline C10_HOST_DEVICE operator float() const; 356 : 357 : #if defined(__CUDACC__) || defined(__HIPCC__) 358 : inline C10_HOST_DEVICE Half(const __half& value); 359 : inline C10_HOST_DEVICE operator __half() const; 360 : #endif 361 : #ifdef SYCL_LANGUAGE_VERSION 362 : inline C10_HOST_DEVICE Half(const sycl::half& value); 363 : inline C10_HOST_DEVICE operator sycl::half() const; 364 : #endif 365 : }; 366 : 367 : // TODO : move to complex.h 368 : template <> 369 : struct alignas(4) complex<Half> { 370 : Half real_; 371 : Half imag_; 372 : 373 : // Constructors 374 : complex() = default; 375 : // Half constructor is not constexpr so the following constructor can't 376 : // be constexpr 377 : C10_HOST_DEVICE explicit inline complex(const Half& real, const Half& imag) 378 : : real_(real), imag_(imag) {} 379 : C10_HOST_DEVICE inline complex(const c10::complex<float>& value) 380 : : real_(value.real()), imag_(value.imag()) {} 381 : 382 : // Conversion operator 383 : inline C10_HOST_DEVICE operator c10::complex<float>() const { 384 : return {real_, imag_}; 385 : } 386 : 387 : constexpr C10_HOST_DEVICE Half real() const { 388 : return real_; 389 : } 390 : constexpr C10_HOST_DEVICE Half imag() const { 391 : return imag_; 392 : } 393 : 394 : C10_HOST_DEVICE complex<Half>& operator+=(const complex<Half>& other) { 395 : real_ = static_cast<float>(real_) + static_cast<float>(other.real_); 396 : imag_ = static_cast<float>(imag_) + static_cast<float>(other.imag_); 397 : return *this; 398 : } 399 : 400 : C10_HOST_DEVICE complex<Half>& operator-=(const complex<Half>& other) { 401 : real_ = static_cast<float>(real_) - static_cast<float>(other.real_); 402 : imag_ = static_cast<float>(imag_) - static_cast<float>(other.imag_); 403 : return *this; 404 : } 405 : 406 : C10_HOST_DEVICE complex<Half>& operator*=(const complex<Half>& other) { 407 : auto a = static_cast<float>(real_); 408 : auto b = static_cast<float>(imag_); 409 : auto c = static_cast<float>(other.real()); 410 : auto d = static_cast<float>(other.imag()); 411 : real_ = a * c - b * d; 412 : imag_ = a * d + b * c; 413 : return *this; 414 : } 415 : }; 416 : 417 : // In some versions of MSVC, there will be a compiler error when building. 418 : // C4146: unary minus operator applied to unsigned type, result still unsigned 419 : // C4804: unsafe use of type 'bool' in operation 420 : // It can be addressed by disabling the following warning. 421 : #ifdef _MSC_VER 422 : #pragma warning(push) 423 : #pragma warning(disable : 4146) 424 : #pragma warning(disable : 4804) 425 : #pragma warning(disable : 4018) 426 : #endif 427 : 428 : // The overflow checks may involve float to int conversion which may 429 : // trigger precision loss warning. Re-enable the warning once the code 430 : // is fixed. See T58053069. 431 : C10_CLANG_DIAGNOSTIC_PUSH() 432 : #if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion") 433 : C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion") 434 : #endif 435 : 436 : // bool can be converted to any type. 437 : // Without specializing on bool, in pytorch_linux_trusty_py2_7_9_build: 438 : // `error: comparison of constant '255' with boolean expression is always false` 439 : // for `f > limit::max()` below 440 : template <typename To, typename From> 441 0 : typename std::enable_if<std::is_same<From, bool>::value, bool>::type overflows( 442 : From /*f*/) { 443 0 : return false; 444 : } 445 : 446 : // skip isnan and isinf check for integral types 447 : template <typename To, typename From> 448 : typename std::enable_if< 449 : std::is_integral<From>::value && !std::is_same<From, bool>::value, 450 : bool>::type 451 7072 : overflows(From f) { 452 : using limit = std::numeric_limits<typename scalar_value_type<To>::type>; 453 : if (!limit::is_signed && std::numeric_limits<From>::is_signed) { 454 : // allow for negative numbers to wrap using two's complement arithmetic. 455 : // For example, with uint8, this allows for `a - b` to be treated as 456 : // `a + 255 * b`. 457 : return greater_than_max<To>(f) || 458 : (c10::is_negative(f) && -static_cast<uint64_t>(f) > limit::max()); 459 : } else { 460 7072 : return c10::less_than_lowest<To>(f) || greater_than_max<To>(f); 461 : } 462 : } 463 : 464 : template <typename To, typename From> 465 : typename std::enable_if<std::is_floating_point<From>::value, bool>::type 466 0 : overflows(From f) { 467 : using limit = std::numeric_limits<typename scalar_value_type<To>::type>; 468 : if (limit::has_infinity && std::isinf(static_cast<double>(f))) { 469 : return false; 470 : } 471 0 : if (!limit::has_quiet_NaN && (f != f)) { 472 0 : return true; 473 : } 474 0 : return f < limit::lowest() || f > limit::max(); 475 : } 476 : 477 : C10_CLANG_DIAGNOSTIC_POP() 478 : 479 : #ifdef _MSC_VER 480 : #pragma warning(pop) 481 : #endif 482 : 483 : template <typename To, typename From> 484 0 : typename std::enable_if<is_complex<From>::value, bool>::type overflows(From f) { 485 : // casts from complex to real are considered to overflow if the 486 : // imaginary component is non-zero 487 0 : if (!is_complex<To>::value && f.imag() != 0) { 488 0 : return true; 489 : } 490 : // Check for overflow componentwise 491 : // (Technically, the imag overflow check is guaranteed to be false 492 : // when !is_complex<To>, but any optimizer worth its salt will be 493 : // able to figure it out.) 494 : return overflows< 495 : typename scalar_value_type<To>::type, 496 0 : typename From::value_type>(f.real()) || 497 : overflows< 498 : typename scalar_value_type<To>::type, 499 0 : typename From::value_type>(f.imag()); 500 : } 501 : 502 : C10_API std::ostream& operator<<(std::ostream& out, const Half& value); 503 : 504 : } // namespace c10 505 : 506 : #include <c10/util/Half-inl.h> // IWYU pragma: keep |
![]() |
Generated by: LCOV version 2.0-1 |
</html>