|
|
@ -138,7 +138,7 @@ template <typename T> class Batch { |
|
|
|
|
|
|
|
|
|
|
|
Array<T> batch = arr.reshaped({mBatchSize, mShape.rows(), mShape.cols()}); |
|
|
|
Array<T> batch = arr.reshaped({mBatchSize, mShape.rows(), mShape.cols()}); |
|
|
|
for (uint32_t i = 0; i < mBatchSize; ++i) { |
|
|
|
for (uint32_t i = 0; i < mBatchSize; ++i) { |
|
|
|
#ifdef CUDA |
|
|
|
#ifdef CUDACC |
|
|
|
mBatch[i] = batch[i].dataDevice(); |
|
|
|
mBatch[i] = batch[i].dataDevice(); |
|
|
|
#else |
|
|
|
#else |
|
|
|
mBatch[i] = batch[i].data(); |
|
|
|
mBatch[i] = batch[i].data(); |
|
|
@ -154,7 +154,7 @@ template <typename T> class Batch { |
|
|
|
void add(const Array<T>& arr) { |
|
|
|
void add(const Array<T>& arr) { |
|
|
|
CT_ERROR(not arr.isView(), "Cannot add non-view Arrays"); |
|
|
|
CT_ERROR(not arr.isView(), "Cannot add non-view Arrays"); |
|
|
|
CT_ERROR_IF(mCount, ==, mBatchSize, "Batch is full, cannot add more arrays"); |
|
|
|
CT_ERROR_IF(mCount, ==, mBatchSize, "Batch is full, cannot add more arrays"); |
|
|
|
#ifdef CUDA |
|
|
|
#ifdef CUDACC |
|
|
|
mBatch[mCount] = arr.dataDevice(); |
|
|
|
mBatch[mCount] = arr.dataDevice(); |
|
|
|
#else |
|
|
|
#else |
|
|
|
mBatch[mCount] = arr.data(); |
|
|
|
mBatch[mCount] = arr.data(); |
|
|
|