#ifndef BLAS_H
#define BLAS_H

#include "Array.h"
#include "Core.h"
#include "Macros.h"

namespace CudaTools {

namespace BLAS {

struct BatchInfo {
    uint32_t strideA, strideB, strideC;
    uint32_t size;
};

template <typename T> struct Check {
    static void isAtLeast2D(const Array<T>& arr, const std::string& name = "Array") {
        CT_ERROR_IF(arr.shape().axes(), <, 2, (name + " needs to be at least 2D").c_str());
    };

    static void isSquare(const Array<T>& arr, const std::string& name = "Array") {
        isAtLeast2D(arr, name);
        CT_ERROR_IF(arr.shape().rows(), !=, arr.shape().cols(), (name + " is not square").c_str())
    };

    static void isValidMatmul(const Array<T>& A, const Array<T>& B, const Array<T>& C,
                              const std::string& nameA = "A", const std::string& nameB = "B",
                              const std::string nameC = "C") {
        isAtLeast2D(A, nameA);
        isAtLeast2D(B, nameB);
        isAtLeast2D(C, nameB);
        CT_ERROR_IF(A.shape().cols(), !=, B.shape().rows(),
                    (nameA + nameB + " is not a valid matrix multiplication").c_str());

        Shape ABshape({A.shape().rows(), B.shape().cols()});
        Shape Cshape({C.shape().rows(), C.shape().cols()});

        CT_ERROR_IF(
            ABshape, !=, Cshape,
            ("The shape of " + nameA + nameB + " does not match the shape of " + nameC).c_str());
    };

    static uint32_t getUpperItems(const Array<T>& arr) {
        uint32_t upperItems = 1;
        for (uint32_t iAxis = 0; iAxis < arr.shape().axes() - 2; ++iAxis) {
            upperItems *= arr.shape().dim(iAxis);
        }
        return upperItems;
    };

    static void matchUpperShape(const Array<T>& A, const Array<T>& B,
                                const std::string& nameA = "A", const std::string& nameB = "B") {
        CT_ERROR_IF(A.shape().axes(), !=, B.shape().axes(),
                    (nameA + " and " + nameB + " shapes do not match for broadcasting").c_str());
        for (uint32_t iAxis = 0; iAxis < A.shape().axes() - 2; ++iAxis) {
            uint32_t Adim = A.shape().dim(iAxis);
            uint32_t Bdim = B.shape().dim(iAxis);
            CT_ERROR_IF(
                Adim, !=, Bdim,
                (nameA + " and " + nameB + " shapes do not match for broadcasting").c_str());
        }
    };

    static BatchInfo isBroadcastable(const Array<T>& A, const Array<T>& B, const Array<T>& C,
                                     const std::string& nameA = "A", const std::string& nameB = "B",
                                     const std::string nameC = "C") {
        isValidMatmul(A, B, C, nameA, nameB, nameC);
        uint32_t itemsA = getUpperItems(A);
        uint32_t itemsB = getUpperItems(B);
        uint32_t itemsC = getUpperItems(C);

        uint32_t Asize = A.shape().rows() * A.shape().cols();
        uint32_t Bsize = B.shape().rows() * B.shape().cols();
        uint32_t Csize = C.shape().rows() * C.shape().cols();

        if (itemsA == itemsB) {
            CT_ERROR_IF(itemsA, !=, itemsC,
                        ("Incorrect dimensions to broadcast to output " + nameC).c_str());
            matchUpperShape(A, B, nameA, nameB);
            matchUpperShape(A, C, nameA, nameC);
            return BatchInfo{Asize, Bsize, Csize, itemsC};
        } else if (itemsA > itemsB) {
            CT_ERROR_IF(
                itemsB, !=, 1,
                ("Cannot broadcast operation to " + nameB + " with non-matching " + nameA).c_str());
            CT_ERROR_IF(itemsA, !=, itemsC,
                        ("Incorrect dimensions to broadcast to output " + nameC).c_str());
            matchUpperShape(A, C, nameA, nameC);
            return BatchInfo{Asize, 0, Csize, itemsC};
        } else {
            CT_ERROR_IF(
                itemsA, !=, 1,
                ("Cannot broadcast operation to " + nameA + " with non-matching " + nameB).c_str());
            CT_ERROR_IF(itemsA, !=, itemsC,
                        ("Incorrect dimensions to broadcast to output " + nameC).c_str());
            matchUpperShape(B, C, nameB, nameC);
            return BatchInfo{0, Bsize, Csize, itemsC};
        }
    };
};

/**
 * Represents a Batch of Arrays with the same shape. Mainly used for cuBLAS functions.
 */
template <typename T> class Batch {
  protected:
    Array<T*> mBatch;
    Shape mShape;

    uint32_t mCount = 0;
    uint32_t mBatchSize;

  public:
    Batch() = delete;

    /**
     * Constructs a batch from a given size.
     */
    Batch(const uint32_t size) : mBatchSize(size){};

    /**
     * Constructs a batch from a non-view Array.
     */
    Batch(const Array<T>& arr) {
        CT_ERROR(arr.isView(), "Array cannot be a view");
        mShape = Shape({arr.shape().rows(), arr.shape().cols()});
        mBatchSize = mCount = Check<T>::getUpperItems(arr);

        mBatch = Array<T*>({mBatchSize});

        Array<T> batch = arr.reshaped({mBatchSize, mShape.rows(), mShape.cols()});
        for (uint32_t i = 0; i < mBatchSize; ++i) {
#ifdef CUDA
            mBatch[i] = batch[i].dataDevice();
#else
            mBatch[i] = batch[i].data();
#endif
        }

        mBatch.updateDevice().wait();
    };

    /**
     * Adds a matrix to the batch. Array must be a view.
     */
    void add(const Array<T>& arr) {
        CT_ERROR(not arr.isView(), "Cannot add non-view Arrays");
        CT_ERROR_IF(mCount, ==, mBatchSize, "Batch is full, cannot add more arrays");
#ifdef CUDA
        mBatch[mCount] = arr.dataDevice();
#else
        mBatch[mCount] = arr.data();
#endif
        if (mCount == 0) {
            mShape = arr.shape();
            mBatchSize = mCount = Check<T>::getUpperItems(arr);
        } else {
            CT_ERROR_IF(arr.shape(), !=, mShape, "Cannot add matrix of different shape to batch");
        }
        ++mCount;

        if (mCount == mBatchSize) {
            mBatch.updateDevice().wait();
        }
    };

    /**
     * Indexing operator which returns a view of the Array in the Batch at the given index.
     */
    Array<T> operator[](const uint32_t index) const {
        CT_ERROR_IF(index, >=, mBatchSize, "Index exceeds batch size");
        return Array<T>(mBatch[index], {mShape.rows(), mShape.cols()});
    };

    /**
     * Returns the batch Array of pointers.
     */
    Array<T*> batch() const { return mBatch.view(); };
    Shape shape() const { return mShape; } /**< Gets the shape of the matrices in the batch. */
    uint32_t size() const { return mBatchSize; }        /**< Gets the batch size.*/
    bool full() const { return mBatchSize == mCount; }; /**< Gets if the batch is full. */
};

////////////////
// cuBLAS API //
////////////////

template <typename T, typename F1, typename F2, typename... Args>
constexpr void invoke(F1 f1, F2 f2, Args&&... args) {
    if constexpr (std::is_same<T, float>::value) {
        CUBLAS_CHECK(f1(args...));
    } else if constexpr (std::is_same<T, double>::value) {
        CUBLAS_CHECK(f2(args...));
    } else {
        CT_ERROR(true, "BLAS functions are not callable with that type");
    }
}

/**
 * Computes the matrix-vector product: \f$ y = \alpha Ax + \beta y \f$. It will automatically
 * broadcast the operation if applicable.
 */
template <typename T>
StreamID GEMV(const T alpha, const Array<T>& A, const Array<T>& x, const T beta, const Array<T>& y,
              const StreamID& stream = DEF_CUBLAS_STREAM) {

    BatchInfo bi = Check<T>::isBroadcastable(A, x, y, "A", "x", "y");
    CT_ERROR_IF(x.shape().cols(), !=, 1, "x must be a column vector");
    CT_ERROR_IF(y.shape().cols(), !=, 1, "x must be a column vector");

    uint32_t rows = A.shape().rows();
    uint32_t cols = A.shape().cols();
    T a = alpha, b = beta;
#ifdef CUDA
    CUBLAS_CHECK(
        cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream.id)));
    if (bi.size == 1) {
        invoke<T>(cublasSgemv, cublasDgemv, Manager::get()->cublasHandle(), CUBLAS_OP_N, rows, cols,
                  &a, A.dataDevice(), rows, x.dataDevice(), 1, &b, y.dataDevice(), 1);

    } else { // Greater than 2, so broadcast.
        invoke<T>(cublasSgemvStridedBatched, cublasDgemvStridedBatched,
                  Manager::get()->cublasHandle(), CUBLAS_OP_N, rows, cols, &a, A.dataDevice(), rows,
                  bi.strideA, x.dataDevice(), 1, bi.strideB, &b, y.dataDevice(), 1, bi.strideC,
                  bi.size);
    }

#else
    if (bi.size == 1) {
        y.eigenMap() = a * (A.eigenMap() * x.eigenMap()) + b * y.eigenMap();
    } else { // Greater than 2, so broadcast.
#pragma omp parallel for
        for (uint32_t i = 0; i < bi.size; ++i) {
            auto Ai = Array<T>(A, {rows, cols}, i * bi.strideA).eigenMap();
            auto xi = Array<T>(x, {cols, 1}, i * bi.strideB).eigenMap();
            auto yi = Array<T>(y, {rows, 1}, i * bi.strideC).eigenMap();
            yi = a * (Ai * xi) + b * yi;
        }
    }
#endif
    return StreamID{stream};
}

