#ifndef CUDATOOLS_COMPLEX_H #define CUDATOOLS_COMPLEX_H #include "Macros.h" #include #include /** * This is directly adapated from cuComplex.h, except placed into a C++ friendly format. */ namespace CudaTools { namespace Types { using real32 = float; /**< Type alias for 32-bit floating point datatype. */ using real64 = double; /**< Type alias for 64-bit floating point datatype. */ #ifdef CUDACC using real16 = __half; /**< Type alias for 16-bit floating point datatype, when using GPU. Otherwise, defaults to float. */ using realb16 = __nv_bfloat16; /**< Type alias for the 16-bit bfloat datatype, when using GPU. Otherwise, defaults to float. */ #else using real16 = float; using realb16 = float; #endif // CUDACC template class complex { private: T r = 0; T i = 0; public: HD complex() = default; HD complex(T real, T imag) : r(real), i(imag){}; HD complex(T x) : r(x), i(0){}; HD complex operator+(const complex z) const { return complex(r + z.r, i + z.i); }; HD complex operator-(const complex z) const { return complex(r - z.r, i - z.i); }; HD complex operator*(const T y) const { return complex(r * y, i * y); }; HD complex operator/(const T y) const { return complex(r / y, i / y); }; HD complex operator*(const complex z) const { return complex(r * z.r - i * z.i, r * z.i + i * z.r); }; HD complex operator/(const complex z) const { T s = std::abs(z.r) + std::abs(z.i); T oos = 1.0f / s; T ars = r * oos, ais = i * oos, brs = z.r * oos, bis = z.i * oos; s = (brs * brs) + (bis * bis); oos = 1.0f / s; return complex(ars * brs + ais * bis, ais * brs - ars * bis) * oos; }; HD void operator+=(const complex z) { r += z.r; i += z.i; }; HD void operator-=(const complex z) { r -= z.r; i -= z.i; }; HD void operator*=(const T y) { r *= y; i *= y; }; HD void operator/=(const T y) { r /= y; i /= y; }; HD void operator*=(const complex z) { T a = r * z.r - i * z.i, b = r * z.i + i * z.r; r = a; i = b; } HD void operator/=(const complex z) { T s = std::abs(z.r) + std::abs(z.i); T oos = 1.0f / s; T ars = r * oos, ais = i * oos, brs = z.r * oos, bis = z.i * oos; s = (brs * brs) + (bis * bis); oos = 1.0f / s; r = (ars * brs + ais * bis) * oos; i = (ais * brs - ars * bis) * oos; }; HD T abs() const { T a = std::abs(r), b = std::abs(i); T v, w; if (a > b) { v = a; w = b; } else { v = b; w = a; } T t = w / v; t = 1.0f + t * t; t = v * std::sqrt(t); if ((v == 0.0f) || (v > 3.402823466e38f) || (w > 3.402823466e38f)) { t = v + w; } return t; } HD complex conj() const { return complex(r, -1 * i); } HD T real() const { return r; }; HD T imag() const { return i; }; }; template class complex; template class complex; template complex operator*(const T y, const complex z) { return z * y; }; template complex operator/(const T y, const complex z) { return z / y; }; template complex operator*(const real32, const complex); template complex operator*(const real64, const complex); template complex operator/(const real32, const complex); template complex operator/(const real64, const complex); #ifdef CUDACC using complex64 = complex; /**< Type alias for 64-bit complex floating point datatype. * This adapts depending on the CUDA compilation flag, and * will automatically switch std::complex. */ using complex128 = complex; /**< Type alias for 128-bit complex floating point datatype. * This adapts depending on the CUDA compilation flag, and will * automatically switch std::complex. */ #else using complex64 = std::complex; using complex128 = std::complex; #endif /** Type alises and lots of metaprogramming definitions, primarily dealing with * the different numeric types and overrides. */ template struct ComplexUnderlying_S { typedef T type; }; template <> struct ComplexUnderlying_S { typedef float type; }; template <> struct ComplexUnderlying_S { typedef double type; }; template using ComplexUnderlying = typename ComplexUnderlying_S::type; template struct ComplexConversion_S { typedef T type; }; template <> struct ComplexConversion_S { typedef std::complex type; }; template <> struct ComplexConversion_S { typedef std::complex type; }; template using ComplexConversion = typename ComplexConversion_S::type; template inline constexpr bool is_int = std::is_integral::value; template inline constexpr bool is_float = std::is_floating_point::value; template inline constexpr bool is_complex = std::is_same::value or std::is_same::value; template inline constexpr bool is_host_num = is_int or is_float or is_complex; }; // namespace Types }; // namespace CudaTools #endif