From cc56e79aee596649af83d939984456d5a528234a Mon Sep 17 00:00:00 2001 From: kjao Date: Sun, 9 Jul 2023 19:36:17 -0500 Subject: [PATCH] Added Python library support --- docs/source/core.rst | 1 + include/Array.h | 22 ++++++++++++++++++++++ include/Macros.h | 23 ++++++++++++++++------- 3 files changed, 39 insertions(+), 7 deletions(-) diff --git a/docs/source/core.rst b/docs/source/core.rst index 8d34452..5c7528e 100644 --- a/docs/source/core.rst +++ b/docs/source/core.rst @@ -43,6 +43,7 @@ Compilation Options ------------------- .. doxygendefine:: CUDATOOLS_ARRAY_MAX_AXES .. doxygendefine:: CUDATOOLS_USE_EIGEN +.. doxygendefine:: CUDATOOLS_USE_PYTHON Macro Functions =============== diff --git a/include/Array.h b/include/Array.h index cd6e646..cb7518b 100644 --- a/include/Array.h +++ b/include/Array.h @@ -16,6 +16,12 @@ #include #endif +#ifdef CUDATOOLS_USE_PYTHON +#include +#include +namespace py = pybind11; +#endif + #ifdef DEVICE #define POINTER pDevice #else @@ -747,6 +753,22 @@ template class Array { CT_ERROR(mIsSlice, "Cannot update device copy on a slice"); return CudaTools::copy(pHost, pDevice, mShape.items() * sizeof(T), stream); }; + +#ifdef CUDATOOLS_USE_PYTHON + /** + * Returns a py::array for making an Array available as a Python numpy array. + */ + py::array pyArray() const { + std::vector dims, strides; + for (uint iAxis = 0; iAxis < mShape.axes(); ++iAxis) { + dims.push_back(static_cast(mShape.dim(iAxis))); + strides.push_back(sizeof(T) * static_cast(mShape.stride(iAxis))); + } + return py::array_t( + py::buffer_info((void*)pHost, sizeof(T), py::format_descriptor::format(), + static_cast(mShape.axes()), dims, strides)); + }; +#endif }; template diff --git a/include/Macros.h b/include/Macros.h index 1b6f08d..35ee63e 100644 --- a/include/Macros.h +++ b/include/Macros.h @@ -55,6 +55,12 @@ */ #define CUDATOOLS_USE_EIGEN +/** + * \def CUDATOOLS_USE_PYTHON + * Compile the CudaTools library with Python support. + */ +#define CUDATOOLS_USE_PYTHON + /** * \def KERNEL(call, settings, ...) * Used to call a CUDA kernel. @@ -224,12 +230,13 @@ #ifdef DEVICE #define CT_ERROR_IF(a, op, b, msg) \ if (a op b) { \ - printf("[ERROR] %s:%d\n | %s: (" #a ") " #op " (" #b ").\n", __FILE__, __LINE__, msg); \ + printf("\033[1;31m[CudaTools]\033[0m %s:%d\n | %s: (" #a ") " #op " (" #b ").\n", \ + __FILE__, __LINE__, msg); \ } #define CT_ERROR(a, msg) \ if (a) { \ - printf("[ERROR] %s:%d\n | %s: " #a ".\n", __FILE__, __LINE__, msg); \ + printf("\033[1;31m[CudaTools]\033[0m %s:%d\n | %s: " #a ".\n", __FILE__, __LINE__, msg); \ } #else @@ -239,14 +246,14 @@ std::ostringstream os_b; \ os_a << a; \ os_b << b; \ - printf("[ERROR] %s:%d\n | %s: (" #a ")%s " #op " (" #b ")%s.\n", __FILE__, __LINE__, msg, \ - os_a.str().c_str(), os_b.str().c_str()); \ + printf("\033[1;31m[CudaTools]\033[0m %s:%d\n | %s: (" #a ")%s " #op " (" #b ")%s.\n", \ + __FILE__, __LINE__, msg, os_a.str().c_str(), os_b.str().c_str()); \ throw std::exception(); \ } #define CT_ERROR(a, msg) \ if (a) { \ - printf("[ERROR] %s:%d\n | %s: " #a ".\n", __FILE__, __LINE__, msg); \ + printf("\033[1;31m[CudaTools]\033[0m %s:%d\n | %s: " #a ".\n", __FILE__, __LINE__, msg); \ throw std::exception(); \ } #endif @@ -259,7 +266,8 @@ do { \ cudaError_t err = (call); \ if (err != cudaSuccess) { \ - printf("[CUDA] %s:%d\n | %s\n", __FILE__, __LINE__, cudaGetErrorString(err)); \ + printf("\033[1;31m[CUDA]\033[0m %s:%d\n | %s\n", __FILE__, __LINE__, \ + cudaGetErrorString(err)); \ throw std::exception(); \ } \ } while (0) @@ -268,7 +276,8 @@ do { \ cublasStatus_t err = (call); \ if (err != CUBLAS_STATUS_SUCCESS) { \ - printf("[cuBLAS] %s:%d\n | %s\n", __FILE__, __LINE__, cublasGetStatusName(err)); \ + printf("\033[1;31m[cuBLAS]\033[0m %s:%d\n | %s\n", __FILE__, __LINE__, \ + cublasGetStatusName(err)); \ throw std::exception(); \ } \ } while (0)