#ifndef CUDATOOLS_COMPLEX_H #define CUDATOOLS_COMPLEX_H #include "Macros.h" #include #include /** * This is directly adapated from cuComplex.h, except placed into a C++ friendly format. */ namespace CudaTools { template class complex { private: T r = 0; T i = 0; public: HD complex() = default; HD complex(T real, T imag) : r(real), i(imag){}; HD complex(T x) : r(x), i(0){}; HD complex operator+(const complex z) const { return complex(r + z.r, i + z.i); }; HD complex operator-(const complex z) const { return complex(r - z.r, i - z.i); }; HD complex operator*(const T y) const { return complex(r * y, i * y); }; HD complex operator/(const T y) const { return complex(r / y, i / y); }; HD complex operator*(const complex z) const { return complex(r * z.r - i * z.i, r * z.i + i * z.r); }; HD complex operator/(const complex z) const { T s = std::abs(z.r) + std::abs(z.i); T oos = 1.0f / s; T ars = r * oos, ais = i * oos, brs = z.r * oos, bis = z.i * oos; s = (brs * brs) + (bis * bis); oos = 1.0f / s; return complex(ars * brs + ais * bis, ais * brs - ars * bis) * oos; }; HD void operator+=(const complex z) { r += z.r; i += z.i; }; HD void operator-=(const complex z) { r -= z.r; i -= z.i; }; HD void operator*=(const T y) { r *= y; i *= y; }; HD void operator/=(const T y) { r /= y; i /= y; }; HD void operator*=(const complex z) { T a = r * z.r - i * z.i, b = r * z.i + i * z.r; r = a; i = b; } HD void operator/=(const complex z) { T s = std::abs(z.r) + std::abs(z.i); T oos = 1.0f / s; T ars = r * oos, ais = i * oos, brs = z.r * oos, bis = z.i * oos; s = (brs * brs) + (bis * bis); oos = 1.0f / s; r = (ars * brs + ais * bis) * oos; i = (ais * brs - ars * bis) * oos; }; HD T abs() const { T a = std::abs(r), b = std::abs(i); T v, w; if (a > b) { v = a; w = b; } else { v = b; w = a; } T t = w / v; t = 1.0f + t * t; t = v * std::sqrt(t); if ((v == 0.0f) || (v > 3.402823466e38f) || (w > 3.402823466e38f)) { t = v + w; } return t; } HD complex conj() const { return complex(r, -1 * i); } HD T real() const { return r; }; HD T imag() const { return i; }; }; template class complex; template class complex; template complex operator*(const T y, const complex z) { return z * y; }; template complex operator/(const T y, const complex z) { return z / y; }; template complex operator*(const real32, const complex); template complex operator*(const real64, const complex); template complex operator/(const real32, const complex); template complex operator/(const real64, const complex); }; // namespace CudaTools #ifdef CUDA using complex64 = CudaTools::complex; using complex128 = CudaTools::complex; #else using complex64 = std::complex; /**< Type alias for 64-bit complex floating point datatype. * This adapts depending on the CUDA compilation flag, and * will automatically switch CudaTools::complex. */ using complex128 = std::complex; /**< Type alias for 128-bit complex floating point datatype. This adapts * depending on the CUDA compilation flag, and will automatically switch * CudaTools::complex. */ #endif #endif