|
|
|
@ -2,16 +2,13 @@ |
|
|
|
|
#define CUDATOOLS_BLAS_H |
|
|
|
|
|
|
|
|
|
#include "Array.h" |
|
|
|
|
#include "Complex.h" |
|
|
|
|
#include "Core.h" |
|
|
|
|
#include "Macros.h" |
|
|
|
|
#include "Types.h" |
|
|
|
|
|
|
|
|
|
#ifdef CUDACC |
|
|
|
|
#include <cuComplex.h> |
|
|
|
|
#endif |
|
|
|
|
using namespace CudaTools::Types; |
|
|
|
|
|
|
|
|
|
namespace CudaTools { |
|
|
|
|
|
|
|
|
|
namespace BLAS { |
|
|
|
|
|
|
|
|
|
struct BatchInfo { |
|
|
|
@ -19,17 +16,20 @@ struct BatchInfo { |
|
|
|
|
uint32_t size; |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
template <typename T> struct Check { |
|
|
|
|
struct Check { |
|
|
|
|
template <typename T> |
|
|
|
|
static void isAtLeast2D(const Array<T>& arr, const std::string& name = "Array") { |
|
|
|
|
CT_ERROR_IF(arr.shape().axes(), <, 2, (name + " needs to be at least 2D").c_str()); |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
template <typename T> |
|
|
|
|
static void isSquare(const Array<T>& arr, const std::string& name = "Array") { |
|
|
|
|
isAtLeast2D(arr, name); |
|
|
|
|
CT_ERROR_IF(arr.shape().rows(), !=, arr.shape().cols(), (name + " is not square").c_str()) |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
static void isValidMatmul(const Array<T>& A, const Array<T>& B, const Array<T>& C, |
|
|
|
|
template <typename T, typename U, typename V> |
|
|
|
|
static void isValidMatmul(const Array<T>& A, const Array<U>& B, const Array<V>& C, |
|
|
|
|
const std::string& nameA = "A", const std::string& nameB = "B", |
|
|
|
|
const std::string nameC = "C") { |
|
|
|
|
isAtLeast2D(A, nameA); |
|
|
|
@ -46,7 +46,7 @@ template <typename T> struct Check { |
|
|
|
|
("The shape of " + nameA + nameB + " does not match the shape of " + nameC).c_str()); |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
static uint32_t getUpperItems(const Array<T>& arr) { |
|
|
|
|
template <typename T> static uint32_t getUpperItems(const Array<T>& arr) { |
|
|
|
|
uint32_t upperItems = 1; |
|
|
|
|
for (uint32_t iAxis = 0; iAxis < arr.shape().axes() - 2; ++iAxis) { |
|
|
|
|
upperItems *= arr.shape().dim(iAxis); |
|
|
|
@ -54,7 +54,8 @@ template <typename T> struct Check { |
|
|
|
|
return upperItems; |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
static void matchUpperShape(const Array<T>& A, const Array<T>& B, |
|
|
|
|
template <typename T, typename U> |
|
|
|
|
static void matchUpperShape(const Array<T>& A, const Array<U>& B, |
|
|
|
|
const std::string& nameA = "A", const std::string& nameB = "B") { |
|
|
|
|
CT_ERROR_IF(A.shape().axes(), !=, B.shape().axes(), |
|
|
|
|
(nameA + " and " + nameB + " shapes do not match for broadcasting").c_str()); |
|
|
|
@ -67,7 +68,8 @@ template <typename T> struct Check { |
|
|
|
|
} |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
static BatchInfo isBroadcastable(const Array<T>& A, const Array<T>& B, const Array<T>& C, |
|
|
|
|
template <typename T, typename U, typename V> |
|
|
|
|
static BatchInfo isBroadcastable(const Array<T>& A, const Array<U>& B, const Array<V>& C, |
|
|
|
|
const std::string& nameA = "A", const std::string& nameB = "B", |
|
|
|
|
const std::string nameC = "C") { |
|
|
|
|
isValidMatmul(A, B, C, nameA, nameB, nameC); |
|
|
|
@ -130,7 +132,7 @@ template <typename T> class Batch { |
|
|
|
|
Batch(const Array<T>& arr) { |
|
|
|
|
CT_ERROR(arr.isView(), "Array cannot be a view"); |
|
|
|
|
mShape = Shape({arr.shape().rows(), arr.shape().cols()}); |
|
|
|
|
mBatchSize = mCount = Check<T>::getUpperItems(arr); |
|
|
|
|
mBatchSize = mCount = Check::getUpperItems(arr); |
|
|
|
|
|
|
|
|
|
mBatch = Array<T*>({mBatchSize}); |
|
|
|
|
|
|
|
|
@ -159,7 +161,7 @@ template <typename T> class Batch { |
|
|
|
|
#endif |
|
|
|
|
if (mCount == 0) { |
|
|
|
|
mShape = arr.shape(); |
|
|
|
|
mBatchSize = mCount = Check<T>::getUpperItems(arr); |
|
|
|
|
mBatchSize = mCount = Check::getUpperItems(arr); |
|
|
|
|
} else { |
|
|
|
|
CT_ERROR_IF(arr.shape(), !=, mShape, "Cannot add matrix of different shape to batch"); |
|
|
|
|
} |
|
|
|
@ -195,15 +197,30 @@ template <typename T> struct CudaComplexConversion_S { typedef T type; }; |
|
|
|
|
#ifdef CUDACC |
|
|
|
|
template <> struct CudaComplexConversion_S<complex64> { typedef cuComplex type; }; |
|
|
|
|
template <> struct CudaComplexConversion_S<complex128> { typedef cuDoubleComplex type; }; |
|
|
|
|
#else |
|
|
|
|
|
|
|
|
|
#endif |
|
|
|
|
|
|
|
|
|
template <typename T> using CudaComplexConversion = typename CudaComplexConversion_S<T>::type; |
|
|
|
|
|
|
|
|
|
template <typename T> struct CublasTypeLetter_S { char letter; }; |
|
|
|
|
template <> struct CublasTypeLetter_S<real32> { char letter = 'S'; }; |
|
|
|
|
template <> struct CublasTypeLetter_S<real64> { char letter = 'D'; }; |
|
|
|
|
template <> struct CublasTypeLetter_S<complex64> { char letter = 'C'; }; |
|
|
|
|
template <> struct CublasTypeLetter_S<complex128> { char letter = 'Z'; }; |
|
|
|
|
#ifdef CUDACC |
|
|
|
|
template <> struct CublasTypeLetter_S<real16> { char letter = 'H'; }; |
|
|
|
|
#endif |
|
|
|
|
|
|
|
|
|
template <typename T> char CublasTypeLetter = CublasTypeLetter_S<T>::letter; |
|
|
|
|
|
|
|
|
|
// Shorthands to reduce clutter.
|
|
|
|
|
|
|
|
|
|
#define CAST(var) reinterpret_cast<CudaComplexConversion<T>*>(var) |
|
|
|
|
#define DCAST(var) reinterpret_cast<CudaComplexConversion<T>**>(var) |
|
|
|
|
|
|
|
|
|
#define cublas(T, func) cublas##CublasTypeLetter<T>##func |
|
|
|
|
|
|
|
|
|
template <typename T, typename F1, typename F2, typename F3, typename F4, typename... Args> |
|
|
|
|
constexpr void invoke(F1 f1, F2 f2, F3 f3, F4 f4, Args&&... args) { |
|
|
|
|
if constexpr (std::is_same<T, real32>::value) { |
|
|
|
@ -215,7 +232,26 @@ constexpr void invoke(F1 f1, F2 f2, F3 f3, F4 f4, Args&&... args) { |
|
|
|
|
} else if constexpr (std::is_same<T, complex128>::value) { |
|
|
|
|
CUBLAS_CHECK(f4(args...)); |
|
|
|
|
} else { |
|
|
|
|
CT_ERROR(true, "BLAS functions are not callable with that type"); |
|
|
|
|
CT_ERROR(true, "This BLAS function is not callable with that type"); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// If someone can think of a better solution, please tell me.
|
|
|
|
|
template <typename T, typename F1, typename F2, typename F3, typename F4, typename F5, |
|
|
|
|
typename... Args> |
|
|
|
|
constexpr void invoke5(F1 f1, F2 f2, F3 f3, F4 f4, F5 f5, Args&&... args) { |
|
|
|
|
if constexpr (std::is_same<T, real32>::value) { |
|
|
|
|
CUBLAS_CHECK(f1(args...)); |
|
|
|
|
} else if constexpr (std::is_same<T, real64>::value) { |
|
|
|
|
CUBLAS_CHECK(f2(args...)); |
|
|
|
|
} else if constexpr (std::is_same<T, complex64>::value) { |
|
|
|
|
CUBLAS_CHECK(f3(args...)); |
|
|
|
|
} else if constexpr (std::is_same<T, complex128>::value) { |
|
|
|
|
CUBLAS_CHECK(f4(args...)); |
|
|
|
|
} else if constexpr (std::is_same<T, real16>::value) { |
|
|
|
|
CUBLAS_CHECK(f5(args...)); |
|
|
|
|
} else { |
|
|
|
|
CT_ERROR(true, "This BLAS function is not callable with that type"); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
@ -227,7 +263,7 @@ template <typename T> |
|
|
|
|
StreamID GEMV(const T alpha, const Array<T>& A, const Array<T>& x, const T beta, const Array<T>& y, |
|
|
|
|
const StreamID& stream = DEF_CUBLAS_STREAM) { |
|
|
|
|
|
|
|
|
|
BatchInfo bi = Check<T>::isBroadcastable(A, x, y, "A", "x", "y"); |
|
|
|
|
BatchInfo bi = Check::isBroadcastable(A, x, y, "A", "x", "y"); |
|
|
|
|
CT_ERROR_IF(x.shape().cols(), !=, 1, "x must be a column vector"); |
|
|
|
|
CT_ERROR_IF(y.shape().cols(), !=, 1, "x must be a column vector"); |
|
|
|
|
|
|
|
|
@ -241,7 +277,6 @@ StreamID GEMV(const T alpha, const Array<T>& A, const Array<T>& x, const T beta, |
|
|
|
|
Manager::get()->cublasHandle(), CUBLAS_OP_N, rows, cols, CAST(&a), |
|
|
|
|
CAST(A.dataDevice()), rows, CAST(x.dataDevice()), 1, CAST(&b), |
|
|
|
|
CAST(y.dataDevice()), 1); |
|
|
|
|
|
|
|
|
|
} else { // Greater than 2, so broadcast.
|
|
|
|
|
invoke<T>(cublasSgemvStridedBatched, cublasDgemvStridedBatched, cublasCgemvStridedBatched, |
|
|
|
|
cublasZgemvStridedBatched, Manager::get()->cublasHandle(), CUBLAS_OP_N, rows, |
|
|
|
@ -269,11 +304,11 @@ StreamID GEMV(const T alpha, const Array<T>& A, const Array<T>& x, const T beta, |
|
|
|
|
* Computes the matrix-matrix product: \f$ C = \alpha AB + \beta C \f$. It will automatically |
|
|
|
|
* broadcast the operation if applicable. |
|
|
|
|
*/ |
|
|
|
|
template <typename T> |
|
|
|
|
StreamID GEMM(const T alpha, const Array<T>& A, const Array<T>& B, const T beta, const Array<T>& C, |
|
|
|
|
template <typename T, typename U, typename V> |
|
|
|
|
StreamID GEMM(const T alpha, const Array<U>& A, const Array<U>& B, const T beta, const Array<V>& C, |
|
|
|
|
const StreamID& stream = DEF_CUBLAS_STREAM) { |
|
|
|
|
|
|
|
|
|
BatchInfo bi = Check<T>::isBroadcastable(A, B, C, "A", "B", "C"); |
|
|
|
|
BatchInfo bi = Check::isBroadcastable(A, B, C, "A", "B", "C"); |
|
|
|
|
// A is m x k, B is k x n.
|
|
|
|
|
uint32_t m = A.shape().rows(); |
|
|
|
|
uint32_t k = A.shape().cols(); |
|
|
|
@ -282,18 +317,19 @@ StreamID GEMM(const T alpha, const Array<T>& A, const Array<T>& B, const T beta, |
|
|
|
|
T a = alpha, b = beta; |
|
|
|
|
#ifdef CUDA |
|
|
|
|
CUBLAS_CHECK(cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream))); |
|
|
|
|
|
|
|
|
|
if (bi.size == 1) { |
|
|
|
|
invoke<T>(cublasSgemm, cublasDgemm, cublasCgemm, cublasZgemm, |
|
|
|
|
invoke5<T>(cublasSgemm, cublasDgemm, cublasCgemm, cublasZgemm, cublasHgemm, |
|
|
|
|
Manager::get()->cublasHandle(), CUBLAS_OP_N, CUBLAS_OP_N, m, n, k, CAST(&a), |
|
|
|
|
CAST(A.dataDevice()), m, CAST(B.dataDevice()), k, CAST(&b), CAST(C.dataDevice()), |
|
|
|
|
m); |
|
|
|
|
|
|
|
|
|
} else { // Greater than 2, so broadcast.
|
|
|
|
|
invoke<T>(cublasSgemmStridedBatched, cublasDgemmStridedBatched, cublasCgemmStridedBatched, |
|
|
|
|
cublasZgemmStridedBatched, Manager::get()->cublasHandle(), CUBLAS_OP_N, |
|
|
|
|
CUBLAS_OP_N, m, n, k, CAST(&a), CAST(A.dataDevice()), m, bi.strideA, |
|
|
|
|
CAST(B.dataDevice()), k, bi.strideB, CAST(&b), CAST(C.dataDevice()), m, |
|
|
|
|
bi.strideC, bi.size); |
|
|
|
|
invoke5<T>(cublasSgemmStridedBatched, cublasDgemmStridedBatched, cublasCgemmStridedBatched, |
|
|
|
|
cublasZgemmStridedBatched, cublasHgemmStridedBatched, |
|
|
|
|
Manager::get()->cublasHandle(), CUBLAS_OP_N, CUBLAS_OP_N, m, n, k, CAST(&a), |
|
|
|
|
CAST(A.dataDevice()), m, bi.strideA, CAST(B.dataDevice()), k, bi.strideB, |
|
|
|
|
CAST(&b), CAST(C.dataDevice()), m, bi.strideC, bi.size); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
#else |
|
|
|
@ -487,7 +523,7 @@ class PLUBatch : public Batch<T> { |
|
|
|
|
* Constructor of a PLUBatch from a multi-dimensional array, batched across upper dimensions. |
|
|
|
|
*/ |
|
|
|
|
PLUBatch(const Array<T>& arr) : Batch<T>(arr) { |
|
|
|
|
Check<T>::isSquare(arr, "LU Array"); |
|
|
|
|
Check::isSquare(arr, "LU Array"); |
|
|
|
|
|
|
|
|
|
mPivotsBatch = Array<int32_t>({this->mBatchSize * this->mShape.rows()}); |
|
|
|
|
mInfoLU = Array<int32_t>({this->mBatchSize}); |
|
|
|
|