diff --git a/BLAS.h b/BLAS.h index 2eb17cc..e315b2a 100644 --- a/BLAS.h +++ b/BLAS.h @@ -138,7 +138,7 @@ template class Batch { Array batch = arr.reshaped({mBatchSize, mShape.rows(), mShape.cols()}); for (uint32_t i = 0; i < mBatchSize; ++i) { -#ifdef CUDA +#ifdef CUDACC mBatch[i] = batch[i].dataDevice(); #else mBatch[i] = batch[i].data(); @@ -154,7 +154,7 @@ template class Batch { void add(const Array& arr) { CT_ERROR(not arr.isView(), "Cannot add non-view Arrays"); CT_ERROR_IF(mCount, ==, mBatchSize, "Batch is full, cannot add more arrays"); -#ifdef CUDA +#ifdef CUDACC mBatch[mCount] = arr.dataDevice(); #else mBatch[mCount] = arr.data();