/**
 * Computes the matrix-matrix product: \f$ C = \alpha AB + \beta C \f$. It will automatically
 * broadcast the operation if applicable.
 */
template <typename T>
StreamID GEMM(const T alpha, const Array<T>& A, const Array<T>& B, const T beta, const Array<T>& C,
              const StreamID& stream = DEF_CUBLAS_STREAM) {

    BatchInfo bi = Check<T>::isBroadcastable(A, B, C, "A", "B", "C");
    // A is m x k, B is k x n.
    uint32_t m = A.shape().rows();
    uint32_t k = A.shape().cols();
    uint32_t n = B.shape().cols();

    T a = alpha, b = beta;
#ifdef CUDA
    CUBLAS_CHECK(
        cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream.id)));
    if (bi.size == 1) {
        invoke<T>(cublasSgemm, cublasDgemm, Manager::get()->cublasHandle(), CUBLAS_OP_N,
                  CUBLAS_OP_N, m, n, k, &a, A.dataDevice(), m, B.dataDevice(), k, &b,
                  C.dataDevice(), m);

    } else { // Greater than 2, so broadcast.
        invoke<T>(cublasSgemmStridedBatched, cublasDgemmStridedBatched,
                  Manager::get()->cublasHandle(), CUBLAS_OP_N, CUBLAS_OP_N, m, n, k, &a,
                  A.dataDevice(), m, bi.strideA, B.dataDevice(), k, bi.strideB, &b, C.dataDevice(),
                  m, bi.strideC, bi.size);
    }

