#ifndef CUDATOOLS_COMPLEX_H
#define CUDATOOLS_COMPLEX_H

#include "Macros.h"
#include <cmath>
#include <complex>

/**
 * This is directly adapated from cuComplex.h, except placed into a C++ friendly format.
 */

namespace CudaTools {

template <typename T> class complex {
  private:
    T r = 0;
    T i = 0;

  public:
    HD complex() = default;
    HD complex(T real, T imag) : r(real), i(imag){};
    HD complex(T x) : r(x), i(0){};

    HD complex<T> operator+(const complex<T> z) const { return complex(r + z.r, i + z.i); };
    HD complex<T> operator-(const complex<T> z) const { return complex(r - z.r, i - z.i); };
    HD complex<T> operator*(const T y) const { return complex(r * y, i * y); };
    HD complex<T> operator/(const T y) const { return complex(r / y, i / y); };

    HD complex<T> operator*(const complex<T> z) const {
        return complex(r * z.r - i * z.i, r * z.i + i * z.r);
    };
    HD complex<T> operator/(const complex<T> z) const {
        T s = std::abs(z.r) + std::abs(z.i);
        T oos = 1.0f / s;
        T ars = r * oos, ais = i * oos, brs = z.r * oos, bis = z.i * oos;
        s = (brs * brs) + (bis * bis);
        oos = 1.0f / s;
        return complex(ars * brs + ais * bis, ais * brs - ars * bis) * oos;
    };

    HD void operator+=(const complex<T> z) {
        r += z.r;
        i += z.i;
    };
    HD void operator-=(const complex<T> z) {
        r -= z.r;
        i -= z.i;
    };
    HD void operator*=(const T y) {
        r *= y;
        i *= y;
    };
    HD void operator/=(const T y) {
        r /= y;
        i /= y;
    };

    HD void operator*=(const complex<T> z) {
        T a = r * z.r - i * z.i, b = r * z.i + i * z.r;
        r = a;
        i = b;
    }

    HD void operator/=(const complex<T> z) {
        T s = std::abs(z.r) + std::abs(z.i);
        T oos = 1.0f / s;
        T ars = r * oos, ais = i * oos, brs = z.r * oos, bis = z.i * oos;
        s = (brs * brs) + (bis * bis);
        oos = 1.0f / s;
        r = (ars * brs + ais * bis) * oos;
        i = (ais * brs - ars * bis) * oos;
    };

    HD T abs() const {
        T a = std::abs(r), b = std::abs(i);
        T v, w;
        if (a > b) {
            v = a;
            w = b;
        } else {
            v = b;
            w = a;
        }
        T t = w / v;
        t = 1.0f + t * t;
        t = v * std::sqrt(t);
        if ((v == 0.0f) || (v > 3.402823466e38f) || (w > 3.402823466e38f)) {
            t = v + w;
        }
        return t;
    }

    HD complex<T> conj() const { return complex(r, -1 * i); }

    HD T real() const { return r; };
    HD T imag() const { return i; };
};

template class complex<real32>;
template class complex<real64>;

template <class T> complex<T> operator*(const T y, const complex<T> z) { return z * y; };
template <class T> complex<T> operator/(const T y, const complex<T> z) { return z / y; };

template complex<real32> operator*<real32>(const real32, const complex<real32>);
template complex<real64> operator*<real64>(const real64, const complex<real64>);
template complex<real32> operator/<real32>(const real32, const complex<real32>);
template complex<real64> operator/<real64>(const real64, const complex<real64>);

}; // namespace CudaTools

#ifdef CUDA
using complex64 = CudaTools::complex<real32>;
using complex128 = CudaTools::complex<real64>;
#else
using complex64 = std::complex<real32>; /**< Type alias for 64-bit complex floating point datatype.
                                         * This adapts depending on the CUDA compilation flag, and
                                         * will automatically switch CudaTools::complex<real32>. */
using complex128 =
    std::complex<real64>; /**< Type alias for 128-bit complex floating point datatype. This adapts
                           * depending on the CUDA compilation flag, and will automatically switch
                           * CudaTools::complex<real64>. */
#endif

#endif