You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
569 lines
18 KiB
569 lines
18 KiB
#define CUDATOOLS_IMPLEMENTATION
|
|
#define CUDATOOLS_ARRAY_MAX_AXES 8
|
|
#include <Array.h>
|
|
#include <BLAS.h>
|
|
#include <Core.h>
|
|
#include <Types.h>
|
|
|
|
#include <Eigen/Core>
|
|
#include <chrono>
|
|
#include <complex>
|
|
|
|
using namespace CudaTools::Types;
|
|
namespace CT = CudaTools;
|
|
|
|
/////////////
|
|
// Helpers //
|
|
/////////////
|
|
|
|
#define TIME_START(name) auto begin_##name = std::chrono::steady_clock::now()
|
|
|
|
#define TIME_END(name) \
|
|
auto end_##name = std::chrono::steady_clock::now(); \
|
|
auto time_ms_##name = \
|
|
std::chrono::duration_cast<std::chrono::milliseconds>(end_##name - begin_##name).count(); \
|
|
auto time_mus_##name = \
|
|
std::chrono::duration_cast<std::chrono::microseconds>(end_##name - begin_##name).count(); \
|
|
if (time_ms_##name == 0) { \
|
|
printf("[%s] Time Elapsed: %ld[µs]\n", #name, time_mus_##name); \
|
|
} else { \
|
|
printf("[%s] Time Elapsed: %ld[ms]\n", #name, time_ms_##name); \
|
|
}
|
|
|
|
#define TIME(call, name) \
|
|
TIME_START(name); \
|
|
call; \
|
|
TIME_END(name);
|
|
|
|
#define TEST(predicate, name, msg) \
|
|
failed += (predicate) ? 0 : 1; \
|
|
printf("[%s] ", (predicate) ? "\033[1;32mPASS\033[0m" : "\033[1;31mFAIL\033[0m"); \
|
|
printf("%s | %s.\n", name, msg);
|
|
|
|
template <typename T> struct Type;
|
|
|
|
#define REGISTER_PARSE_TYPE(X) \
|
|
template <> struct Type<X> { static const std::string name; }; \
|
|
const std::string Type<X>::name = #X
|
|
|
|
REGISTER_PARSE_TYPE(uint8_t);
|
|
REGISTER_PARSE_TYPE(int16_t);
|
|
REGISTER_PARSE_TYPE(int32_t);
|
|
REGISTER_PARSE_TYPE(real32);
|
|
REGISTER_PARSE_TYPE(real64);
|
|
REGISTER_PARSE_TYPE(complex64);
|
|
REGISTER_PARSE_TYPE(complex128);
|
|
|
|
std::string box(std::string str) {
|
|
std::string tops(str.size() + 6, '#');
|
|
return tops + "\n## " + str + " ##\n" + tops;
|
|
}
|
|
|
|
std::string box2(std::string str) {
|
|
std::string tops(str.size() - 5, '-');
|
|
return tops + "\n|| " + str + " ||\n" + tops;
|
|
}
|
|
|
|
std::string boxSmall(std::string str) {
|
|
std::string tops(6, '-');
|
|
return tops + "[ " + str + " ]" + tops;
|
|
}
|
|
|
|
std::string separator() {
|
|
std::string line(40, '=');
|
|
return "\n" + line + "\n";
|
|
}
|
|
|
|
template <typename T> std::string type() { return "\033[1;96m" + Type<T>::name + "\033[0m"; }
|
|
|
|
CT::Shape makeRandom2DShape() {
|
|
std::random_device rd;
|
|
std::mt19937 mt(rd());
|
|
std::uniform_int_distribution<uint32_t> dist(1, 15);
|
|
return CT::Shape({dist(mt), dist(mt)});
|
|
}
|
|
|
|
///////////
|
|
// Tests //
|
|
///////////
|
|
|
|
class TestClass {
|
|
DEVICE_COPY(TestClass);
|
|
|
|
public:
|
|
int x;
|
|
TestClass(const int x) : x(x) {
|
|
allocateDevice();
|
|
updateDevice().wait();
|
|
};
|
|
};
|
|
|
|
KERNEL(times, const CT::Array<int> arr) {
|
|
BASIC_LOOP(arr.shape().length()) { arr[iThread] *= 2; }
|
|
}
|
|
|
|
KERNEL(classTest, TestClass* const test) { test->x = 100; }
|
|
|
|
KERNEL(collatz, const CT::Array<uint32_t> arr) {
|
|
BASIC_LOOP(arr.shape().length()) {
|
|
if (arr[iThread] % 2) {
|
|
arr[iThread] = 3 * arr[iThread] + 1;
|
|
} else {
|
|
arr[iThread] = arr[iThread] >> 1;
|
|
}
|
|
}
|
|
}
|
|
|
|
KERNEL(plusOne, const CT::Array<uint32_t> arr) {
|
|
BASIC_LOOP(arr.shape().length()) { arr[iThread] += 1; }
|
|
}
|
|
|
|
KERNEL(addArray, const CT::Array<uint32_t> a, const CT::Array<uint32_t> b) {
|
|
BASIC_LOOP(a.shape().length()) { a[iThread] += b[iThread]; }
|
|
}
|
|
|
|
struct MacroTests {
|
|
static uint32_t Kernel() {
|
|
uint32_t failed = 0;
|
|
CT::Array<int> A = CT::Array<int>::constant({10}, 1);
|
|
A.updateDevice().wait();
|
|
CT::Kernel::launch(times, CT::Kernel::basic(A.shape().items()), A.view()).wait();
|
|
A.updateHost().wait();
|
|
|
|
uint32_t errors = 0;
|
|
for (auto it = A.begin(); it != A.end(); ++it) {
|
|
if (*it != 2) ++errors;
|
|
}
|
|
|
|
std::ostringstream msg;
|
|
msg << "Errors: " << errors;
|
|
TEST(errors == 0, "Kernel", msg.str().c_str());
|
|
return failed;
|
|
};
|
|
|
|
static uint32_t Class() {
|
|
uint32_t failed = 0;
|
|
TestClass test(1);
|
|
CT::Kernel::launch(classTest, CT::Kernel::basic(1), test.that()).wait();
|
|
test.updateHost().wait();
|
|
|
|
TEST(test.x == 100, "Class", "Errors: 0");
|
|
return failed;
|
|
}
|
|
};
|
|
|
|
template <typename T> struct ArrayTests {
|
|
static uint32_t Indexing() {
|
|
uint32_t failed = 0;
|
|
CT::Array<T> A = CT::Array<T>::range(0, 240);
|
|
A.reshape({5, 3, 1, 4, 2, 1, 1, 2});
|
|
|
|
uint32_t errors = 0;
|
|
for (uint32_t i = 0; i < 5; ++i) {
|
|
for (uint32_t j = 0; j < 3; ++j) {
|
|
for (uint32_t k = 0; k < 4; ++k) {
|
|
for (uint32_t l = 0; l < 2; ++l) {
|
|
for (uint32_t m = 0; m < 2; ++m) {
|
|
if ((T)A[i][j][0][k][l][0][0][m] != (T)A[{i, j, 0, k, l, 0, 0, m}]) {
|
|
++errors;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
std::ostringstream msg;
|
|
msg << "Errors: " << errors;
|
|
TEST(errors == 0, "Element", msg.str().c_str());
|
|
|
|
errors = 0;
|
|
CT::Array<T> ApartGroup_1 = A[{2, 2}];
|
|
CT::Array<T> ApartIndiv_1 = A[2][2];
|
|
for (uint32_t k = 0; k < 4; ++k) {
|
|
for (uint32_t l = 0; l < 2; ++l) {
|
|
for (uint32_t m = 0; m < 2; ++m) {
|
|
if ((T)ApartIndiv_1[0][k][l][0][0][m] != (T)ApartGroup_1[{0, k, l, 0, 0, m}]) {
|
|
++errors;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
msg.str("");
|
|
msg << "Errors: " << errors;
|
|
TEST(errors == 0, "Axis (1/2)", msg.str().c_str());
|
|
|
|
errors = 0;
|
|
CT::Array<T> ApartGroup_2 = A[{3, 2, 0, 3}];
|
|
CT::Array<T> ApartIndiv_2 = A[3][2][0][3];
|
|
|
|
for (uint32_t l = 0; l < 2; ++l) {
|
|
for (uint32_t m = 0; m < 2; ++m) {
|
|
if ((T)ApartIndiv_2[l][0][0][m] != (T)ApartGroup_2[{l, 0, 0, m}]) {
|
|
++errors;
|
|
}
|
|
}
|
|
}
|
|
|
|
msg.str("");
|
|
msg << "Errors: " << errors;
|
|
TEST(errors == 0, "Axis (2/2)", msg.str().c_str());
|
|
return failed;
|
|
};
|
|
|
|
static uint32_t Slicing() {
|
|
uint32_t failed = 0;
|
|
CT::Array<T> A = CT::Array<T>::constant({4, 5, 5}, 0);
|
|
|
|
CT::Array<T> Aslice = A.slice({{0, 3}, {1, 4}, {1, 4}});
|
|
T num = (T)1;
|
|
for (auto it = Aslice.begin(); it != Aslice.end(); ++it) {
|
|
*it = num;
|
|
++num;
|
|
}
|
|
|
|
CT::Array<T> Aslice2 = A[3].slice({{0, 5}, {0, 1}});
|
|
num = (T)-1;
|
|
for (auto it = Aslice2.begin(); it != Aslice2.end(); ++it) {
|
|
*it = num;
|
|
--num;
|
|
}
|
|
|
|
uint32_t errors = 0;
|
|
for (int i = 0; i < 3; ++i) {
|
|
for (int j = 0; j < 3; ++j) {
|
|
for (int k = 0; k < 3; ++k) {
|
|
if ((T)A[i][1 + j][1 + k] != (T)(9 * i + 3 * j + k + 1)) {
|
|
++errors;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
std::ostringstream msg;
|
|
msg << "Errors: " << errors;
|
|
TEST(errors == 0, "Block", msg.str().c_str());
|
|
|
|
errors = 0;
|
|
for (int i = 0; i < 5; ++i) {
|
|
if ((T)A[3][i][0] != (T)(-(i + 1))) {
|
|
++errors;
|
|
}
|
|
}
|
|
|
|
msg.str("");
|
|
msg << "Errors: " << errors;
|
|
TEST(errors == 0, "Column", msg.str().c_str());
|
|
return failed;
|
|
}
|
|
};
|
|
|
|
template <typename T> struct BLASTests {
|
|
static double thres;
|
|
static uint32_t GEMV(int attempts) {
|
|
uint32_t failed = 0;
|
|
for (int i = 0; i < attempts; i++) {
|
|
CT::Shape Ashape = makeRandom2DShape();
|
|
CT::Shape xshape = CT::Shape({Ashape.cols(), 1});
|
|
CT::Shape yshape = CT::Shape({Ashape.rows(), 1});
|
|
|
|
CT::Array<T> A(Ashape);
|
|
CT::Array<T> x(xshape);
|
|
CT::Array<T> y(yshape);
|
|
|
|
A.setRandom(-100, 100);
|
|
x.setRandom(-100, 100);
|
|
|
|
A.updateDevice();
|
|
x.updateDevice().wait();
|
|
|
|
CT::BLAS::GEMV<T>(1.0, A, x, 0.0, y).wait();
|
|
y.updateHost().wait();
|
|
|
|
CT::Array<T> yTest(yshape, true);
|
|
yTest.eigenMap() = A.eigenMap() * x.eigenMap();
|
|
|
|
double norm = (y.eigenMap() - y.eigenMap()).norm();
|
|
|
|
std::ostringstream name;
|
|
name << "GEMV (" << i + 1 << "/" << attempts << ")";
|
|
std::ostringstream msg;
|
|
msg << "Matrix Shape: " << Ashape << ", "
|
|
<< "Residual: " << norm;
|
|
TEST(norm < thres, name.str().c_str(), msg.str().c_str());
|
|
}
|
|
return failed;
|
|
};
|
|
|
|
static uint32_t GEMVBroadcast() {
|
|
uint32_t failed = 0;
|
|
CT::Shape Ashape = makeRandom2DShape();
|
|
CT::Shape xshape = CT::Shape({Ashape.cols(), 1});
|
|
CT::Shape yshape = CT::Shape({Ashape.rows(), 1});
|
|
|
|
CT::Array<T> A({2, 3, Ashape.rows(), Ashape.cols()});
|
|
CT::Array<T> x({2, 3, xshape.rows(), xshape.cols()});
|
|
CT::Array<T> y({2, 3, yshape.rows(), yshape.cols()});
|
|
|
|
A.setRandom(-100, 100);
|
|
x.setRandom(-100, 100);
|
|
|
|
A.updateDevice();
|
|
x.updateDevice().wait();
|
|
|
|
CT::BLAS::GEMV<T>(1.0, A, x, 0.0, y).wait();
|
|
y.updateHost().wait();
|
|
|
|
double norm = 0;
|
|
CT::Array<T> yTest(yshape, true);
|
|
for (int i = 0; i < 2; ++i) {
|
|
for (int j = 0; j < 3; ++j) {
|
|
yTest.eigenMap() = A[i][j].eigenMap() * x[i][j].eigenMap();
|
|
norm += (yTest.eigenMap() - y[i][j].eigenMap()).norm();
|
|
}
|
|
}
|
|
|
|
std::ostringstream msg;
|
|
msg << "Matrix Shape: " << Ashape << ", "
|
|
<< "Residual: " << norm;
|
|
TEST(norm < thres, "GEMV Broadcast", msg.str().c_str());
|
|
return failed;
|
|
};
|
|
|
|
static uint32_t GEMM(int attempts) {
|
|
uint32_t failed = 0;
|
|
for (int i = 0; i < attempts; i++) {
|
|
CT::Shape Ashape = makeRandom2DShape();
|
|
CT::Shape Bshape = makeRandom2DShape();
|
|
Bshape = CT::Shape({Ashape.cols(), Bshape.cols()});
|
|
|
|
CT::Shape Cshape = CT::Shape({Ashape.rows(), Bshape.cols()});
|
|
|
|
CT::Array<T> A(Ashape);
|
|
CT::Array<T> B(Bshape);
|
|
CT::Array<T> C(Cshape);
|
|
|
|
A.setRandom(-100, 100);
|
|
B.setRandom(-100, 100);
|
|
C.setRandom(-100, 100);
|
|
|
|
A.updateDevice();
|
|
B.updateDevice();
|
|
C.updateDevice().wait();
|
|
|
|
CT::BLAS::GEMM<T>(1.0, A, B, 0.0, C).wait();
|
|
C.updateHost().wait();
|
|
|
|
CT::Array<T> CTest(Cshape, true);
|
|
CTest.eigenMap() = A.eigenMap() * B.eigenMap();
|
|
|
|
double norm = (CTest.eigenMap() - C.eigenMap()).norm();
|
|
|
|
std::ostringstream name;
|
|
name << "GEMM (" << i + 1 << "/" << attempts << ")";
|
|
std::ostringstream msg;
|
|
msg << "Matrix Shapes: " << Ashape << Bshape << ", "
|
|
<< "Residual: " << norm;
|
|
TEST(norm < thres, name.str().c_str(), msg.str().c_str());
|
|
}
|
|
return failed;
|
|
};
|
|
|
|
static uint32_t GEMMBroadcast() {
|
|
uint32_t failed = 0;
|
|
CT::Shape Ashape = makeRandom2DShape();
|
|
CT::Shape Bshape = makeRandom2DShape();
|
|
Bshape = CT::Shape({Ashape.cols(), Bshape.cols()});
|
|
|
|
CT::Shape Cshape = CT::Shape({Ashape.rows(), Bshape.cols()});
|
|
|
|
CT::Array<T> A({2, 3, Ashape.rows(), Ashape.cols()});
|
|
CT::Array<T> B({2, 3, Bshape.rows(), Bshape.cols()});
|
|
CT::Array<T> C({2, 3, Cshape.rows(), Cshape.cols()});
|
|
|
|
A.setRandom(-100, 100);
|
|
B.setRandom(-100, 100);
|
|
|
|
A.updateDevice();
|
|
B.updateDevice();
|
|
C.updateDevice().wait();
|
|
|
|
CT::BLAS::GEMM<T>(1.0, A, B, 0.0, C).wait();
|
|
C.updateHost().wait();
|
|
|
|
double norm = 0;
|
|
CT::Array<T> CTest(Cshape, true);
|
|
for (int i = 0; i < 2; ++i) {
|
|
for (int j = 0; j < 3; ++j) {
|
|
CTest.eigenMap() = A[i][j].eigenMap() * B[i][j].eigenMap();
|
|
norm += (CTest.eigenMap() - C[i][j].eigenMap()).norm();
|
|
}
|
|
}
|
|
|
|
std::ostringstream msg;
|
|
msg << "Matrix Shapes: " << Ashape << Bshape << ", "
|
|
<< "Residual: " << norm;
|
|
TEST(norm < thres, "GEMM Broadcast", msg.str().c_str());
|
|
return failed;
|
|
};
|
|
|
|
static uint32_t PLU() {
|
|
uint32_t failed = 0;
|
|
CT::Shape Ashape = makeRandom2DShape();
|
|
CT::Shape xshape = makeRandom2DShape();
|
|
Ashape = CT::Shape({Ashape.rows(), Ashape.rows()});
|
|
xshape = CT::Shape({Ashape.rows(), xshape.cols()});
|
|
|
|
CT::Array<T> A({2, 3, Ashape.rows(), Ashape.rows()});
|
|
CT::Array<T> x({2, 3, xshape.rows(), xshape.cols()});
|
|
CT::Array<T> b({2, 3, xshape.rows(), xshape.cols()});
|
|
CT::Array<T> Ax({2, 3, xshape.rows(), xshape.cols()});
|
|
|
|
A.setRandom(-100, 100);
|
|
b.setRandom(-100, 100);
|
|
|
|
CT::Array<T> LU(A.copy());
|
|
x = b;
|
|
|
|
A.updateDevice();
|
|
LU.updateDevice();
|
|
x.updateDevice().wait();
|
|
|
|
CT::BLAS::PLUBatch<T> luBatch(LU);
|
|
CT::BLAS::Batch<T> xBatch(x);
|
|
luBatch.computeLU().wait();
|
|
luBatch.solve(xBatch).wait();
|
|
|
|
// Compute Ax and compare difference.
|
|
CT::BLAS::GEMM<T>(1.0, A, x, 0.0, Ax).wait();
|
|
Ax.updateHost();
|
|
|
|
double norm = 0;
|
|
for (int i = 0; i < 2; ++i) {
|
|
for (int j = 0; j < 3; ++j) {
|
|
norm += (Ax[i][j].eigenMap() - b[i][j].eigenMap()).norm();
|
|
}
|
|
}
|
|
|
|
std::ostringstream msg;
|
|
msg << "Matrix Shape: " << Ashape << xshape << ", "
|
|
<< "Residual: " << norm;
|
|
TEST(norm < thres, "PLU/Solve", msg.str().c_str());
|
|
return failed;
|
|
}
|
|
};
|
|
|
|
template <> double BLASTests<float>::thres = 10e-1;
|
|
template <> double BLASTests<double>::thres = 10e-8;
|
|
template <> double BLASTests<complex64>::thres = 10e-1;
|
|
template <> double BLASTests<complex128>::thres = 10e-8;
|
|
|
|
uint32_t doMacroTests() {
|
|
uint32_t failed = 0;
|
|
failed += MacroTests::Kernel();
|
|
failed += MacroTests::Class();
|
|
std::cout << "\n";
|
|
return failed;
|
|
}
|
|
|
|
template <typename T> uint32_t doArrayTests() {
|
|
uint32_t failed = 0;
|
|
std::cout << boxSmall("Index Tests : " + type<T>()) << "\n";
|
|
failed += ArrayTests<T>::Indexing();
|
|
std::cout << "\n" << boxSmall("Slice Tests : " + type<T>()) << "\n";
|
|
failed += ArrayTests<T>::Slicing();
|
|
std::cout << "\n";
|
|
return failed;
|
|
}
|
|
|
|
template <typename T> uint32_t doBLASTests() {
|
|
uint32_t failed = 0;
|
|
std::cout << boxSmall("GEMV Tests : " + type<T>()) << "\n";
|
|
failed += BLASTests<T>::GEMV(5);
|
|
failed += BLASTests<T>::GEMVBroadcast();
|
|
|
|
std::cout << "\n" << boxSmall("GEMM Tests : " + type<T>()) << "\n";
|
|
failed += BLASTests<T>::GEMM(5);
|
|
failed += BLASTests<T>::GEMMBroadcast();
|
|
|
|
std::cout << "\n" << boxSmall("PLU Tests : " + type<T>()) << "\n";
|
|
failed += BLASTests<T>::PLU();
|
|
std::cout << "\n";
|
|
return failed;
|
|
}
|
|
|
|
void addNum(const CT::Array<uint32_t> A, uint32_t num) {
|
|
auto Aeig = A.atLeast2D().eigenMap();
|
|
Aeig = Aeig.array() + num;
|
|
}
|
|
|
|
void myGraph(CT::GraphManager* gm, const CT::Array<uint32_t> A, const CT::Array<uint32_t> B) {
|
|
// tools->launchHostFunction("graphStream", myHostFunc, A->view(), 5);
|
|
A.updateDevice("graphStream");
|
|
gm->makeBranch("graphStream", "graphStreamBranch");
|
|
B.updateDevice("graphStreamBranch");
|
|
for (uint32_t iTimes = 0; iTimes < 30; ++iTimes) {
|
|
CT::Kernel::launch(collatz, CT::Kernel::basic(A.shape().items(), "graphStream"), A.view());
|
|
CT::Kernel::launch(plusOne, CT::Kernel::basic(A.shape().items(), "graphStreamBranch"),
|
|
B.view());
|
|
}
|
|
|
|
gm->joinBranch("graphStream", "graphStreamBranch");
|
|
CT::Kernel::launch(addArray, CT::Kernel::basic(A.shape().items(), "graphStream"), A.view(),
|
|
B.view());
|
|
A.updateHost("graphStream");
|
|
B.updateHost("graphStream");
|
|
gm->launchHostFunction("graphStream", addNum, A.view(), 5);
|
|
}
|
|
|
|
uint32_t doGraphTest() {
|
|
uint32_t failed = 0;
|
|
CT::Array<uint32_t> A = CT::Array<uint32_t>::constant({1000000}, 50);
|
|
CT::Array<uint32_t> B = CT::Array<uint32_t>::constant({1000000}, 0);
|
|
CT::Manager::get()->addStream("graphStream");
|
|
CT::Manager::get()->addStream("graphStreamBranch");
|
|
|
|
CT::GraphManager gm;
|
|
CT::Graph graph("graphStream", myGraph, &gm, A.view(), B.view());
|
|
graph.execute().wait();
|
|
|
|
uint32_t errors = 0;
|
|
for (auto it = A.begin(); it != A.end(); ++it) {
|
|
if (*it != 36) ++errors;
|
|
}
|
|
|
|
std::ostringstream msg;
|
|
msg << "Errors: " << errors;
|
|
TEST(errors == 0, "Graph", msg.str().c_str());
|
|
return failed;
|
|
}
|
|
|
|
int main() {
|
|
uint32_t failed = 0;
|
|
std::cout << box("Macro Tests") << "\n";
|
|
failed += doMacroTests();
|
|
|
|
std::cout << box("Array Tests") << "\n";
|
|
// Test different sizes.
|
|
failed += doArrayTests<uint8_t>();
|
|
failed += doArrayTests<int16_t>();
|
|
failed += doArrayTests<int32_t>();
|
|
failed += doArrayTests<real64>();
|
|
|
|
std::cout << box("BLAS Tests") << "\n";
|
|
failed += doBLASTests<real32>();
|
|
failed += doBLASTests<real64>();
|
|
failed += doBLASTests<complex64>();
|
|
failed += doBLASTests<complex128>();
|
|
|
|
std::cout << box("Stream/Graph Tests") << "\n";
|
|
failed += doGraphTest();
|
|
|
|
constexpr uint32_t tests = 2 + 4 * 5 + 13 * 4 + 1;
|
|
std::ostringstream msg;
|
|
msg << ((failed == 0) ? "\033[1;32mPASS \033[0m(" : "\033[1;31mFAIL \033[0m(")
|
|
<< (tests - failed) << "/" << tests << ")";
|
|
std::cout << box2(msg.str()) << "\n";
|
|
|
|
return 0;
|
|
}
|
|
|