#else
    if (bi.size == 1) {
        C.eigenMap() = a * (A.eigenMap() * B.eigenMap()) + b * C.eigenMap();
    } else { // Greater than 2, so broadcast.
#pragma omp parallel for
        for (uint32_t i = 0; i < bi.size; ++i) {
            auto Ai = Array<T>(A, {m, k}, i * bi.strideA).eigenMap();
            auto Bi = Array<T>(B, {k, n}, i * bi.strideB).eigenMap();
            auto Ci = Array<T>(C, {m, n}, i * bi.strideC).eigenMap();
            Ci = a * (Ai * Bi) + b * Ci;
        }
    }
#endif
    return StreamID{stream};
}

/**
 * Computes the diagonal matrix multiplication: \f$ C = A\mathrm{diag}(X) \f$, or \f$ C =
 * \mathrm{diag}(X)A \f$ if left = true.
 */
template <typename T>
StreamID DGMM(const Array<T>& A, const Array<T>& X, const Array<T>& C, const bool left = false,
              const StreamID& stream = DEF_CUBLAS_STREAM) {
    CT_ERROR_IF(X.shape().cols(), !=, 1, "'x' must be a column vector.");
    if (left) {
        CT_ERROR_IF(A.shape().rows(), !=, X.shape().rows(),
                    "Rows of 'A' and length of 'x' need to match.");
    } else {
        CT_ERROR_IF(A.shape().cols(), !=, X.shape().rows(),
                    "Columns of 'A' and length of 'x' need to match.");
    }
    CT_ERROR_IF(A.shape().rows(), !=, C.shape().rows(),
                "Rows of 'A' and rows() of 'C' need to  match.");
    CT_ERROR_IF(A.shape().cols(), !=, C.shape().cols(),
                "Rows of 'A' and columns of 'C' need to match.");

#ifdef CUDA
    uint32_t m = C.shape().rows();
    uint32_t n = C.shape().cols();
    auto mode = (left) ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT;
    CUBLAS_CHECK(
        cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream.id)));
    invoke<T>(cublasSdgmm, cublasDdgmm, Manager::get()->cublasHandle(), m, n, A.dataDevice(),
              A.shape().rows(), X.dataDevice(), 1, C.dataDevice(), m);
