about summary refs log tree commit diff
path: root/nixpkgs/pkgs/development/python-modules/gpuctypes/0001-fix-dlopen-cuda.patch
diff options
context:
space:
mode:
Diffstat (limited to 'nixpkgs/pkgs/development/python-modules/gpuctypes/0001-fix-dlopen-cuda.patch')
-rw-r--r--nixpkgs/pkgs/development/python-modules/gpuctypes/0001-fix-dlopen-cuda.patch44
1 files changed, 44 insertions, 0 deletions
diff --git a/nixpkgs/pkgs/development/python-modules/gpuctypes/0001-fix-dlopen-cuda.patch b/nixpkgs/pkgs/development/python-modules/gpuctypes/0001-fix-dlopen-cuda.patch
new file mode 100644
index 000000000000..bc9f6c7ec64b
--- /dev/null
+++ b/nixpkgs/pkgs/development/python-modules/gpuctypes/0001-fix-dlopen-cuda.patch
@@ -0,0 +1,44 @@
+From d448321436e8314d3e2a6a09d4017c4bc10f612d Mon Sep 17 00:00:00 2001
+From: Gaetan Lepage <gaetan@glepage.com>
+Date: Sat, 17 Feb 2024 17:37:22 +0100
+Subject: [PATCH] fix-dlopen-cuda
+
+---
+ gpuctypes/cuda.py | 20 ++++++++++++++++++--
+ 1 file changed, 18 insertions(+), 2 deletions(-)
+
+diff --git a/gpuctypes/cuda.py b/gpuctypes/cuda.py
+index acba81c..091f7f7 100644
+--- a/gpuctypes/cuda.py
++++ b/gpuctypes/cuda.py
+@@ -143,9 +143,25 @@ def char_pointer_cast(string, encoding='utf-8'):
+ 
+ 
+ 
++NAME_TO_PATHS = {
++    "libcuda.so": ["@driverLink@/lib/libcuda.so"],
++    "libnvrtc.so": ["@libnvrtc@"],
++}
++def _try_dlopen(name):
++    try:
++        return ctypes.CDLL(name)
++    except OSError:
++        pass
++    for candidate in NAME_TO_PATHS.get(name, []):
++        try:
++            return ctypes.CDLL(candidate)
++        except OSError:
++	        pass
++    raise RuntimeError(f"{name} not found")
++
+ _libraries = {}
+-_libraries['libcuda.so'] = ctypes.CDLL(ctypes.util.find_library('cuda'))
+-_libraries['libnvrtc.so'] = ctypes.CDLL(ctypes.util.find_library('nvrtc'))
++_libraries['libcuda.so'] = _try_dlopen('libcuda.so')
++_libraries['libnvrtc.so'] = _try_dlopen('libnvrtc.so')
+ 
+ 
+ cuuint32_t = ctypes.c_uint32
+-- 
+2.43.0
+