#ifndef MACROS_H
#define MACROS_H

#include <exception>
#include <sstream>
#include <stdarg.h>

#if defined(CUDA) && defined(__CUDACC__)
#define CUDACC
#endif

#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ > 0)
#define DEVICE
#endif

#ifdef CUDATOOLS_DOXYGEN
/**
 * \def CUDACC
 * This macro is defined when this code is being compiled by nvcc and the CUDA compilation
 * flag is set. This should be used to enclose code where CUDA specific libraries and syntax are
 * being used.
 */
#define CUDACC

/**
 * \def DEVICE
 * This macro is defined when this code is being compiled for the device. The difference between
 * this and CUDACC is that this should exclusively be used to dcide if code is being compiled
 * to execute on the device. CUDACC is only determines what compiler is being used.
 */
#define DEVICE

/**
 * \def HD
 * Mark a function in front with this if it needs to be callable on both the
 * CPU and CUDA device.
 */
#define HD

/**
 * \def SHARED
 * Mark a variable as static shared memory.
 */
#define SHARED

/**
 * \def DECLARE_KERNEL(call, ...)
 * Used to declare (in header) a CUDA kernel.
 * \param call the name of the kernel
 * \param ... the arguments of the kernel
 */
#define DECLARE_KERNEL(call, ...)

/**
 * \def DEFINE_KERNEL(call, ...)
 * Used to define (in implementation) a CUDA kernel.
 * \param call the name of the kernel
 * \param ... the arguments of the kernel
 */
#define DEFINE_KERNEL(call, ...)

/**
 * \def KERNEL(call, settings, ...)
 * Used to call a CUDA kernel.
 * \param call the name of the kernel
 * \param settings the associated CudaTools::Kernel::Settings to initialize the kernel with
 * \param ... the arguments of the kernel
 */
#define KERNEL(call, settings, ...)

/**
 * \def BASIC_LOOP(N)
 * Can be used in conjunction with CudaTools::Kernel::Basic, which is mainly used for embarassingly
 * parallel situations. Exposes the loop/thread number as iThread.
 * \param N number of iterations
 */
#define BASIC_LOOP(N)

/**
 * \def DEVICE_CLASS(name)
 * Can be used inside a class declaration (header) which generates boilerplate code to allow this
 * class to be used on the device.
 *
 * This macro creates a few functions:\n
 * name* that(): returns the pointer to this instance on the device.
 *
 * void allocateDevice(): allocates the memory on the device for this class instance.
 *
 * CudaTools::StreamID updateHost(const CudaTools::StreamID& stream): updates the host instance
 * of the class.
 *
 * CudaTools::StreamID updateDevice(const CudaTools::StreamID& stream): updates
 * the device instance of the class.
 * \param name the name of the class
 */
#define DEVICE_CLASS(name)

/**
 * \def CT_ERROR_IF(a, op, b, msg)
 * Used for throwing runtime errors given a condition with an operator.
 */
#define CT_ERROR_IF(a, op, b, msg)

/**
 * \def CT_ERROR(a, msg)
 * Used for throwing runtime errors given a bool.
 */
#define CT_ERROR(a, msg)

/**
 * \def CUDA_CHECK(call)
 * Gets the error generated by a CUDA function call if there is one.
 * \param call CUDA function to check if there are errors when running.
 */
#define CUDA_CHECK(call)

/**
 * \def CUBLAS_CHECK(call)
 * Gets the error generated by a cuBLAS function call if there is one.
 * \param call cuBLAS function to check if there are errors when running.
 */
#define CUBLAS_CHECK(call)

/**
 * \def CUDA_MEM(call)
 * Gets the GPU memory used from function call if there is one.
 * \param call function to measure memory usage.
 * \param name an identifier to use as a variable and when printing. Must satisfy variable naming.
 */
#define CUDA_MEM(call, name)
#endif

///////////////////
// KERNEL MACROS //
///////////////////

#ifdef CUDACC

#include <cublas_v2.h>
#include <cuda_runtime.h>

#define HD __host__ __device__
#define SHARED __shared__

#define DECLARE_KERNEL(call, ...) __global__ void call(__VA_ARGS__)

#define DEFINE_KERNEL(call, ...)                                                                   \
    template CudaTools::StreamID CudaTools::runKernel(                                             \
        void (*)(__VA_ARGS__), const CudaTools::Kernel::Settings&, __VA_ARGS__);                   \
    __global__ void call(__VA_ARGS__)

#else
#define HD
#define SHARED

#define DECLARE_KERNEL(call, ...) void call(__VA_ARGS__)

#define DEFINE_KERNEL(call, ...)                                                                   \
    template CudaTools::StreamID CudaTools::runKernel(                                             \
        void (*)(__VA_ARGS__), const CudaTools::Kernel::Settings&, __VA_ARGS__);                   \
    void call(__VA_ARGS__)

#endif // CUDACC

#define KERNEL(call, settings, ...) CudaTools::runKernel(call, settings, __VA_ARGS__)

///////////////////
// DEVICE MACROS //
///////////////////

#ifdef DEVICE