#else
    if (left) {
        C.eigenMap() = X.eigenMap().asDiagonal() * A.eigenMap();
    } else {
        C.eigenMap() = A.eigenMap() * X.eigenMap().asDiagonal();
    }
#endif
    return StreamID{stream};
}

//////////////////////////////
// PLUArray Related Objects //
//////////////////////////////

///////////////////////////
// PartialPivLU Wrapper  //
///////////////////////////

// This class is just a workaround to use Eigen's internals directly.
template <typename T> class PartialPivLU;
namespace internal {
template <typename T> static Array<T> empty({1, 1});
template <typename T> static EigenMapMat<T> empty_map = empty<T>.eigenMap();
}; // namespace internal

template <typename T, ENABLE_IF(IS_FLOAT(T)) = true> class PLUArray;
// This is a wrapper class for Eigen's class so we have more controlled access to
// the underlying data.
template <typename T> class PartialPivLU : public Eigen::PartialPivLU<Eigen::Ref<EigenMat<T>>> {
  private:
    using Base = Eigen::PartialPivLU<Eigen::Ref<EigenMat<T>>>;
    template <typename U, ENABLE_IF(IS_FLOAT(U))> friend class PLUArray;

    EigenMapMat<T> mMapLU;
    EigenMapMat<int32_t> mMapPivots;

  public:
    PartialPivLU()
        : Base(internal::empty_map<T>), mMapLU(internal::empty_map<T>),
          mMapPivots(internal::empty_map<int32_t>){};

    void make(const Array<T>& lu, const Array<int32_t>& pivots) {

        new (&mMapLU) EigenMapMat<T>(lu.eigenMap());
        new (&mMapPivots) EigenMapMat<int32_t>(pivots.atLeast2D().eigenMap());

        new (&this->m_lu) decltype(Base::m_lu)(mMapLU.derived());
        new (&this->m_p) decltype(Base::m_p)(mMapPivots.derived());

        // new (&this->m_rowsTranspositions) decltype(Base::m_rowsTranspositions)(
        //     mMapPivots.derived());

        this->m_l1_norm = 0;
        this->m_det_p = 0;
        this->m_isInitialized = true;
    };
};

namespace internal {
// We only create one and copy-construct to avoid the re-initialization.
template <typename T> static PartialPivLU<T> BlankPPLU = PartialPivLU<T>();
}; // namespace internal

/**
 * Class for storing the PLU decomposition an Array. This is restricted to floating point types.
 */
