#define CUDATOOLS_IMPLEMENTATION #define CUDATOOLS_ARRAY_MAX_AXES 8 #include #include #include #include #include #include #include using namespace CudaTools::Types; namespace CT = CudaTools; ///////////// // Helpers // ///////////// #define TIME_START(name) auto begin_##name = std::chrono::steady_clock::now() #define TIME_END(name) \ auto end_##name = std::chrono::steady_clock::now(); \ auto time_ms_##name = \ std::chrono::duration_cast(end_##name - begin_##name).count(); \ auto time_mus_##name = \ std::chrono::duration_cast(end_##name - begin_##name).count(); \ if (time_ms_##name == 0) { \ printf("[%s] Time Elapsed: %ld[µs]\n", #name, time_mus_##name); \ } else { \ printf("[%s] Time Elapsed: %ld[ms]\n", #name, time_ms_##name); \ } #define TIME(call, name) \ TIME_START(name); \ call; \ TIME_END(name); #define TEST(predicate, name, msg) \ failed += (predicate) ? 0 : 1; \ printf("[%s] ", (predicate) ? "\033[1;32mPASS\033[0m" : "\033[1;31mFAIL\033[0m"); \ printf("%s | %s.\n", name, msg); template struct Type; #define REGISTER_PARSE_TYPE(X) \ template <> struct Type { static const std::string name; }; \ const std::string Type::name = #X REGISTER_PARSE_TYPE(uint8_t); REGISTER_PARSE_TYPE(int16_t); REGISTER_PARSE_TYPE(int32_t); REGISTER_PARSE_TYPE(real32); REGISTER_PARSE_TYPE(real64); REGISTER_PARSE_TYPE(complex64); REGISTER_PARSE_TYPE(complex128); std::string box(std::string str) { std::string tops(str.size() + 6, '#'); return tops + "\n## " + str + " ##\n" + tops; } std::string box2(std::string str) { std::string tops(str.size() - 5, '-'); return tops + "\n|| " + str + " ||\n" + tops; } std::string boxSmall(std::string str) { std::string tops(6, '-'); return tops + "[ " + str + " ]" + tops; } std::string separator() { std::string line(40, '='); return "\n" + line + "\n"; } template std::string type() { return "\033[1;96m" + Type::name + "\033[0m"; } CT::Shape makeRandom2DShape() { std::random_device rd; std::mt19937 mt(rd()); std::uniform_int_distribution dist(1, 15); return CT::Shape({dist(mt), dist(mt)}); } /////////// // Tests // /////////// class TestClass { DEVICE_COPY(TestClass); public: int x; TestClass(const int x) : x(x) { allocateDevice(); updateDevice().wait(); }; }; KERNEL(times, const CT::Array arr) { BASIC_LOOP(arr.shape().length()) { arr[iThread] *= 2; } } KERNEL(classTest, TestClass* const test) { test->x = 100; } KERNEL(collatz, const CT::Array arr) { BASIC_LOOP(arr.shape().length()) { if (arr[iThread] % 2) { arr[iThread] = 3 * arr[iThread] + 1; } else { arr[iThread] = arr[iThread] >> 1; } } } KERNEL(plusOne, const CT::Array arr) { BASIC_LOOP(arr.shape().length()) { arr[iThread] += 1; } } KERNEL(addArray, const CT::Array a, const CT::Array b) { BASIC_LOOP(a.shape().length()) { a[iThread] += b[iThread]; } } struct MacroTests { static uint32_t Kernel() { uint32_t failed = 0; CT::Array A = CT::Array::constant({10}, 1); A.updateDevice().wait(); CT::Kernel::launch(times, CT::Kernel::basic(A.shape().items()), A.view()).wait(); A.updateHost().wait(); uint32_t errors = 0; for (auto it = A.begin(); it != A.end(); ++it) { if (*it != 2) ++errors; } std::ostringstream msg; msg << "Errors: " << errors; TEST(errors == 0, "Kernel", msg.str().c_str()); return failed; }; static uint32_t Class() { uint32_t failed = 0; TestClass test(1); CT::Kernel::launch(classTest, CT::Kernel::basic(1), test.that()).wait(); test.updateHost().wait(); TEST(test.x == 100, "Class", "Errors: 0"); return failed; } }; template struct ArrayTests { static uint32_t Indexing() { uint32_t failed = 0; CT::Array A = CT::Array::range(0, 240); A.reshape({5, 3, 1, 4, 2, 1, 1, 2}); uint32_t errors = 0; for (uint32_t i = 0; i < 5; ++i) { for (uint32_t j = 0; j < 3; ++j) { for (uint32_t k = 0; k < 4; ++k) { for (uint32_t l = 0; l < 2; ++l) { for (uint32_t m = 0; m < 2; ++m) { if ((T)A[i][j][0][k][l][0][0][m] != (T)A[{i, j, 0, k, l, 0, 0, m}]) { ++errors; } } } } } } std::ostringstream msg; msg << "Errors: " << errors; TEST(errors == 0, "Element", msg.str().c_str()); errors = 0; CT::Array ApartGroup_1 = A[{2, 2}]; CT::Array ApartIndiv_1 = A[2][2]; for (uint32_t k = 0; k < 4; ++k) { for (uint32_t l = 0; l < 2; ++l) { for (uint32_t m = 0; m < 2; ++m) { if ((T)ApartIndiv_1[0][k][l][0][0][m] != (T)ApartGroup_1[{0, k, l, 0, 0, m}]) { ++errors; } } } } msg.str(""); msg << "Errors: " << errors; TEST(errors == 0, "Axis (1/2)", msg.str().c_str()); errors = 0; CT::Array ApartGroup_2 = A[{3, 2, 0, 3}]; CT::Array ApartIndiv_2 = A[3][2][0][3]; for (uint32_t l = 0; l < 2; ++l) { for (uint32_t m = 0; m < 2; ++m) { if ((T)ApartIndiv_2[l][0][0][m] != (T)ApartGroup_2[{l, 0, 0, m}]) { ++errors; } } } msg.str(""); msg << "Errors: " << errors; TEST(errors == 0, "Axis (2/2)", msg.str().c_str()); return failed; }; static uint32_t Slicing() { uint32_t failed = 0; CT::Array A = CT::Array::constant({4, 5, 5}, 0); CT::Array Aslice = A.slice({{0, 3}, {1, 4}, {1, 4}}); T num = (T)1; for (auto it = Aslice.begin(); it != Aslice.end(); ++it) { *it = num; ++num; } CT::Array Aslice2 = A[3].slice({{0, 5}, {0, 1}}); num = (T)-1; for (auto it = Aslice2.begin(); it != Aslice2.end(); ++it) { *it = num; --num; } uint32_t errors = 0; for (int i = 0; i < 3; ++i) { for (int j = 0; j < 3; ++j) { for (int k = 0; k < 3; ++k) { if ((T)A[i][1 + j][1 + k] != (T)(9 * i + 3 * j + k + 1)) { ++errors; } } } } std::ostringstream msg; msg << "Errors: " << errors; TEST(errors == 0, "Block", msg.str().c_str()); errors = 0; for (int i = 0; i < 5; ++i) { if ((T)A[3][i][0] != (T)(-(i + 1))) { ++errors; } } msg.str(""); msg << "Errors: " << errors; TEST(errors == 0, "Column", msg.str().c_str()); return failed; } }; template struct BLASTests { static double thres; static uint32_t GEMV(int attempts) { uint32_t failed = 0; for (int i = 0; i < attempts; i++) { CT::Shape Ashape = makeRandom2DShape(); CT::Shape xshape = CT::Shape({Ashape.cols(), 1}); CT::Shape yshape = CT::Shape({Ashape.rows(), 1}); CT::Array A(Ashape); CT::Array x(xshape); CT::Array y(yshape); A.setRandom(-100, 100); x.setRandom(-100, 100); A.updateDevice(); x.updateDevice().wait(); CT::BLAS::GEMV(1.0, A, x, 0.0, y).wait(); y.updateHost().wait(); CT::Array yTest(yshape, true); yTest.eigenMap() = A.eigenMap() * x.eigenMap(); double norm = (y.eigenMap() - y.eigenMap()).norm(); std::ostringstream name; name << "GEMV (" << i + 1 << "/" << attempts << ")"; std::ostringstream msg; msg << "Matrix Shape: " << Ashape << ", " << "Residual: " << norm; TEST(norm < thres, name.str().c_str(), msg.str().c_str()); } return failed; }; static uint32_t GEMVBroadcast() { uint32_t failed = 0; CT::Shape Ashape = makeRandom2DShape(); CT::Shape xshape = CT::Shape({Ashape.cols(), 1}); CT::Shape yshape = CT::Shape({Ashape.rows(), 1}); CT::Array A({2, 3, Ashape.rows(), Ashape.cols()}); CT::Array x({2, 3, xshape.rows(), xshape.cols()}); CT::Array y({2, 3, yshape.rows(), yshape.cols()}); A.setRandom(-100, 100); x.setRandom(-100, 100); A.updateDevice(); x.updateDevice().wait(); CT::BLAS::GEMV(1.0, A, x, 0.0, y).wait(); y.updateHost().wait(); double norm = 0; CT::Array yTest(yshape, true); for (int i = 0; i < 2; ++i) { for (int j = 0; j < 3; ++j) { yTest.eigenMap() = A[i][j].eigenMap() * x[i][j].eigenMap(); norm += (yTest.eigenMap() - y[i][j].eigenMap()).norm(); } } std::ostringstream msg; msg << "Matrix Shape: " << Ashape << ", " << "Residual: " << norm; TEST(norm < thres, "GEMV Broadcast", msg.str().c_str()); return failed; }; static uint32_t GEMM(int attempts) { uint32_t failed = 0; for (int i = 0; i < attempts; i++) { CT::Shape Ashape = makeRandom2DShape(); CT::Shape Bshape = makeRandom2DShape(); Bshape = CT::Shape({Ashape.cols(), Bshape.cols()}); CT::Shape Cshape = CT::Shape({Ashape.rows(), Bshape.cols()}); CT::Array A(Ashape); CT::Array B(Bshape); CT::Array C(Cshape); A.setRandom(-100, 100); B.setRandom(-100, 100); C.setRandom(-100, 100); A.updateDevice(); B.updateDevice(); C.updateDevice().wait(); CT::BLAS::GEMM(1.0, A, B, 0.0, C).wait(); C.updateHost().wait(); CT::Array CTest(Cshape, true); CTest.eigenMap() = A.eigenMap() * B.eigenMap(); double norm = (CTest.eigenMap() - C.eigenMap()).norm(); std::ostringstream name; name << "GEMM (" << i + 1 << "/" << attempts << ")"; std::ostringstream msg; msg << "Matrix Shapes: " << Ashape << Bshape << ", " << "Residual: " << norm; TEST(norm < thres, name.str().c_str(), msg.str().c_str()); } return failed; }; static uint32_t GEMMBroadcast() { uint32_t failed = 0; CT::Shape Ashape = makeRandom2DShape(); CT::Shape Bshape = makeRandom2DShape(); Bshape = CT::Shape({Ashape.cols(), Bshape.cols()}); CT::Shape Cshape = CT::Shape({Ashape.rows(), Bshape.cols()}); CT::Array A({2, 3, Ashape.rows(), Ashape.cols()}); CT::Array B({2, 3, Bshape.rows(), Bshape.cols()}); CT::Array C({2, 3, Cshape.rows(), Cshape.cols()}); A.setRandom(-100, 100); B.setRandom(-100, 100); A.updateDevice(); B.updateDevice(); C.updateDevice().wait(); CT::BLAS::GEMM(1.0, A, B, 0.0, C).wait(); C.updateHost().wait(); double norm = 0; CT::Array CTest(Cshape, true); for (int i = 0; i < 2; ++i) { for (int j = 0; j < 3; ++j) { CTest.eigenMap() = A[i][j].eigenMap() * B[i][j].eigenMap(); norm += (CTest.eigenMap() - C[i][j].eigenMap()).norm(); } } std::ostringstream msg; msg << "Matrix Shapes: " << Ashape << Bshape << ", " << "Residual: " << norm; TEST(norm < thres, "GEMM Broadcast", msg.str().c_str()); return failed; }; static uint32_t PLU() { uint32_t failed = 0; CT::Shape Ashape = makeRandom2DShape(); CT::Shape xshape = makeRandom2DShape(); Ashape = CT::Shape({Ashape.rows(), Ashape.rows()}); xshape = CT::Shape({Ashape.rows(), xshape.cols()}); CT::Array A({2, 3, Ashape.rows(), Ashape.rows()}); CT::Array x({2, 3, xshape.rows(), xshape.cols()}); CT::Array b({2, 3, xshape.rows(), xshape.cols()}); CT::Array Ax({2, 3, xshape.rows(), xshape.cols()}); A.setRandom(-100, 100); b.setRandom(-100, 100); CT::Array LU(A.copy()); x = b; A.updateDevice(); LU.updateDevice(); x.updateDevice().wait(); CT::BLAS::PLUBatch luBatch(LU); CT::BLAS::Batch xBatch(x); luBatch.computeLU().wait(); luBatch.solve(xBatch).wait(); // Compute Ax and compare difference. CT::BLAS::GEMM(1.0, A, x, 0.0, Ax).wait(); Ax.updateHost(); double norm = 0; for (int i = 0; i < 2; ++i) { for (int j = 0; j < 3; ++j) { norm += (Ax[i][j].eigenMap() - b[i][j].eigenMap()).norm(); } } std::ostringstream msg; msg << "Matrix Shape: " << Ashape << xshape << ", " << "Residual: " << norm; TEST(norm < thres, "PLU/Solve", msg.str().c_str()); return failed; } }; template <> double BLASTests::thres = 10e-1; template <> double BLASTests::thres = 10e-8; template <> double BLASTests::thres = 10e-1; template <> double BLASTests::thres = 10e-8; uint32_t doMacroTests() { uint32_t failed = 0; failed += MacroTests::Kernel(); failed += MacroTests::Class(); std::cout << "\n"; return failed; } template uint32_t doArrayTests() { uint32_t failed = 0; std::cout << boxSmall("Index Tests : " + type()) << "\n"; failed += ArrayTests::Indexing(); std::cout << "\n" << boxSmall("Slice Tests : " + type()) << "\n"; failed += ArrayTests::Slicing(); std::cout << "\n"; return failed; } template uint32_t doBLASTests() { uint32_t failed = 0; std::cout << boxSmall("GEMV Tests : " + type()) << "\n"; failed += BLASTests::GEMV(5); failed += BLASTests::GEMVBroadcast(); std::cout << "\n" << boxSmall("GEMM Tests : " + type()) << "\n"; failed += BLASTests::GEMM(5); failed += BLASTests::GEMMBroadcast(); std::cout << "\n" << boxSmall("PLU Tests : " + type()) << "\n"; failed += BLASTests::PLU(); std::cout << "\n"; return failed; } void addNum(const CT::Array A, uint32_t num) { auto Aeig = A.atLeast2D().eigenMap(); Aeig = Aeig.array() + num; } void myGraph(CT::GraphManager* gm, const CT::Array A, const CT::Array B) { // tools->launchHostFunction("graphStream", myHostFunc, A->view(), 5); A.updateDevice("graphStream"); gm->makeBranch("graphStream", "graphStreamBranch"); B.updateDevice("graphStreamBranch"); for (uint32_t iTimes = 0; iTimes < 30; ++iTimes) { CT::Kernel::launch(collatz, CT::Kernel::basic(A.shape().items(), "graphStream"), A.view()); CT::Kernel::launch(plusOne, CT::Kernel::basic(A.shape().items(), "graphStreamBranch"), B.view()); } gm->joinBranch("graphStream", "graphStreamBranch"); CT::Kernel::launch(addArray, CT::Kernel::basic(A.shape().items(), "graphStream"), A.view(), B.view()); A.updateHost("graphStream"); B.updateHost("graphStream"); gm->launchHostFunction("graphStream", addNum, A.view(), 5); } uint32_t doGraphTest() { uint32_t failed = 0; CT::Array A = CT::Array::constant({1000000}, 50); CT::Array B = CT::Array::constant({1000000}, 0); CT::Manager::get()->addStream("graphStream"); CT::Manager::get()->addStream("graphStreamBranch"); CT::GraphManager gm; CT::Graph graph("graphStream", myGraph, &gm, A.view(), B.view()); graph.execute().wait(); uint32_t errors = 0; for (auto it = A.begin(); it != A.end(); ++it) { if (*it != 36) ++errors; } std::ostringstream msg; msg << "Errors: " << errors; TEST(errors == 0, "Graph", msg.str().c_str()); return failed; } int main() { uint32_t failed = 0; std::cout << box("Macro Tests") << "\n"; failed += doMacroTests(); std::cout << box("Array Tests") << "\n"; // Test different sizes. failed += doArrayTests(); failed += doArrayTests(); failed += doArrayTests(); failed += doArrayTests(); std::cout << box("BLAS Tests") << "\n"; failed += doBLASTests(); failed += doBLASTests(); failed += doBLASTests(); failed += doBLASTests(); std::cout << box("Stream/Graph Tests") << "\n"; failed += doGraphTest(); constexpr uint32_t tests = 2 + 4 * 5 + 13 * 4 + 1; std::ostringstream msg; msg << ((failed == 0) ? "\033[1;32mPASS \033[0m(" : "\033[1;31mFAIL \033[0m(") << (tests - failed) << "/" << tests << ")"; std::cout << box2(msg.str()) << "\n"; return 0; }