A library and framework for developing CPU-CUDA compatible applications under one unified code.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

627 lines
23 KiB

#ifndef CUDATOOLS_BLAS_H
#define CUDATOOLS_BLAS_H
#ifndef CUDATOOLS_USE_EIGEN
#error "Cannot use CudaTools BLAS.h header without Eigen."
#endif
#include "Array.h"
#include "Core.h"
#include "Macros.h"
#include "Types.h"
using namespace CudaTools::Types;
namespace CudaTools {
namespace BLAS {
struct BatchInfo {
uint32_t strideA, strideB, strideC;
uint32_t size;
};
struct Check {
template <typename T>
static void isAtLeast2D(const Array<T>& arr, const std::string& name = "Array") {
CT_ERROR_IF(arr.shape().axes(), <, 2, (name + " needs to be at least 2D").c_str());
};
template <typename T>
static void isSquare(const Array<T>& arr, const std::string& name = "Array") {
isAtLeast2D(arr, name);
CT_ERROR_IF(arr.shape().rows(), !=, arr.shape().cols(), (name + " is not square").c_str())
};
template <typename T, typename U, typename V>
static void isValidMatmul(const Array<T>& A, const Array<U>& B, const Array<V>& C,
const std::string& nameA = "A", const std::string& nameB = "B",
const std::string nameC = "C") {
isAtLeast2D(A, nameA);
isAtLeast2D(B, nameB);
isAtLeast2D(C, nameB);
CT_ERROR_IF(A.shape().cols(), !=, B.shape().rows(),
(nameA + nameB + " is not a valid matrix multiplication").c_str());
Shape ABshape({A.shape().rows(), B.shape().cols()});
Shape Cshape({C.shape().rows(), C.shape().cols()});
CT_ERROR_IF(
ABshape, !=, Cshape,
("The shape of " + nameA + nameB + " does not match the shape of " + nameC).c_str());
};
template <typename T> static uint32_t getUpperItems(const Array<T>& arr) {
uint32_t upperItems = 1;
for (uint32_t iAxis = 0; iAxis < arr.shape().axes() - 2; ++iAxis) {
upperItems *= arr.shape().dim(iAxis);
}
return upperItems;
};
template <typename T, typename U>
static void matchUpperShape(const Array<T>& A, const Array<U>& B,
const std::string& nameA = "A", const std::string& nameB = "B") {
CT_ERROR_IF(A.shape().axes(), !=, B.shape().axes(),
(nameA + " and " + nameB + " shapes do not match for broadcasting").c_str());
for (uint32_t iAxis = 0; iAxis < A.shape().axes() - 2; ++iAxis) {
uint32_t Adim = A.shape().dim(iAxis);
uint32_t Bdim = B.shape().dim(iAxis);
CT_ERROR_IF(
Adim, !=, Bdim,
(nameA + " and " + nameB + " shapes do not match for broadcasting").c_str());
}
};
template <typename T, typename U, typename V>
static BatchInfo isBroadcastable(const Array<T>& A, const Array<U>& B, const Array<V>& C,
const std::string& nameA = "A", const std::string& nameB = "B",
const std::string nameC = "C") {
isValidMatmul(A, B, C, nameA, nameB, nameC);
uint32_t itemsA = getUpperItems(A);
uint32_t itemsB = getUpperItems(B);
uint32_t itemsC = getUpperItems(C);
uint32_t Asize = A.shape().rows() * A.shape().cols();
uint32_t Bsize = B.shape().rows() * B.shape().cols();
uint32_t Csize = C.shape().rows() * C.shape().cols();
if (itemsA == itemsB) {
CT_ERROR_IF(itemsA, !=, itemsC,
("Incorrect dimensions to broadcast to output " + nameC).c_str());
matchUpperShape(A, B, nameA, nameB);
matchUpperShape(A, C, nameA, nameC);
return BatchInfo{Asize, Bsize, Csize, itemsC};
} else if (itemsA > itemsB) {
CT_ERROR_IF(
itemsB, !=, 1,
("Cannot broadcast operation to " + nameB + " with non-matching " + nameA).c_str());
CT_ERROR_IF(itemsA, !=, itemsC,
("Incorrect dimensions to broadcast to output " + nameC).c_str());
matchUpperShape(A, C, nameA, nameC);
return BatchInfo{Asize, 0, Csize, itemsC};
} else {
CT_ERROR_IF(
itemsA, !=, 1,
("Cannot broadcast operation to " + nameA + " with non-matching " + nameB).c_str());
CT_ERROR_IF(itemsA, !=, itemsC,
("Incorrect dimensions to broadcast to output " + nameC).c_str());
matchUpperShape(B, C, nameB, nameC);
return BatchInfo{0, Bsize, Csize, itemsC};
}
};
};
/**
* Represents a Batch of Arrays with the same shape. Mainly used for cuBLAS functions.
*/
template <typename T> class Batch {
protected:
Array<T*> mBatch;
Shape mShape;
uint32_t mCount = 0;
uint32_t mBatchSize;
public:
Batch() = delete;
/**
* Constructs a batch from a given size.
*/
Batch(const uint32_t size) : mBatchSize(size){};
/**
* Constructs a batch from a non-view Array.
*/
Batch(const Array<T>& arr) {
CT_ERROR(arr.isView(), "Array cannot be a view");
mShape = Shape({arr.shape().rows(), arr.shape().cols()});
mBatchSize = mCount = Check::getUpperItems(arr);
mBatch = Array<T*>({mBatchSize});
Array<T> batch = arr.reshaped({mBatchSize, mShape.rows(), mShape.cols()});
for (uint32_t i = 0; i < mBatchSize; ++i) {
#ifdef CUDACC
mBatch[i] = batch[i].dataDevice();
#else
mBatch[i] = batch[i].data();
#endif
}
mBatch.updateDevice().wait();
};
/**
* Adds a matrix to the batch. Array must be a view.
*/
void add(const Array<T>& arr) {
CT_ERROR(not arr.isView(), "Cannot add non-view Arrays");
CT_ERROR_IF(mCount, ==, mBatchSize, "Batch is full, cannot add more arrays");
#ifdef CUDACC
mBatch[mCount] = arr.dataDevice();
#else
mBatch[mCount] = arr.data();
#endif
if (mCount == 0) {
mShape = arr.shape();
mBatchSize = mCount = Check::getUpperItems(arr);
} else {
CT_ERROR_IF(arr.shape(), !=, mShape, "Cannot add matrix of different shape to batch");
}
++mCount;
if (mCount == mBatchSize) {
mBatch.updateDevice().wait();
}
};
/**
* Indexing operator which returns a view of the Array in the Batch at the given index.
*/
Array<T> operator[](const uint32_t index) const {
CT_ERROR_IF(index, >=, mBatchSize, "Index exceeds batch size");
return Array<T>(mBatch[index], {mShape.rows(), mShape.cols()});
};
/**
* Returns the batch Array of pointers.
*/
Array<T*> batch() const { return mBatch.view(); };
Shape shape() const { return mShape; } /**< Gets the shape of the matrices in the batch. */
uint32_t size() const { return mBatchSize; } /**< Gets the batch size.*/
bool full() const { return mBatchSize == mCount; }; /**< Gets if the batch is full. */
};
////////////////
// cuBLAS API //
////////////////
template <typename T> struct CudaComplexConversion_S { typedef T type; };
#ifdef CUDACC
template <> struct CudaComplexConversion_S<complex64> { typedef cuComplex type; };
template <> struct CudaComplexConversion_S<complex128> { typedef cuDoubleComplex type; };
#else
#endif
template <typename T> using CudaComplexConversion = typename CudaComplexConversion_S<T>::type;
template <typename T> struct CublasTypeLetter_S { char letter; };
template <> struct CublasTypeLetter_S<real32> { char letter = 'S'; };
template <> struct CublasTypeLetter_S<real64> { char letter = 'D'; };
template <> struct CublasTypeLetter_S<complex64> { char letter = 'C'; };
template <> struct CublasTypeLetter_S<complex128> { char letter = 'Z'; };
#ifdef CUDACC
template <> struct CublasTypeLetter_S<real16> { char letter = 'H'; };
#endif
template <typename T> char CublasTypeLetter = CublasTypeLetter_S<T>::letter;
// Shorthands to reduce clutter.
#define CAST(var) reinterpret_cast<CudaComplexConversion<T>*>(var)
#define DCAST(var) reinterpret_cast<CudaComplexConversion<T>**>(var)
#define cublas(T, func) cublas##CublasTypeLetter<T>##func
template <typename T, typename F1, typename F2, typename F3, typename F4, typename... Args>
constexpr void invoke(F1 f1, F2 f2, F3 f3, F4 f4, Args&&... args) {
if constexpr (std::is_same<T, real32>::value) {
CUBLAS_CHECK(f1(args...));
} else if constexpr (std::is_same<T, real64>::value) {
CUBLAS_CHECK(f2(args...));
} else if constexpr (std::is_same<T, complex64>::value) {
CUBLAS_CHECK(f3(args...));
} else if constexpr (std::is_same<T, complex128>::value) {
CUBLAS_CHECK(f4(args...));
} else {
CT_ERROR(true, "This BLAS function is not callable with that type");
}
}
// If someone can think of a better solution, please tell me.
template <typename T, typename F1, typename F2, typename F3, typename F4, typename F5,
typename... Args>
constexpr void invoke5(F1 f1, F2 f2, F3 f3, F4 f4, F5 f5, Args&&... args) {
if constexpr (std::is_same<T, real32>::value) {
CUBLAS_CHECK(f1(args...));
} else if constexpr (std::is_same<T, real64>::value) {
CUBLAS_CHECK(f2(args...));
} else if constexpr (std::is_same<T, complex64>::value) {
CUBLAS_CHECK(f3(args...));
} else if constexpr (std::is_same<T, complex128>::value) {
CUBLAS_CHECK(f4(args...));
} else if constexpr (std::is_same<T, real16>::value) {
CUBLAS_CHECK(f5(args...));
} else {
CT_ERROR(true, "This BLAS function is not callable with that type");
}
}
/**
* Computes the matrix-vector product: \f$ y = \alpha Ax + \beta y \f$. It will automatically
* broadcast the operation if applicable.
*/
template <typename T>
StreamID GEMV(const T alpha, const Array<T>& A, const Array<T>& x, const T beta, const Array<T>& y,
const StreamID& stream = DEF_CUBLAS_STREAM) {
BatchInfo bi = Check::isBroadcastable(A, x, y, "A", "x", "y");
CT_ERROR_IF(x.shape().cols(), !=, 1, "x must be a column vector");
CT_ERROR_IF(y.shape().cols(), !=, 1, "x must be a column vector");
uint32_t rows = A.shape().rows();
uint32_t cols = A.shape().cols();
T a = alpha, b = beta;
#ifdef CUDACC
CUBLAS_CHECK(cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream)));
if (bi.size == 1) {
invoke<T>(cublasSgemv, cublasDgemv, cublasCgemv, cublasZgemv,
Manager::get()->cublasHandle(), CUBLAS_OP_N, rows, cols, CAST(&a),
CAST(A.dataDevice()), rows, CAST(x.dataDevice()), 1, CAST(&b),
CAST(y.dataDevice()), 1);
} else { // Greater than 2, so broadcast.
invoke<T>(cublasSgemvStridedBatched, cublasDgemvStridedBatched, cublasCgemvStridedBatched,
cublasZgemvStridedBatched, Manager::get()->cublasHandle(), CUBLAS_OP_N, rows,
cols, CAST(&a), CAST(A.dataDevice()), rows, bi.strideA, CAST(x.dataDevice()), 1,
bi.strideB, CAST(&b), CAST(y.dataDevice()), 1, bi.strideC, bi.size);
}
#else
if (bi.size == 1) {
y.eigenMap() = a * (A.eigenMap() * x.eigenMap()) + b * y.eigenMap();
} else { // Greater than 2, so broadcast.
#pragma omp parallel for
for (uint32_t i = 0; i < bi.size; ++i) {
auto Ai = Array<T>(A, {rows, cols}, i * bi.strideA).eigenMap();
auto xi = Array<T>(x, {cols, 1}, i * bi.strideB).eigenMap();
auto yi = Array<T>(y, {rows, 1}, i * bi.strideC).eigenMap();
yi = a * (Ai * xi) + b * yi;
}
}
#endif
return StreamID{stream};
}
/**
* Computes the matrix-matrix product: \f$ C = \alpha AB + \beta C \f$. It will automatically
* broadcast the operation if applicable.
*/
template <typename T>
StreamID GEMM(const T alpha, const Array<T>& A, const Array<T>& B, const T beta, const Array<T>& C,
const StreamID& stream = DEF_CUBLAS_STREAM) {
BatchInfo bi = Check::isBroadcastable(A, B, C, "A", "B", "C");
// A is m x k, B is k x n.
uint32_t m = A.shape().rows();
uint32_t k = A.shape().cols();
uint32_t n = B.shape().cols();
T a = alpha, b = beta;
#ifdef CUDACC
CUBLAS_CHECK(cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream)));
if (bi.size == 1) {
invoke5<T>(cublasSgemm, cublasDgemm, cublasCgemm, cublasZgemm, cublasHgemm,
Manager::get()->cublasHandle(), CUBLAS_OP_N, CUBLAS_OP_N, m, n, k, CAST(&a),
CAST(A.dataDevice()), m, CAST(B.dataDevice()), k, CAST(&b), CAST(C.dataDevice()),
m);
} else { // Greater than 2, so broadcast.
invoke5<T>(cublasSgemmStridedBatched, cublasDgemmStridedBatched, cublasCgemmStridedBatched,
cublasZgemmStridedBatched, cublasHgemmStridedBatched,
Manager::get()->cublasHandle(), CUBLAS_OP_N, CUBLAS_OP_N, m, n, k, CAST(&a),
CAST(A.dataDevice()), m, bi.strideA, CAST(B.dataDevice()), k, bi.strideB,
CAST(&b), CAST(C.dataDevice()), m, bi.strideC, bi.size);
}
#else
if (bi.size == 1) {
C.eigenMap() = a * (A.eigenMap() * B.eigenMap()) + b * C.eigenMap();
} else { // Greater than 2, so broadcast.
#pragma omp parallel for
for (uint32_t i = 0; i < bi.size; ++i) {
auto Ai = Array<T>(A, {m, k}, i * bi.strideA).eigenMap();
auto Bi = Array<T>(B, {k, n}, i * bi.strideB).eigenMap();
auto Ci = Array<T>(C, {m, n}, i * bi.strideC).eigenMap();
Ci = a * (Ai * Bi) + b * Ci;
}
}
#endif
return StreamID{stream};
}
/**
* Computes the diagonal matrix multiplication: \f$ C = A\mathrm{diag}(X) \f$, or \f$ C =
* \mathrm{diag}(X)A \f$ if left = true.
*/
template <typename T>
StreamID DGMM(const Array<T>& A, const Array<T>& X, const Array<T>& C, const bool left = false,
const StreamID& stream = DEF_CUBLAS_STREAM) {
CT_ERROR_IF(X.shape().cols(), !=, 1, "'x' must be a column vector.");
if (left) {
CT_ERROR_IF(A.shape().rows(), !=, X.shape().rows(),
"Rows of 'A' and length of 'x' need to match.");
} else {
CT_ERROR_IF(A.shape().cols(), !=, X.shape().rows(),
"Columns of 'A' and length of 'x' need to match.");
}
CT_ERROR_IF(A.shape().rows(), !=, C.shape().rows(),
"Rows of 'A' and rows() of 'C' need to match.");
CT_ERROR_IF(A.shape().cols(), !=, C.shape().cols(),
"Rows of 'A' and columns of 'C' need to match.");
#ifdef CUDACC
uint32_t m = C.shape().rows();
uint32_t n = C.shape().cols();
auto mode = (left) ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT;
CUBLAS_CHECK(cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream)));
invoke<T>(cublasSdgmm, cublasDdgmm, cublasCdgmm, cublasZdgmm, Manager::get()->cublasHandle(), m,
n, CAST(A.dataDevice()), A.shape().rows(), CAST(X.dataDevice()), 1,
CAST(C.dataDevice()), m);
#else
if (left) {
C.eigenMap() = X.eigenMap().asDiagonal() * A.eigenMap();
} else {
C.eigenMap() = A.eigenMap() * X.eigenMap().asDiagonal();
}
#endif
return StreamID{stream};
}
//////////////////////////////
// PLUArray Related Objects //
//////////////////////////////
///////////////////////////
// PartialPivLU Wrapper //
///////////////////////////
// This class is just a workaround to use Eigen's internals directly.
template <typename T> class PartialPivLU;
namespace internal {
template <typename T> static Array<T> empty({1, 1});
template <typename T> static EigenMapMat<T> empty_map = empty<T>.eigenMap();
}; // namespace internal
template <typename T, std::enable_if_t<is_float<T> or is_complex<T>, bool> = true> class PLUArray;
// This is a wrapper class for Eigen's class so we have more controlled access to
// the underlying data.
template <typename T> class PartialPivLU : public Eigen::PartialPivLU<Eigen::Ref<EigenMat<T>>> {
private:
using Base = Eigen::PartialPivLU<Eigen::Ref<EigenMat<T>>>;
template <typename U, std::enable_if_t<is_float<U> or is_complex<U>, bool>>
friend class PLUArray;
EigenMapMat<T> mMapLU;
EigenMapMat<int32_t> mMapPivots;
public:
PartialPivLU()
: Base(internal::empty_map<T>), mMapLU(internal::empty_map<T>),
mMapPivots(internal::empty_map<int32_t>){};
void make(const Array<T>& lu, const Array<int32_t>& pivots) {
new (&mMapLU) EigenMapMat<T>(lu.eigenMap());
new (&mMapPivots) EigenMapMat<int32_t>(pivots.atLeast2D().eigenMap());
new (&this->m_lu) decltype(Base::m_lu)(mMapLU.derived());
new (&this->m_p) decltype(Base::m_p)(mMapPivots.derived());
// new (&this->m_rowsTranspositions) decltype(Base::m_rowsTranspositions)(
// mMapPivots.derived());
this->m_l1_norm = 0;
this->m_det_p = 0;
this->m_isInitialized = true;
};
};
namespace internal {
// We only create one and copy-construct to avoid the re-initialization.
template <typename T> static PartialPivLU<T> BlankPPLU = PartialPivLU<T>();
}; // namespace internal
/**
* Class for storing the PLU decomposition an Array. This is restricted to floating point types.
*/
template <typename T, std::enable_if_t<is_float<T> or is_complex<T>, bool>> class PLUArray {
private:
Array<T> mLU;
Array<int32_t> mPivots;
PartialPivLU<T> mPPLU = internal::BlankPPLU<T>;
public:
PLUArray() = delete;
/**
* Constructor for a PLUArray given the matrix dimension.
*/
PLUArray(const uint32_t n) : mLU({n, n}), mPivots({n}) { mPPLU.make(mLU, mPivots); };
/**
* Constructor for a PLUArray given an existing array.
*/
PLUArray(const Array<T>& arr)
: mLU((arr.isView()) ? arr.view() : arr), mPivots({arr.shape().rows()}) {
CT_ERROR_IF(mLU.shape().axes(), !=, 2, "Array must be a 2D matrix");
CT_ERROR_IF(mLU.shape().rows(), !=, mLU.shape().cols(), "Matrix must be square");
mPPLU.make(mLU, mPivots);
};
/**
* Constructor for a PLUArray given an existing location in memory for both the matrix and
* the pivots.
*/
PLUArray(const Array<T>& arr, const Array<int32_t> pivots)
: mLU(arr.view()), mPivots(pivots.view()) {
CT_ERROR_IF(mLU.shape().axes(), !=, 2, "Array must be a 2D matrix");
CT_ERROR_IF(mLU.shape().rows(), !=, mLU.shape().cols(), "Matrix must be square");
mPPLU.make(mLU, mPivots);
};
uint32_t rank() { return mLU.shape().rows(); }; /**< Gets the rank of the LU matrix. */
Array<T> LU() const { return mLU.view(); }; /**< Gets the LU matrix. */
Array<int32_t> pivots() const { return mPivots.view(); }; /**< Gets the LU matrix. */
/**
* Comptues the inplace LU factorization for this array on CPU.
*/
void computeLU() {
mPPLU.compute();
mPPLU.mMapPivots = mPPLU.permutationP().indices();
};
/**
* Solves the system \f$ LUx = b \f$ and returns \f$x\f$.
*/
Array<T> solve(const Array<T>& b) {
Array<T> x(b.shape());
x.eigenMap() = mPPLU.solve(b.eigenMap());
return x;
};
};
/**
* This is a batch version of PLUArray, to enable usage of the cuBLAS API. This is restricted to
* floating point types.
*/
template <typename T, std::enable_if_t<is_float<T> or is_complex<T>, bool> = true>
class PLUBatch : public Batch<T> {
private:
Array<int32_t> mPivotsBatch;
Array<int32_t> mInfoLU;
int32_t mInfoSolve;
bool mInitialized = false;
public:
/**
* Constructor of a PLUBatch from a given batch size.
*/
PLUBatch(const uint32_t size) : Batch<T>(size), mInfoLU({size}){};
/**
* Constructor of a PLUBatch from a multi-dimensional array, batched across upper dimensions.
*/
PLUBatch(const Array<T>& arr) : Batch<T>(arr) {
Check::isSquare(arr, "LU Array");
mPivotsBatch = Array<int32_t>({this->mBatchSize * this->mShape.rows()});
mInfoLU = Array<int32_t>({this->mBatchSize});
};
/**
* Indexing operator which returns the PLUArray in the PLUBatch at the given index.
*/
PLUArray<T> operator[](const uint32_t index) const {
CT_ERROR_IF(index, >=, this->mBatchSize, "Index exceeds batch size");
Array<T> lu(this->mBatch[index], {this->mShape.rows(), this->mShape.cols()});
Array<int32_t> pivots(mPivotsBatch.data() + index * this->mShape.rows(),
{this->mShape.rows()});
return PLUArray<T>(lu, pivots);
};
/**
* Computes the inplace PLU decomposition of batch of arrays.
*/
StreamID computeLU(const StreamID& stream = DEF_CUBLAS_STREAM) {
#ifdef CUDACC
uint32_t n = this->mShape.rows();
CUBLAS_CHECK(
cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream)));
invoke<T>(cublasSgetrfBatched, cublasDgetrfBatched, cublasCgetrfBatched,
cublasZgetrfBatched, Manager::get()->cublasHandle(), n,
DCAST(this->mBatch.dataDevice()), n, mPivotsBatch.dataDevice(),
mInfoLU.dataDevice(), this->mBatchSize);
#else
#pragma omp parallel for
for (uint32_t i = 0; i < this->mBatchSize; ++i) {
(*this)[i].computeLU();
}
#endif
mInitialized = true;
return stream;
};
/**
* Solves the batched system \f$LUx = b\f$ inplace. The solution \f$x\f$ is written back into
* \f$b\f$.
*/
StreamID solve(const Batch<T>& b, const StreamID& stream = DEF_CUBLAS_STREAM) {
CT_ERROR(not mInitialized,
"Cannot solve system if PLUBatch has not yet computed its LU decomposition");
CT_ERROR_IF(b.size(), !=, this->mBatchSize,
"Upper dimensions of b do not match batch size");
CT_ERROR_IF(b.shape().rows(), !=, this->mShape.rows(),
"The length of each column of b must match the matrix rank");
#ifdef CUDACC
uint32_t n = b.shape().rows();
uint32_t nrhs = b.shape().cols();
CUBLAS_CHECK(
cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream)));
invoke<T>(cublasSgetrsBatched, cublasDgetrsBatched, cublasCgetrsBatched,
cublasZgetrsBatched, Manager::get()->cublasHandle(), CUBLAS_OP_N, n, nrhs,
DCAST(this->mBatch.dataDevice()), n, mPivotsBatch.dataDevice(),
DCAST(b.batch().dataDevice()), n, &mInfoSolve, this->mBatchSize);
#else
#pragma omp parallel for
for (uint32_t i = 0; i < this->mBatchSize; ++i) {
b[i] = (*this)[i].solve(b[i]);
}
#endif
return stream;
};
/**
* Gets the pivots data from the device to the host. Does nothing for CPU.
*/
StreamID getPivots(const StreamID& stream = DEF_MEM_STREAM) const {
mPivotsBatch.updateHost(stream);
return stream;
};
/**
* Gets the info array for the LU decomposition for the device to the host. Does not
* return useful information for CPU.
*/
Array<int32_t> getLUInfo() const {
mInfoLU.updateHost().wait();
return mInfoLU;
};
/**
* Checks validity of the solve operation. Does not return useful information for CPU.
*/
int32_t validSolve() const { return mInfoSolve == 0; }
};
}; // namespace BLAS
}; // namespace CudaTools
#endif