template <typename T, ENABLE_IF(IS_FLOAT(T))> class PLUArray {
  private:
    Array<T> mLU;
    Array<int32_t> mPivots;
    PartialPivLU<T> mPPLU = internal::BlankPPLU<T>;

  public:
    PLUArray() = delete;

    /**
     * Constructor for a PLUArray given the matrix dimension.
     */
    PLUArray(const uint32_t n) : mLU({n, n}), mPivots({n}) { mPPLU.make(mLU, mPivots); };

    /**
     * Constructor for a PLUArray given an existing array.
     */
    PLUArray(const Array<T>& arr)
        : mLU((arr.isView()) ? arr.view() : arr), mPivots({arr.shape().rows()}) {
        CT_ERROR_IF(mLU.shape().axes(), !=, 2, "Array must be a 2D matrix");
        CT_ERROR_IF(mLU.shape().rows(), !=, mLU.shape().cols(), "Matrix must be square");
        mPPLU.make(mLU, mPivots);
    };

    /**
     * Constructor for a PLUArray given an existing location in memory for both the matrix and
     * the pivots.
     */
    PLUArray(const Array<T>& arr, const Array<int32_t> pivots)
        : mLU(arr.view()), mPivots(pivots.view()) {
        CT_ERROR_IF(mLU.shape().axes(), !=, 2, "Array must be a 2D matrix");
        CT_ERROR_IF(mLU.shape().rows(), !=, mLU.shape().cols(), "Matrix must be square");
        mPPLU.make(mLU, mPivots);
    };

    uint32_t rank() { return mLU.shape().rows(); }; /**< Gets the rank of the LU matrix. */
    Array<T> LU() const { return mLU.view(); };     /**< Gets the LU matrix. */
    Array<int32_t> pivots() const { return mPivots.view(); }; /**< Gets the LU matrix. */

    /**
     * Comptues the inplace LU factorization for this array on CPU.
     */
    void computeLU() {
        mPPLU.compute();
        mPPLU.mMapPivots = mPPLU.permutationP().indices();
    };

    /**
     * Solves the system \f$ LUx = b \f$ and returns \f$x\f$.
     */
    Array<T> solve(const Array<T>& b) {
        Array<T> x(b.shape());
        x.eigenMap() = mPPLU.solve(b.eigenMap());
        return x;
    };
};

/**
 * This is a batch version of PLUArray, to enable usage of the cuBLAS API. This is restricted to
 * floating point types.
 */
template <typename T, std::enable_if_t<std::is_floating_point<T>::value, bool> = true>
class PLUBatch : public Batch<T> {
  private:
    Array<int32_t> mPivotsBatch;
    Array<int32_t> mInfoLU;
    int32_t mInfoSolve;

    bool mInitialized = false;

  public:
    /**
     * Constructor of a PLUBatch from a given batch size.
     */
    PLUBatch(const uint32_t size) : Batch<T>(size), mInfoLU({size}){};

    /**
     * Constructor of a PLUBatch from a multi-dimensional array, batched across upper dimensions.
     */
    PLUBatch(const Array<T>& arr) : Batch<T>(arr) {
        Check<T>::isSquare(arr, "LU Array");

        mPivotsBatch = Array<int32_t>({this->mBatchSize * this->mShape.rows()});
        mInfoLU = Array<int32_t>({this->mBatchSize});
    };

    /**
     * Indexing operator which returns the PLUArray in the PLUBatch at the given index.
     */
    PLUArray<T> operator[](const uint32_t index) const {
        CT_ERROR_IF(index, >=, this->mBatchSize, "Index exceeds batch size");
        Array<T> lu(this->mBatch[index], {this->mShape.rows(), this->mShape.cols()});
        Array<int32_t> pivots(mPivotsBatch.data() + index * this->mShape.rows(),
                              {this->mShape.rows()});
        return PLUArray<T>(lu, pivots);
    };

    /**
     * Computes the inplace PLU decomposition of batch of arrays.
     */
    StreamID computeLU(const StreamID& stream = DEF_CUBLAS_STREAM) {
#ifdef CUDA
        uint32_t n = this->mShape.rows();
        CUBLAS_CHECK(
            cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream.id)));
        invoke<T>(cublasSgetrfBatched, cublasDgetrfBatched, Manager::get()->cublasHandle(), n,
                  this->mBatch.dataDevice(), n, mPivotsBatch.dataDevice(), mInfoLU.dataDevice(),
                  this->mBatchSize);

#else
#pragma omp parallel for
        for (uint32_t i = 0; i < this->mBatchSize; ++i) {
            (*this)[i].computeLU();
        }
