A library and framework for developing CPU-CUDA compatible applications under one unified code.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

166 lines
5.6 KiB

#ifndef CUDATOOLS_COMPLEX_H
#define CUDATOOLS_COMPLEX_H
#include "Macros.h"
#include <cmath>
#include <complex>
/**
* This is directly adapated from cuComplex.h, except placed into a C++ friendly format.
*/
namespace CudaTools {
namespace Types {
using real32 = float; /**< Type alias for 32-bit floating point datatype. */
using real64 = double; /**< Type alias for 64-bit floating point datatype. */
#ifdef CUDACC
using real16 = __half; /**< Type alias for 16-bit floating point datatype, when using GPU.
Otherwise, defaults to float. */
using realb16 = __nv_bfloat16; /**< Type alias for the 16-bit bfloat datatype, when using GPU.
Otherwise, defaults to float. */
#else
using real16 = float;
using realb16 = float;
#endif // CUDACC
template <typename T> class complex {
private:
T r = 0;
T i = 0;
public:
HD complex() = default;
HD complex(T real, T imag) : r(real), i(imag){};
HD complex(T x) : r(x), i(0){};
HD complex<T> operator+(const complex<T> z) const { return complex(r + z.r, i + z.i); };
HD complex<T> operator-(const complex<T> z) const { return complex(r - z.r, i - z.i); };
HD complex<T> operator*(const T y) const { return complex(r * y, i * y); };
HD complex<T> operator/(const T y) const { return complex(r / y, i / y); };
HD complex<T> operator*(const complex<T> z) const {
return complex(r * z.r - i * z.i, r * z.i + i * z.r);
};
HD complex<T> operator/(const complex<T> z) const {
T s = std::abs(z.r) + std::abs(z.i);
T oos = 1.0f / s;
T ars = r * oos, ais = i * oos, brs = z.r * oos, bis = z.i * oos;
s = (brs * brs) + (bis * bis);
oos = 1.0f / s;
return complex(ars * brs + ais * bis, ais * brs - ars * bis) * oos;
};
HD void operator+=(const complex<T> z) {
r += z.r;
i += z.i;
};
HD void operator-=(const complex<T> z) {
r -= z.r;
i -= z.i;
};
HD void operator*=(const T y) {
r *= y;
i *= y;
};
HD void operator/=(const T y) {
r /= y;
i /= y;
};
HD void operator*=(const complex<T> z) {
T a = r * z.r - i * z.i, b = r * z.i + i * z.r;
r = a;
i = b;
}
HD void operator/=(const complex<T> z) {
T s = std::abs(z.r) + std::abs(z.i);
T oos = 1.0f / s;
T ars = r * oos, ais = i * oos, brs = z.r * oos, bis = z.i * oos;
s = (brs * brs) + (bis * bis);
oos = 1.0f / s;
r = (ars * brs + ais * bis) * oos;
i = (ais * brs - ars * bis) * oos;
};
HD T abs() const {
T a = std::abs(r), b = std::abs(i);
T v, w;
if (a > b) {
v = a;
w = b;
} else {
v = b;
w = a;
}
T t = w / v;
t = 1.0f + t * t;
t = v * std::sqrt(t);
if ((v == 0.0f) || (v > 3.402823466e38f) || (w > 3.402823466e38f)) {
t = v + w;
}
return t;
}
HD complex<T> conj() const { return complex(r, -1 * i); }
HD T real() const { return r; };
HD T imag() const { return i; };
};
template class complex<real32>;
template class complex<real64>;
template <class T> complex<T> operator*(const T y, const complex<T> z) { return z * y; };
template <class T> complex<T> operator/(const T y, const complex<T> z) { return z / y; };
template complex<real32> operator*<real32>(const real32, const complex<real32>);
template complex<real64> operator*<real64>(const real64, const complex<real64>);
template complex<real32> operator/<real32>(const real32, const complex<real32>);
template complex<real64> operator/<real64>(const real64, const complex<real64>);
#ifdef CUDACC
using complex64 = complex<real32>; /**< Type alias for 64-bit complex floating point datatype.
* This adapts depending on the CUDA compilation flag, and
* will automatically switch std::complex<real32>. */
using complex128 = complex<real64>; /**< Type alias for 128-bit complex floating point datatype.
* This adapts depending on the CUDA compilation flag, and will
* automatically switch std::complex<real64>. */
#else
using complex64 = std::complex<real32>;
using complex128 = std::complex<real64>;
#endif
/** Type alises and lots of metaprogramming definitions, primarily dealing with
* the different numeric types and overrides. */
template <typename T> struct ComplexUnderlying_S { typedef T type; };
template <> struct ComplexUnderlying_S<complex64> { typedef float type; };
template <> struct ComplexUnderlying_S<complex128> { typedef double type; };
template <typename T> using ComplexUnderlying = typename ComplexUnderlying_S<T>::type;
template <typename T> struct ComplexConversion_S { typedef T type; };
template <> struct ComplexConversion_S<complex64> { typedef std::complex<float> type; };
template <> struct ComplexConversion_S<complex128> { typedef std::complex<double> type; };
template <typename T> using ComplexConversion = typename ComplexConversion_S<T>::type;
template <typename T> inline constexpr bool is_int = std::is_integral<T>::value;
template <typename T> inline constexpr bool is_float = std::is_floating_point<T>::value;
template <typename T>
inline constexpr bool is_complex =
std::is_same<T, complex64>::value or std::is_same<T, complex128>::value;
template <typename T> inline constexpr bool is_host_num = is_int<T> or is_float<T> or is_complex<T>;
}; // namespace Types
}; // namespace CudaTools
#endif