Line data Source code
1 : #pragma once
2 :
3 : #include <complex>
4 :
5 : #include <c10/macros/Macros.h>
6 :
7 : #if defined(__CUDACC__) || defined(__HIPCC__)
8 : #include <thrust/complex.h>
9 : #endif
10 :
11 : C10_CLANG_DIAGNOSTIC_PUSH()
12 : #if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion")
13 : C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion")
14 : #endif
15 : #if C10_CLANG_HAS_WARNING("-Wfloat-conversion")
16 : C10_CLANG_DIAGNOSTIC_IGNORE("-Wfloat-conversion")
17 : #endif
18 :
19 : namespace c10 {
20 :
21 : // c10::complex is an implementation of complex numbers that aims
22 : // to work on all devices supported by PyTorch
23 : //
24 : // Most of the APIs duplicates std::complex
25 : // Reference: https://en.cppreference.com/w/cpp/numeric/complex
26 : //
27 : // [NOTE: Complex Operator Unification]
28 : // Operators currently use a mix of std::complex, thrust::complex, and
29 : // c10::complex internally. The end state is that all operators will use
30 : // c10::complex internally. Until then, there may be some hacks to support all
31 : // variants.
32 : //
33 : //
34 : // [Note on Constructors]
35 : //
36 : // The APIs of constructors are mostly copied from C++ standard:
37 : // https://en.cppreference.com/w/cpp/numeric/complex/complex
38 : //
39 : // Since C++14, all constructors are constexpr in std::complex
40 : //
41 : // There are three types of constructors:
42 : // - initializing from real and imag:
43 : // `constexpr complex( const T& re = T(), const T& im = T() );`
44 : // - implicitly-declared copy constructor
45 : // - converting constructors
46 : //
47 : // Converting constructors:
48 : // - std::complex defines converting constructor between float/double/long
49 : // double,
50 : // while we define converting constructor between float/double.
51 : // - For these converting constructors, upcasting is implicit, downcasting is
52 : // explicit.
53 : // - We also define explicit casting from std::complex/thrust::complex
54 : // - Note that the conversion from thrust is not constexpr, because
55 : // thrust does not define them as constexpr ????
56 : //
57 : //
58 : // [Operator =]
59 : //
60 : // The APIs of operator = are mostly copied from C++ standard:
61 : // https://en.cppreference.com/w/cpp/numeric/complex/operator%3D
62 : //
63 : // Since C++20, all operator= are constexpr. Although we are not building with
64 : // C++20, we also obey this behavior.
65 : //
66 : // There are three types of assign operator:
67 : // - Assign a real value from the same scalar type
68 : // - In std, this is templated as complex& operator=(const T& x)
69 : // with specialization `complex& operator=(T x)` for float/double/long
70 : // double Since we only support float and double, on will use `complex&
71 : // operator=(T x)`
72 : // - Copy assignment operator and converting assignment operator
73 : // - There is no specialization of converting assignment operators, which type
74 : // is
75 : // convertible is solely dependent on whether the scalar type is convertible
76 : //
77 : // In addition to the standard assignment, we also provide assignment operators
78 : // with std and thrust
79 : //
80 : //
81 : // [Casting operators]
82 : //
83 : // std::complex does not have casting operators. We define casting operators
84 : // casting to std::complex and thrust::complex
85 : //
86 : //
87 : // [Operator ""]
88 : //
89 : // std::complex has custom literals `i`, `if` and `il` defined in namespace
90 : // `std::literals::complex_literals`. We define our own custom literals in the
91 : // namespace `c10::complex_literals`. Our custom literals does not follow the
92 : // same behavior as in std::complex, instead, we define _if, _id to construct
93 : // float/double complex literals.
94 : //
95 : //
96 : // [real() and imag()]
97 : //
98 : // In C++20, there are two overload of these functions, one it to return the
99 : // real/imag, another is to set real/imag, they are both constexpr. We follow
100 : // this design.
101 : //
102 : //
103 : // [Operator +=,-=,*=,/=]
104 : //
105 : // Since C++20, these operators become constexpr. In our implementation, they
106 : // are also constexpr.
107 : //
108 : // There are two types of such operators: operating with a real number, or
109 : // operating with another complex number. For the operating with a real number,
110 : // the generic template form has argument type `const T &`, while the overload
111 : // for float/double/long double has `T`. We will follow the same type as
112 : // float/double/long double in std.
113 : //
114 : // [Unary operator +-]
115 : //
116 : // Since C++20, they are constexpr. We also make them expr
117 : //
118 : // [Binary operators +-*/]
119 : //
120 : // Each operator has three versions (taking + as example):
121 : // - complex + complex
122 : // - complex + real
123 : // - real + complex
124 : //
125 : // [Operator ==, !=]
126 : //
127 : // Each operator has three versions (taking == as example):
128 : // - complex == complex
129 : // - complex == real
130 : // - real == complex
131 : //
132 : // Some of them are removed on C++20, but we decide to keep them
133 : //
134 : // [Operator <<, >>]
135 : //
136 : // These are implemented by casting to std::complex
137 : //
138 : //
139 : //
140 : // TODO(@zasdfgbnm): c10::complex<c10::Half> is not currently supported,
141 : // because:
142 : // - lots of members and functions of c10::Half are not constexpr
143 : // - thrust::complex only support float and double
144 :
145 : template <typename T>
146 : struct alignas(sizeof(T) * 2) complex {
147 : using value_type = T;
148 :
149 : T real_ = T(0);
150 : T imag_ = T(0);
151 :
152 : constexpr complex() = default;
153 : C10_HOST_DEVICE constexpr complex(const T& re, const T& im = T())
154 : : real_(re), imag_(im) {}
155 : template <typename U>
156 : explicit constexpr complex(const std::complex<U>& other)
157 : : complex(other.real(), other.imag()) {}
158 : #if defined(__CUDACC__) || defined(__HIPCC__)
159 : template <typename U>
160 : explicit C10_HOST_DEVICE complex(const thrust::complex<U>& other)
161 : : real_(other.real()), imag_(other.imag()) {}
162 : // NOTE can not be implemented as follow due to ROCm bug:
163 : // explicit C10_HOST_DEVICE complex(const thrust::complex<U> &other):
164 : // complex(other.real(), other.imag()) {}
165 : #endif
166 :
167 : // Use SFINAE to specialize casting constructor for c10::complex<float> and
168 : // c10::complex<double>
169 : template <typename U = T>
170 : C10_HOST_DEVICE explicit constexpr complex(
171 : const std::enable_if_t<std::is_same<U, float>::value, complex<double>>&
172 : other)
173 : : real_(other.real_), imag_(other.imag_) {}
174 : template <typename U = T>
175 : C10_HOST_DEVICE constexpr complex(
176 : const std::enable_if_t<std::is_same<U, double>::value, complex<float>>&
177 : other)
178 : : real_(other.real_), imag_(other.imag_) {}
179 :
180 : constexpr complex<T>& operator=(T re) {
181 : real_ = re;
182 : imag_ = 0;
183 : return *this;
184 : }
185 :
186 : constexpr complex<T>& operator+=(T re) {
187 : real_ += re;
188 : return *this;
189 : }
190 :
191 : constexpr complex<T>& operator-=(T re) {
192 : real_ -= re;
193 : return *this;
194 : }
195 :
196 : constexpr complex<T>& operator*=(T re) {
197 : real_ *= re;
198 : imag_ *= re;
199 : return *this;
200 : }
201 :
202 : constexpr complex<T>& operator/=(T re) {
203 : real_ /= re;
204 : imag_ /= re;
205 : return *this;
206 : }
207 :
208 : template <typename U>
209 : constexpr complex<T>& operator=(const complex<U>& rhs) {
210 : real_ = rhs.real();
211 : imag_ = rhs.imag();
212 : return *this;
213 : }
214 :
215 : template <typename U>
216 : constexpr complex<T>& operator+=(const complex<U>& rhs) {
217 : real_ += rhs.real();
218 : imag_ += rhs.imag();
219 : return *this;
220 : }
221 :
222 : template <typename U>
223 : constexpr complex<T>& operator-=(const complex<U>& rhs) {
224 : real_ -= rhs.real();
225 : imag_ -= rhs.imag();
226 : return *this;
227 : }
228 :
229 : template <typename U>
230 : constexpr complex<T>& operator*=(const complex<U>& rhs) {
231 : // (a + bi) * (c + di) = (a*c - b*d) + (a * d + b * c) i
232 : T a = real_;
233 : T b = imag_;
234 : U c = rhs.real();
235 : U d = rhs.imag();
236 : real_ = a * c - b * d;
237 : imag_ = a * d + b * c;
238 : return *this;
239 : }
240 :
241 : #ifdef __APPLE__
242 : #define FORCE_INLINE_APPLE __attribute__((always_inline))
243 : #else
244 : #define FORCE_INLINE_APPLE
245 : #endif
246 : template <typename U>
247 : constexpr FORCE_INLINE_APPLE complex<T>& operator/=(const complex<U>& rhs)
248 : __ubsan_ignore_float_divide_by_zero__ {
249 : // (a + bi) / (c + di) = (ac + bd)/(c^2 + d^2) + (bc - ad)/(c^2 + d^2) i
250 : // the calculation below follows numpy's complex division
251 : T a = real_;
252 : T b = imag_;
253 : U c = rhs.real();
254 : U d = rhs.imag();
255 :
256 : #if defined(__GNUC__) && !defined(__clang__)
257 : // std::abs is already constexpr by gcc
258 : auto abs_c = std::abs(c);
259 : auto abs_d = std::abs(d);
260 : #else
261 : auto abs_c = c < 0 ? -c : c;
262 : auto abs_d = d < 0 ? -d : d;
263 : #endif
264 :
265 : if (abs_c >= abs_d) {
266 : if (abs_c == 0 && abs_d == 0) {
267 : /* divide by zeros should yield a complex inf or nan */
268 : real_ = a / abs_c;
269 : imag_ = b / abs_d;
270 : } else {
271 : auto rat = d / c;
272 : auto scl = 1.0 / (c + d * rat);
273 : real_ = (a + b * rat) * scl;
274 : imag_ = (b - a * rat) * scl;
275 : }
276 : } else {
277 : auto rat = c / d;
278 : auto scl = 1.0 / (d + c * rat);
279 : real_ = (a * rat + b) * scl;
280 : imag_ = (b * rat - a) * scl;
281 : }
282 : return *this;
283 : }
284 : #undef FORCE_INLINE_APPLE
285 :
286 : template <typename U>
287 : constexpr complex<T>& operator=(const std::complex<U>& rhs) {
288 : real_ = rhs.real();
289 : imag_ = rhs.imag();
290 : return *this;
291 : }
292 :
293 : #if defined(__CUDACC__) || defined(__HIPCC__)
294 : template <typename U>
295 : C10_HOST_DEVICE complex<T>& operator=(const thrust::complex<U>& rhs) {
296 : real_ = rhs.real();
297 : imag_ = rhs.imag();
298 : return *this;
299 : }
300 : #endif
301 :
302 : template <typename U>
303 : explicit constexpr operator std::complex<U>() const {
304 : return std::complex<U>(std::complex<T>(real(), imag()));
305 : }
306 :
307 : #if defined(__CUDACC__) || defined(__HIPCC__)
308 : template <typename U>
309 : C10_HOST_DEVICE explicit operator thrust::complex<U>() const {
310 : return static_cast<thrust::complex<U>>(thrust::complex<T>(real(), imag()));
311 : }
312 : #endif
313 :
314 : // consistent with NumPy behavior
315 : explicit constexpr operator bool() const {
316 : return real() || imag();
317 : }
318 :
319 0 : C10_HOST_DEVICE constexpr T real() const {
320 0 : return real_;
321 : }
322 : constexpr void real(T value) {
323 : real_ = value;
324 : }
325 0 : constexpr T imag() const {
326 0 : return imag_;
327 : }
328 : constexpr void imag(T value) {
329 : imag_ = value;
330 : }
331 : };
332 :
333 : namespace complex_literals {
334 :
335 : constexpr complex<float> operator"" _if(long double imag) {
336 : return complex<float>(0.0f, static_cast<float>(imag));
337 : }
338 :
339 : constexpr complex<double> operator"" _id(long double imag) {
340 : return complex<double>(0.0, static_cast<double>(imag));
341 : }
342 :
343 : constexpr complex<float> operator"" _if(unsigned long long imag) {
344 : return complex<float>(0.0f, static_cast<float>(imag));
345 : }
346 :
347 : constexpr complex<double> operator"" _id(unsigned long long imag) {
348 : return complex<double>(0.0, static_cast<double>(imag));
349 : }
350 :
351 : } // namespace complex_literals
352 :
353 : template <typename T>
354 : constexpr complex<T> operator+(const complex<T>& val) {
355 : return val;
356 : }
357 :
358 : template <typename T>
359 : constexpr complex<T> operator-(const complex<T>& val) {
360 : return complex<T>(-val.real(), -val.imag());
361 : }
362 :
363 : template <typename T>
364 : constexpr complex<T> operator+(const complex<T>& lhs, const complex<T>& rhs) {
365 : complex<T> result = lhs;
366 : return result += rhs;
367 : }
368 :
369 : template <typename T>
370 : constexpr complex<T> operator+(const complex<T>& lhs, const T& rhs) {
371 : complex<T> result = lhs;
372 : return result += rhs;
373 : }
374 :
375 : template <typename T>
376 : constexpr complex<T> operator+(const T& lhs, const complex<T>& rhs) {
377 : return complex<T>(lhs + rhs.real(), rhs.imag());
378 : }
379 :
380 : template <typename T>
381 : constexpr complex<T> operator-(const complex<T>& lhs, const complex<T>& rhs) {
382 : complex<T> result = lhs;
383 : return result -= rhs;
384 : }
385 :
386 : template <typename T>
387 : constexpr complex<T> operator-(const complex<T>& lhs, const T& rhs) {
388 : complex<T> result = lhs;
389 : return result -= rhs;
390 : }
391 :
392 : template <typename T>
393 : constexpr complex<T> operator-(const T& lhs, const complex<T>& rhs) {
394 : complex<T> result = -rhs;
395 : return result += lhs;
396 : }
397 :
398 : template <typename T>
399 : constexpr complex<T> operator*(const complex<T>& lhs, const complex<T>& rhs) {
400 : complex<T> result = lhs;
401 : return result *= rhs;
402 : }
403 :
404 : template <typename T>
405 : constexpr complex<T> operator*(const complex<T>& lhs, const T& rhs) {
406 : complex<T> result = lhs;
407 : return result *= rhs;
408 : }
409 :
410 : template <typename T>
411 : constexpr complex<T> operator*(const T& lhs, const complex<T>& rhs) {
412 : complex<T> result = rhs;
413 : return result *= lhs;
414 : }
415 :
416 : template <typename T>
417 : constexpr complex<T> operator/(const complex<T>& lhs, const complex<T>& rhs) {
418 : complex<T> result = lhs;
419 : return result /= rhs;
420 : }
421 :
422 : template <typename T>
423 : constexpr complex<T> operator/(const complex<T>& lhs, const T& rhs) {
424 : complex<T> result = lhs;
425 : return result /= rhs;
426 : }
427 :
428 : template <typename T>
429 : constexpr complex<T> operator/(const T& lhs, const complex<T>& rhs) {
430 : complex<T> result(lhs, T());
431 : return result /= rhs;
432 : }
433 :
434 : // Define operators between integral scalars and c10::complex. std::complex does
435 : // not support this when T is a floating-point number. This is useful because it
436 : // saves a lot of "static_cast" when operate a complex and an integer. This
437 : // makes the code both less verbose and potentially more efficient.
438 : #define COMPLEX_INTEGER_OP_TEMPLATE_CONDITION \
439 : typename std::enable_if_t< \
440 : std::is_floating_point<fT>::value && std::is_integral<iT>::value, \
441 : int> = 0
442 :
443 : template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
444 : constexpr c10::complex<fT> operator+(const c10::complex<fT>& a, const iT& b) {
445 : return a + static_cast<fT>(b);
446 : }
447 :
448 : template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
449 : constexpr c10::complex<fT> operator+(const iT& a, const c10::complex<fT>& b) {
450 : return static_cast<fT>(a) + b;
451 : }
452 :
453 : template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
454 : constexpr c10::complex<fT> operator-(const c10::complex<fT>& a, const iT& b) {
455 : return a - static_cast<fT>(b);
456 : }
457 :
458 : template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
459 : constexpr c10::complex<fT> operator-(const iT& a, const c10::complex<fT>& b) {
460 : return static_cast<fT>(a) - b;
461 : }
462 :
463 : template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
464 : constexpr c10::complex<fT> operator*(const c10::complex<fT>& a, const iT& b) {
465 : return a * static_cast<fT>(b);
466 : }
467 :
468 : template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
469 : constexpr c10::complex<fT> operator*(const iT& a, const c10::complex<fT>& b) {
470 : return static_cast<fT>(a) * b;
471 : }
472 :
473 : template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
474 : constexpr c10::complex<fT> operator/(const c10::complex<fT>& a, const iT& b) {
475 : return a / static_cast<fT>(b);
476 : }
477 :
478 : template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
479 : constexpr c10::complex<fT> operator/(const iT& a, const c10::complex<fT>& b) {
480 : return static_cast<fT>(a) / b;
481 : }
482 :
483 : #undef COMPLEX_INTEGER_OP_TEMPLATE_CONDITION
484 :
485 : template <typename T>
486 : constexpr bool operator==(const complex<T>& lhs, const complex<T>& rhs) {
487 : return (lhs.real() == rhs.real()) && (lhs.imag() == rhs.imag());
488 : }
489 :
490 : template <typename T>
491 : constexpr bool operator==(const complex<T>& lhs, const T& rhs) {
492 : return (lhs.real() == rhs) && (lhs.imag() == T());
493 : }
494 :
495 : template <typename T>
496 : constexpr bool operator==(const T& lhs, const complex<T>& rhs) {
497 : return (lhs == rhs.real()) && (T() == rhs.imag());
498 : }
499 :
500 : template <typename T>
501 : constexpr bool operator!=(const complex<T>& lhs, const complex<T>& rhs) {
502 : return !(lhs == rhs);
503 : }
504 :
505 : template <typename T>
506 : constexpr bool operator!=(const complex<T>& lhs, const T& rhs) {
507 : return !(lhs == rhs);
508 : }
509 :
510 : template <typename T>
511 : constexpr bool operator!=(const T& lhs, const complex<T>& rhs) {
512 : return !(lhs == rhs);
513 : }
514 :
515 : template <typename T, typename CharT, typename Traits>
516 : std::basic_ostream<CharT, Traits>& operator<<(
517 : std::basic_ostream<CharT, Traits>& os,
518 : const complex<T>& x) {
519 : return (os << static_cast<std::complex<T>>(x));
520 : }
521 :
522 : template <typename T, typename CharT, typename Traits>
523 : std::basic_istream<CharT, Traits>& operator>>(
524 : std::basic_istream<CharT, Traits>& is,
525 : complex<T>& x) {
526 : std::complex<T> tmp;
527 : is >> tmp;
528 : x = tmp;
529 : return is;
530 : }
531 :
532 : } // namespace c10
533 :
534 : // std functions
535 : //
536 : // The implementation of these functions also follow the design of C++20
537 :
538 : namespace std {
539 :
540 : template <typename T>
541 : constexpr T real(const c10::complex<T>& z) {
542 : return z.real();
543 : }
544 :
545 : template <typename T>
546 : constexpr T imag(const c10::complex<T>& z) {
547 : return z.imag();
548 : }
549 :
550 : template <typename T>
551 : C10_HOST_DEVICE T abs(const c10::complex<T>& z) {
552 : #if defined(__CUDACC__) || defined(__HIPCC__)
553 : return thrust::abs(static_cast<thrust::complex<T>>(z));
554 : #else
555 : return std::abs(static_cast<std::complex<T>>(z));
556 : #endif
557 : }
558 :
559 : #if defined(USE_ROCM)
560 : #define ROCm_Bug(x)
561 : #else
562 : #define ROCm_Bug(x) x
563 : #endif
564 :
565 : template <typename T>
566 : C10_HOST_DEVICE T arg(const c10::complex<T>& z) {
567 : return ROCm_Bug(std)::atan2(std::imag(z), std::real(z));
568 : }
569 :
570 : #undef ROCm_Bug
571 :
572 : template <typename T>
573 : constexpr T norm(const c10::complex<T>& z) {
574 : return z.real() * z.real() + z.imag() * z.imag();
575 : }
576 :
577 : // For std::conj, there are other versions of it:
578 : // constexpr std::complex<float> conj( float z );
579 : // template< class DoubleOrInteger >
580 : // constexpr std::complex<double> conj( DoubleOrInteger z );
581 : // constexpr std::complex<long double> conj( long double z );
582 : // These are not implemented
583 : // TODO(@zasdfgbnm): implement them as c10::conj
584 : template <typename T>
585 : constexpr c10::complex<T> conj(const c10::complex<T>& z) {
586 : return c10::complex<T>(z.real(), -z.imag());
587 : }
588 :
589 : // Thrust does not have complex --> complex version of thrust::proj,
590 : // so this function is not implemented at c10 right now.
591 : // TODO(@zasdfgbnm): implement it by ourselves
592 :
593 : // There is no c10 version of std::polar, because std::polar always
594 : // returns std::complex. Use c10::polar instead;
595 :
596 : } // namespace std
597 :
598 : namespace c10 {
599 :
600 : template <typename T>
601 : C10_HOST_DEVICE complex<T> polar(const T& r, const T& theta = T()) {
602 : #if defined(__CUDACC__) || defined(__HIPCC__)
603 : return static_cast<complex<T>>(thrust::polar(r, theta));
604 : #else
605 : // std::polar() requires r >= 0, so spell out the explicit implementation to
606 : // avoid a branch.
607 : return complex<T>(r * std::cos(theta), r * std::sin(theta));
608 : #endif
609 : }
610 :
611 : } // namespace c10
612 :
613 : C10_CLANG_DIAGNOSTIC_POP()
614 :
615 : #define C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H
616 : // math functions are included in a separate file
617 : #include <c10/util/complex_math.h> // IWYU pragma: keep
618 : // utilities for complex types
619 : #include <c10/util/complex_utils.h> // IWYU pragma: keep
620 : #undef C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H
|