#endif
        mInitialized = true;
        return stream;
    };

    /**
     * Solves the batched system \f$LUx = b\f$ inplace. The solution \f$x\f$ is written back into
     * \f$b\f$.
     */
    StreamID solve(const Batch<T>& b, const StreamID& stream = DEF_CUBLAS_STREAM) {
        CT_ERROR(not mInitialized,
                 "Cannot solve system if PLUBatch has not yet computed its LU decomposition");
        CT_ERROR_IF(b.size(), !=, this->mBatchSize,
                    "Upper dimensions of b do not match batch size");
        CT_ERROR_IF(b.shape().rows(), !=, this->mShape.rows(),
                    "The length of each column of b must match the matrix rank");

#ifdef CUDA
        uint32_t n = b.shape().rows();
        uint32_t nrhs = b.shape().cols();
        CUBLAS_CHECK(
            cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream.id)));
        invoke<T>(cublasSgetrsBatched, cublasDgetrsBatched, Manager::get()->cublasHandle(),
                  CUBLAS_OP_N, n, nrhs, this->mBatch.dataDevice(), n, mPivotsBatch.dataDevice(),
                  b.batch().dataDevice(), n, &mInfoSolve, this->mBatchSize);

#else
#pragma omp parallel for
        for (uint32_t i = 0; i < this->mBatchSize; ++i) {
            b[i] = (*this)[i].solve(b[i]);
        }
#endif
        return stream;
    };

    /**
     * Gets the pivots data from the device to the host. Does nothing for CPU.
     */
    StreamID getPivots(const StreamID& stream = DEF_MEM_STREAM) const {
        mPivotsBatch.updateHost(stream);
        return stream;
    };

    /**
     * Gets the info array for the LU decomposition for the device to the host. Does not
     * return useful information for CPU.
     */
    Array<int32_t> getLUInfo() const {
        mInfoLU.updateHost().wait();
        return mInfoLU;
    };

    /**
     * Checks validity of the solve operation. Does not return useful information for CPU.
     */
    int32_t validSolve() const { return mInfoSolve == 0; }
};

// /**
//  * Gets the inverse of each A[i], using an already PLU factorized A[i].
//  * Only available if compiling with CUDA.
//  */
// template <typename T>
// void inverseBatch(const Array<T*>& batchA, const Array<T*>& batchC, const Array<int>&
// pivots,
//                   const Array<int>& info, const Shape shapeA, const Shape shapeC,
//                   const uint stream = 0) {
// #ifdef CUDA
//     CT_ERROR_IF(shapeA.rows(), !=, shapeA.cols(),
//           "'A' needs to be square, rows() and column need to match.");
//     CT_ERROR_IF(shapeA.rows(), !=, shapeC.cols(), "'A' needs to be the same shape as
//     'C'."); CT_ERROR_IF(shapeA.rows(), !=, shapeC.rows(), "'A' needs to be the same shape
//     as 'C'.");

//     CT_ERROR_IF(shapeA.rows(), !=, pivots.shape().rows(),
//           "Rows()/columns of 'A' and rows() of pivots need to match.");
//     CT_ERROR_IF(batchA.shape().rows(), !=, pivots.shape().cols(),
//           "Batch size and columns of pivots need to match.");
//     CT_ERROR_IF(info.shape().cols(), !=, 1, "Info needs to be a column vector.")
//     CT_ERROR_IF(batchA.shape().rows(), !=, info.shape().rows(),
//           "Batch size and length of info need to match.");
//     CT_ERROR_IF(batchA.shape().rows(), !=, batchC.shape().rows(),
//           "Batches 'A[i]' and 'C[i]' need to match.");

//     std::string s = "cublas" + std::to_string(stream);
//     CUBLAS_CHECK(
//         cublasSetStream(Manager::get()->cublasHandle(),
//         Manager::get()->stream(s)));
//     invoke<T>(cublasSgetriBatched, cublasDgetriBatched,
//     Manager::get()->cublasHandle(),
//               shapeA.rows(), batchA.dataDevice(), shapeA.rows(), pivots.dataDevice(),
//               batchC.dataDevice(), shapeC.rows(), info.dataDevice(),
//               batchA.shape().rows());
// #else
//     CT_ERROR_IF(true, ==, true, "inverseBatch is not callable without CUDA.");
// #endif
// }

}; // namespace BLAS
}; // namespace CudaTools

#endif