diff --git a/BLAS.h b/BLAS.h index 1164ce9..2eb17cc 100644 --- a/BLAS.h +++ b/BLAS.h @@ -270,7 +270,7 @@ StreamID GEMV(const T alpha, const Array& A, const Array& x, const T beta, uint32_t rows = A.shape().rows(); uint32_t cols = A.shape().cols(); T a = alpha, b = beta; -#ifdef CUDA +#ifdef CUDACC CUBLAS_CHECK(cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream))); if (bi.size == 1) { invoke(cublasSgemv, cublasDgemv, cublasCgemv, cublasZgemv, @@ -315,7 +315,7 @@ StreamID GEMM(const T alpha, const Array& A, const Array& B, const T beta, uint32_t n = B.shape().cols(); T a = alpha, b = beta; -#ifdef CUDA +#ifdef CUDACC CUBLAS_CHECK(cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream))); if (bi.size == 1) { @@ -368,7 +368,7 @@ StreamID DGMM(const Array& A, const Array& X, const Array& C, const boo CT_ERROR_IF(A.shape().cols(), !=, C.shape().cols(), "Rows of 'A' and columns of 'C' need to match."); -#ifdef CUDA +#ifdef CUDACC uint32_t m = C.shape().rows(); uint32_t n = C.shape().cols(); auto mode = (left) ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT; @@ -544,7 +544,7 @@ class PLUBatch : public Batch { * Computes the inplace PLU decomposition of batch of arrays. */ StreamID computeLU(const StreamID& stream = DEF_CUBLAS_STREAM) { -#ifdef CUDA +#ifdef CUDACC uint32_t n = this->mShape.rows(); CUBLAS_CHECK( cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream))); @@ -575,7 +575,7 @@ class PLUBatch : public Batch { CT_ERROR_IF(b.shape().rows(), !=, this->mShape.rows(), "The length of each column of b must match the matrix rank"); -#ifdef CUDA +#ifdef CUDACC uint32_t n = b.shape().rows(); uint32_t nrhs = b.shape().cols(); CUBLAS_CHECK( diff --git a/Core.h b/Core.h index fd28d3d..17148c7 100644 --- a/Core.h +++ b/Core.h @@ -143,7 +143,7 @@ Settings basic(const size_t threads, const StreamID& stream = DEF_KERNEL_STREAM) template StreamID launch(F func, const Kernel::Settings& sett, Args... args) { -#ifdef CUDA +#ifdef CUDACC func<<stream(sett.stream.mId)>>>(args...); #else