#define CUDATOOLS_IMPLEMENTATION
#include <Array.h>
#include <Core.h>

KERNEL(times2, const CudaTools::Array<int> arr) {
    CudaTools::Array<int> flat = arr.flattened();
    BASIC_LOOP(arr.shape().items()) { flat[iThread] *= 2; }
}

KERNEL(times2double, const CudaTools::Array<double> arr) {
    CudaTools::Array<double> flat = arr.flattened();
    BASIC_LOOP(arr.shape().items()) { flat[iThread] *= 2; }
}

int main() {
    CudaTools::Array<int> arrRange = CudaTools::Array<int>::range(0, 10);
    CudaTools::Array<int> arrConst = CudaTools::Array<int>::constant({10}, 1);
    CudaTools::Array<double> arrLinspace = CudaTools::Array<double>::linspace(0, 5, 10);
    CudaTools::Array<int> arrComma({2, 2}); // 2x2 array.
    arrComma << 1, 2, 3, 4;                 // Comma initializer if needed.

    arrRange.updateDevice();
    arrConst.updateDevice();
    arrLinspace.updateDevice();
    arrComma.updateDevice().wait();

    std::cout << "Before Kernel:\n";
    std::cout << arrRange << "\n" << arrConst << "\n" << arrLinspace << "\n" << arrComma << "\n";

    // Call the kernel multiple times asynchronously. Note: since they share same
    // stream, they are not run in parallel, just queued on the device.
    // NOTE: Notice that a view is passed into the kernel, not the Array itself.
    CudaTools::Kernel::launch(times2, CudaTools::Kernel::basic(arrRange.shape().items()),
                              arrRange.view());
    CudaTools::Kernel::launch(times2, CudaTools::Kernel::basic(arrConst.shape().items()),
                              arrConst.view());
    CudaTools::Kernel::launch(times2double, CudaTools::Kernel::basic(arrLinspace.shape().items()),
                              arrLinspace.view());
    CudaTools::Kernel::launch(times2, CudaTools::Kernel::basic(arrComma.shape().items()),
                              arrComma.view())
        .wait();
    arrRange.updateHost();
    arrConst.updateHost();
    arrLinspace.updateHost();
    arrComma.updateHost().wait(); // Same stream, so you should wait for the last call.

    std::cout << "After Kernel:\n";
    std::cout << arrRange << "\n" << arrConst << "\n" << arrLinspace << "\n" << arrComma << "\n";
    return 0;
}