68 KiB
68 KiB
<html lang="en">
<head>
</head>
</html>
LCOV - code coverage report | ||||||||||||||||||||||
![]() | ||||||||||||||||||||||
|
||||||||||||||||||||||
![]() |
Line data Source code 1 : #pragma once 2 : 3 : #include <complex> 4 : 5 : #include <c10/macros/Macros.h> 6 : 7 : #if defined(__CUDACC__) || defined(__HIPCC__) 8 : #include <thrust/complex.h> 9 : #endif 10 : 11 : C10_CLANG_DIAGNOSTIC_PUSH() 12 : #if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion") 13 : C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion") 14 : #endif 15 : #if C10_CLANG_HAS_WARNING("-Wfloat-conversion") 16 : C10_CLANG_DIAGNOSTIC_IGNORE("-Wfloat-conversion") 17 : #endif 18 : 19 : namespace c10 { 20 : 21 : // c10::complex is an implementation of complex numbers that aims 22 : // to work on all devices supported by PyTorch 23 : // 24 : // Most of the APIs duplicates std::complex 25 : // Reference: https://en.cppreference.com/w/cpp/numeric/complex 26 : // 27 : // [NOTE: Complex Operator Unification] 28 : // Operators currently use a mix of std::complex, thrust::complex, and 29 : // c10::complex internally. The end state is that all operators will use 30 : // c10::complex internally. Until then, there may be some hacks to support all 31 : // variants. 32 : // 33 : // 34 : // [Note on Constructors] 35 : // 36 : // The APIs of constructors are mostly copied from C++ standard: 37 : // https://en.cppreference.com/w/cpp/numeric/complex/complex 38 : // 39 : // Since C++14, all constructors are constexpr in std::complex 40 : // 41 : // There are three types of constructors: 42 : // - initializing from real and imag: 43 : // `constexpr complex( const T& re = T(), const T& im = T() );` 44 : // - implicitly-declared copy constructor 45 : // - converting constructors 46 : // 47 : // Converting constructors: 48 : // - std::complex defines converting constructor between float/double/long 49 : // double, 50 : // while we define converting constructor between float/double. 51 : // - For these converting constructors, upcasting is implicit, downcasting is 52 : // explicit. 53 : // - We also define explicit casting from std::complex/thrust::complex 54 : // - Note that the conversion from thrust is not constexpr, because 55 : // thrust does not define them as constexpr ???? 56 : // 57 : // 58 : // [Operator =] 59 : // 60 : // The APIs of operator = are mostly copied from C++ standard: 61 : // https://en.cppreference.com/w/cpp/numeric/complex/operator%3D 62 : // 63 : // Since C++20, all operator= are constexpr. Although we are not building with 64 : // C++20, we also obey this behavior. 65 : // 66 : // There are three types of assign operator: 67 : // - Assign a real value from the same scalar type 68 : // - In std, this is templated as complex& operator=(const T& x) 69 : // with specialization `complex& operator=(T x)` for float/double/long 70 : // double Since we only support float and double, on will use `complex& 71 : // operator=(T x)` 72 : // - Copy assignment operator and converting assignment operator 73 : // - There is no specialization of converting assignment operators, which type 74 : // is 75 : // convertible is solely dependent on whether the scalar type is convertible 76 : // 77 : // In addition to the standard assignment, we also provide assignment operators 78 : // with std and thrust 79 : // 80 : // 81 : // [Casting operators] 82 : // 83 : // std::complex does not have casting operators. We define casting operators 84 : // casting to std::complex and thrust::complex 85 : // 86 : // 87 : // [Operator ""] 88 : // 89 : // std::complex has custom literals `i`, `if` and `il` defined in namespace 90 : // `std::literals::complex_literals`. We define our own custom literals in the 91 : // namespace `c10::complex_literals`. Our custom literals does not follow the 92 : // same behavior as in std::complex, instead, we define _if, _id to construct 93 : // float/double complex literals. 94 : // 95 : // 96 : // [real() and imag()] 97 : // 98 : // In C++20, there are two overload of these functions, one it to return the 99 : // real/imag, another is to set real/imag, they are both constexpr. We follow 100 : // this design. 101 : // 102 : // 103 : // [Operator +=,-=,*=,/=] 104 : // 105 : // Since C++20, these operators become constexpr. In our implementation, they 106 : // are also constexpr. 107 : // 108 : // There are two types of such operators: operating with a real number, or 109 : // operating with another complex number. For the operating with a real number, 110 : // the generic template form has argument type `const T &`, while the overload 111 : // for float/double/long double has `T`. We will follow the same type as 112 : // float/double/long double in std. 113 : // 114 : // [Unary operator +-] 115 : // 116 : // Since C++20, they are constexpr. We also make them expr 117 : // 118 : // [Binary operators +-*/] 119 : // 120 : // Each operator has three versions (taking + as example): 121 : // - complex + complex 122 : // - complex + real 123 : // - real + complex 124 : // 125 : // [Operator ==, !=] 126 : // 127 : // Each operator has three versions (taking == as example): 128 : // - complex == complex 129 : // - complex == real 130 : // - real == complex 131 : // 132 : // Some of them are removed on C++20, but we decide to keep them 133 : // 134 : // [Operator <<, >>] 135 : // 136 : // These are implemented by casting to std::complex 137 : // 138 : // 139 : // 140 : // TODO(@zasdfgbnm): c10::complex<c10::Half> is not currently supported, 141 : // because: 142 : // - lots of members and functions of c10::Half are not constexpr 143 : // - thrust::complex only support float and double 144 : 145 : template <typename T> 146 : struct alignas(sizeof(T) * 2) complex { 147 : using value_type = T; 148 : 149 : T real_ = T(0); 150 : T imag_ = T(0); 151 : 152 : constexpr complex() = default; 153 : C10_HOST_DEVICE constexpr complex(const T& re, const T& im = T()) 154 : : real_(re), imag_(im) {} 155 : template <typename U> 156 : explicit constexpr complex(const std::complex<U>& other) 157 : : complex(other.real(), other.imag()) {} 158 : #if defined(__CUDACC__) || defined(__HIPCC__) 159 : template <typename U> 160 : explicit C10_HOST_DEVICE complex(const thrust::complex<U>& other) 161 : : real_(other.real()), imag_(other.imag()) {} 162 : // NOTE can not be implemented as follow due to ROCm bug: 163 : // explicit C10_HOST_DEVICE complex(const thrust::complex<U> &other): 164 : // complex(other.real(), other.imag()) {} 165 : #endif 166 : 167 : // Use SFINAE to specialize casting constructor for c10::complex<float> and 168 : // c10::complex<double> 169 : template <typename U = T> 170 : C10_HOST_DEVICE explicit constexpr complex( 171 : const std::enable_if_t<std::is_same<U, float>::value, complex<double>>& 172 : other) 173 : : real_(other.real_), imag_(other.imag_) {} 174 : template <typename U = T> 175 : C10_HOST_DEVICE constexpr complex( 176 : const std::enable_if_t<std::is_same<U, double>::value, complex<float>>& 177 : other) 178 : : real_(other.real_), imag_(other.imag_) {} 179 : 180 : constexpr complex<T>& operator=(T re) { 181 : real_ = re; 182 : imag_ = 0; 183 : return *this; 184 : } 185 : 186 : constexpr complex<T>& operator+=(T re) { 187 : real_ += re; 188 : return *this; 189 : } 190 : 191 : constexpr complex<T>& operator-=(T re) { 192 : real_ -= re; 193 : return *this; 194 : } 195 : 196 : constexpr complex<T>& operator*=(T re) { 197 : real_ *= re; 198 : imag_ *= re; 199 : return *this; 200 : } 201 : 202 : constexpr complex<T>& operator/=(T re) { 203 : real_ /= re; 204 : imag_ /= re; 205 : return *this; 206 : } 207 : 208 : template <typename U> 209 : constexpr complex<T>& operator=(const complex<U>& rhs) { 210 : real_ = rhs.real(); 211 : imag_ = rhs.imag(); 212 : return *this; 213 : } 214 : 215 : template <typename U> 216 : constexpr complex<T>& operator+=(const complex<U>& rhs) { 217 : real_ += rhs.real(); 218 : imag_ += rhs.imag(); 219 : return *this; 220 : } 221 : 222 : template <typename U> 223 : constexpr complex<T>& operator-=(const complex<U>& rhs) { 224 : real_ -= rhs.real(); 225 : imag_ -= rhs.imag(); 226 : return *this; 227 : } 228 : 229 : template <typename U> 230 : constexpr complex<T>& operator*=(const complex<U>& rhs) { 231 : // (a + bi) * (c + di) = (a*c - b*d) + (a * d + b * c) i 232 : T a = real_; 233 : T b = imag_; 234 : U c = rhs.real(); 235 : U d = rhs.imag(); 236 : real_ = a * c - b * d; 237 : imag_ = a * d + b * c; 238 : return *this; 239 : } 240 : 241 : #ifdef __APPLE__ 242 : #define FORCE_INLINE_APPLE __attribute__((always_inline)) 243 : #else 244 : #define FORCE_INLINE_APPLE 245 : #endif 246 : template <typename U> 247 : constexpr FORCE_INLINE_APPLE complex<T>& operator/=(const complex<U>& rhs) 248 : __ubsan_ignore_float_divide_by_zero__ { 249 : // (a + bi) / (c + di) = (ac + bd)/(c^2 + d^2) + (bc - ad)/(c^2 + d^2) i 250 : // the calculation below follows numpy's complex division 251 : T a = real_; 252 : T b = imag_; 253 : U c = rhs.real(); 254 : U d = rhs.imag(); 255 : 256 : #if defined(__GNUC__) && !defined(__clang__) 257 : // std::abs is already constexpr by gcc 258 : auto abs_c = std::abs(c); 259 : auto abs_d = std::abs(d); 260 : #else 261 : auto abs_c = c < 0 ? -c : c; 262 : auto abs_d = d < 0 ? -d : d; 263 : #endif 264 : 265 : if (abs_c >= abs_d) { 266 : if (abs_c == 0 && abs_d == 0) { 267 : /* divide by zeros should yield a complex inf or nan */ 268 : real_ = a / abs_c; 269 : imag_ = b / abs_d; 270 : } else { 271 : auto rat = d / c; 272 : auto scl = 1.0 / (c + d * rat); 273 : real_ = (a + b * rat) * scl; 274 : imag_ = (b - a * rat) * scl; 275 : } 276 : } else { 277 : auto rat = c / d; 278 : auto scl = 1.0 / (d + c * rat); 279 : real_ = (a * rat + b) * scl; 280 : imag_ = (b * rat - a) * scl; 281 : } 282 : return *this; 283 : } 284 : #undef FORCE_INLINE_APPLE 285 : 286 : template <typename U> 287 : constexpr complex<T>& operator=(const std::complex<U>& rhs) { 288 : real_ = rhs.real(); 289 : imag_ = rhs.imag(); 290 : return *this; 291 : } 292 : 293 : #if defined(__CUDACC__) || defined(__HIPCC__) 294 : template <typename U> 295 : C10_HOST_DEVICE complex<T>& operator=(const thrust::complex<U>& rhs) { 296 : real_ = rhs.real(); 297 : imag_ = rhs.imag(); 298 : return *this; 299 : } 300 : #endif 301 : 302 : template <typename U> 303 : explicit constexpr operator std::complex<U>() const { 304 : return std::complex<U>(std::complex<T>(real(), imag())); 305 : } 306 : 307 : #if defined(__CUDACC__) || defined(__HIPCC__) 308 : template <typename U> 309 : C10_HOST_DEVICE explicit operator thrust::complex<U>() const { 310 : return static_cast<thrust::complex<U>>(thrust::complex<T>(real(), imag())); 311 : } 312 : #endif 313 : 314 : // consistent with NumPy behavior 315 : explicit constexpr operator bool() const { 316 : return real() || imag(); 317 : } 318 : 319 0 : C10_HOST_DEVICE constexpr T real() const { 320 0 : return real_; 321 : } 322 : constexpr void real(T value) { 323 : real_ = value; 324 : } 325 0 : constexpr T imag() const { 326 0 : return imag_; 327 : } 328 : constexpr void imag(T value) { 329 : imag_ = value; 330 : } 331 : }; 332 : 333 : namespace complex_literals { 334 : 335 : constexpr complex<float> operator"" _if(long double imag) { 336 : return complex<float>(0.0f, static_cast<float>(imag)); 337 : } 338 : 339 : constexpr complex<double> operator"" _id(long double imag) { 340 : return complex<double>(0.0, static_cast<double>(imag)); 341 : } 342 : 343 : constexpr complex<float> operator"" _if(unsigned long long imag) { 344 : return complex<float>(0.0f, static_cast<float>(imag)); 345 : } 346 : 347 : constexpr complex<double> operator"" _id(unsigned long long imag) { 348 : return complex<double>(0.0, static_cast<double>(imag)); 349 : } 350 : 351 : } // namespace complex_literals 352 : 353 : template <typename T> 354 : constexpr complex<T> operator+(const complex<T>& val) { 355 : return val; 356 : } 357 : 358 : template <typename T> 359 : constexpr complex<T> operator-(const complex<T>& val) { 360 : return complex<T>(-val.real(), -val.imag()); 361 : } 362 : 363 : template <typename T> 364 : constexpr complex<T> operator+(const complex<T>& lhs, const complex<T>& rhs) { 365 : complex<T> result = lhs; 366 : return result += rhs; 367 : } 368 : 369 : template <typename T> 370 : constexpr complex<T> operator+(const complex<T>& lhs, const T& rhs) { 371 : complex<T> result = lhs; 372 : return result += rhs; 373 : } 374 : 375 : template <typename T> 376 : constexpr complex<T> operator+(const T& lhs, const complex<T>& rhs) { 377 : return complex<T>(lhs + rhs.real(), rhs.imag()); 378 : } 379 : 380 : template <typename T> 381 : constexpr complex<T> operator-(const complex<T>& lhs, const complex<T>& rhs) { 382 : complex<T> result = lhs; 383 : return result -= rhs; 384 : } 385 : 386 : template <typename T> 387 : constexpr complex<T> operator-(const complex<T>& lhs, const T& rhs) { 388 : complex<T> result = lhs; 389 : return result -= rhs; 390 : } 391 : 392 : template <typename T> 393 : constexpr complex<T> operator-(const T& lhs, const complex<T>& rhs) { 394 : complex<T> result = -rhs; 395 : return result += lhs; 396 : } 397 : 398 : template <typename T> 399 : constexpr complex<T> operator*(const complex<T>& lhs, const complex<T>& rhs) { 400 : complex<T> result = lhs; 401 : return result *= rhs; 402 : } 403 : 404 : template <typename T> 405 : constexpr complex<T> operator*(const complex<T>& lhs, const T& rhs) { 406 : complex<T> result = lhs; 407 : return result *= rhs; 408 : } 409 : 410 : template <typename T> 411 : constexpr complex<T> operator*(const T& lhs, const complex<T>& rhs) { 412 : complex<T> result = rhs; 413 : return result *= lhs; 414 : } 415 : 416 : template <typename T> 417 : constexpr complex<T> operator/(const complex<T>& lhs, const complex<T>& rhs) { 418 : complex<T> result = lhs; 419 : return result /= rhs; 420 : } 421 : 422 : template <typename T> 423 : constexpr complex<T> operator/(const complex<T>& lhs, const T& rhs) { 424 : complex<T> result = lhs; 425 : return result /= rhs; 426 : } 427 : 428 : template <typename T> 429 : constexpr complex<T> operator/(const T& lhs, const complex<T>& rhs) { 430 : complex<T> result(lhs, T()); 431 : return result /= rhs; 432 : } 433 : 434 : // Define operators between integral scalars and c10::complex. std::complex does 435 : // not support this when T is a floating-point number. This is useful because it 436 : // saves a lot of "static_cast" when operate a complex and an integer. This 437 : // makes the code both less verbose and potentially more efficient. 438 : #define COMPLEX_INTEGER_OP_TEMPLATE_CONDITION \ 439 : typename std::enable_if_t< \ 440 : std::is_floating_point<fT>::value && std::is_integral<iT>::value, \ 441 : int> = 0 442 : 443 : template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION> 444 : constexpr c10::complex<fT> operator+(const c10::complex<fT>& a, const iT& b) { 445 : return a + static_cast<fT>(b); 446 : } 447 : 448 : template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION> 449 : constexpr c10::complex<fT> operator+(const iT& a, const c10::complex<fT>& b) { 450 : return static_cast<fT>(a) + b; 451 : } 452 : 453 : template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION> 454 : constexpr c10::complex<fT> operator-(const c10::complex<fT>& a, const iT& b) { 455 : return a - static_cast<fT>(b); 456 : } 457 : 458 : template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION> 459 : constexpr c10::complex<fT> operator-(const iT& a, const c10::complex<fT>& b) { 460 : return static_cast<fT>(a) - b; 461 : } 462 : 463 : template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION> 464 : constexpr c10::complex<fT> operator*(const c10::complex<fT>& a, const iT& b) { 465 : return a * static_cast<fT>(b); 466 : } 467 : 468 : template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION> 469 : constexpr c10::complex<fT> operator*(const iT& a, const c10::complex<fT>& b) { 470 : return static_cast<fT>(a) * b; 471 : } 472 : 473 : template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION> 474 : constexpr c10::complex<fT> operator/(const c10::complex<fT>& a, const iT& b) { 475 : return a / static_cast<fT>(b); 476 : } 477 : 478 : template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION> 479 : constexpr c10::complex<fT> operator/(const iT& a, const c10::complex<fT>& b) { 480 : return static_cast<fT>(a) / b; 481 : } 482 : 483 : #undef COMPLEX_INTEGER_OP_TEMPLATE_CONDITION 484 : 485 : template <typename T> 486 : constexpr bool operator==(const complex<T>& lhs, const complex<T>& rhs) { 487 : return (lhs.real() == rhs.real()) && (lhs.imag() == rhs.imag()); 488 : } 489 : 490 : template <typename T> 491 : constexpr bool operator==(const complex<T>& lhs, const T& rhs) { 492 : return (lhs.real() == rhs) && (lhs.imag() == T()); 493 : } 494 : 495 : template <typename T> 496 : constexpr bool operator==(const T& lhs, const complex<T>& rhs) { 497 : return (lhs == rhs.real()) && (T() == rhs.imag()); 498 : } 499 : 500 : template <typename T> 501 : constexpr bool operator!=(const complex<T>& lhs, const complex<T>& rhs) { 502 : return !(lhs == rhs); 503 : } 504 : 505 : template <typename T> 506 : constexpr bool operator!=(const complex<T>& lhs, const T& rhs) { 507 : return !(lhs == rhs); 508 : } 509 : 510 : template <typename T> 511 : constexpr bool operator!=(const T& lhs, const complex<T>& rhs) { 512 : return !(lhs == rhs); 513 : } 514 : 515 : template <typename T, typename CharT, typename Traits> 516 : std::basic_ostream<CharT, Traits>& operator<<( 517 : std::basic_ostream<CharT, Traits>& os, 518 : const complex<T>& x) { 519 : return (os << static_cast<std::complex<T>>(x)); 520 : } 521 : 522 : template <typename T, typename CharT, typename Traits> 523 : std::basic_istream<CharT, Traits>& operator>>( 524 : std::basic_istream<CharT, Traits>& is, 525 : complex<T>& x) { 526 : std::complex<T> tmp; 527 : is >> tmp; 528 : x = tmp; 529 : return is; 530 : } 531 : 532 : } // namespace c10 533 : 534 : // std functions 535 : // 536 : // The implementation of these functions also follow the design of C++20 537 : 538 : namespace std { 539 : 540 : template <typename T> 541 : constexpr T real(const c10::complex<T>& z) { 542 : return z.real(); 543 : } 544 : 545 : template <typename T> 546 : constexpr T imag(const c10::complex<T>& z) { 547 : return z.imag(); 548 : } 549 : 550 : template <typename T> 551 : C10_HOST_DEVICE T abs(const c10::complex<T>& z) { 552 : #if defined(__CUDACC__) || defined(__HIPCC__) 553 : return thrust::abs(static_cast<thrust::complex<T>>(z)); 554 : #else 555 : return std::abs(static_cast<std::complex<T>>(z)); 556 : #endif 557 : } 558 : 559 : #if defined(USE_ROCM) 560 : #define ROCm_Bug(x) 561 : #else 562 : #define ROCm_Bug(x) x 563 : #endif 564 : 565 : template <typename T> 566 : C10_HOST_DEVICE T arg(const c10::complex<T>& z) { 567 : return ROCm_Bug(std)::atan2(std::imag(z), std::real(z)); 568 : } 569 : 570 : #undef ROCm_Bug 571 : 572 : template <typename T> 573 : constexpr T norm(const c10::complex<T>& z) { 574 : return z.real() * z.real() + z.imag() * z.imag(); 575 : } 576 : 577 : // For std::conj, there are other versions of it: 578 : // constexpr std::complex<float> conj( float z ); 579 : // template< class DoubleOrInteger > 580 : // constexpr std::complex<double> conj( DoubleOrInteger z ); 581 : // constexpr std::complex<long double> conj( long double z ); 582 : // These are not implemented 583 : // TODO(@zasdfgbnm): implement them as c10::conj 584 : template <typename T> 585 : constexpr c10::complex<T> conj(const c10::complex<T>& z) { 586 : return c10::complex<T>(z.real(), -z.imag()); 587 : } 588 : 589 : // Thrust does not have complex --> complex version of thrust::proj, 590 : // so this function is not implemented at c10 right now. 591 : // TODO(@zasdfgbnm): implement it by ourselves 592 : 593 : // There is no c10 version of std::polar, because std::polar always 594 : // returns std::complex. Use c10::polar instead; 595 : 596 : } // namespace std 597 : 598 : namespace c10 { 599 : 600 : template <typename T> 601 : C10_HOST_DEVICE complex<T> polar(const T& r, const T& theta = T()) { 602 : #if defined(__CUDACC__) || defined(__HIPCC__) 603 : return static_cast<complex<T>>(thrust::polar(r, theta)); 604 : #else 605 : // std::polar() requires r >= 0, so spell out the explicit implementation to 606 : // avoid a branch. 607 : return complex<T>(r * std::cos(theta), r * std::sin(theta)); 608 : #endif 609 : } 610 : 611 : } // namespace c10 612 : 613 : C10_CLANG_DIAGNOSTIC_POP() 614 : 615 : #define C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H 616 : // math functions are included in a separate file 617 : #include <c10/util/complex_math.h> // IWYU pragma: keep 618 : // utilities for complex types 619 : #include <c10/util/complex_utils.h> // IWYU pragma: keep 620 : #undef C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H |
![]() |
Generated by: LCOV version 2.0-1 |
</html>