about summary refs log tree commit diff
path: root/nixpkgs/pkgs/development/python-modules/gpuctypes/fix-dlopen-cuda.patch
diff options
context:
space:
mode:
Diffstat (limited to 'nixpkgs/pkgs/development/python-modules/gpuctypes/fix-dlopen-cuda.patch')
-rw-r--r--nixpkgs/pkgs/development/python-modules/gpuctypes/fix-dlopen-cuda.patch32
1 files changed, 32 insertions, 0 deletions
diff --git a/nixpkgs/pkgs/development/python-modules/gpuctypes/fix-dlopen-cuda.patch b/nixpkgs/pkgs/development/python-modules/gpuctypes/fix-dlopen-cuda.patch
new file mode 100644
index 000000000000..8d3b69e35e11
--- /dev/null
+++ b/nixpkgs/pkgs/development/python-modules/gpuctypes/fix-dlopen-cuda.patch
@@ -0,0 +1,32 @@
+diff --git a/gpuctypes/cuda.py b/gpuctypes/cuda.py
+index acba81c..aac5fc7 100644
+--- a/gpuctypes/cuda.py
++++ b/gpuctypes/cuda.py
+@@ -143,9 +143,25 @@ def char_pointer_cast(string, encoding='utf-8'):
+ 
+ 
+ 
++NAME_TO_PATHS = {
++    "libcuda.so": ["@driverLink@/lib/libcuda.so"],
++    "libnvrtc.so": ["@libnvrtc@"],
++}
++def _try_dlopen(name):
++    try:
++        return ctypes.CDLL(name)
++    except OSError:
++        pass
++    for candidate in NAME_TO_PATHS.get(name, []):
++        try:
++            return ctypes.CDLL(candidate)
++        except OSError:
++	     pass
++    raise RuntimeError(f"{name} not found")
++
+ _libraries = {}
+-_libraries['libcuda.so'] = ctypes.CDLL(ctypes.util.find_library('cuda'))
+-_libraries['libnvrtc.so'] = ctypes.CDLL(ctypes.util.find_library('nvrtc'))
++_libraries['libcuda.so'] = _try_dlopen('libcuda.so')
++_libraries['libnvrtc.so'] = _try_dlopen('libnvrtc.so')
+ 
+ 
+ cuuint32_t = ctypes.c_uint32