From d448321436e8314d3e2a6a09d4017c4bc10f612d Mon Sep 17 00:00:00 2001 From: Gaetan Lepage Date: Sat, 17 Feb 2024 17:37:22 +0100 Subject: [PATCH] fix-dlopen-cuda --- gpuctypes/cuda.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/gpuctypes/cuda.py b/gpuctypes/cuda.py index acba81c..091f7f7 100644 --- a/gpuctypes/cuda.py +++ b/gpuctypes/cuda.py @@ -143,9 +143,25 @@ def char_pointer_cast(string, encoding='utf-8'): +NAME_TO_PATHS = { + "libcuda.so": ["@driverLink@/lib/libcuda.so"], + "libnvrtc.so": ["@libnvrtc@"], +} +def _try_dlopen(name): + try: + return ctypes.CDLL(name) + except OSError: + pass + for candidate in NAME_TO_PATHS.get(name, []): + try: + return ctypes.CDLL(candidate) + except OSError: + pass + raise RuntimeError(f"{name} not found") + _libraries = {} -_libraries['libcuda.so'] = ctypes.CDLL(ctypes.util.find_library('cuda')) -_libraries['libnvrtc.so'] = ctypes.CDLL(ctypes.util.find_library('nvrtc')) +_libraries['libcuda.so'] = _try_dlopen('libcuda.so') +_libraries['libnvrtc.so'] = _try_dlopen('libnvrtc.so') cuuint32_t = ctypes.c_uint32 -- 2.43.0