You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
166 lines
5.6 KiB
166 lines
5.6 KiB
#ifndef CUDATOOLS_COMPLEX_H
|
|
#define CUDATOOLS_COMPLEX_H
|
|
|
|
#include "Macros.h"
|
|
#include <cmath>
|
|
#include <complex>
|
|
|
|
/**
|
|
* This is directly adapated from cuComplex.h, except placed into a C++ friendly format.
|
|
*/
|
|
|
|
namespace CudaTools {
|
|
|
|
namespace Types {
|
|
|
|
using real32 = float; /**< Type alias for 32-bit floating point datatype. */
|
|
using real64 = double; /**< Type alias for 64-bit floating point datatype. */
|
|
|
|
#ifdef CUDACC
|
|
|
|
using real16 = __half; /**< Type alias for 16-bit floating point datatype, when using GPU.
|
|
Otherwise, defaults to float. */
|
|
using realb16 = __nv_bfloat16; /**< Type alias for the 16-bit bfloat datatype, when using GPU.
|
|
Otherwise, defaults to float. */
|
|
|
|
#else
|
|
|
|
using real16 = float;
|
|
using realb16 = float;
|
|
|
|
#endif // CUDACC
|
|
|
|
template <typename T> class complex {
|
|
private:
|
|
T r = 0;
|
|
T i = 0;
|
|
|
|
public:
|
|
HD complex() = default;
|
|
HD complex(T real, T imag) : r(real), i(imag){};
|
|
HD complex(T x) : r(x), i(0){};
|
|
|
|
HD complex<T> operator+(const complex<T> z) const { return complex(r + z.r, i + z.i); };
|
|
HD complex<T> operator-(const complex<T> z) const { return complex(r - z.r, i - z.i); };
|
|
HD complex<T> operator*(const T y) const { return complex(r * y, i * y); };
|
|
HD complex<T> operator/(const T y) const { return complex(r / y, i / y); };
|
|
|
|
HD complex<T> operator*(const complex<T> z) const {
|
|
return complex(r * z.r - i * z.i, r * z.i + i * z.r);
|
|
};
|
|
HD complex<T> operator/(const complex<T> z) const {
|
|
T s = std::abs(z.r) + std::abs(z.i);
|
|
T oos = 1.0f / s;
|
|
T ars = r * oos, ais = i * oos, brs = z.r * oos, bis = z.i * oos;
|
|
s = (brs * brs) + (bis * bis);
|
|
oos = 1.0f / s;
|
|
return complex(ars * brs + ais * bis, ais * brs - ars * bis) * oos;
|
|
};
|
|
|
|
HD void operator+=(const complex<T> z) {
|
|
r += z.r;
|
|
i += z.i;
|
|
};
|
|
HD void operator-=(const complex<T> z) {
|
|
r -= z.r;
|
|
i -= z.i;
|
|
};
|
|
HD void operator*=(const T y) {
|
|
r *= y;
|
|
i *= y;
|
|
};
|
|
HD void operator/=(const T y) {
|
|
r /= y;
|
|
i /= y;
|
|
};
|
|
|
|
HD void operator*=(const complex<T> z) {
|
|
T a = r * z.r - i * z.i, b = r * z.i + i * z.r;
|
|
r = a;
|
|
i = b;
|
|
}
|
|
|
|
HD void operator/=(const complex<T> z) {
|
|
T s = std::abs(z.r) + std::abs(z.i);
|
|
T oos = 1.0f / s;
|
|
T ars = r * oos, ais = i * oos, brs = z.r * oos, bis = z.i * oos;
|
|
s = (brs * brs) + (bis * bis);
|
|
oos = 1.0f / s;
|
|
r = (ars * brs + ais * bis) * oos;
|
|
i = (ais * brs - ars * bis) * oos;
|
|
};
|
|
|
|
HD T abs() const {
|
|
T a = std::abs(r), b = std::abs(i);
|
|
T v, w;
|
|
if (a > b) {
|
|
v = a;
|
|
w = b;
|
|
} else {
|
|
v = b;
|
|
w = a;
|
|
}
|
|
T t = w / v;
|
|
t = 1.0f + t * t;
|
|
t = v * std::sqrt(t);
|
|
if ((v == 0.0f) || (v > 3.402823466e38f) || (w > 3.402823466e38f)) {
|
|
t = v + w;
|
|
}
|
|
return t;
|
|
}
|
|
|
|
HD complex<T> conj() const { return complex(r, -1 * i); }
|
|
|
|
HD T real() const { return r; };
|
|
HD T imag() const { return i; };
|
|
};
|
|
|
|
template class complex<real32>;
|
|
template class complex<real64>;
|
|
|
|
template <class T> complex<T> operator*(const T y, const complex<T> z) { return z * y; };
|
|
template <class T> complex<T> operator/(const T y, const complex<T> z) { return z / y; };
|
|
|
|
template complex<real32> operator*<real32>(const real32, const complex<real32>);
|
|
template complex<real64> operator*<real64>(const real64, const complex<real64>);
|
|
template complex<real32> operator/<real32>(const real32, const complex<real32>);
|
|
template complex<real64> operator/<real64>(const real64, const complex<real64>);
|
|
|
|
#ifdef CUDACC
|
|
using complex64 = complex<real32>; /**< Type alias for 64-bit complex floating point datatype.
|
|
* This adapts depending on the CUDA compilation flag, and
|
|
* will automatically switch std::complex<real32>. */
|
|
|
|
using complex128 = complex<real64>; /**< Type alias for 128-bit complex floating point datatype.
|
|
* This adapts depending on the CUDA compilation flag, and will
|
|
* automatically switch std::complex<real64>. */
|
|
|
|
#else
|
|
using complex64 = std::complex<real32>;
|
|
using complex128 = std::complex<real64>;
|
|
#endif
|
|
|
|
/** Type alises and lots of metaprogramming definitions, primarily dealing with
|
|
* the different numeric types and overrides. */
|
|
|
|
template <typename T> struct ComplexUnderlying_S { typedef T type; };
|
|
template <> struct ComplexUnderlying_S<complex64> { typedef float type; };
|
|
template <> struct ComplexUnderlying_S<complex128> { typedef double type; };
|
|
template <typename T> using ComplexUnderlying = typename ComplexUnderlying_S<T>::type;
|
|
|
|
template <typename T> struct ComplexConversion_S { typedef T type; };
|
|
template <> struct ComplexConversion_S<complex64> { typedef std::complex<float> type; };
|
|
template <> struct ComplexConversion_S<complex128> { typedef std::complex<double> type; };
|
|
template <typename T> using ComplexConversion = typename ComplexConversion_S<T>::type;
|
|
|
|
template <typename T> inline constexpr bool is_int = std::is_integral<T>::value;
|
|
template <typename T> inline constexpr bool is_float = std::is_floating_point<T>::value;
|
|
template <typename T>
|
|
inline constexpr bool is_complex =
|
|
std::is_same<T, complex64>::value or std::is_same<T, complex128>::value;
|
|
template <typename T> inline constexpr bool is_host_num = is_int<T> or is_float<T> or is_complex<T>;
|
|
|
|
}; // namespace Types
|
|
}; // namespace CudaTools
|
|
|
|
#endif
|
|
|