You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
627 lines
23 KiB
627 lines
23 KiB
#ifndef CUDATOOLS_BLAS_H
|
|
#define CUDATOOLS_BLAS_H
|
|
|
|
#ifndef CUDATOOLS_USE_EIGEN
|
|
#error "Cannot use CudaTools BLAS.h header without Eigen."
|
|
#endif
|
|
|
|
#include "Array.h"
|
|
#include "Core.h"
|
|
#include "Macros.h"
|
|
#include "Types.h"
|
|
|
|
using namespace CudaTools::Types;
|
|
|
|
namespace CudaTools {
|
|
namespace BLAS {
|
|
|
|
struct BatchInfo {
|
|
uint32_t strideA, strideB, strideC;
|
|
uint32_t size;
|
|
};
|
|
|
|
struct Check {
|
|
template <typename T>
|
|
static void isAtLeast2D(const Array<T>& arr, const std::string& name = "Array") {
|
|
CT_ERROR_IF(arr.shape().axes(), <, 2, (name + " needs to be at least 2D").c_str());
|
|
};
|
|
|
|
template <typename T>
|
|
static void isSquare(const Array<T>& arr, const std::string& name = "Array") {
|
|
isAtLeast2D(arr, name);
|
|
CT_ERROR_IF(arr.shape().rows(), !=, arr.shape().cols(), (name + " is not square").c_str())
|
|
};
|
|
|
|
template <typename T, typename U, typename V>
|
|
static void isValidMatmul(const Array<T>& A, const Array<U>& B, const Array<V>& C,
|
|
const std::string& nameA = "A", const std::string& nameB = "B",
|
|
const std::string nameC = "C") {
|
|
isAtLeast2D(A, nameA);
|
|
isAtLeast2D(B, nameB);
|
|
isAtLeast2D(C, nameB);
|
|
CT_ERROR_IF(A.shape().cols(), !=, B.shape().rows(),
|
|
(nameA + nameB + " is not a valid matrix multiplication").c_str());
|
|
|
|
Shape ABshape({A.shape().rows(), B.shape().cols()});
|
|
Shape Cshape({C.shape().rows(), C.shape().cols()});
|
|
|
|
CT_ERROR_IF(
|
|
ABshape, !=, Cshape,
|
|
("The shape of " + nameA + nameB + " does not match the shape of " + nameC).c_str());
|
|
};
|
|
|
|
template <typename T> static uint32_t getUpperItems(const Array<T>& arr) {
|
|
uint32_t upperItems = 1;
|
|
for (uint32_t iAxis = 0; iAxis < arr.shape().axes() - 2; ++iAxis) {
|
|
upperItems *= arr.shape().dim(iAxis);
|
|
}
|
|
return upperItems;
|
|
};
|
|
|
|
template <typename T, typename U>
|
|
static void matchUpperShape(const Array<T>& A, const Array<U>& B,
|
|
const std::string& nameA = "A", const std::string& nameB = "B") {
|
|
CT_ERROR_IF(A.shape().axes(), !=, B.shape().axes(),
|
|
(nameA + " and " + nameB + " shapes do not match for broadcasting").c_str());
|
|
for (uint32_t iAxis = 0; iAxis < A.shape().axes() - 2; ++iAxis) {
|
|
uint32_t Adim = A.shape().dim(iAxis);
|
|
uint32_t Bdim = B.shape().dim(iAxis);
|
|
CT_ERROR_IF(
|
|
Adim, !=, Bdim,
|
|
(nameA + " and " + nameB + " shapes do not match for broadcasting").c_str());
|
|
}
|
|
};
|
|
|
|
template <typename T, typename U, typename V>
|
|
static BatchInfo isBroadcastable(const Array<T>& A, const Array<U>& B, const Array<V>& C,
|
|
const std::string& nameA = "A", const std::string& nameB = "B",
|
|
const std::string nameC = "C") {
|
|
isValidMatmul(A, B, C, nameA, nameB, nameC);
|
|
uint32_t itemsA = getUpperItems(A);
|
|
uint32_t itemsB = getUpperItems(B);
|
|
uint32_t itemsC = getUpperItems(C);
|
|
|
|
uint32_t Asize = A.shape().rows() * A.shape().cols();
|
|
uint32_t Bsize = B.shape().rows() * B.shape().cols();
|
|
uint32_t Csize = C.shape().rows() * C.shape().cols();
|
|
|
|
if (itemsA == itemsB) {
|
|
CT_ERROR_IF(itemsA, !=, itemsC,
|
|
("Incorrect dimensions to broadcast to output " + nameC).c_str());
|
|
matchUpperShape(A, B, nameA, nameB);
|
|
matchUpperShape(A, C, nameA, nameC);
|
|
return BatchInfo{Asize, Bsize, Csize, itemsC};
|
|
} else if (itemsA > itemsB) {
|
|
CT_ERROR_IF(
|
|
itemsB, !=, 1,
|
|
("Cannot broadcast operation to " + nameB + " with non-matching " + nameA).c_str());
|
|
CT_ERROR_IF(itemsA, !=, itemsC,
|
|
("Incorrect dimensions to broadcast to output " + nameC).c_str());
|
|
matchUpperShape(A, C, nameA, nameC);
|
|
return BatchInfo{Asize, 0, Csize, itemsC};
|
|
} else {
|
|
CT_ERROR_IF(
|
|
itemsA, !=, 1,
|
|
("Cannot broadcast operation to " + nameA + " with non-matching " + nameB).c_str());
|
|
CT_ERROR_IF(itemsA, !=, itemsC,
|
|
("Incorrect dimensions to broadcast to output " + nameC).c_str());
|
|
matchUpperShape(B, C, nameB, nameC);
|
|
return BatchInfo{0, Bsize, Csize, itemsC};
|
|
}
|
|
};
|
|
};
|
|
|
|
/**
|
|
* Represents a Batch of Arrays with the same shape. Mainly used for cuBLAS functions.
|
|
*/
|
|
template <typename T> class Batch {
|
|
protected:
|
|
Array<T*> mBatch;
|
|
Shape mShape;
|
|
|
|
uint32_t mCount = 0;
|
|
uint32_t mBatchSize;
|
|
|
|
public:
|
|
Batch() = delete;
|
|
|
|
/**
|
|
* Constructs a batch from a given size.
|
|
*/
|
|
Batch(const uint32_t size) : mBatchSize(size){};
|
|
|
|
/**
|
|
* Constructs a batch from a non-view Array.
|
|
*/
|
|
Batch(const Array<T>& arr) {
|
|
CT_ERROR(arr.isView(), "Array cannot be a view");
|
|
mShape = Shape({arr.shape().rows(), arr.shape().cols()});
|
|
mBatchSize = mCount = Check::getUpperItems(arr);
|
|
|
|
mBatch = Array<T*>({mBatchSize});
|
|
|
|
Array<T> batch = arr.reshaped({mBatchSize, mShape.rows(), mShape.cols()});
|
|
for (uint32_t i = 0; i < mBatchSize; ++i) {
|
|
#ifdef CUDACC
|
|
mBatch[i] = batch[i].dataDevice();
|
|
#else
|
|
mBatch[i] = batch[i].data();
|
|
#endif
|
|
}
|
|
|
|
mBatch.updateDevice().wait();
|
|
};
|
|
|
|
/**
|
|
* Adds a matrix to the batch. Array must be a view.
|
|
*/
|
|
void add(const Array<T>& arr) {
|
|
CT_ERROR(not arr.isView(), "Cannot add non-view Arrays");
|
|
CT_ERROR_IF(mCount, ==, mBatchSize, "Batch is full, cannot add more arrays");
|
|
#ifdef CUDACC
|
|
mBatch[mCount] = arr.dataDevice();
|
|
#else
|
|
mBatch[mCount] = arr.data();
|
|
#endif
|
|
if (mCount == 0) {
|
|
mShape = arr.shape();
|
|
mBatchSize = mCount = Check::getUpperItems(arr);
|
|
} else {
|
|
CT_ERROR_IF(arr.shape(), !=, mShape, "Cannot add matrix of different shape to batch");
|
|
}
|
|
++mCount;
|
|
|
|
if (mCount == mBatchSize) {
|
|
mBatch.updateDevice().wait();
|
|
}
|
|
};
|
|
|
|
/**
|
|
* Indexing operator which returns a view of the Array in the Batch at the given index.
|
|
*/
|
|
Array<T> operator[](const uint32_t index) const {
|
|
CT_ERROR_IF(index, >=, mBatchSize, "Index exceeds batch size");
|
|
return Array<T>(mBatch[index], {mShape.rows(), mShape.cols()});
|
|
};
|
|
|
|
/**
|
|
* Returns the batch Array of pointers.
|
|
*/
|
|
Array<T*> batch() const { return mBatch.view(); };
|
|
Shape shape() const { return mShape; } /**< Gets the shape of the matrices in the batch. */
|
|
uint32_t size() const { return mBatchSize; } /**< Gets the batch size.*/
|
|
bool full() const { return mBatchSize == mCount; }; /**< Gets if the batch is full. */
|
|
};
|
|
|
|
////////////////
|
|
// cuBLAS API //
|
|
////////////////
|
|
|
|
template <typename T> struct CudaComplexConversion_S { typedef T type; };
|
|
#ifdef CUDACC
|
|
template <> struct CudaComplexConversion_S<complex64> { typedef cuComplex type; };
|
|
template <> struct CudaComplexConversion_S<complex128> { typedef cuDoubleComplex type; };
|
|
#else
|
|
|
|
#endif
|
|
|
|
template <typename T> using CudaComplexConversion = typename CudaComplexConversion_S<T>::type;
|
|
|
|
template <typename T> struct CublasTypeLetter_S { char letter; };
|
|
template <> struct CublasTypeLetter_S<real32> { char letter = 'S'; };
|
|
template <> struct CublasTypeLetter_S<real64> { char letter = 'D'; };
|
|
template <> struct CublasTypeLetter_S<complex64> { char letter = 'C'; };
|
|
template <> struct CublasTypeLetter_S<complex128> { char letter = 'Z'; };
|
|
#ifdef CUDACC
|
|
template <> struct CublasTypeLetter_S<real16> { char letter = 'H'; };
|
|
#endif
|
|
|
|
template <typename T> char CublasTypeLetter = CublasTypeLetter_S<T>::letter;
|
|
|
|
// Shorthands to reduce clutter.
|
|
|
|
#define CAST(var) reinterpret_cast<CudaComplexConversion<T>*>(var)
|
|
#define DCAST(var) reinterpret_cast<CudaComplexConversion<T>**>(var)
|
|
|
|
#define cublas(T, func) cublas##CublasTypeLetter<T>##func
|
|
|
|
template <typename T, typename F1, typename F2, typename F3, typename F4, typename... Args>
|
|
constexpr void invoke(F1 f1, F2 f2, F3 f3, F4 f4, Args&&... args) {
|
|
if constexpr (std::is_same<T, real32>::value) {
|
|
CUBLAS_CHECK(f1(args...));
|
|
} else if constexpr (std::is_same<T, real64>::value) {
|
|
CUBLAS_CHECK(f2(args...));
|
|
} else if constexpr (std::is_same<T, complex64>::value) {
|
|
CUBLAS_CHECK(f3(args...));
|
|
} else if constexpr (std::is_same<T, complex128>::value) {
|
|
CUBLAS_CHECK(f4(args...));
|
|
} else {
|
|
CT_ERROR(true, "This BLAS function is not callable with that type");
|
|
}
|
|
}
|
|
|
|
// If someone can think of a better solution, please tell me.
|
|
template <typename T, typename F1, typename F2, typename F3, typename F4, typename F5,
|
|
typename... Args>
|
|
constexpr void invoke5(F1 f1, F2 f2, F3 f3, F4 f4, F5 f5, Args&&... args) {
|
|
if constexpr (std::is_same<T, real32>::value) {
|
|
CUBLAS_CHECK(f1(args...));
|
|
} else if constexpr (std::is_same<T, real64>::value) {
|
|
CUBLAS_CHECK(f2(args...));
|
|
} else if constexpr (std::is_same<T, complex64>::value) {
|
|
CUBLAS_CHECK(f3(args...));
|
|
} else if constexpr (std::is_same<T, complex128>::value) {
|
|
CUBLAS_CHECK(f4(args...));
|
|
} else if constexpr (std::is_same<T, real16>::value) {
|
|
CUBLAS_CHECK(f5(args...));
|
|
} else {
|
|
CT_ERROR(true, "This BLAS function is not callable with that type");
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Computes the matrix-vector product: \f$ y = \alpha Ax + \beta y \f$. It will automatically
|
|
* broadcast the operation if applicable.
|
|
*/
|
|
template <typename T>
|
|
StreamID GEMV(const T alpha, const Array<T>& A, const Array<T>& x, const T beta, const Array<T>& y,
|
|
const StreamID& stream = DEF_CUBLAS_STREAM) {
|
|
|
|
BatchInfo bi = Check::isBroadcastable(A, x, y, "A", "x", "y");
|
|
CT_ERROR_IF(x.shape().cols(), !=, 1, "x must be a column vector");
|
|
CT_ERROR_IF(y.shape().cols(), !=, 1, "x must be a column vector");
|
|
|
|
uint32_t rows = A.shape().rows();
|
|
uint32_t cols = A.shape().cols();
|
|
T a = alpha, b = beta;
|
|
#ifdef CUDACC
|
|
CUBLAS_CHECK(cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream)));
|
|
if (bi.size == 1) {
|
|
invoke<T>(cublasSgemv, cublasDgemv, cublasCgemv, cublasZgemv,
|
|
Manager::get()->cublasHandle(), CUBLAS_OP_N, rows, cols, CAST(&a),
|
|
CAST(A.dataDevice()), rows, CAST(x.dataDevice()), 1, CAST(&b),
|
|
CAST(y.dataDevice()), 1);
|
|
} else { // Greater than 2, so broadcast.
|
|
invoke<T>(cublasSgemvStridedBatched, cublasDgemvStridedBatched, cublasCgemvStridedBatched,
|
|
cublasZgemvStridedBatched, Manager::get()->cublasHandle(), CUBLAS_OP_N, rows,
|
|
cols, CAST(&a), CAST(A.dataDevice()), rows, bi.strideA, CAST(x.dataDevice()), 1,
|
|
bi.strideB, CAST(&b), CAST(y.dataDevice()), 1, bi.strideC, bi.size);
|
|
}
|
|
|
|
#else
|
|
if (bi.size == 1) {
|
|
y.eigenMap() = a * (A.eigenMap() * x.eigenMap()) + b * y.eigenMap();
|
|
} else { // Greater than 2, so broadcast.
|
|
#pragma omp parallel for
|
|
for (uint32_t i = 0; i < bi.size; ++i) {
|
|
auto Ai = Array<T>(A, {rows, cols}, i * bi.strideA).eigenMap();
|
|
auto xi = Array<T>(x, {cols, 1}, i * bi.strideB).eigenMap();
|
|
auto yi = Array<T>(y, {rows, 1}, i * bi.strideC).eigenMap();
|
|
yi = a * (Ai * xi) + b * yi;
|
|
}
|
|
}
|
|
#endif
|
|
return StreamID{stream};
|
|
}
|
|
|
|
/**
|
|
* Computes the matrix-matrix product: \f$ C = \alpha AB + \beta C \f$. It will automatically
|
|
* broadcast the operation if applicable.
|
|
*/
|
|
template <typename T>
|
|
StreamID GEMM(const T alpha, const Array<T>& A, const Array<T>& B, const T beta, const Array<T>& C,
|
|
const StreamID& stream = DEF_CUBLAS_STREAM) {
|
|
|
|
BatchInfo bi = Check::isBroadcastable(A, B, C, "A", "B", "C");
|
|
// A is m x k, B is k x n.
|
|
uint32_t m = A.shape().rows();
|
|
uint32_t k = A.shape().cols();
|
|
uint32_t n = B.shape().cols();
|
|
|
|
T a = alpha, b = beta;
|
|
#ifdef CUDACC
|
|
CUBLAS_CHECK(cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream)));
|
|
|
|
if (bi.size == 1) {
|
|
invoke5<T>(cublasSgemm, cublasDgemm, cublasCgemm, cublasZgemm, cublasHgemm,
|
|
Manager::get()->cublasHandle(), CUBLAS_OP_N, CUBLAS_OP_N, m, n, k, CAST(&a),
|
|
CAST(A.dataDevice()), m, CAST(B.dataDevice()), k, CAST(&b), CAST(C.dataDevice()),
|
|
m);
|
|
|
|
} else { // Greater than 2, so broadcast.
|
|
invoke5<T>(cublasSgemmStridedBatched, cublasDgemmStridedBatched, cublasCgemmStridedBatched,
|
|
cublasZgemmStridedBatched, cublasHgemmStridedBatched,
|
|
Manager::get()->cublasHandle(), CUBLAS_OP_N, CUBLAS_OP_N, m, n, k, CAST(&a),
|
|
CAST(A.dataDevice()), m, bi.strideA, CAST(B.dataDevice()), k, bi.strideB,
|
|
CAST(&b), CAST(C.dataDevice()), m, bi.strideC, bi.size);
|
|
}
|
|
|
|
#else
|
|
if (bi.size == 1) {
|
|
C.eigenMap() = a * (A.eigenMap() * B.eigenMap()) + b * C.eigenMap();
|
|
} else { // Greater than 2, so broadcast.
|
|
#pragma omp parallel for
|
|
for (uint32_t i = 0; i < bi.size; ++i) {
|
|
auto Ai = Array<T>(A, {m, k}, i * bi.strideA).eigenMap();
|
|
auto Bi = Array<T>(B, {k, n}, i * bi.strideB).eigenMap();
|
|
auto Ci = Array<T>(C, {m, n}, i * bi.strideC).eigenMap();
|
|
Ci = a * (Ai * Bi) + b * Ci;
|
|
}
|
|
}
|
|
#endif
|
|
return StreamID{stream};
|
|
}
|
|
|
|
/**
|
|
* Computes the diagonal matrix multiplication: \f$ C = A\mathrm{diag}(X) \f$, or \f$ C =
|
|
* \mathrm{diag}(X)A \f$ if left = true.
|
|
*/
|
|
template <typename T>
|
|
StreamID DGMM(const Array<T>& A, const Array<T>& X, const Array<T>& C, const bool left = false,
|
|
const StreamID& stream = DEF_CUBLAS_STREAM) {
|
|
CT_ERROR_IF(X.shape().cols(), !=, 1, "'x' must be a column vector.");
|
|
if (left) {
|
|
CT_ERROR_IF(A.shape().rows(), !=, X.shape().rows(),
|
|
"Rows of 'A' and length of 'x' need to match.");
|
|
} else {
|
|
CT_ERROR_IF(A.shape().cols(), !=, X.shape().rows(),
|
|
"Columns of 'A' and length of 'x' need to match.");
|
|
}
|
|
CT_ERROR_IF(A.shape().rows(), !=, C.shape().rows(),
|
|
"Rows of 'A' and rows() of 'C' need to match.");
|
|
CT_ERROR_IF(A.shape().cols(), !=, C.shape().cols(),
|
|
"Rows of 'A' and columns of 'C' need to match.");
|
|
|
|
#ifdef CUDACC
|
|
uint32_t m = C.shape().rows();
|
|
uint32_t n = C.shape().cols();
|
|
auto mode = (left) ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT;
|
|
CUBLAS_CHECK(cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream)));
|
|
invoke<T>(cublasSdgmm, cublasDdgmm, cublasCdgmm, cublasZdgmm, Manager::get()->cublasHandle(), m,
|
|
n, CAST(A.dataDevice()), A.shape().rows(), CAST(X.dataDevice()), 1,
|
|
CAST(C.dataDevice()), m);
|
|
#else
|
|
if (left) {
|
|
C.eigenMap() = X.eigenMap().asDiagonal() * A.eigenMap();
|
|
} else {
|
|
C.eigenMap() = A.eigenMap() * X.eigenMap().asDiagonal();
|
|
}
|
|
#endif
|
|
return StreamID{stream};
|
|
}
|
|
|
|
//////////////////////////////
|
|
// PLUArray Related Objects //
|
|
//////////////////////////////
|
|
|
|
///////////////////////////
|
|
// PartialPivLU Wrapper //
|
|
///////////////////////////
|
|
|
|
// This class is just a workaround to use Eigen's internals directly.
|
|
template <typename T> class PartialPivLU;
|
|
namespace internal {
|
|
template <typename T> static Array<T> empty({1, 1});
|
|
template <typename T> static EigenMapMat<T> empty_map = empty<T>.eigenMap();
|
|
}; // namespace internal
|
|
|
|
template <typename T, std::enable_if_t<is_float<T> or is_complex<T>, bool> = true> class PLUArray;
|
|
// This is a wrapper class for Eigen's class so we have more controlled access to
|
|
// the underlying data.
|
|
template <typename T> class PartialPivLU : public Eigen::PartialPivLU<Eigen::Ref<EigenMat<T>>> {
|
|
private:
|
|
using Base = Eigen::PartialPivLU<Eigen::Ref<EigenMat<T>>>;
|
|
template <typename U, std::enable_if_t<is_float<U> or is_complex<U>, bool>>
|
|
friend class PLUArray;
|
|
|
|
EigenMapMat<T> mMapLU;
|
|
EigenMapMat<int32_t> mMapPivots;
|
|
|
|
public:
|
|
PartialPivLU()
|
|
: Base(internal::empty_map<T>), mMapLU(internal::empty_map<T>),
|
|
mMapPivots(internal::empty_map<int32_t>){};
|
|
|
|
void make(const Array<T>& lu, const Array<int32_t>& pivots) {
|
|
|
|
new (&mMapLU) EigenMapMat<T>(lu.eigenMap());
|
|
new (&mMapPivots) EigenMapMat<int32_t>(pivots.atLeast2D().eigenMap());
|
|
|
|
new (&this->m_lu) decltype(Base::m_lu)(mMapLU.derived());
|
|
new (&this->m_p) decltype(Base::m_p)(mMapPivots.derived());
|
|
|
|
// new (&this->m_rowsTranspositions) decltype(Base::m_rowsTranspositions)(
|
|
// mMapPivots.derived());
|
|
|
|
this->m_l1_norm = 0;
|
|
this->m_det_p = 0;
|
|
this->m_isInitialized = true;
|
|
};
|
|
};
|
|
|
|
namespace internal {
|
|
// We only create one and copy-construct to avoid the re-initialization.
|
|
template <typename T> static PartialPivLU<T> BlankPPLU = PartialPivLU<T>();
|
|
}; // namespace internal
|
|
|
|
/**
|
|
* Class for storing the PLU decomposition an Array. This is restricted to floating point types.
|
|
*/
|
|
template <typename T, std::enable_if_t<is_float<T> or is_complex<T>, bool>> class PLUArray {
|
|
private:
|
|
Array<T> mLU;
|
|
Array<int32_t> mPivots;
|
|
PartialPivLU<T> mPPLU = internal::BlankPPLU<T>;
|
|
|
|
public:
|
|
PLUArray() = delete;
|
|
|
|
/**
|
|
* Constructor for a PLUArray given the matrix dimension.
|
|
*/
|
|
PLUArray(const uint32_t n) : mLU({n, n}), mPivots({n}) { mPPLU.make(mLU, mPivots); };
|
|
|
|
/**
|
|
* Constructor for a PLUArray given an existing array.
|
|
*/
|
|
PLUArray(const Array<T>& arr)
|
|
: mLU((arr.isView()) ? arr.view() : arr), mPivots({arr.shape().rows()}) {
|
|
CT_ERROR_IF(mLU.shape().axes(), !=, 2, "Array must be a 2D matrix");
|
|
CT_ERROR_IF(mLU.shape().rows(), !=, mLU.shape().cols(), "Matrix must be square");
|
|
mPPLU.make(mLU, mPivots);
|
|
};
|
|
|
|
/**
|
|
* Constructor for a PLUArray given an existing location in memory for both the matrix and
|
|
* the pivots.
|
|
*/
|
|
PLUArray(const Array<T>& arr, const Array<int32_t> pivots)
|
|
: mLU(arr.view()), mPivots(pivots.view()) {
|
|
CT_ERROR_IF(mLU.shape().axes(), !=, 2, "Array must be a 2D matrix");
|
|
CT_ERROR_IF(mLU.shape().rows(), !=, mLU.shape().cols(), "Matrix must be square");
|
|
mPPLU.make(mLU, mPivots);
|
|
};
|
|
|
|
uint32_t rank() { return mLU.shape().rows(); }; /**< Gets the rank of the LU matrix. */
|
|
Array<T> LU() const { return mLU.view(); }; /**< Gets the LU matrix. */
|
|
Array<int32_t> pivots() const { return mPivots.view(); }; /**< Gets the LU matrix. */
|
|
|
|
/**
|
|
* Comptues the inplace LU factorization for this array on CPU.
|
|
*/
|
|
void computeLU() {
|
|
mPPLU.compute();
|
|
mPPLU.mMapPivots = mPPLU.permutationP().indices();
|
|
};
|
|
|
|
/**
|
|
* Solves the system \f$ LUx = b \f$ and returns \f$x\f$.
|
|
*/
|
|
Array<T> solve(const Array<T>& b) {
|
|
Array<T> x(b.shape());
|
|
x.eigenMap() = mPPLU.solve(b.eigenMap());
|
|
return x;
|
|
};
|
|
};
|
|
|
|
/**
|
|
* This is a batch version of PLUArray, to enable usage of the cuBLAS API. This is restricted to
|
|
* floating point types.
|
|
*/
|
|
template <typename T, std::enable_if_t<is_float<T> or is_complex<T>, bool> = true>
|
|
class PLUBatch : public Batch<T> {
|
|
private:
|
|
Array<int32_t> mPivotsBatch;
|
|
Array<int32_t> mInfoLU;
|
|
int32_t mInfoSolve;
|
|
|
|
bool mInitialized = false;
|
|
|
|
public:
|
|
/**
|
|
* Constructor of a PLUBatch from a given batch size.
|
|
*/
|
|
PLUBatch(const uint32_t size) : Batch<T>(size), mInfoLU({size}){};
|
|
|
|
/**
|
|
* Constructor of a PLUBatch from a multi-dimensional array, batched across upper dimensions.
|
|
*/
|
|
PLUBatch(const Array<T>& arr) : Batch<T>(arr) {
|
|
Check::isSquare(arr, "LU Array");
|
|
|
|
mPivotsBatch = Array<int32_t>({this->mBatchSize * this->mShape.rows()});
|
|
mInfoLU = Array<int32_t>({this->mBatchSize});
|
|
};
|
|
|
|
/**
|
|
* Indexing operator which returns the PLUArray in the PLUBatch at the given index.
|
|
*/
|
|
PLUArray<T> operator[](const uint32_t index) const {
|
|
CT_ERROR_IF(index, >=, this->mBatchSize, "Index exceeds batch size");
|
|
Array<T> lu(this->mBatch[index], {this->mShape.rows(), this->mShape.cols()});
|
|
Array<int32_t> pivots(mPivotsBatch.data() + index * this->mShape.rows(),
|
|
{this->mShape.rows()});
|
|
return PLUArray<T>(lu, pivots);
|
|
};
|
|
|
|
/**
|
|
* Computes the inplace PLU decomposition of batch of arrays.
|
|
*/
|
|
StreamID computeLU(const StreamID& stream = DEF_CUBLAS_STREAM) {
|
|
#ifdef CUDACC
|
|
uint32_t n = this->mShape.rows();
|
|
CUBLAS_CHECK(
|
|
cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream)));
|
|
invoke<T>(cublasSgetrfBatched, cublasDgetrfBatched, cublasCgetrfBatched,
|
|
cublasZgetrfBatched, Manager::get()->cublasHandle(), n,
|
|
DCAST(this->mBatch.dataDevice()), n, mPivotsBatch.dataDevice(),
|
|
mInfoLU.dataDevice(), this->mBatchSize);
|
|
|
|
#else
|
|
#pragma omp parallel for
|
|
for (uint32_t i = 0; i < this->mBatchSize; ++i) {
|
|
(*this)[i].computeLU();
|
|
}
|
|
#endif
|
|
mInitialized = true;
|
|
return stream;
|
|
};
|
|
|
|
/**
|
|
* Solves the batched system \f$LUx = b\f$ inplace. The solution \f$x\f$ is written back into
|
|
* \f$b\f$.
|
|
*/
|
|
StreamID solve(const Batch<T>& b, const StreamID& stream = DEF_CUBLAS_STREAM) {
|
|
CT_ERROR(not mInitialized,
|
|
"Cannot solve system if PLUBatch has not yet computed its LU decomposition");
|
|
CT_ERROR_IF(b.size(), !=, this->mBatchSize,
|
|
"Upper dimensions of b do not match batch size");
|
|
CT_ERROR_IF(b.shape().rows(), !=, this->mShape.rows(),
|
|
"The length of each column of b must match the matrix rank");
|
|
|
|
#ifdef CUDACC
|
|
uint32_t n = b.shape().rows();
|
|
uint32_t nrhs = b.shape().cols();
|
|
CUBLAS_CHECK(
|
|
cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream)));
|
|
invoke<T>(cublasSgetrsBatched, cublasDgetrsBatched, cublasCgetrsBatched,
|
|
cublasZgetrsBatched, Manager::get()->cublasHandle(), CUBLAS_OP_N, n, nrhs,
|
|
DCAST(this->mBatch.dataDevice()), n, mPivotsBatch.dataDevice(),
|
|
DCAST(b.batch().dataDevice()), n, &mInfoSolve, this->mBatchSize);
|
|
|
|
#else
|
|
#pragma omp parallel for
|
|
for (uint32_t i = 0; i < this->mBatchSize; ++i) {
|
|
b[i] = (*this)[i].solve(b[i]);
|
|
}
|
|
#endif
|
|
return stream;
|
|
};
|
|
|
|
/**
|
|
* Gets the pivots data from the device to the host. Does nothing for CPU.
|
|
*/
|
|
StreamID getPivots(const StreamID& stream = DEF_MEM_STREAM) const {
|
|
mPivotsBatch.updateHost(stream);
|
|
return stream;
|
|
};
|
|
|
|
/**
|
|
* Gets the info array for the LU decomposition for the device to the host. Does not
|
|
* return useful information for CPU.
|
|
*/
|
|
Array<int32_t> getLUInfo() const {
|
|
mInfoLU.updateHost().wait();
|
|
return mInfoLU;
|
|
};
|
|
|
|
/**
|
|
* Checks validity of the solve operation. Does not return useful information for CPU.
|
|
*/
|
|
int32_t validSolve() const { return mInfoSolve == 0; }
|
|
};
|
|
|
|
}; // namespace BLAS
|
|
}; // namespace CudaTools
|
|
|
|
#endif
|
|
|