diff --git a/gpuctypes/cuda.py b/gpuctypes/cuda.py index acba81c..aac5fc7 100644 --- a/gpuctypes/cuda.py +++ b/gpuctypes/cuda.py @@ -143,9 +143,25 @@ def char_pointer_cast(string, encoding='utf-8'): +NAME_TO_PATHS = { + "libcuda.so": ["@driverLink@/lib/libcuda.so"], + "libnvrtc.so": ["@libnvrtc@"], +} +def _try_dlopen(name): + try: + return ctypes.CDLL(name) + except OSError: + pass + for candidate in NAME_TO_PATHS.get(name, []): + try: + return ctypes.CDLL(candidate) + except OSError: + pass + raise RuntimeError(f"{name} not found") + _libraries = {} -_libraries['libcuda.so'] = ctypes.CDLL(ctypes.util.find_library('cuda')) -_libraries['libnvrtc.so'] = ctypes.CDLL(ctypes.util.find_library('nvrtc')) +_libraries['libcuda.so'] = _try_dlopen('libcuda.so') +_libraries['libnvrtc.so'] = _try_dlopen('libnvrtc.so') cuuint32_t = ctypes.c_uint32