#ifndef CUDATOOLS_BLAS_H #define CUDATOOLS_BLAS_H #include "Array.h" #include "Core.h" #include "Macros.h" #include "Types.h" using namespace CudaTools::Types; namespace CudaTools { namespace BLAS { struct BatchInfo { uint32_t strideA, strideB, strideC; uint32_t size; }; struct Check { template static void isAtLeast2D(const Array& arr, const std::string& name = "Array") { CT_ERROR_IF(arr.shape().axes(), <, 2, (name + " needs to be at least 2D").c_str()); }; template static void isSquare(const Array& arr, const std::string& name = "Array") { isAtLeast2D(arr, name); CT_ERROR_IF(arr.shape().rows(), !=, arr.shape().cols(), (name + " is not square").c_str()) }; template static void isValidMatmul(const Array& A, const Array& B, const Array& C, const std::string& nameA = "A", const std::string& nameB = "B", const std::string nameC = "C") { isAtLeast2D(A, nameA); isAtLeast2D(B, nameB); isAtLeast2D(C, nameB); CT_ERROR_IF(A.shape().cols(), !=, B.shape().rows(), (nameA + nameB + " is not a valid matrix multiplication").c_str()); Shape ABshape({A.shape().rows(), B.shape().cols()}); Shape Cshape({C.shape().rows(), C.shape().cols()}); CT_ERROR_IF( ABshape, !=, Cshape, ("The shape of " + nameA + nameB + " does not match the shape of " + nameC).c_str()); }; template static uint32_t getUpperItems(const Array& arr) { uint32_t upperItems = 1; for (uint32_t iAxis = 0; iAxis < arr.shape().axes() - 2; ++iAxis) { upperItems *= arr.shape().dim(iAxis); } return upperItems; }; template static void matchUpperShape(const Array& A, const Array& B, const std::string& nameA = "A", const std::string& nameB = "B") { CT_ERROR_IF(A.shape().axes(), !=, B.shape().axes(), (nameA + " and " + nameB + " shapes do not match for broadcasting").c_str()); for (uint32_t iAxis = 0; iAxis < A.shape().axes() - 2; ++iAxis) { uint32_t Adim = A.shape().dim(iAxis); uint32_t Bdim = B.shape().dim(iAxis); CT_ERROR_IF( Adim, !=, Bdim, (nameA + " and " + nameB + " shapes do not match for broadcasting").c_str()); } }; template static BatchInfo isBroadcastable(const Array& A, const Array& B, const Array& C, const std::string& nameA = "A", const std::string& nameB = "B", const std::string nameC = "C") { isValidMatmul(A, B, C, nameA, nameB, nameC); uint32_t itemsA = getUpperItems(A); uint32_t itemsB = getUpperItems(B); uint32_t itemsC = getUpperItems(C); uint32_t Asize = A.shape().rows() * A.shape().cols(); uint32_t Bsize = B.shape().rows() * B.shape().cols(); uint32_t Csize = C.shape().rows() * C.shape().cols(); if (itemsA == itemsB) { CT_ERROR_IF(itemsA, !=, itemsC, ("Incorrect dimensions to broadcast to output " + nameC).c_str()); matchUpperShape(A, B, nameA, nameB); matchUpperShape(A, C, nameA, nameC); return BatchInfo{Asize, Bsize, Csize, itemsC}; } else if (itemsA > itemsB) { CT_ERROR_IF( itemsB, !=, 1, ("Cannot broadcast operation to " + nameB + " with non-matching " + nameA).c_str()); CT_ERROR_IF(itemsA, !=, itemsC, ("Incorrect dimensions to broadcast to output " + nameC).c_str()); matchUpperShape(A, C, nameA, nameC); return BatchInfo{Asize, 0, Csize, itemsC}; } else { CT_ERROR_IF( itemsA, !=, 1, ("Cannot broadcast operation to " + nameA + " with non-matching " + nameB).c_str()); CT_ERROR_IF(itemsA, !=, itemsC, ("Incorrect dimensions to broadcast to output " + nameC).c_str()); matchUpperShape(B, C, nameB, nameC); return BatchInfo{0, Bsize, Csize, itemsC}; } }; }; /** * Represents a Batch of Arrays with the same shape. Mainly used for cuBLAS functions. */ template class Batch { protected: Array mBatch; Shape mShape; uint32_t mCount = 0; uint32_t mBatchSize; public: Batch() = delete; /** * Constructs a batch from a given size. */ Batch(const uint32_t size) : mBatchSize(size){}; /** * Constructs a batch from a non-view Array. */ Batch(const Array& arr) { CT_ERROR(arr.isView(), "Array cannot be a view"); mShape = Shape({arr.shape().rows(), arr.shape().cols()}); mBatchSize = mCount = Check::getUpperItems(arr); mBatch = Array({mBatchSize}); Array batch = arr.reshaped({mBatchSize, mShape.rows(), mShape.cols()}); for (uint32_t i = 0; i < mBatchSize; ++i) { #ifdef CUDA mBatch[i] = batch[i].dataDevice(); #else mBatch[i] = batch[i].data(); #endif } mBatch.updateDevice().wait(); }; /** * Adds a matrix to the batch. Array must be a view. */ void add(const Array& arr) { CT_ERROR(not arr.isView(), "Cannot add non-view Arrays"); CT_ERROR_IF(mCount, ==, mBatchSize, "Batch is full, cannot add more arrays"); #ifdef CUDA mBatch[mCount] = arr.dataDevice(); #else mBatch[mCount] = arr.data(); #endif if (mCount == 0) { mShape = arr.shape(); mBatchSize = mCount = Check::getUpperItems(arr); } else { CT_ERROR_IF(arr.shape(), !=, mShape, "Cannot add matrix of different shape to batch"); } ++mCount; if (mCount == mBatchSize) { mBatch.updateDevice().wait(); } }; /** * Indexing operator which returns a view of the Array in the Batch at the given index. */ Array operator[](const uint32_t index) const { CT_ERROR_IF(index, >=, mBatchSize, "Index exceeds batch size"); return Array(mBatch[index], {mShape.rows(), mShape.cols()}); }; /** * Returns the batch Array of pointers. */ Array batch() const { return mBatch.view(); }; Shape shape() const { return mShape; } /**< Gets the shape of the matrices in the batch. */ uint32_t size() const { return mBatchSize; } /**< Gets the batch size.*/ bool full() const { return mBatchSize == mCount; }; /**< Gets if the batch is full. */ }; //////////////// // cuBLAS API // //////////////// template struct CudaComplexConversion_S { typedef T type; }; #ifdef CUDACC template <> struct CudaComplexConversion_S { typedef cuComplex type; }; template <> struct CudaComplexConversion_S { typedef cuDoubleComplex type; }; #else #endif template using CudaComplexConversion = typename CudaComplexConversion_S::type; template struct CublasTypeLetter_S { char letter; }; template <> struct CublasTypeLetter_S { char letter = 'S'; }; template <> struct CublasTypeLetter_S { char letter = 'D'; }; template <> struct CublasTypeLetter_S { char letter = 'C'; }; template <> struct CublasTypeLetter_S { char letter = 'Z'; }; #ifdef CUDACC template <> struct CublasTypeLetter_S { char letter = 'H'; }; #endif template char CublasTypeLetter = CublasTypeLetter_S::letter; // Shorthands to reduce clutter. #define CAST(var) reinterpret_cast*>(var) #define DCAST(var) reinterpret_cast**>(var) #define cublas(T, func) cublas##CublasTypeLetter##func template constexpr void invoke(F1 f1, F2 f2, F3 f3, F4 f4, Args&&... args) { if constexpr (std::is_same::value) { CUBLAS_CHECK(f1(args...)); } else if constexpr (std::is_same::value) { CUBLAS_CHECK(f2(args...)); } else if constexpr (std::is_same::value) { CUBLAS_CHECK(f3(args...)); } else if constexpr (std::is_same::value) { CUBLAS_CHECK(f4(args...)); } else { CT_ERROR(true, "This BLAS function is not callable with that type"); } } // If someone can think of a better solution, please tell me. template constexpr void invoke5(F1 f1, F2 f2, F3 f3, F4 f4, F5 f5, Args&&... args) { if constexpr (std::is_same::value) { CUBLAS_CHECK(f1(args...)); } else if constexpr (std::is_same::value) { CUBLAS_CHECK(f2(args...)); } else if constexpr (std::is_same::value) { CUBLAS_CHECK(f3(args...)); } else if constexpr (std::is_same::value) { CUBLAS_CHECK(f4(args...)); } else if constexpr (std::is_same::value) { CUBLAS_CHECK(f5(args...)); } else { CT_ERROR(true, "This BLAS function is not callable with that type"); } } /** * Computes the matrix-vector product: \f$ y = \alpha Ax + \beta y \f$. It will automatically * broadcast the operation if applicable. */ template StreamID GEMV(const T alpha, const Array& A, const Array& x, const T beta, const Array& y, const StreamID& stream = DEF_CUBLAS_STREAM) { BatchInfo bi = Check::isBroadcastable(A, x, y, "A", "x", "y"); CT_ERROR_IF(x.shape().cols(), !=, 1, "x must be a column vector"); CT_ERROR_IF(y.shape().cols(), !=, 1, "x must be a column vector"); uint32_t rows = A.shape().rows(); uint32_t cols = A.shape().cols(); T a = alpha, b = beta; #ifdef CUDA CUBLAS_CHECK(cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream))); if (bi.size == 1) { invoke(cublasSgemv, cublasDgemv, cublasCgemv, cublasZgemv, Manager::get()->cublasHandle(), CUBLAS_OP_N, rows, cols, CAST(&a), CAST(A.dataDevice()), rows, CAST(x.dataDevice()), 1, CAST(&b), CAST(y.dataDevice()), 1); } else { // Greater than 2, so broadcast. invoke(cublasSgemvStridedBatched, cublasDgemvStridedBatched, cublasCgemvStridedBatched, cublasZgemvStridedBatched, Manager::get()->cublasHandle(), CUBLAS_OP_N, rows, cols, CAST(&a), CAST(A.dataDevice()), rows, bi.strideA, CAST(x.dataDevice()), 1, bi.strideB, CAST(&b), CAST(y.dataDevice()), 1, bi.strideC, bi.size); } #else if (bi.size == 1) { y.eigenMap() = a * (A.eigenMap() * x.eigenMap()) + b * y.eigenMap(); } else { // Greater than 2, so broadcast. #pragma omp parallel for for (uint32_t i = 0; i < bi.size; ++i) { auto Ai = Array(A, {rows, cols}, i * bi.strideA).eigenMap(); auto xi = Array(x, {cols, 1}, i * bi.strideB).eigenMap(); auto yi = Array(y, {rows, 1}, i * bi.strideC).eigenMap(); yi = a * (Ai * xi) + b * yi; } } #endif return StreamID{stream}; } /** * Computes the matrix-matrix product: \f$ C = \alpha AB + \beta C \f$. It will automatically * broadcast the operation if applicable. */ template StreamID GEMM(const T alpha, const Array& A, const Array& B, const T beta, const Array& C, const StreamID& stream = DEF_CUBLAS_STREAM) { BatchInfo bi = Check::isBroadcastable(A, B, C, "A", "B", "C"); // A is m x k, B is k x n. uint32_t m = A.shape().rows(); uint32_t k = A.shape().cols(); uint32_t n = B.shape().cols(); T a = alpha, b = beta; #ifdef CUDA CUBLAS_CHECK(cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream))); if (bi.size == 1) { invoke5(cublasSgemm, cublasDgemm, cublasCgemm, cublasZgemm, cublasHgemm, Manager::get()->cublasHandle(), CUBLAS_OP_N, CUBLAS_OP_N, m, n, k, CAST(&a), CAST(A.dataDevice()), m, CAST(B.dataDevice()), k, CAST(&b), CAST(C.dataDevice()), m); } else { // Greater than 2, so broadcast. invoke5(cublasSgemmStridedBatched, cublasDgemmStridedBatched, cublasCgemmStridedBatched, cublasZgemmStridedBatched, cublasHgemmStridedBatched, Manager::get()->cublasHandle(), CUBLAS_OP_N, CUBLAS_OP_N, m, n, k, CAST(&a), CAST(A.dataDevice()), m, bi.strideA, CAST(B.dataDevice()), k, bi.strideB, CAST(&b), CAST(C.dataDevice()), m, bi.strideC, bi.size); } #else if (bi.size == 1) { C.eigenMap() = a * (A.eigenMap() * B.eigenMap()) + b * C.eigenMap(); } else { // Greater than 2, so broadcast. #pragma omp parallel for for (uint32_t i = 0; i < bi.size; ++i) { auto Ai = Array(A, {m, k}, i * bi.strideA).eigenMap(); auto Bi = Array(B, {k, n}, i * bi.strideB).eigenMap(); auto Ci = Array(C, {m, n}, i * bi.strideC).eigenMap(); Ci = a * (Ai * Bi) + b * Ci; } } #endif return StreamID{stream}; } /** * Computes the diagonal matrix multiplication: \f$ C = A\mathrm{diag}(X) \f$, or \f$ C = * \mathrm{diag}(X)A \f$ if left = true. */ template StreamID DGMM(const Array& A, const Array& X, const Array& C, const bool left = false, const StreamID& stream = DEF_CUBLAS_STREAM) { CT_ERROR_IF(X.shape().cols(), !=, 1, "'x' must be a column vector."); if (left) { CT_ERROR_IF(A.shape().rows(), !=, X.shape().rows(), "Rows of 'A' and length of 'x' need to match."); } else { CT_ERROR_IF(A.shape().cols(), !=, X.shape().rows(), "Columns of 'A' and length of 'x' need to match."); } CT_ERROR_IF(A.shape().rows(), !=, C.shape().rows(), "Rows of 'A' and rows() of 'C' need to match."); CT_ERROR_IF(A.shape().cols(), !=, C.shape().cols(), "Rows of 'A' and columns of 'C' need to match."); #ifdef CUDA uint32_t m = C.shape().rows(); uint32_t n = C.shape().cols(); auto mode = (left) ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT; CUBLAS_CHECK(cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream))); invoke(cublasSdgmm, cublasDdgmm, cublasCdgmm, cublasZdgmm, Manager::get()->cublasHandle(), m, n, CAST(A.dataDevice()), A.shape().rows(), CAST(X.dataDevice()), 1, CAST(C.dataDevice()), m); #else if (left) { C.eigenMap() = X.eigenMap().asDiagonal() * A.eigenMap(); } else { C.eigenMap() = A.eigenMap() * X.eigenMap().asDiagonal(); } #endif return StreamID{stream}; } ////////////////////////////// // PLUArray Related Objects // ////////////////////////////// /////////////////////////// // PartialPivLU Wrapper // /////////////////////////// // This class is just a workaround to use Eigen's internals directly. template class PartialPivLU; namespace internal { template static Array empty({1, 1}); template static EigenMapMat empty_map = empty.eigenMap(); }; // namespace internal template or is_complex, bool> = true> class PLUArray; // This is a wrapper class for Eigen's class so we have more controlled access to // the underlying data. template class PartialPivLU : public Eigen::PartialPivLU>> { private: using Base = Eigen::PartialPivLU>>; template or is_complex, bool>> friend class PLUArray; EigenMapMat mMapLU; EigenMapMat mMapPivots; public: PartialPivLU() : Base(internal::empty_map), mMapLU(internal::empty_map), mMapPivots(internal::empty_map){}; void make(const Array& lu, const Array& pivots) { new (&mMapLU) EigenMapMat(lu.eigenMap()); new (&mMapPivots) EigenMapMat(pivots.atLeast2D().eigenMap()); new (&this->m_lu) decltype(Base::m_lu)(mMapLU.derived()); new (&this->m_p) decltype(Base::m_p)(mMapPivots.derived()); // new (&this->m_rowsTranspositions) decltype(Base::m_rowsTranspositions)( // mMapPivots.derived()); this->m_l1_norm = 0; this->m_det_p = 0; this->m_isInitialized = true; }; }; namespace internal { // We only create one and copy-construct to avoid the re-initialization. template static PartialPivLU BlankPPLU = PartialPivLU(); }; // namespace internal /** * Class for storing the PLU decomposition an Array. This is restricted to floating point types. */ template or is_complex, bool>> class PLUArray { private: Array mLU; Array mPivots; PartialPivLU mPPLU = internal::BlankPPLU; public: PLUArray() = delete; /** * Constructor for a PLUArray given the matrix dimension. */ PLUArray(const uint32_t n) : mLU({n, n}), mPivots({n}) { mPPLU.make(mLU, mPivots); }; /** * Constructor for a PLUArray given an existing array. */ PLUArray(const Array& arr) : mLU((arr.isView()) ? arr.view() : arr), mPivots({arr.shape().rows()}) { CT_ERROR_IF(mLU.shape().axes(), !=, 2, "Array must be a 2D matrix"); CT_ERROR_IF(mLU.shape().rows(), !=, mLU.shape().cols(), "Matrix must be square"); mPPLU.make(mLU, mPivots); }; /** * Constructor for a PLUArray given an existing location in memory for both the matrix and * the pivots. */ PLUArray(const Array& arr, const Array pivots) : mLU(arr.view()), mPivots(pivots.view()) { CT_ERROR_IF(mLU.shape().axes(), !=, 2, "Array must be a 2D matrix"); CT_ERROR_IF(mLU.shape().rows(), !=, mLU.shape().cols(), "Matrix must be square"); mPPLU.make(mLU, mPivots); }; uint32_t rank() { return mLU.shape().rows(); }; /**< Gets the rank of the LU matrix. */ Array LU() const { return mLU.view(); }; /**< Gets the LU matrix. */ Array pivots() const { return mPivots.view(); }; /**< Gets the LU matrix. */ /** * Comptues the inplace LU factorization for this array on CPU. */ void computeLU() { mPPLU.compute(); mPPLU.mMapPivots = mPPLU.permutationP().indices(); }; /** * Solves the system \f$ LUx = b \f$ and returns \f$x\f$. */ Array solve(const Array& b) { Array x(b.shape()); x.eigenMap() = mPPLU.solve(b.eigenMap()); return x; }; }; /** * This is a batch version of PLUArray, to enable usage of the cuBLAS API. This is restricted to * floating point types. */ template or is_complex, bool> = true> class PLUBatch : public Batch { private: Array mPivotsBatch; Array mInfoLU; int32_t mInfoSolve; bool mInitialized = false; public: /** * Constructor of a PLUBatch from a given batch size. */ PLUBatch(const uint32_t size) : Batch(size), mInfoLU({size}){}; /** * Constructor of a PLUBatch from a multi-dimensional array, batched across upper dimensions. */ PLUBatch(const Array& arr) : Batch(arr) { Check::isSquare(arr, "LU Array"); mPivotsBatch = Array({this->mBatchSize * this->mShape.rows()}); mInfoLU = Array({this->mBatchSize}); }; /** * Indexing operator which returns the PLUArray in the PLUBatch at the given index. */ PLUArray operator[](const uint32_t index) const { CT_ERROR_IF(index, >=, this->mBatchSize, "Index exceeds batch size"); Array lu(this->mBatch[index], {this->mShape.rows(), this->mShape.cols()}); Array pivots(mPivotsBatch.data() + index * this->mShape.rows(), {this->mShape.rows()}); return PLUArray(lu, pivots); }; /** * Computes the inplace PLU decomposition of batch of arrays. */ StreamID computeLU(const StreamID& stream = DEF_CUBLAS_STREAM) { #ifdef CUDA uint32_t n = this->mShape.rows(); CUBLAS_CHECK( cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream))); invoke(cublasSgetrfBatched, cublasDgetrfBatched, cublasCgetrfBatched, cublasZgetrfBatched, Manager::get()->cublasHandle(), n, DCAST(this->mBatch.dataDevice()), n, mPivotsBatch.dataDevice(), mInfoLU.dataDevice(), this->mBatchSize); #else #pragma omp parallel for for (uint32_t i = 0; i < this->mBatchSize; ++i) { (*this)[i].computeLU(); } #endif mInitialized = true; return stream; }; /** * Solves the batched system \f$LUx = b\f$ inplace. The solution \f$x\f$ is written back into * \f$b\f$. */ StreamID solve(const Batch& b, const StreamID& stream = DEF_CUBLAS_STREAM) { CT_ERROR(not mInitialized, "Cannot solve system if PLUBatch has not yet computed its LU decomposition"); CT_ERROR_IF(b.size(), !=, this->mBatchSize, "Upper dimensions of b do not match batch size"); CT_ERROR_IF(b.shape().rows(), !=, this->mShape.rows(), "The length of each column of b must match the matrix rank"); #ifdef CUDA uint32_t n = b.shape().rows(); uint32_t nrhs = b.shape().cols(); CUBLAS_CHECK( cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream))); invoke(cublasSgetrsBatched, cublasDgetrsBatched, cublasCgetrsBatched, cublasZgetrsBatched, Manager::get()->cublasHandle(), CUBLAS_OP_N, n, nrhs, DCAST(this->mBatch.dataDevice()), n, mPivotsBatch.dataDevice(), DCAST(b.batch().dataDevice()), n, &mInfoSolve, this->mBatchSize); #else #pragma omp parallel for for (uint32_t i = 0; i < this->mBatchSize; ++i) { b[i] = (*this)[i].solve(b[i]); } #endif return stream; }; /** * Gets the pivots data from the device to the host. Does nothing for CPU. */ StreamID getPivots(const StreamID& stream = DEF_MEM_STREAM) const { mPivotsBatch.updateHost(stream); return stream; }; /** * Gets the info array for the LU decomposition for the device to the host. Does not * return useful information for CPU. */ Array getLUInfo() const { mInfoLU.updateHost().wait(); return mInfoLU; }; /** * Checks validity of the solve operation. Does not return useful information for CPU. */ int32_t validSolve() const { return mInfoSolve == 0; } }; }; // namespace BLAS }; // namespace CudaTools #endif