#define BASIC_LOOP(N)                                                                              \
    uint32_t iThread = blockIdx.x * blockDim.x + threadIdx.x;                                      \
    if (iThread < N)
#else
#define BASIC_LOOP(N) _Pragma("omp parallel for") for (uint32_t iThread = 0; iThread < N; ++iThread)

#endif

//////////////////
// CLASS MACROS //
//////////////////

#define UPDATE_FUNC(name)                                                                          \
    inline CudaTools::StreamID updateHost(const CudaTools::StreamID& stream =                      \
                                              CudaTools::DEF_MEM_STREAM) {                         \
        return CudaTools::pull(this, that(), sizeof(name));                                        \
    };                                                                                             \
    inline CudaTools::StreamID updateDevice(const CudaTools::StreamID& stream =                    \
                                                CudaTools::DEF_MEM_STREAM) {                       \
        return CudaTools::push(this, that(), sizeof(name));                                        \
    }

#ifdef CUDA

#define DEVICE_CLASS(name)                                                                         \
  private:                                                                                         \
    name* __deviceInstance__ = nullptr;                                                            \
                                                                                                   \
  public:                                                                                          \
    inline name* that() { return __deviceInstance__; }                                             \
    inline void allocateDevice() { __deviceInstance__ = (name*)CudaTools::malloc(sizeof(name)); }; \
    UPDATE_FUNC(name)

#else
#define DEVICE_CLASS(name)                                                                         \
  public:                                                                                          \
    inline name* that() { return this; };                                                          \
    inline void allocateDevice(){};                                                                \
    UPDATE_FUNC(name)

#endif

#ifndef CUDATOOLS_ARRAY_MAX_AXES
/**
 * \def CUDATOOLS_ARRAY_MAX_AXES
 * The maximum number of axes/dimensions an CudaTools::Array can have. The default is
 * set to 4, but can be manully set fit the program needs.
 */
#define CUDATOOLS_ARRAY_MAX_AXES 4
#endif

////////////////////
// Error Checking //
////////////////////

#ifndef NO_DIMENSION_CHECK
#ifdef DEVICE
#define CT_ERROR_IF(a, op, b, msg)                                                                 \
    if (a op b) {                                                                                  \
        printf("[ERROR] %s:%d\n | %s: (" #a ") " #op " (" #b ").\n", __FILE__, __LINE__, msg);     \
    }

#define CT_ERROR(a, msg)                                                                           \
    if (a) {                                                                                       \
        printf("[ERROR] %s:%d\n | %s: " #a ".\n", __FILE__, __LINE__, msg);                        \
    }
#else

#define CT_ERROR_IF(a, op, b, msg)                                                                 \
    if (a op b) {                                                                                  \
        std::ostringstream os_a;                                                                   \
        std::ostringstream os_b;                                                                   \
        os_a << a;                                                                                 \
        os_b << b;                                                                                 \
        printf("[ERROR] %s:%d\n | %s: (" #a ")%s " #op " (" #b ")%s.\n", __FILE__, __LINE__, msg,  \
               os_a.str().c_str(), os_b.str().c_str());                                            \
        throw std::exception();                                                                    \
    }

#define CT_ERROR(a, msg)                                                                           \
    if (a) {                                                                                       \
        printf("[ERROR] %s:%d\n | %s: " #a ".\n", __FILE__, __LINE__, msg);                        \
        throw std::exception();                                                                    \
    }
#endif

#endif // NO_DIMENSION_CHECK

#if defined(CUDACC) && !defined(NO_CUDA_CHECK)

#define CUDA_CHECK(call)                                                                           \
    do {                                                                                           \
        cudaError_t err = (call);                                                                  \
        if (err != cudaSuccess) {                                                                  \
            printf("[CUDA] %s:%d\n | %s\n", __FILE__, __LINE__, cudaGetErrorString(err));          \
            throw std::exception();                                                                \
        }                                                                                          \
    } while (0)

#define CUBLAS_CHECK(call)                                                                         \
    do {                                                                                           \
        cublasStatus_t err = (call);                                                               \
        if (err != CUBLAS_STATUS_SUCCESS) {                                                        \
            printf("[cuBLAS] %s:%d\n | %s\n", __FILE__, __LINE__,                                  \
                   CudaTools::cublasGetErrorString(err));                                          \
            throw std::exception();                                                                \
        }                                                                                          \
    } while (0)

#define CUDA_MEM(call, name)                                                                       \
    size_t free_bef_##name, free_aft_##name;                                                       \
    cudaMemGetInfo(&free_bef_##name, NULL);                                                        \
    call;                                                                                          \
    CudaTools::Manager::get()->sync();                                                             \
    cudaMemGetInfo(&free_aft_##name, NULL);                                                        \
    printf("[%s] GPU Memory Usage: %iMiB\n", #name,                                                \
           (free_bef_##name - free_aft_##name) / (1024 * 1024));

#else
#define CUDA_CHECK(call) (call)
#define CUBLAS_CHECK(call) (call)
#define CUDA_MEM(call, name) (call)
#endif

#endif // MACROS_H