Fixed compilation macro name issue

main
Kenneth Jao 2 years ago
parent f0b7f113e3
commit 917c95d70c
  1. 10
      BLAS.h
  2. 2
      Core.h

@ -270,7 +270,7 @@ StreamID GEMV(const T alpha, const Array<T>& A, const Array<T>& x, const T beta,
uint32_t rows = A.shape().rows();
uint32_t cols = A.shape().cols();
T a = alpha, b = beta;
#ifdef CUDA
#ifdef CUDACC
CUBLAS_CHECK(cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream)));
if (bi.size == 1) {
invoke<T>(cublasSgemv, cublasDgemv, cublasCgemv, cublasZgemv,
@ -315,7 +315,7 @@ StreamID GEMM(const T alpha, const Array<T>& A, const Array<T>& B, const T beta,
uint32_t n = B.shape().cols();
T a = alpha, b = beta;
#ifdef CUDA
#ifdef CUDACC
CUBLAS_CHECK(cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream)));
if (bi.size == 1) {
@ -368,7 +368,7 @@ StreamID DGMM(const Array<T>& A, const Array<T>& X, const Array<T>& C, const boo
CT_ERROR_IF(A.shape().cols(), !=, C.shape().cols(),
"Rows of 'A' and columns of 'C' need to match.");
#ifdef CUDA
#ifdef CUDACC
uint32_t m = C.shape().rows();
uint32_t n = C.shape().cols();
auto mode = (left) ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT;
@ -544,7 +544,7 @@ class PLUBatch : public Batch<T> {
* Computes the inplace PLU decomposition of batch of arrays.
*/
StreamID computeLU(const StreamID& stream = DEF_CUBLAS_STREAM) {
#ifdef CUDA
#ifdef CUDACC
uint32_t n = this->mShape.rows();
CUBLAS_CHECK(
cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream)));
@ -575,7 +575,7 @@ class PLUBatch : public Batch<T> {
CT_ERROR_IF(b.shape().rows(), !=, this->mShape.rows(),
"The length of each column of b must match the matrix rank");
#ifdef CUDA
#ifdef CUDACC
uint32_t n = b.shape().rows();
uint32_t nrhs = b.shape().cols();
CUBLAS_CHECK(

@ -143,7 +143,7 @@ Settings basic(const size_t threads, const StreamID& stream = DEF_KERNEL_STREAM)
template <typename F, typename... Args>
StreamID launch(F func, const Kernel::Settings& sett, Args... args) {
#ifdef CUDA
#ifdef CUDACC
func<<<sett.blockGrid, sett.threadBlock, sett.sharedMemoryBytes,
Manager::get()->stream(sett.stream.mId)>>>(args...);
#else

Loading…
Cancel
Save