Added CUDA Graphs support

main
Kenneth Jao 2 years ago
parent 00a27b66c3
commit a393ff92d2
  1. 11
      Array.h
  2. 13
      BLAS.h
  3. 235
      Core.h
  4. 22
      Macros.h
  5. 2
      Makefile
  6. 2
      README.rst
  7. 78
      tests.cu.cpp

@ -7,6 +7,7 @@
#include <Eigen/Dense> #include <Eigen/Dense>
#include <cmath> #include <cmath>
#include <complex> #include <complex>
#include <cstdlib>
#include <iomanip> #include <iomanip>
#include <random> #include <random>
#include <type_traits> #include <type_traits>
@ -788,12 +789,16 @@ void printAxis(std::ostream& out, const Array<T>& arr, const uint32_t axis, size
template <typename T> std::ostream& operator<<(std::ostream& out, const Array<T>& arr) { template <typename T> std::ostream& operator<<(std::ostream& out, const Array<T>& arr) {
size_t width = 0; size_t width = 0;
if constexpr (is_num<T>) { if constexpr (is_int<T>) {
T max_val = 0; T max_val = 0;
bool negative = false; bool negative = false;
for (auto it = arr.begin(); it != arr.end(); ++it) { for (auto it = arr.begin(); it != arr.end(); ++it) {
if (*it < 0) negative = true; T val = *it;
max_val = (abs(*it) > max_val) ? abs(*it) : max_val; if (*it < 0) {
negative = true;
val *= -1;
}
max_val = (val > max_val) ? val : max_val;
} }
width = std::to_string(max_val).size() + 1; width = std::to_string(max_val).size() + 1;
width += (negative) ? 1 : 0; width += (negative) ? 1 : 0;

@ -235,8 +235,7 @@ StreamID GEMV(const T alpha, const Array<T>& A, const Array<T>& x, const T beta,
uint32_t cols = A.shape().cols(); uint32_t cols = A.shape().cols();
T a = alpha, b = beta; T a = alpha, b = beta;
#ifdef CUDA #ifdef CUDA
CUBLAS_CHECK( CUBLAS_CHECK(cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream)));
cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream.id)));
if (bi.size == 1) { if (bi.size == 1) {
invoke<T>(cublasSgemv, cublasDgemv, cublasCgemv, cublasZgemv, invoke<T>(cublasSgemv, cublasDgemv, cublasCgemv, cublasZgemv,
Manager::get()->cublasHandle(), CUBLAS_OP_N, rows, cols, CAST(&a), Manager::get()->cublasHandle(), CUBLAS_OP_N, rows, cols, CAST(&a),
@ -282,8 +281,7 @@ StreamID GEMM(const T alpha, const Array<T>& A, const Array<T>& B, const T beta,
T a = alpha, b = beta; T a = alpha, b = beta;
#ifdef CUDA #ifdef CUDA
CUBLAS_CHECK( CUBLAS_CHECK(cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream)));
cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream.id)));
if (bi.size == 1) { if (bi.size == 1) {
invoke<T>(cublasSgemm, cublasDgemm, cublasCgemm, cublasZgemm, invoke<T>(cublasSgemm, cublasDgemm, cublasCgemm, cublasZgemm,
Manager::get()->cublasHandle(), CUBLAS_OP_N, CUBLAS_OP_N, m, n, k, CAST(&a), Manager::get()->cublasHandle(), CUBLAS_OP_N, CUBLAS_OP_N, m, n, k, CAST(&a),
@ -338,8 +336,7 @@ StreamID DGMM(const Array<T>& A, const Array<T>& X, const Array<T>& C, const boo
uint32_t m = C.shape().rows(); uint32_t m = C.shape().rows();
uint32_t n = C.shape().cols(); uint32_t n = C.shape().cols();
auto mode = (left) ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT; auto mode = (left) ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT;
CUBLAS_CHECK( CUBLAS_CHECK(cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream)));
cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream.id)));
invoke<T>(cublasSdgmm, cublasDdgmm, cublasCdgmm, cublasZdgmm, Manager::get()->cublasHandle(), m, invoke<T>(cublasSdgmm, cublasDdgmm, cublasCdgmm, cublasZdgmm, Manager::get()->cublasHandle(), m,
n, CAST(A.dataDevice()), A.shape().rows(), CAST(X.dataDevice()), 1, n, CAST(A.dataDevice()), A.shape().rows(), CAST(X.dataDevice()), 1,
CAST(C.dataDevice()), m); CAST(C.dataDevice()), m);
@ -514,7 +511,7 @@ class PLUBatch : public Batch<T> {
#ifdef CUDA #ifdef CUDA
uint32_t n = this->mShape.rows(); uint32_t n = this->mShape.rows();
CUBLAS_CHECK( CUBLAS_CHECK(
cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream.id))); cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream)));
invoke<T>(cublasSgetrfBatched, cublasDgetrfBatched, cublasCgetrfBatched, invoke<T>(cublasSgetrfBatched, cublasDgetrfBatched, cublasCgetrfBatched,
cublasZgetrfBatched, Manager::get()->cublasHandle(), n, cublasZgetrfBatched, Manager::get()->cublasHandle(), n,
DCAST(this->mBatch.dataDevice()), n, mPivotsBatch.dataDevice(), DCAST(this->mBatch.dataDevice()), n, mPivotsBatch.dataDevice(),
@ -546,7 +543,7 @@ class PLUBatch : public Batch<T> {
uint32_t n = b.shape().rows(); uint32_t n = b.shape().rows();
uint32_t nrhs = b.shape().cols(); uint32_t nrhs = b.shape().cols();
CUBLAS_CHECK( CUBLAS_CHECK(
cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream.id))); cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream)));
invoke<T>(cublasSgetrsBatched, cublasDgetrsBatched, cublasCgetrsBatched, invoke<T>(cublasSgetrsBatched, cublasDgetrsBatched, cublasCgetrsBatched,
cublasZgetrsBatched, Manager::get()->cublasHandle(), CUBLAS_OP_N, n, nrhs, cublasZgetrsBatched, Manager::get()->cublasHandle(), CUBLAS_OP_N, n, nrhs,
DCAST(this->mBatch.dataDevice()), n, mPivotsBatch.dataDevice(), DCAST(this->mBatch.dataDevice()), n, mPivotsBatch.dataDevice(),

235
Core.h

@ -2,13 +2,16 @@
#define CUDATOOLS_H #define CUDATOOLS_H
#include "Macros.h" #include "Macros.h"
#include <functional>
#include <iostream> #include <iostream>
#include <string> #include <string>
#include <tuple>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
namespace CudaTools { namespace CudaTools {
struct Event;
/** /**
* Simple wrapper for the name of a stream. Its purposes is to allow for * Simple wrapper for the name of a stream. Its purposes is to allow for
* 'streams' to be passed on host code, and allowing for simple syntax * 'streams' to be passed on host code, and allowing for simple syntax
@ -16,18 +19,19 @@ namespace CudaTools {
*/ */
struct StreamID { struct StreamID {
public: public:
std::string id; std::string mId;
StreamID() : id(""){}; StreamID() : mId(""){};
/** /**
* The constructor for a StreamID. * The constructor for a StreamID.
*/ */
StreamID(const std::string& id_) : id(id_){}; StreamID(const std::string& id_) : mId(id_){};
StreamID(const char* id_) : id(id_){}; StreamID(const char* id_) : mId(id_){};
void wait() const; /**< Makes host wait for this stream. */
/** /**
* Waits for the stream with this stream ID. * Makes this stream wait for this event. Does not block the host.
*/ */
void wait() const; void wait(const Event& event) const;
}; };
static const StreamID DEF_MEM_STREAM = StreamID{"defaultMemory"}; static const StreamID DEF_MEM_STREAM = StreamID{"defaultMemory"};
@ -137,6 +141,20 @@ struct Settings {
*/ */
Settings basic(const size_t threads, const StreamID& stream = DEF_KERNEL_STREAM); Settings basic(const size_t threads, const StreamID& stream = DEF_KERNEL_STREAM);
/**
* Launches a kernel with the provided function, settings and its arguments.
*/
template <typename F, typename... Args>
StreamID launch(F func, const Kernel::Settings& sett, Args... args) {
#ifdef CUDA
func<<<sett.blockGrid, sett.threadBlock, sett.sharedMemoryBytes,
Manager::get()->stream(sett.stream.mId)>>>(args...);
#else
func(args...);
#endif
return sett.stream;
}
}; // namespace Kernel }; // namespace Kernel
template <typename T> class Array; template <typename T> class Array;
@ -186,29 +204,143 @@ class Shape {
std::ostream& operator<<(std::ostream& out, const Shape& s); std::ostream& operator<<(std::ostream& out, const Shape& s);
/**
* A simple class that manages a CUDA Event.
*/
struct Event {
#ifdef CUDACC
cudaEvent_t mEvent;
#endif
Event();
~Event();
void record(const StreamID& stream); /**< Records a event from a stream. */
};
template <typename F, typename... Args> struct FuncHolder {
F mFunc;
std::tuple<Args...> mArgs;
FuncHolder() = delete;
FuncHolder(F func, Args... args) : mFunc(func), mArgs(std::make_tuple(args...)){};
static void run(void* data) {
FuncHolder<F, Args...>* fh = (FuncHolder<F, Args...>*)(data);
std::apply([fh](auto&&... args) { fh->mFunc(args...); }, fh->mArgs);
};
};
/**
* Accessory struct to deal with host callbacks for CUDA Graphs in a nice fashion.
*/
struct GraphTools {
std::vector<void*> mHostData;
std::vector<Event*> mEvents;
~GraphTools();
/**
* Within a function that is being stream captured, launch a host function that can
* be captured into the graph.
*/
template <typename F, typename... Args>
void launchHostFunction(const StreamID& stream, F func, Args&&... args) {
#ifdef CUDACC
FuncHolder<F, Args...>* fh = new FuncHolder<F, Args...>(func, args...);
mHostData.push_back((void*)fh);
cudaHostFn_t run_func = fh->run;
CUDA_CHECK(cudaLaunchHostFunc(Manager::get()->stream(stream), run_func, fh));
#else
func(args...);
#endif
}
/**
* Makes a new branch in the graph to be run in parallel by a new stream.
* \param orig_stream the original stream to branch from.
* \param branch_stream the stream of the new branch.
*/
void makeBranch(const StreamID& orig_stream, const StreamID& branch_stream);
/**
* Joins a existing branch in the graph to collapse a parallel block.
* \param orig_stream the original stream to join the branch to.
* \param branch_stream the stream of the branch to join.
*/
void joinBranch(const StreamID& orig_stream, const StreamID& branch_stream);
};
/**
* A class that manages CUDA Graphs.
*/
template <typename F, typename... Args> class Graph {
private:
#ifdef CUDACC
cudaGraph_t mGraph;
cudaGraphExec_t mInstance;
#endif
FuncHolder<F, Args...> mFuncHolder;
StreamID mStream;
public:
Graph() = delete;
/**
* The constructor for a Graph, which captures the function.
* \param func the function to capture.
* \param stream the origin stream to use.
* \param args the arguments of the function.
*/
Graph(const StreamID& stream, F func, Args... args)
: mFuncHolder(func, args...), mStream(stream) {
#ifdef CUDACC
CUDA_CHECK(
cudaStreamBeginCapture(Manager::get()->stream(mStream), cudaStreamCaptureModeGlobal));
mFuncHolder.run((void*)&mFuncHolder);
CUDA_CHECK(cudaStreamEndCapture(Manager::get()->stream(mStream), &mGraph));
CUDA_CHECK(cudaGraphInstantiate(&mInstance, mGraph, NULL, NULL, 0));
#endif
};
~Graph() {
#ifdef CUDACC
CUDA_CHECK(cudaGraphDestroy(mGraph));
CUDA_CHECK(cudaGraphExecDestroy(mInstance));
#endif
};
/**
* Executes the instantiated graph, or simply runs the function with provided
* arguments if compiling for CPU.
*/
StreamID execute() const {
#ifdef CUDACC
cudaGraphLaunch(mInstance, Manager::get()->stream(mStream));
#else
mFuncHolder.run((void*)&mFuncHolder);
#endif
return mStream;
}
};
}; // namespace CudaTools }; // namespace CudaTools
#ifdef CUDATOOLS_IMPLEMENTATION #ifdef CUDATOOLS_IMPLEMENTATION
namespace CudaTools { namespace CudaTools {
template <typename T, typename... Args> //////////////////////
StreamID runKernel(T func, const Kernel::Settings& sett, Args... args) { // StreamID Methods //
#ifdef CUDA //////////////////////
func<<<sett.blockGrid, sett.threadBlock, sett.sharedMemoryBytes,
Manager::get()->stream(sett.stream.id)>>>(args...); void StreamID::wait() const { Manager::get()->waitFor(mId); }
#else
func(args...); void StreamID::wait(const Event& event) const {
#ifdef CUDACC
CUDA_CHECK(cudaStreamWaitEvent(Manager::get()->stream(mId), event.mEvent, 0));
#endif #endif
return sett.stream;
} }
//////////////////// ////////////////////
// Memory Methods // // Memory Methods //
//////////////////// ////////////////////
void StreamID::wait() const { Manager::get()->waitFor(id); }
void* malloc(const size_t size) { void* malloc(const size_t size) {
#ifdef CUDACC #ifdef CUDACC
void* pDevice; void* pDevice;
@ -228,7 +360,7 @@ void free(void* const pDevice) {
StreamID push(void* const pHost, void* const pDevice, const size_t size, const StreamID& stream) { StreamID push(void* const pHost, void* const pDevice, const size_t size, const StreamID& stream) {
#ifdef CUDACC #ifdef CUDACC
CUDA_CHECK(cudaMemcpyAsync(pDevice, pHost, size, cudaMemcpyHostToDevice, CUDA_CHECK(cudaMemcpyAsync(pDevice, pHost, size, cudaMemcpyHostToDevice,
Manager::get()->stream(stream.id))); Manager::get()->stream(stream)));
#endif #endif
return stream; return stream;
} }
@ -236,7 +368,7 @@ StreamID push(void* const pHost, void* const pDevice, const size_t size, const S
StreamID pull(void* const pHost, void* const pDevice, const size_t size, const StreamID& stream) { StreamID pull(void* const pHost, void* const pDevice, const size_t size, const StreamID& stream) {
#ifdef CUDACC #ifdef CUDACC
CUDA_CHECK(cudaMemcpyAsync(pHost, pDevice, size, cudaMemcpyDeviceToHost, CUDA_CHECK(cudaMemcpyAsync(pHost, pDevice, size, cudaMemcpyDeviceToHost,
Manager::get()->stream(stream.id))); Manager::get()->stream(stream)));
#endif #endif
return stream; return stream;
} }
@ -245,7 +377,7 @@ StreamID deviceCopy(void* const pSrc, void* const pDest, const size_t size,
const StreamID& stream) { const StreamID& stream) {
#ifdef CUDACC #ifdef CUDACC
CUDA_CHECK(cudaMemcpyAsync(pDest, pSrc, size, cudaMemcpyDeviceToDevice, CUDA_CHECK(cudaMemcpyAsync(pDest, pSrc, size, cudaMemcpyDeviceToDevice,
Manager::get()->stream(stream.id))); Manager::get()->stream(stream)));
#endif #endif
return stream; return stream;
} }
@ -289,11 +421,11 @@ Manager::~Manager() {
void Manager::waitFor(const StreamID& stream) const { void Manager::waitFor(const StreamID& stream) const {
#ifdef CUDACC #ifdef CUDACC
auto it = mStreams.find(stream.id); auto it = mStreams.find(stream.mId);
if (it != mStreams.end()) { if (it != mStreams.end()) {
CUDA_CHECK(cudaStreamSynchronize(it->second)); CUDA_CHECK(cudaStreamSynchronize(it->second));
} else { } else {
CT_ERROR(true, ("Invalid stream " + stream.id).c_str()); CT_ERROR(true, ("Invalid stream " + stream.mId).c_str());
} }
#endif #endif
} }
@ -314,11 +446,11 @@ void Manager::addStream(const std::string& name) {
#ifdef CUDACC #ifdef CUDACC
cudaStream_t Manager::stream(const StreamID& stream) const { cudaStream_t Manager::stream(const StreamID& stream) const {
auto it = mStreams.find(stream.id); auto it = mStreams.find(stream.mId);
if (it != mStreams.end()) { if (it != mStreams.end()) {
return it->second; return it->second;
} else { } else {
CT_ERROR(true, ("Invalid stream " + stream.id).c_str()); CT_ERROR(true, ("Invalid stream " + stream.mId).c_str());
} }
} }
@ -407,7 +539,7 @@ void Settings::setSharedMemSize(const size_t bytes) {
void Settings::setStream(const StreamID& stream_) { void Settings::setStream(const StreamID& stream_) {
#ifdef CUDACC #ifdef CUDACC
stream.id = stream_.id; stream = stream_;
#endif #endif
} }
@ -425,7 +557,8 @@ Settings basic(const size_t threads, const StreamID& stream) {
#endif #endif
return sett; return sett;
} }
} // namespace Kernel
}; // namespace Kernel
///////////////////// /////////////////////
// Shape Functions // // Shape Functions //
@ -506,6 +639,57 @@ std::ostream& operator<<(std::ostream& out, const Shape& s) {
return out << s.dim(s.axes() - 1) << ")"; return out << s.dim(s.axes() - 1) << ")";
} }
///////////////////
// Event Methods //
///////////////////
Event::Event() {
#ifdef CUDACC
CUDA_CHECK(cudaEventCreate(&mEvent));
#endif
}
Event::~Event() {
#ifdef CUDACC
CUDA_CHECK(cudaEventDestroy(mEvent));
#endif
}
void Event::record(const StreamID& stream) {
#ifdef CUDACC
CUDA_CHECK(cudaEventRecord(mEvent, Manager::get()->stream(stream)));
#endif
}
////////////////////////
// GraphTools Methods //
////////////////////////
GraphTools::~GraphTools() {
#ifdef CUDACC
for (void* func : mHostData) {
delete func;
}
for (Event* event : mEvents) {
delete event;
}
#endif
}
void GraphTools::makeBranch(const StreamID& orig_stream, const StreamID& branch_stream) {
Event* event = new Event();
event->record(orig_stream);
mEvents.push_back(event);
branch_stream.wait(*event);
}
void GraphTools::joinBranch(const StreamID& orig_stream, const StreamID& branch_stream) {
Event* event = new Event();
event->record(branch_stream);
mEvents.push_back(event);
orig_stream.wait(*event);
}
#ifdef CUDACC #ifdef CUDACC
const char* cublasGetErrorString(cublasStatus_t error) { const char* cublasGetErrorString(cublasStatus_t error) {
switch (error) { switch (error) {
@ -537,7 +721,6 @@ const char* cublasGetErrorString(cublasStatus_t error) {
return "<unknown>"; return "<unknown>";
} }
#endif #endif
}; // namespace CudaTools }; // namespace CudaTools
#endif // CUDATOOLS_IMPLEMENTATION #endif // CUDATOOLS_IMPLEMENTATION

@ -145,27 +145,17 @@ using real64 = double; /**< Type alias for 64-bit floating point datatype. */
#define HD __host__ __device__ #define HD __host__ __device__
#define SHARED __shared__ #define SHARED __shared__
#define DECLARE_KERNEL(call, ...) __global__ void call(__VA_ARGS__) #define KERNEL(call, ...) __global__ void call(__VA_ARGS__)
#define DEFINE_KERNEL(call, ...) \
template CudaTools::StreamID CudaTools::runKernel( \
void (*)(__VA_ARGS__), const CudaTools::Kernel::Settings&, __VA_ARGS__); \
__global__ void call(__VA_ARGS__)
#else #else
#define HD #define HD
#define SHARED #define SHARED
#define DECLARE_KERNEL(call, ...) void call(__VA_ARGS__) #define KERNEL(call, ...) void call(__VA_ARGS__)
#define DEFINE_KERNEL(call, ...) \
template CudaTools::StreamID CudaTools::runKernel( \
void (*)(__VA_ARGS__), const CudaTools::Kernel::Settings&, __VA_ARGS__); \
void call(__VA_ARGS__)
#endif // CUDACC #endif // CUDACC
#define KERNEL(call, settings, ...) CudaTools::runKernel(call, settings, __VA_ARGS__) //#define KERNEL(call, settings, ...) CudaTools::runKernel(call, settings, __VA_ARGS__)
/////////////////// ///////////////////
// DEVICE MACROS // // DEVICE MACROS //
@ -218,8 +208,10 @@ using real64 = double; /**< Type alias for 64-bit floating point datatype. */
#ifndef CUDATOOLS_ARRAY_MAX_AXES #ifndef CUDATOOLS_ARRAY_MAX_AXES
/** /**
* \def CUDATOOLS_ARRAY_MAX_AXES * \def CUDATOOLS_ARRAY_MAX_AXES
* The maximum number of axes/dimensions an CudaTools::Array can have. The default is * The maximum number of axes/dimensions an
* set to 4, but can be manully set fit the program needs. * CudaTools::Array can have. The default is set
* to 4, but can be manully set fit the program
* needs.
*/ */
#define CUDATOOLS_ARRAY_MAX_AXES 4 #define CUDATOOLS_ARRAY_MAX_AXES 4
#endif #endif

@ -1,7 +1,7 @@
CC := g++-10 CC := g++-10
NVCC := nvcc NVCC := nvcc
CFLAGS := -Wall -std=c++17 -fopenmp -MMD CFLAGS := -Wall -std=c++17 -fopenmp -MMD
NVCC_FLAGS := -MMD -w -Xcompiler NVCC_FLAGS := -MMD -std=c++17 -w -Xcompiler
INCLUDE := INCLUDE :=
LIBS_DIR := LIBS_DIR :=

@ -31,7 +31,7 @@ After installing the required Python packages
.. code-block:: bash .. code-block:: bash
$ pip install -r requirements $ pip install -r requirements.txt
you can now run the script you can now run the script

@ -97,18 +97,36 @@ class TestClass {
}; };
}; };
DEFINE_KERNEL(times, const CT::Array<int> arr) { KERNEL(times, const CT::Array<int> arr) {
BASIC_LOOP(arr.shape().length()) { arr[iThread] *= 2; } BASIC_LOOP(arr.shape().length()) { arr[iThread] *= 2; }
} }
DEFINE_KERNEL(classTest, TestClass* const test) { test->x = 100; } KERNEL(classTest, TestClass* const test) { test->x = 100; }
KERNEL(collatz, const CT::Array<uint32_t> arr) {
BASIC_LOOP(arr.shape().length()) {
if (arr[iThread] % 2) {
arr[iThread] = 3 * arr[iThread] + 1;
} else {
arr[iThread] = arr[iThread] >> 1;
}
}
}
KERNEL(plusOne, const CT::Array<uint32_t> arr) {
BASIC_LOOP(arr.shape().length()) { arr[iThread] += 1; }
}
KERNEL(addBoth, const CT::Array<uint32_t> a, const CT::Array<uint32_t> b) {
BASIC_LOOP(a.shape().length()) { a[iThread] += b[iThread]; }
}
struct MacroTests { struct MacroTests {
static uint32_t Kernel() { static uint32_t Kernel() {
uint32_t failed = 0; uint32_t failed = 0;
CT::Array<int> A = CT::Array<int>::constant({10}, 1); CT::Array<int> A = CT::Array<int>::constant({10}, 1);
A.updateDevice().wait(); A.updateDevice().wait();
KERNEL(times, CT::Kernel::basic(A.shape().items()), A.view()).wait(); CT::Kernel::launch(times, CT::Kernel::basic(A.shape().items()), A.view()).wait();
A.updateHost().wait(); A.updateHost().wait();
uint32_t errors = 0; uint32_t errors = 0;
@ -125,7 +143,7 @@ struct MacroTests {
static uint32_t Class() { static uint32_t Class() {
uint32_t failed = 0; uint32_t failed = 0;
TestClass test(1); TestClass test(1);
KERNEL(classTest, CT::Kernel::basic(1), test.that()).wait(); CT::Kernel::launch(classTest, CT::Kernel::basic(1), test.that()).wait();
test.updateHost().wait(); test.updateHost().wait();
TEST(test.x == 100, "Class", "Errors: 0"); TEST(test.x == 100, "Class", "Errors: 0");
@ -473,6 +491,53 @@ template <typename T> uint32_t doBLASTests() {
return failed; return failed;
} }
void myHostFunc(const CT::Array<uint32_t> A, uint32_t num) {
auto Aeig = A.atLeast2D().eigenMap();
Aeig = Aeig.array() + num;
}
void myBasicGraph(CT::GraphTools* tools, CT::Array<uint32_t>* A, CT::Array<uint32_t>* B) {
// tools->launchHostFunction("graphStream", myHostFunc, A->view(), 5);
A->updateDevice("graphStream");
tools->makeBranch("graphStream", "graphStreamBranch");
B->updateDevice("graphStreamBranch");
for (uint32_t iTimes = 0; iTimes < 30; ++iTimes) {
CT::Kernel::launch(collatz, CT::Kernel::basic(A->shape().items(), "graphStream"),
A->view());
CT::Kernel::launch(plusOne, CT::Kernel::basic(A->shape().items(), "graphStreamBranch"),
B->view());
}
tools->joinBranch("graphStream", "graphStreamBranch");
CT::Kernel::launch(addBoth, CT::Kernel::basic(A->shape().items(), "graphStream"), A->view(),
B->view());
A->updateHost("graphStream");
B->updateHost("graphStream");
tools->launchHostFunction("graphStream", myHostFunc, A->view(), 5);
}
uint32_t doGraphTest() {
uint32_t failed = 0;
CT::Array<uint32_t> A = CT::Array<uint32_t>::constant({1000000}, 50);
CT::Array<uint32_t> B = CT::Array<uint32_t>::constant({1000000}, 0);
CT::Manager::get()->addStream("graphStream");
CT::Manager::get()->addStream("graphStreamBranch");
CT::GraphTools tools;
CT::Graph graph("graphStream", myBasicGraph, &tools, &A, &B);
graph.execute().wait();
uint32_t errors = 0;
for (auto it = A.begin(); it != A.end(); ++it) {
if (*it != 36) ++errors;
}
std::ostringstream msg;
msg << "Errors: " << errors;
TEST(errors == 0, "Graph", msg.str().c_str());
return failed;
}
int main() { int main() {
uint32_t failed = 0; uint32_t failed = 0;
std::cout << box("Macro Tests") << "\n"; std::cout << box("Macro Tests") << "\n";
@ -491,7 +556,10 @@ int main() {
failed += doBLASTests<complex64>(); failed += doBLASTests<complex64>();
failed += doBLASTests<complex128>(); failed += doBLASTests<complex128>();
constexpr uint32_t tests = 2 + 4 * 5 + 13 * 4; std::cout << box("Stream/Graph Tests") << "\n";
failed += doGraphTest();
constexpr uint32_t tests = 2 + 4 * 5 + 13 * 4 + 1;
std::ostringstream msg; std::ostringstream msg;
msg << ((failed == 0) ? "\033[1;32mPASS \033[0m(" : "\033[1;31mFAIL \033[0m(") msg << ((failed == 0) ? "\033[1;32mPASS \033[0m(" : "\033[1;31mFAIL \033[0m(")
<< (tests - failed) << "/" << tests << ")"; << (tests - failed) << "/" << tests << ")";

Loading…
Cancel
Save