You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
50 lines
2.1 KiB
50 lines
2.1 KiB
#define CUDATOOLS_IMPLEMENTATION
|
|
#include <Array.h>
|
|
#include <Core.h>
|
|
|
|
KERNEL(times2, const CudaTools::Array<int> arr) {
|
|
CudaTools::Array<int> flat = arr.flattened();
|
|
BASIC_LOOP(arr.shape().items()) { flat[iThread] *= 2; }
|
|
}
|
|
|
|
KERNEL(times2double, const CudaTools::Array<double> arr) {
|
|
CudaTools::Array<double> flat = arr.flattened();
|
|
BASIC_LOOP(arr.shape().items()) { flat[iThread] *= 2; }
|
|
}
|
|
|
|
int main() {
|
|
CudaTools::Array<int> arrRange = CudaTools::Array<int>::range(0, 10);
|
|
CudaTools::Array<int> arrConst = CudaTools::Array<int>::constant({10}, 1);
|
|
CudaTools::Array<double> arrLinspace = CudaTools::Array<double>::linspace(0, 5, 10);
|
|
CudaTools::Array<int> arrComma({2, 2}); // 2x2 array.
|
|
arrComma << 1, 2, 3, 4; // Comma initializer if needed.
|
|
|
|
arrRange.updateDevice();
|
|
arrConst.updateDevice();
|
|
arrLinspace.updateDevice();
|
|
arrComma.updateDevice().wait();
|
|
|
|
std::cout << "Before Kernel:\n";
|
|
std::cout << arrRange << "\n" << arrConst << "\n" << arrLinspace << "\n" << arrComma << "\n";
|
|
|
|
// Call the kernel multiple times asynchronously. Note: since they share same
|
|
// stream, they are not run in parallel, just queued on the device.
|
|
// NOTE: Notice that a view is passed into the kernel, not the Array itself.
|
|
CudaTools::Kernel::launch(times2, CudaTools::Kernel::basic(arrRange.shape().items()),
|
|
arrRange.view());
|
|
CudaTools::Kernel::launch(times2, CudaTools::Kernel::basic(arrConst.shape().items()),
|
|
arrConst.view());
|
|
CudaTools::Kernel::launch(times2double, CudaTools::Kernel::basic(arrLinspace.shape().items()),
|
|
arrLinspace.view());
|
|
CudaTools::Kernel::launch(times2, CudaTools::Kernel::basic(arrComma.shape().items()),
|
|
arrComma.view())
|
|
.wait();
|
|
arrRange.updateHost();
|
|
arrConst.updateHost();
|
|
arrLinspace.updateHost();
|
|
arrComma.updateHost().wait(); // Same stream, so you should wait for the last call.
|
|
|
|
std::cout << "After Kernel:\n";
|
|
std::cout << arrRange << "\n" << arrConst << "\n" << arrLinspace << "\n" << arrComma << "\n";
|
|
return 0;
|
|
}
|
|
|