LCOV - code coverage report
Current view: top level - libtorch/include/c10/util - complex.h (source / functions) Coverage Total Hit
Test: coverage.info Lines: 0.0 % 4 0
Test Date: 2024-04-30 13:17:26 Functions: 0.0 % 2 0

            Line data    Source code
       1              : #pragma once
       2              : 
       3              : #include <complex>
       4              : 
       5              : #include <c10/macros/Macros.h>
       6              : 
       7              : #if defined(__CUDACC__) || defined(__HIPCC__)
       8              : #include <thrust/complex.h>
       9              : #endif
      10              : 
      11              : C10_CLANG_DIAGNOSTIC_PUSH()
      12              : #if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion")
      13              : C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion")
      14              : #endif
      15              : #if C10_CLANG_HAS_WARNING("-Wfloat-conversion")
      16              : C10_CLANG_DIAGNOSTIC_IGNORE("-Wfloat-conversion")
      17              : #endif
      18              : 
      19              : namespace c10 {
      20              : 
      21              : // c10::complex is an implementation of complex numbers that aims
      22              : // to work on all devices supported by PyTorch
      23              : //
      24              : // Most of the APIs duplicates std::complex
      25              : // Reference: https://en.cppreference.com/w/cpp/numeric/complex
      26              : //
      27              : // [NOTE: Complex Operator Unification]
      28              : // Operators currently use a mix of std::complex, thrust::complex, and
      29              : // c10::complex internally. The end state is that all operators will use
      30              : // c10::complex internally.  Until then, there may be some hacks to support all
      31              : // variants.
      32              : //
      33              : //
      34              : // [Note on Constructors]
      35              : //
      36              : // The APIs of constructors are mostly copied from C++ standard:
      37              : //   https://en.cppreference.com/w/cpp/numeric/complex/complex
      38              : //
      39              : // Since C++14, all constructors are constexpr in std::complex
      40              : //
      41              : // There are three types of constructors:
      42              : // - initializing from real and imag:
      43              : //     `constexpr complex( const T& re = T(), const T& im = T() );`
      44              : // - implicitly-declared copy constructor
      45              : // - converting constructors
      46              : //
      47              : // Converting constructors:
      48              : // - std::complex defines converting constructor between float/double/long
      49              : // double,
      50              : //   while we define converting constructor between float/double.
      51              : // - For these converting constructors, upcasting is implicit, downcasting is
      52              : //   explicit.
      53              : // - We also define explicit casting from std::complex/thrust::complex
      54              : //   - Note that the conversion from thrust is not constexpr, because
      55              : //     thrust does not define them as constexpr ????
      56              : //
      57              : //
      58              : // [Operator =]
      59              : //
      60              : // The APIs of operator = are mostly copied from C++ standard:
      61              : //   https://en.cppreference.com/w/cpp/numeric/complex/operator%3D
      62              : //
      63              : // Since C++20, all operator= are constexpr. Although we are not building with
      64              : // C++20, we also obey this behavior.
      65              : //
      66              : // There are three types of assign operator:
      67              : // - Assign a real value from the same scalar type
      68              : //   - In std, this is templated as complex& operator=(const T& x)
      69              : //     with specialization `complex& operator=(T x)` for float/double/long
      70              : //     double Since we only support float and double, on will use `complex&
      71              : //     operator=(T x)`
      72              : // - Copy assignment operator and converting assignment operator
      73              : //   - There is no specialization of converting assignment operators, which type
      74              : //   is
      75              : //     convertible is solely dependent on whether the scalar type is convertible
      76              : //
      77              : // In addition to the standard assignment, we also provide assignment operators
      78              : // with std and thrust
      79              : //
      80              : //
      81              : // [Casting operators]
      82              : //
      83              : // std::complex does not have casting operators. We define casting operators
      84              : // casting to std::complex and thrust::complex
      85              : //
      86              : //
      87              : // [Operator ""]
      88              : //
      89              : // std::complex has custom literals `i`, `if` and `il` defined in namespace
      90              : // `std::literals::complex_literals`. We define our own custom literals in the
      91              : // namespace `c10::complex_literals`. Our custom literals does not follow the
      92              : // same behavior as in std::complex, instead, we define _if, _id to construct
      93              : // float/double complex literals.
      94              : //
      95              : //
      96              : // [real() and imag()]
      97              : //
      98              : // In C++20, there are two overload of these functions, one it to return the
      99              : // real/imag, another is to set real/imag, they are both constexpr. We follow
     100              : // this design.
     101              : //
     102              : //
     103              : // [Operator +=,-=,*=,/=]
     104              : //
     105              : // Since C++20, these operators become constexpr. In our implementation, they
     106              : // are also constexpr.
     107              : //
     108              : // There are two types of such operators: operating with a real number, or
     109              : // operating with another complex number. For the operating with a real number,
     110              : // the generic template form has argument type `const T &`, while the overload
     111              : // for float/double/long double has `T`. We will follow the same type as
     112              : // float/double/long double in std.
     113              : //
     114              : // [Unary operator +-]
     115              : //
     116              : // Since C++20, they are constexpr. We also make them expr
     117              : //
     118              : // [Binary operators +-*/]
     119              : //
     120              : // Each operator has three versions (taking + as example):
     121              : // - complex + complex
     122              : // - complex + real
     123              : // - real + complex
     124              : //
     125              : // [Operator ==, !=]
     126              : //
     127              : // Each operator has three versions (taking == as example):
     128              : // - complex == complex
     129              : // - complex == real
     130              : // - real == complex
     131              : //
     132              : // Some of them are removed on C++20, but we decide to keep them
     133              : //
     134              : // [Operator <<, >>]
     135              : //
     136              : // These are implemented by casting to std::complex
     137              : //
     138              : //
     139              : //
     140              : // TODO(@zasdfgbnm): c10::complex<c10::Half> is not currently supported,
     141              : // because:
     142              : //  - lots of members and functions of c10::Half are not constexpr
     143              : //  - thrust::complex only support float and double
     144              : 
     145              : template <typename T>
     146              : struct alignas(sizeof(T) * 2) complex {
     147              :   using value_type = T;
     148              : 
     149              :   T real_ = T(0);
     150              :   T imag_ = T(0);
     151              : 
     152              :   constexpr complex() = default;
     153              :   C10_HOST_DEVICE constexpr complex(const T& re, const T& im = T())
     154              :       : real_(re), imag_(im) {}
     155              :   template <typename U>
     156              :   explicit constexpr complex(const std::complex<U>& other)
     157              :       : complex(other.real(), other.imag()) {}
     158              : #if defined(__CUDACC__) || defined(__HIPCC__)
     159              :   template <typename U>
     160              :   explicit C10_HOST_DEVICE complex(const thrust::complex<U>& other)
     161              :       : real_(other.real()), imag_(other.imag()) {}
     162              : // NOTE can not be implemented as follow due to ROCm bug:
     163              : //   explicit C10_HOST_DEVICE complex(const thrust::complex<U> &other):
     164              : //   complex(other.real(), other.imag()) {}
     165              : #endif
     166              : 
     167              :   // Use SFINAE to specialize casting constructor for c10::complex<float> and
     168              :   // c10::complex<double>
     169              :   template <typename U = T>
     170              :   C10_HOST_DEVICE explicit constexpr complex(
     171              :       const std::enable_if_t<std::is_same<U, float>::value, complex<double>>&
     172              :           other)
     173              :       : real_(other.real_), imag_(other.imag_) {}
     174              :   template <typename U = T>
     175              :   C10_HOST_DEVICE constexpr complex(
     176              :       const std::enable_if_t<std::is_same<U, double>::value, complex<float>>&
     177              :           other)
     178              :       : real_(other.real_), imag_(other.imag_) {}
     179              : 
     180              :   constexpr complex<T>& operator=(T re) {
     181              :     real_ = re;
     182              :     imag_ = 0;
     183              :     return *this;
     184              :   }
     185              : 
     186              :   constexpr complex<T>& operator+=(T re) {
     187              :     real_ += re;
     188              :     return *this;
     189              :   }
     190              : 
     191              :   constexpr complex<T>& operator-=(T re) {
     192              :     real_ -= re;
     193              :     return *this;
     194              :   }
     195              : 
     196              :   constexpr complex<T>& operator*=(T re) {
     197              :     real_ *= re;
     198              :     imag_ *= re;
     199              :     return *this;
     200              :   }
     201              : 
     202              :   constexpr complex<T>& operator/=(T re) {
     203              :     real_ /= re;
     204              :     imag_ /= re;
     205              :     return *this;
     206              :   }
     207              : 
     208              :   template <typename U>
     209              :   constexpr complex<T>& operator=(const complex<U>& rhs) {
     210              :     real_ = rhs.real();
     211              :     imag_ = rhs.imag();
     212              :     return *this;
     213              :   }
     214              : 
     215              :   template <typename U>
     216              :   constexpr complex<T>& operator+=(const complex<U>& rhs) {
     217              :     real_ += rhs.real();
     218              :     imag_ += rhs.imag();
     219              :     return *this;
     220              :   }
     221              : 
     222              :   template <typename U>
     223              :   constexpr complex<T>& operator-=(const complex<U>& rhs) {
     224              :     real_ -= rhs.real();
     225              :     imag_ -= rhs.imag();
     226              :     return *this;
     227              :   }
     228              : 
     229              :   template <typename U>
     230              :   constexpr complex<T>& operator*=(const complex<U>& rhs) {
     231              :     // (a + bi) * (c + di) = (a*c - b*d) + (a * d + b * c) i
     232              :     T a = real_;
     233              :     T b = imag_;
     234              :     U c = rhs.real();
     235              :     U d = rhs.imag();
     236              :     real_ = a * c - b * d;
     237              :     imag_ = a * d + b * c;
     238              :     return *this;
     239              :   }
     240              : 
     241              : #ifdef __APPLE__
     242              : #define FORCE_INLINE_APPLE __attribute__((always_inline))
     243              : #else
     244              : #define FORCE_INLINE_APPLE
     245              : #endif
     246              :   template <typename U>
     247              :   constexpr FORCE_INLINE_APPLE complex<T>& operator/=(const complex<U>& rhs)
     248              :       __ubsan_ignore_float_divide_by_zero__ {
     249              :     // (a + bi) / (c + di) = (ac + bd)/(c^2 + d^2) + (bc - ad)/(c^2 + d^2) i
     250              :     // the calculation below follows numpy's complex division
     251              :     T a = real_;
     252              :     T b = imag_;
     253              :     U c = rhs.real();
     254              :     U d = rhs.imag();
     255              : 
     256              : #if defined(__GNUC__) && !defined(__clang__)
     257              :     // std::abs is already constexpr by gcc
     258              :     auto abs_c = std::abs(c);
     259              :     auto abs_d = std::abs(d);
     260              : #else
     261              :     auto abs_c = c < 0 ? -c : c;
     262              :     auto abs_d = d < 0 ? -d : d;
     263              : #endif
     264              : 
     265              :     if (abs_c >= abs_d) {
     266              :       if (abs_c == 0 && abs_d == 0) {
     267              :         /* divide by zeros should yield a complex inf or nan */
     268              :         real_ = a / abs_c;
     269              :         imag_ = b / abs_d;
     270              :       } else {
     271              :         auto rat = d / c;
     272              :         auto scl = 1.0 / (c + d * rat);
     273              :         real_ = (a + b * rat) * scl;
     274              :         imag_ = (b - a * rat) * scl;
     275              :       }
     276              :     } else {
     277              :       auto rat = c / d;
     278              :       auto scl = 1.0 / (d + c * rat);
     279              :       real_ = (a * rat + b) * scl;
     280              :       imag_ = (b * rat - a) * scl;
     281              :     }
     282              :     return *this;
     283              :   }
     284              : #undef FORCE_INLINE_APPLE
     285              : 
     286              :   template <typename U>
     287              :   constexpr complex<T>& operator=(const std::complex<U>& rhs) {
     288              :     real_ = rhs.real();
     289              :     imag_ = rhs.imag();
     290              :     return *this;
     291              :   }
     292              : 
     293              : #if defined(__CUDACC__) || defined(__HIPCC__)
     294              :   template <typename U>
     295              :   C10_HOST_DEVICE complex<T>& operator=(const thrust::complex<U>& rhs) {
     296              :     real_ = rhs.real();
     297              :     imag_ = rhs.imag();
     298              :     return *this;
     299              :   }
     300              : #endif
     301              : 
     302              :   template <typename U>
     303              :   explicit constexpr operator std::complex<U>() const {
     304              :     return std::complex<U>(std::complex<T>(real(), imag()));
     305              :   }
     306              : 
     307              : #if defined(__CUDACC__) || defined(__HIPCC__)
     308              :   template <typename U>
     309              :   C10_HOST_DEVICE explicit operator thrust::complex<U>() const {
     310              :     return static_cast<thrust::complex<U>>(thrust::complex<T>(real(), imag()));
     311              :   }
     312              : #endif
     313              : 
     314              :   // consistent with NumPy behavior
     315              :   explicit constexpr operator bool() const {
     316              :     return real() || imag();
     317              :   }
     318              : 
     319            0 :   C10_HOST_DEVICE constexpr T real() const {
     320            0 :     return real_;
     321              :   }
     322              :   constexpr void real(T value) {
     323              :     real_ = value;
     324              :   }
     325            0 :   constexpr T imag() const {
     326            0 :     return imag_;
     327              :   }
     328              :   constexpr void imag(T value) {
     329              :     imag_ = value;
     330              :   }
     331              : };
     332              : 
     333              : namespace complex_literals {
     334              : 
     335              : constexpr complex<float> operator"" _if(long double imag) {
     336              :   return complex<float>(0.0f, static_cast<float>(imag));
     337              : }
     338              : 
     339              : constexpr complex<double> operator"" _id(long double imag) {
     340              :   return complex<double>(0.0, static_cast<double>(imag));
     341              : }
     342              : 
     343              : constexpr complex<float> operator"" _if(unsigned long long imag) {
     344              :   return complex<float>(0.0f, static_cast<float>(imag));
     345              : }
     346              : 
     347              : constexpr complex<double> operator"" _id(unsigned long long imag) {
     348              :   return complex<double>(0.0, static_cast<double>(imag));
     349              : }
     350              : 
     351              : } // namespace complex_literals
     352              : 
     353              : template <typename T>
     354              : constexpr complex<T> operator+(const complex<T>& val) {
     355              :   return val;
     356              : }
     357              : 
     358              : template <typename T>
     359              : constexpr complex<T> operator-(const complex<T>& val) {
     360              :   return complex<T>(-val.real(), -val.imag());
     361              : }
     362              : 
     363              : template <typename T>
     364              : constexpr complex<T> operator+(const complex<T>& lhs, const complex<T>& rhs) {
     365              :   complex<T> result = lhs;
     366              :   return result += rhs;
     367              : }
     368              : 
     369              : template <typename T>
     370              : constexpr complex<T> operator+(const complex<T>& lhs, const T& rhs) {
     371              :   complex<T> result = lhs;
     372              :   return result += rhs;
     373              : }
     374              : 
     375              : template <typename T>
     376              : constexpr complex<T> operator+(const T& lhs, const complex<T>& rhs) {
     377              :   return complex<T>(lhs + rhs.real(), rhs.imag());
     378              : }
     379              : 
     380              : template <typename T>
     381              : constexpr complex<T> operator-(const complex<T>& lhs, const complex<T>& rhs) {
     382              :   complex<T> result = lhs;
     383              :   return result -= rhs;
     384              : }
     385              : 
     386              : template <typename T>
     387              : constexpr complex<T> operator-(const complex<T>& lhs, const T& rhs) {
     388              :   complex<T> result = lhs;
     389              :   return result -= rhs;
     390              : }
     391              : 
     392              : template <typename T>
     393              : constexpr complex<T> operator-(const T& lhs, const complex<T>& rhs) {
     394              :   complex<T> result = -rhs;
     395              :   return result += lhs;
     396              : }
     397              : 
     398              : template <typename T>
     399              : constexpr complex<T> operator*(const complex<T>& lhs, const complex<T>& rhs) {
     400              :   complex<T> result = lhs;
     401              :   return result *= rhs;
     402              : }
     403              : 
     404              : template <typename T>
     405              : constexpr complex<T> operator*(const complex<T>& lhs, const T& rhs) {
     406              :   complex<T> result = lhs;
     407              :   return result *= rhs;
     408              : }
     409              : 
     410              : template <typename T>
     411              : constexpr complex<T> operator*(const T& lhs, const complex<T>& rhs) {
     412              :   complex<T> result = rhs;
     413              :   return result *= lhs;
     414              : }
     415              : 
     416              : template <typename T>
     417              : constexpr complex<T> operator/(const complex<T>& lhs, const complex<T>& rhs) {
     418              :   complex<T> result = lhs;
     419              :   return result /= rhs;
     420              : }
     421              : 
     422              : template <typename T>
     423              : constexpr complex<T> operator/(const complex<T>& lhs, const T& rhs) {
     424              :   complex<T> result = lhs;
     425              :   return result /= rhs;
     426              : }
     427              : 
     428              : template <typename T>
     429              : constexpr complex<T> operator/(const T& lhs, const complex<T>& rhs) {
     430              :   complex<T> result(lhs, T());
     431              :   return result /= rhs;
     432              : }
     433              : 
     434              : // Define operators between integral scalars and c10::complex. std::complex does
     435              : // not support this when T is a floating-point number. This is useful because it
     436              : // saves a lot of "static_cast" when operate a complex and an integer. This
     437              : // makes the code both less verbose and potentially more efficient.
     438              : #define COMPLEX_INTEGER_OP_TEMPLATE_CONDITION                           \
     439              :   typename std::enable_if_t<                                            \
     440              :       std::is_floating_point<fT>::value && std::is_integral<iT>::value, \
     441              :       int> = 0
     442              : 
     443              : template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
     444              : constexpr c10::complex<fT> operator+(const c10::complex<fT>& a, const iT& b) {
     445              :   return a + static_cast<fT>(b);
     446              : }
     447              : 
     448              : template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
     449              : constexpr c10::complex<fT> operator+(const iT& a, const c10::complex<fT>& b) {
     450              :   return static_cast<fT>(a) + b;
     451              : }
     452              : 
     453              : template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
     454              : constexpr c10::complex<fT> operator-(const c10::complex<fT>& a, const iT& b) {
     455              :   return a - static_cast<fT>(b);
     456              : }
     457              : 
     458              : template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
     459              : constexpr c10::complex<fT> operator-(const iT& a, const c10::complex<fT>& b) {
     460              :   return static_cast<fT>(a) - b;
     461              : }
     462              : 
     463              : template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
     464              : constexpr c10::complex<fT> operator*(const c10::complex<fT>& a, const iT& b) {
     465              :   return a * static_cast<fT>(b);
     466              : }
     467              : 
     468              : template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
     469              : constexpr c10::complex<fT> operator*(const iT& a, const c10::complex<fT>& b) {
     470              :   return static_cast<fT>(a) * b;
     471              : }
     472              : 
     473              : template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
     474              : constexpr c10::complex<fT> operator/(const c10::complex<fT>& a, const iT& b) {
     475              :   return a / static_cast<fT>(b);
     476              : }
     477              : 
     478              : template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
     479              : constexpr c10::complex<fT> operator/(const iT& a, const c10::complex<fT>& b) {
     480              :   return static_cast<fT>(a) / b;
     481              : }
     482              : 
     483              : #undef COMPLEX_INTEGER_OP_TEMPLATE_CONDITION
     484              : 
     485              : template <typename T>
     486              : constexpr bool operator==(const complex<T>& lhs, const complex<T>& rhs) {
     487              :   return (lhs.real() == rhs.real()) && (lhs.imag() == rhs.imag());
     488              : }
     489              : 
     490              : template <typename T>
     491              : constexpr bool operator==(const complex<T>& lhs, const T& rhs) {
     492              :   return (lhs.real() == rhs) && (lhs.imag() == T());
     493              : }
     494              : 
     495              : template <typename T>
     496              : constexpr bool operator==(const T& lhs, const complex<T>& rhs) {
     497              :   return (lhs == rhs.real()) && (T() == rhs.imag());
     498              : }
     499              : 
     500              : template <typename T>
     501              : constexpr bool operator!=(const complex<T>& lhs, const complex<T>& rhs) {
     502              :   return !(lhs == rhs);
     503              : }
     504              : 
     505              : template <typename T>
     506              : constexpr bool operator!=(const complex<T>& lhs, const T& rhs) {
     507              :   return !(lhs == rhs);
     508              : }
     509              : 
     510              : template <typename T>
     511              : constexpr bool operator!=(const T& lhs, const complex<T>& rhs) {
     512              :   return !(lhs == rhs);
     513              : }
     514              : 
     515              : template <typename T, typename CharT, typename Traits>
     516              : std::basic_ostream<CharT, Traits>& operator<<(
     517              :     std::basic_ostream<CharT, Traits>& os,
     518              :     const complex<T>& x) {
     519              :   return (os << static_cast<std::complex<T>>(x));
     520              : }
     521              : 
     522              : template <typename T, typename CharT, typename Traits>
     523              : std::basic_istream<CharT, Traits>& operator>>(
     524              :     std::basic_istream<CharT, Traits>& is,
     525              :     complex<T>& x) {
     526              :   std::complex<T> tmp;
     527              :   is >> tmp;
     528              :   x = tmp;
     529              :   return is;
     530              : }
     531              : 
     532              : } // namespace c10
     533              : 
     534              : // std functions
     535              : //
     536              : // The implementation of these functions also follow the design of C++20
     537              : 
     538              : namespace std {
     539              : 
     540              : template <typename T>
     541              : constexpr T real(const c10::complex<T>& z) {
     542              :   return z.real();
     543              : }
     544              : 
     545              : template <typename T>
     546              : constexpr T imag(const c10::complex<T>& z) {
     547              :   return z.imag();
     548              : }
     549              : 
     550              : template <typename T>
     551              : C10_HOST_DEVICE T abs(const c10::complex<T>& z) {
     552              : #if defined(__CUDACC__) || defined(__HIPCC__)
     553              :   return thrust::abs(static_cast<thrust::complex<T>>(z));
     554              : #else
     555              :   return std::abs(static_cast<std::complex<T>>(z));
     556              : #endif
     557              : }
     558              : 
     559              : #if defined(USE_ROCM)
     560              : #define ROCm_Bug(x)
     561              : #else
     562              : #define ROCm_Bug(x) x
     563              : #endif
     564              : 
     565              : template <typename T>
     566              : C10_HOST_DEVICE T arg(const c10::complex<T>& z) {
     567              :   return ROCm_Bug(std)::atan2(std::imag(z), std::real(z));
     568              : }
     569              : 
     570              : #undef ROCm_Bug
     571              : 
     572              : template <typename T>
     573              : constexpr T norm(const c10::complex<T>& z) {
     574              :   return z.real() * z.real() + z.imag() * z.imag();
     575              : }
     576              : 
     577              : // For std::conj, there are other versions of it:
     578              : //   constexpr std::complex<float> conj( float z );
     579              : //   template< class DoubleOrInteger >
     580              : //   constexpr std::complex<double> conj( DoubleOrInteger z );
     581              : //   constexpr std::complex<long double> conj( long double z );
     582              : // These are not implemented
     583              : // TODO(@zasdfgbnm): implement them as c10::conj
     584              : template <typename T>
     585              : constexpr c10::complex<T> conj(const c10::complex<T>& z) {
     586              :   return c10::complex<T>(z.real(), -z.imag());
     587              : }
     588              : 
     589              : // Thrust does not have complex --> complex version of thrust::proj,
     590              : // so this function is not implemented at c10 right now.
     591              : // TODO(@zasdfgbnm): implement it by ourselves
     592              : 
     593              : // There is no c10 version of std::polar, because std::polar always
     594              : // returns std::complex. Use c10::polar instead;
     595              : 
     596              : } // namespace std
     597              : 
     598              : namespace c10 {
     599              : 
     600              : template <typename T>
     601              : C10_HOST_DEVICE complex<T> polar(const T& r, const T& theta = T()) {
     602              : #if defined(__CUDACC__) || defined(__HIPCC__)
     603              :   return static_cast<complex<T>>(thrust::polar(r, theta));
     604              : #else
     605              :   // std::polar() requires r >= 0, so spell out the explicit implementation to
     606              :   // avoid a branch.
     607              :   return complex<T>(r * std::cos(theta), r * std::sin(theta));
     608              : #endif
     609              : }
     610              : 
     611              : } // namespace c10
     612              : 
     613              : C10_CLANG_DIAGNOSTIC_POP()
     614              : 
     615              : #define C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H
     616              : // math functions are included in a separate file
     617              : #include <c10/util/complex_math.h> // IWYU pragma: keep
     618              : // utilities for complex types
     619              : #include <c10/util/complex_utils.h> // IWYU pragma: keep
     620              : #undef C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H
        

Generated by: LCOV version 2.0-1