|
|
|
@ -270,7 +270,7 @@ StreamID GEMV(const T alpha, const Array<T>& A, const Array<T>& x, const T beta, |
|
|
|
|
uint32_t rows = A.shape().rows(); |
|
|
|
|
uint32_t cols = A.shape().cols(); |
|
|
|
|
T a = alpha, b = beta; |
|
|
|
|
#ifdef CUDA |
|
|
|
|
#ifdef CUDACC |
|
|
|
|
CUBLAS_CHECK(cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream))); |
|
|
|
|
if (bi.size == 1) { |
|
|
|
|
invoke<T>(cublasSgemv, cublasDgemv, cublasCgemv, cublasZgemv, |
|
|
|
@ -315,7 +315,7 @@ StreamID GEMM(const T alpha, const Array<T>& A, const Array<T>& B, const T beta, |
|
|
|
|
uint32_t n = B.shape().cols(); |
|
|
|
|
|
|
|
|
|
T a = alpha, b = beta; |
|
|
|
|
#ifdef CUDA |
|
|
|
|
#ifdef CUDACC |
|
|
|
|
CUBLAS_CHECK(cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream))); |
|
|
|
|
|
|
|
|
|
if (bi.size == 1) { |
|
|
|
@ -368,7 +368,7 @@ StreamID DGMM(const Array<T>& A, const Array<T>& X, const Array<T>& C, const boo |
|
|
|
|
CT_ERROR_IF(A.shape().cols(), !=, C.shape().cols(), |
|
|
|
|
"Rows of 'A' and columns of 'C' need to match."); |
|
|
|
|
|
|
|
|
|
#ifdef CUDA |
|
|
|
|
#ifdef CUDACC |
|
|
|
|
uint32_t m = C.shape().rows(); |
|
|
|
|
uint32_t n = C.shape().cols(); |
|
|
|
|
auto mode = (left) ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT; |
|
|
|
@ -544,7 +544,7 @@ class PLUBatch : public Batch<T> { |
|
|
|
|
* Computes the inplace PLU decomposition of batch of arrays. |
|
|
|
|
*/ |
|
|
|
|
StreamID computeLU(const StreamID& stream = DEF_CUBLAS_STREAM) { |
|
|
|
|
#ifdef CUDA |
|
|
|
|
#ifdef CUDACC |
|
|
|
|
uint32_t n = this->mShape.rows(); |
|
|
|
|
CUBLAS_CHECK( |
|
|
|
|
cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream))); |
|
|
|
@ -575,7 +575,7 @@ class PLUBatch : public Batch<T> { |
|
|
|
|
CT_ERROR_IF(b.shape().rows(), !=, this->mShape.rows(), |
|
|
|
|
"The length of each column of b must match the matrix rank"); |
|
|
|
|
|
|
|
|
|
#ifdef CUDA |
|
|
|
|
#ifdef CUDACC |
|
|
|
|
uint32_t n = b.shape().rows(); |
|
|
|
|
uint32_t nrhs = b.shape().cols(); |
|
|
|
|
CUBLAS_CHECK( |
|
|
|
|