Line data Source code
1 : #pragma once
2 :
3 : /// Defines the Half type (half-precision floating-point) including conversions
4 : /// to standard C types and basic arithmetic operations. Note that arithmetic
5 : /// operations are implemented by converting to floating point and
6 : /// performing the operation in float32, instead of using CUDA half intrinsics.
7 : /// Most uses of this type within ATen are memory bound, including the
8 : /// element-wise kernels, and the half intrinsics aren't efficient on all GPUs.
9 : /// If you are writing a compute bound kernel, you can use the CUDA half
10 : /// intrinsics directly on the Half type from device code.
11 :
12 : #include <c10/macros/Macros.h>
13 : #include <c10/util/C++17.h>
14 : #include <c10/util/TypeSafeSignMath.h>
15 : #include <c10/util/complex.h>
16 : #include <c10/util/floating_point_utils.h>
17 : #include <type_traits>
18 :
19 : #if defined(__cplusplus) && (__cplusplus >= 201103L)
20 : #include <cmath>
21 : #include <cstdint>
22 : #elif !defined(__OPENCL_VERSION__)
23 : #include <math.h>
24 : #include <stdint.h>
25 : #endif
26 :
27 : #ifdef _MSC_VER
28 : #include <intrin.h>
29 : #endif
30 :
31 : #include <complex>
32 : #include <cstdint>
33 : #include <cstring>
34 : #include <iosfwd>
35 : #include <limits>
36 : #include <sstream>
37 : #include <stdexcept>
38 : #include <string>
39 : #include <utility>
40 :
41 : #ifdef __CUDACC__
42 : #include <cuda_fp16.h>
43 : #endif
44 :
45 : #ifdef __HIPCC__
46 : #include <hip/hip_fp16.h>
47 : #endif
48 :
49 : #if defined(CL_SYCL_LANGUAGE_VERSION)
50 : #include <CL/sycl.hpp> // for SYCL 1.2.1
51 : #elif defined(SYCL_LANGUAGE_VERSION)
52 : #include <sycl/sycl.hpp> // for SYCL 2020
53 : #endif
54 :
55 : #include <typeinfo> // operator typeid
56 :
57 : namespace c10 {
58 :
59 : namespace detail {
60 :
61 : /*
62 : * Convert a 16-bit floating-point number in IEEE half-precision format, in bit
63 : * representation, to a 32-bit floating-point number in IEEE single-precision
64 : * format, in bit representation.
65 : *
66 : * @note The implementation doesn't use any floating-point operations.
67 : */
68 : inline uint32_t fp16_ieee_to_fp32_bits(uint16_t h) {
69 : /*
70 : * Extend the half-precision floating-point number to 32 bits and shift to the
71 : * upper part of the 32-bit word:
72 : * +---+-----+------------+-------------------+
73 : * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000|
74 : * +---+-----+------------+-------------------+
75 : * Bits 31 26-30 16-25 0-15
76 : *
77 : * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0
78 : * - zero bits.
79 : */
80 : const uint32_t w = (uint32_t)h << 16;
81 : /*
82 : * Extract the sign of the input number into the high bit of the 32-bit word:
83 : *
84 : * +---+----------------------------------+
85 : * | S |0000000 00000000 00000000 00000000|
86 : * +---+----------------------------------+
87 : * Bits 31 0-31
88 : */
89 : const uint32_t sign = w & UINT32_C(0x80000000);
90 : /*
91 : * Extract mantissa and biased exponent of the input number into the bits 0-30
92 : * of the 32-bit word:
93 : *
94 : * +---+-----+------------+-------------------+
95 : * | 0 |EEEEE|MM MMMM MMMM|0000 0000 0000 0000|
96 : * +---+-----+------------+-------------------+
97 : * Bits 30 27-31 17-26 0-16
98 : */
99 : const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF);
100 : /*
101 : * Renorm shift is the number of bits to shift mantissa left to make the
102 : * half-precision number normalized. If the initial number is normalized, some
103 : * of its high 6 bits (sign == 0 and 5-bit exponent) equals one. In this case
104 : * renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note
105 : * that if we shift denormalized nonsign by renorm_shift, the unit bit of
106 : * mantissa will shift into exponent, turning the biased exponent into 1, and
107 : * making mantissa normalized (i.e. without leading 1).
108 : */
109 : #ifdef _MSC_VER
110 : unsigned long nonsign_bsr;
111 : _BitScanReverse(&nonsign_bsr, (unsigned long)nonsign);
112 : uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31;
113 : #else
114 : uint32_t renorm_shift = __builtin_clz(nonsign);
115 : #endif
116 : renorm_shift = renorm_shift > 5 ? renorm_shift - 5 : 0;
117 : /*
118 : * Iff half-precision number has exponent of 15, the addition overflows
119 : * it into bit 31, and the subsequent shift turns the high 9 bits
120 : * into 1. Thus inf_nan_mask == 0x7F800000 if the half-precision number
121 : * had exponent of 15 (i.e. was NaN or infinity) 0x00000000 otherwise
122 : */
123 : const int32_t inf_nan_mask =
124 : ((int32_t)(nonsign + 0x04000000) >> 8) & INT32_C(0x7F800000);
125 : /*
126 : * Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31
127 : * into 1. Otherwise, bit 31 remains 0. The signed shift right by 31
128 : * broadcasts bit 31 into all bits of the zero_mask. Thus zero_mask ==
129 : * 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h)
130 : * 0x00000000 otherwise
131 : */
132 : const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31;
133 : /*
134 : * 1. Shift nonsign left by renorm_shift to normalize it (if the input
135 : * was denormal)
136 : * 2. Shift nonsign right by 3 so the exponent (5 bits originally)
137 : * becomes an 8-bit field and 10-bit mantissa shifts into the 10 high
138 : * bits of the 23-bit mantissa of IEEE single-precision number.
139 : * 3. Add 0x70 to the exponent (starting at bit 23) to compensate the
140 : * different in exponent bias (0x7F for single-precision number less 0xF
141 : * for half-precision number).
142 : * 4. Subtract renorm_shift from the exponent (starting at bit 23) to
143 : * account for renormalization. As renorm_shift is less than 0x70, this
144 : * can be combined with step 3.
145 : * 5. Binary OR with inf_nan_mask to turn the exponent into 0xFF if the
146 : * input was NaN or infinity.
147 : * 6. Binary ANDNOT with zero_mask to turn the mantissa and exponent
148 : * into zero if the input was zero.
149 : * 7. Combine with the sign of the input number.
150 : */
151 : return sign |
152 : ((((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) |
153 : inf_nan_mask) &
154 : ~zero_mask);
155 : }
156 :
157 : /*
158 : * Convert a 16-bit floating-point number in IEEE half-precision format, in bit
159 : * representation, to a 32-bit floating-point number in IEEE single-precision
160 : * format.
161 : *
162 : * @note The implementation relies on IEEE-like (no assumption about rounding
163 : * mode and no operations on denormals) floating-point operations and bitcasts
164 : * between integer and floating-point variables.
165 : */
166 : C10_HOST_DEVICE inline float fp16_ieee_to_fp32_value(uint16_t h) {
167 : /*
168 : * Extend the half-precision floating-point number to 32 bits and shift to the
169 : * upper part of the 32-bit word:
170 : * +---+-----+------------+-------------------+
171 : * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000|
172 : * +---+-----+------------+-------------------+
173 : * Bits 31 26-30 16-25 0-15
174 : *
175 : * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0
176 : * - zero bits.
177 : */
178 : const uint32_t w = (uint32_t)h << 16;
179 : /*
180 : * Extract the sign of the input number into the high bit of the 32-bit word:
181 : *
182 : * +---+----------------------------------+
183 : * | S |0000000 00000000 00000000 00000000|
184 : * +---+----------------------------------+
185 : * Bits 31 0-31
186 : */
187 : const uint32_t sign = w & UINT32_C(0x80000000);
188 : /*
189 : * Extract mantissa and biased exponent of the input number into the high bits
190 : * of the 32-bit word:
191 : *
192 : * +-----+------------+---------------------+
193 : * |EEEEE|MM MMMM MMMM|0 0000 0000 0000 0000|
194 : * +-----+------------+---------------------+
195 : * Bits 27-31 17-26 0-16
196 : */
197 : const uint32_t two_w = w + w;
198 :
199 : /*
200 : * Shift mantissa and exponent into bits 23-28 and bits 13-22 so they become
201 : * mantissa and exponent of a single-precision floating-point number:
202 : *
203 : * S|Exponent | Mantissa
204 : * +-+---+-----+------------+----------------+
205 : * |0|000|EEEEE|MM MMMM MMMM|0 0000 0000 0000|
206 : * +-+---+-----+------------+----------------+
207 : * Bits | 23-31 | 0-22
208 : *
209 : * Next, there are some adjustments to the exponent:
210 : * - The exponent needs to be corrected by the difference in exponent bias
211 : * between single-precision and half-precision formats (0x7F - 0xF = 0x70)
212 : * - Inf and NaN values in the inputs should become Inf and NaN values after
213 : * conversion to the single-precision number. Therefore, if the biased
214 : * exponent of the half-precision input was 0x1F (max possible value), the
215 : * biased exponent of the single-precision output must be 0xFF (max possible
216 : * value). We do this correction in two steps:
217 : * - First, we adjust the exponent by (0xFF - 0x1F) = 0xE0 (see exp_offset
218 : * below) rather than by 0x70 suggested by the difference in the exponent bias
219 : * (see above).
220 : * - Then we multiply the single-precision result of exponent adjustment by
221 : * 2**(-112) to reverse the effect of exponent adjustment by 0xE0 less the
222 : * necessary exponent adjustment by 0x70 due to difference in exponent bias.
223 : * The floating-point multiplication hardware would ensure than Inf and
224 : * NaN would retain their value on at least partially IEEE754-compliant
225 : * implementations.
226 : *
227 : * Note that the above operations do not handle denormal inputs (where biased
228 : * exponent == 0). However, they also do not operate on denormal inputs, and
229 : * do not produce denormal results.
230 : */
231 : constexpr uint32_t exp_offset = UINT32_C(0xE0) << 23;
232 : // const float exp_scale = 0x1.0p-112f;
233 : constexpr uint32_t scale_bits = (uint32_t)15 << 23;
234 : float exp_scale_val;
235 : std::memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val));
236 : const float exp_scale = exp_scale_val;
237 : const float normalized_value =
238 : fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
239 :
240 : /*
241 : * Convert denormalized half-precision inputs into single-precision results
242 : * (always normalized). Zero inputs are also handled here.
243 : *
244 : * In a denormalized number the biased exponent is zero, and mantissa has
245 : * on-zero bits. First, we shift mantissa into bits 0-9 of the 32-bit word.
246 : *
247 : * zeros | mantissa
248 : * +---------------------------+------------+
249 : * |0000 0000 0000 0000 0000 00|MM MMMM MMMM|
250 : * +---------------------------+------------+
251 : * Bits 10-31 0-9
252 : *
253 : * Now, remember that denormalized half-precision numbers are represented as:
254 : * FP16 = mantissa * 2**(-24).
255 : * The trick is to construct a normalized single-precision number with the
256 : * same mantissa and thehalf-precision input and with an exponent which would
257 : * scale the corresponding mantissa bits to 2**(-24). A normalized
258 : * single-precision floating-point number is represented as: FP32 = (1 +
259 : * mantissa * 2**(-23)) * 2**(exponent - 127) Therefore, when the biased
260 : * exponent is 126, a unit change in the mantissa of the input denormalized
261 : * half-precision number causes a change of the constructed single-precision
262 : * number by 2**(-24), i.e. the same amount.
263 : *
264 : * The last step is to adjust the bias of the constructed single-precision
265 : * number. When the input half-precision number is zero, the constructed
266 : * single-precision number has the value of FP32 = 1 * 2**(126 - 127) =
267 : * 2**(-1) = 0.5 Therefore, we need to subtract 0.5 from the constructed
268 : * single-precision number to get the numerical equivalent of the input
269 : * half-precision number.
270 : */
271 : constexpr uint32_t magic_mask = UINT32_C(126) << 23;
272 : constexpr float magic_bias = 0.5f;
273 : const float denormalized_value =
274 : fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
275 :
276 : /*
277 : * - Choose either results of conversion of input as a normalized number, or
278 : * as a denormalized number, depending on the input exponent. The variable
279 : * two_w contains input exponent in bits 27-31, therefore if its smaller than
280 : * 2**27, the input is either a denormal number, or zero.
281 : * - Combine the result of conversion of exponent and mantissa with the sign
282 : * of the input number.
283 : */
284 : constexpr uint32_t denormalized_cutoff = UINT32_C(1) << 27;
285 : const uint32_t result = sign |
286 : (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value)
287 : : fp32_to_bits(normalized_value));
288 : return fp32_from_bits(result);
289 : }
290 :
291 : /*
292 : * Convert a 32-bit floating-point number in IEEE single-precision format to a
293 : * 16-bit floating-point number in IEEE half-precision format, in bit
294 : * representation.
295 : *
296 : * @note The implementation relies on IEEE-like (no assumption about rounding
297 : * mode and no operations on denormals) floating-point operations and bitcasts
298 : * between integer and floating-point variables.
299 : */
300 : inline uint16_t fp16_ieee_from_fp32_value(float f) {
301 : // const float scale_to_inf = 0x1.0p+112f;
302 : // const float scale_to_zero = 0x1.0p-110f;
303 : constexpr uint32_t scale_to_inf_bits = (uint32_t)239 << 23;
304 : constexpr uint32_t scale_to_zero_bits = (uint32_t)17 << 23;
305 : float scale_to_inf_val, scale_to_zero_val;
306 : std::memcpy(&scale_to_inf_val, &scale_to_inf_bits, sizeof(scale_to_inf_val));
307 : std::memcpy(
308 : &scale_to_zero_val, &scale_to_zero_bits, sizeof(scale_to_zero_val));
309 : const float scale_to_inf = scale_to_inf_val;
310 : const float scale_to_zero = scale_to_zero_val;
311 :
312 : #if defined(_MSC_VER) && _MSC_VER == 1916
313 : float base = ((signbit(f) != 0 ? -f : f) * scale_to_inf) * scale_to_zero;
314 : #else
315 : float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
316 : #endif
317 :
318 : const uint32_t w = fp32_to_bits(f);
319 : const uint32_t shl1_w = w + w;
320 : const uint32_t sign = w & UINT32_C(0x80000000);
321 : uint32_t bias = shl1_w & UINT32_C(0xFF000000);
322 : if (bias < UINT32_C(0x71000000)) {
323 : bias = UINT32_C(0x71000000);
324 : }
325 :
326 : base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
327 : const uint32_t bits = fp32_to_bits(base);
328 : const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
329 : const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
330 : const uint32_t nonsign = exp_bits + mantissa_bits;
331 : return static_cast<uint16_t>(
332 : (sign >> 16) |
333 : (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign));
334 : }
335 :
336 : } // namespace detail
337 :
338 : struct alignas(2) Half {
339 : unsigned short x;
340 :
341 : struct from_bits_t {};
342 : C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
343 : return from_bits_t();
344 : }
345 :
346 : // HIP wants __host__ __device__ tag, CUDA does not
347 : #if defined(USE_ROCM)
348 : C10_HOST_DEVICE Half() = default;
349 : #else
350 : Half() = default;
351 : #endif
352 :
353 : constexpr C10_HOST_DEVICE Half(unsigned short bits, from_bits_t) : x(bits){};
354 : inline C10_HOST_DEVICE Half(float value);
355 : inline C10_HOST_DEVICE operator float() const;
356 :
357 : #if defined(__CUDACC__) || defined(__HIPCC__)
358 : inline C10_HOST_DEVICE Half(const __half& value);
359 : inline C10_HOST_DEVICE operator __half() const;
360 : #endif
361 : #ifdef SYCL_LANGUAGE_VERSION
362 : inline C10_HOST_DEVICE Half(const sycl::half& value);
363 : inline C10_HOST_DEVICE operator sycl::half() const;
364 : #endif
365 : };
366 :
367 : // TODO : move to complex.h
368 : template <>
369 : struct alignas(4) complex<Half> {
370 : Half real_;
371 : Half imag_;
372 :
373 : // Constructors
374 : complex() = default;
375 : // Half constructor is not constexpr so the following constructor can't
376 : // be constexpr
377 : C10_HOST_DEVICE explicit inline complex(const Half& real, const Half& imag)
378 : : real_(real), imag_(imag) {}
379 : C10_HOST_DEVICE inline complex(const c10::complex<float>& value)
380 : : real_(value.real()), imag_(value.imag()) {}
381 :
382 : // Conversion operator
383 : inline C10_HOST_DEVICE operator c10::complex<float>() const {
384 : return {real_, imag_};
385 : }
386 :
387 : constexpr C10_HOST_DEVICE Half real() const {
388 : return real_;
389 : }
390 : constexpr C10_HOST_DEVICE Half imag() const {
391 : return imag_;
392 : }
393 :
394 : C10_HOST_DEVICE complex<Half>& operator+=(const complex<Half>& other) {
395 : real_ = static_cast<float>(real_) + static_cast<float>(other.real_);
396 : imag_ = static_cast<float>(imag_) + static_cast<float>(other.imag_);
397 : return *this;
398 : }
399 :
400 : C10_HOST_DEVICE complex<Half>& operator-=(const complex<Half>& other) {
401 : real_ = static_cast<float>(real_) - static_cast<float>(other.real_);
402 : imag_ = static_cast<float>(imag_) - static_cast<float>(other.imag_);
403 : return *this;
404 : }
405 :
406 : C10_HOST_DEVICE complex<Half>& operator*=(const complex<Half>& other) {
407 : auto a = static_cast<float>(real_);
408 : auto b = static_cast<float>(imag_);
409 : auto c = static_cast<float>(other.real());
410 : auto d = static_cast<float>(other.imag());
411 : real_ = a * c - b * d;
412 : imag_ = a * d + b * c;
413 : return *this;
414 : }
415 : };
416 :
417 : // In some versions of MSVC, there will be a compiler error when building.
418 : // C4146: unary minus operator applied to unsigned type, result still unsigned
419 : // C4804: unsafe use of type 'bool' in operation
420 : // It can be addressed by disabling the following warning.
421 : #ifdef _MSC_VER
422 : #pragma warning(push)
423 : #pragma warning(disable : 4146)
424 : #pragma warning(disable : 4804)
425 : #pragma warning(disable : 4018)
426 : #endif
427 :
428 : // The overflow checks may involve float to int conversion which may
429 : // trigger precision loss warning. Re-enable the warning once the code
430 : // is fixed. See T58053069.
431 : C10_CLANG_DIAGNOSTIC_PUSH()
432 : #if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion")
433 : C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion")
434 : #endif
435 :
436 : // bool can be converted to any type.
437 : // Without specializing on bool, in pytorch_linux_trusty_py2_7_9_build:
438 : // `error: comparison of constant '255' with boolean expression is always false`
439 : // for `f > limit::max()` below
440 : template <typename To, typename From>
441 0 : typename std::enable_if<std::is_same<From, bool>::value, bool>::type overflows(
442 : From /*f*/) {
443 0 : return false;
444 : }
445 :
446 : // skip isnan and isinf check for integral types
447 : template <typename To, typename From>
448 : typename std::enable_if<
449 : std::is_integral<From>::value && !std::is_same<From, bool>::value,
450 : bool>::type
451 7072 : overflows(From f) {
452 : using limit = std::numeric_limits<typename scalar_value_type<To>::type>;
453 : if (!limit::is_signed && std::numeric_limits<From>::is_signed) {
454 : // allow for negative numbers to wrap using two's complement arithmetic.
455 : // For example, with uint8, this allows for `a - b` to be treated as
456 : // `a + 255 * b`.
457 : return greater_than_max<To>(f) ||
458 : (c10::is_negative(f) && -static_cast<uint64_t>(f) > limit::max());
459 : } else {
460 7072 : return c10::less_than_lowest<To>(f) || greater_than_max<To>(f);
461 : }
462 : }
463 :
464 : template <typename To, typename From>
465 : typename std::enable_if<std::is_floating_point<From>::value, bool>::type
466 0 : overflows(From f) {
467 : using limit = std::numeric_limits<typename scalar_value_type<To>::type>;
468 : if (limit::has_infinity && std::isinf(static_cast<double>(f))) {
469 : return false;
470 : }
471 0 : if (!limit::has_quiet_NaN && (f != f)) {
472 0 : return true;
473 : }
474 0 : return f < limit::lowest() || f > limit::max();
475 : }
476 :
477 : C10_CLANG_DIAGNOSTIC_POP()
478 :
479 : #ifdef _MSC_VER
480 : #pragma warning(pop)
481 : #endif
482 :
483 : template <typename To, typename From>
484 0 : typename std::enable_if<is_complex<From>::value, bool>::type overflows(From f) {
485 : // casts from complex to real are considered to overflow if the
486 : // imaginary component is non-zero
487 0 : if (!is_complex<To>::value && f.imag() != 0) {
488 0 : return true;
489 : }
490 : // Check for overflow componentwise
491 : // (Technically, the imag overflow check is guaranteed to be false
492 : // when !is_complex<To>, but any optimizer worth its salt will be
493 : // able to figure it out.)
494 : return overflows<
495 : typename scalar_value_type<To>::type,
496 0 : typename From::value_type>(f.real()) ||
497 : overflows<
498 : typename scalar_value_type<To>::type,
499 0 : typename From::value_type>(f.imag());
500 : }
501 :
502 : C10_API std::ostream& operator<<(std::ostream& out, const Half& value);
503 :
504 : } // namespace c10
505 :
506 : #include <c10/util/Half-inl.h> // IWYU pragma: keep
|