A library and framework for developing CPU-CUDA compatible applications under one unified code.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

50 lines
2.1 KiB

#define CUDATOOLS_IMPLEMENTATION
#include <Array.h>
#include <Core.h>
KERNEL(times2, const CudaTools::Array<int> arr) {
CudaTools::Array<int> flat = arr.flattened();
BASIC_LOOP(arr.shape().items()) { flat[iThread] *= 2; }
}
KERNEL(times2double, const CudaTools::Array<double> arr) {
CudaTools::Array<double> flat = arr.flattened();
BASIC_LOOP(arr.shape().items()) { flat[iThread] *= 2; }
}
int main() {
CudaTools::Array<int> arrRange = CudaTools::Array<int>::range(0, 10);
CudaTools::Array<int> arrConst = CudaTools::Array<int>::constant({10}, 1);
CudaTools::Array<double> arrLinspace = CudaTools::Array<double>::linspace(0, 5, 10);
CudaTools::Array<int> arrComma({2, 2}); // 2x2 array.
arrComma << 1, 2, 3, 4; // Comma initializer if needed.
arrRange.updateDevice();
arrConst.updateDevice();
arrLinspace.updateDevice();
arrComma.updateDevice().wait();
std::cout << "Before Kernel:\n";
std::cout << arrRange << "\n" << arrConst << "\n" << arrLinspace << "\n" << arrComma << "\n";
// Call the kernel multiple times asynchronously. Note: since they share same
// stream, they are not run in parallel, just queued on the device.
// NOTE: Notice that a view is passed into the kernel, not the Array itself.
CudaTools::Kernel::launch(times2, CudaTools::Kernel::basic(arrRange.shape().items()),
arrRange.view());
CudaTools::Kernel::launch(times2, CudaTools::Kernel::basic(arrConst.shape().items()),
arrConst.view());
CudaTools::Kernel::launch(times2double, CudaTools::Kernel::basic(arrLinspace.shape().items()),
arrLinspace.view());
CudaTools::Kernel::launch(times2, CudaTools::Kernel::basic(arrComma.shape().items()),
arrComma.view())
.wait();
arrRange.updateHost();
arrConst.updateHost();
arrLinspace.updateHost();
arrComma.updateHost().wait(); // Same stream, so you should wait for the last call.
std::cout << "After Kernel:\n";
std::cout << arrRange << "\n" << arrConst << "\n" << arrLinspace << "\n" << arrComma << "\n";
return 0;
}