You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
126 lines
3.8 KiB
126 lines
3.8 KiB
2 years ago
|
#ifndef CUDATOOLS_COMPLEX_H
|
||
|
#define CUDATOOLS_COMPLEX_H
|
||
|
|
||
|
#include "Macros.h"
|
||
|
#include <cmath>
|
||
|
#include <complex>
|
||
|
|
||
|
/**
|
||
|
* This is directly adapated from cuComplex.h, except placed into a C++ friendly format.
|
||
|
*/
|
||
|
|
||
|
namespace CudaTools {
|
||
|
|
||
|
template <typename T> class complex {
|
||
|
private:
|
||
|
T r = 0;
|
||
|
T i = 0;
|
||
|
|
||
|
public:
|
||
|
HD complex() = default;
|
||
|
HD complex(T real, T imag) : r(real), i(imag){};
|
||
|
HD complex(T x) : r(x), i(0){};
|
||
|
|
||
|
HD complex<T> operator+(const complex<T> z) const { return complex(r + z.r, i + z.i); };
|
||
|
HD complex<T> operator-(const complex<T> z) const { return complex(r - z.r, i - z.i); };
|
||
|
HD complex<T> operator*(const T y) const { return complex(r * y, i * y); };
|
||
|
HD complex<T> operator/(const T y) const { return complex(r / y, i / y); };
|
||
|
|
||
|
HD complex<T> operator*(const complex<T> z) const {
|
||
|
return complex(r * z.r - i * z.i, r * z.i + i * z.r);
|
||
|
};
|
||
|
HD complex<T> operator/(const complex<T> z) const {
|
||
|
T s = std::abs(z.r) + std::abs(z.i);
|
||
|
T oos = 1.0f / s;
|
||
|
T ars = r * oos, ais = i * oos, brs = z.r * oos, bis = z.i * oos;
|
||
|
s = (brs * brs) + (bis * bis);
|
||
|
oos = 1.0f / s;
|
||
|
return complex(ars * brs + ais * bis, ais * brs - ars * bis) * oos;
|
||
|
};
|
||
|
|
||
|
HD void operator+=(const complex<T> z) {
|
||
|
r += z.r;
|
||
|
i += z.i;
|
||
|
};
|
||
|
HD void operator-=(const complex<T> z) {
|
||
|
r -= z.r;
|
||
|
i -= z.i;
|
||
|
};
|
||
|
HD void operator*=(const T y) {
|
||
|
r *= y;
|
||
|
i *= y;
|
||
|
};
|
||
|
HD void operator/=(const T y) {
|
||
|
r /= y;
|
||
|
i /= y;
|
||
|
};
|
||
|
|
||
|
HD void operator*=(const complex<T> z) {
|
||
|
T a = r * z.r - i * z.i, b = r * z.i + i * z.r;
|
||
|
r = a;
|
||
|
i = b;
|
||
|
}
|
||
|
|
||
|
HD void operator/=(const complex<T> z) {
|
||
|
T s = std::abs(z.r) + std::abs(z.i);
|
||
|
T oos = 1.0f / s;
|
||
|
T ars = r * oos, ais = i * oos, brs = z.r * oos, bis = z.i * oos;
|
||
|
s = (brs * brs) + (bis * bis);
|
||
|
oos = 1.0f / s;
|
||
|
r = (ars * brs + ais * bis) * oos;
|
||
|
i = (ais * brs - ars * bis) * oos;
|
||
|
};
|
||
|
|
||
|
HD T abs() const {
|
||
|
T a = std::abs(r), b = std::abs(i);
|
||
|
T v, w;
|
||
|
if (a > b) {
|
||
|
v = a;
|
||
|
w = b;
|
||
|
} else {
|
||
|
v = b;
|
||
|
w = a;
|
||
|
}
|
||
|
T t = w / v;
|
||
|
t = 1.0f + t * t;
|
||
|
t = v * std::sqrt(t);
|
||
|
if ((v == 0.0f) || (v > 3.402823466e38f) || (w > 3.402823466e38f)) {
|
||
|
t = v + w;
|
||
|
}
|
||
|
return t;
|
||
|
}
|
||
|
|
||
|
HD complex<T> conj() const { return complex(r, -1 * i); }
|
||
|
|
||
|
HD T real() const { return r; };
|
||
|
HD T imag() const { return i; };
|
||
|
};
|
||
|
|
||
|
template class complex<real32>;
|
||
|
template class complex<real64>;
|
||
|
|
||
|
template <class T> complex<T> operator*(const T y, const complex<T> z) { return z * y; };
|
||
|
template <class T> complex<T> operator/(const T y, const complex<T> z) { return z / y; };
|
||
|
|
||
|
template complex<real32> operator*<real32>(const real32, const complex<real32>);
|
||
|
template complex<real64> operator*<real64>(const real64, const complex<real64>);
|
||
|
template complex<real32> operator/<real32>(const real32, const complex<real32>);
|
||
|
template complex<real64> operator/<real64>(const real64, const complex<real64>);
|
||
|
|
||
|
}; // namespace CudaTools
|
||
|
|
||
|
#ifdef CUDA
|
||
|
using complex64 = CudaTools::complex<real32>;
|
||
|
using complex128 = CudaTools::complex<real64>;
|
||
|
#else
|
||
|
using complex64 = std::complex<real32>; /**< Type alias for 64-bit complex floating point datatype.
|
||
|
* This adapts depending on the CUDA compilation flag, and
|
||
|
* will automatically switch CudaTools::complex<real32>. */
|
||
|
using complex128 =
|
||
|
std::complex<real64>; /**< Type alias for 128-bit complex floating point datatype. This adapts
|
||
|
* depending on the CUDA compilation flag, and will automatically switch
|
||
|
* CudaTools::complex<real64>. */
|
||
|
#endif
|
||
|
|
||
|
#endif
|