about summary refs log tree commit diff
path: root/nixpkgs/pkgs/development/python-modules/jaxlib/bin.nix
diff options
context:
space:
mode:
Diffstat (limited to 'nixpkgs/pkgs/development/python-modules/jaxlib/bin.nix')
-rw-r--r--nixpkgs/pkgs/development/python-modules/jaxlib/bin.nix35
1 files changed, 27 insertions, 8 deletions
diff --git a/nixpkgs/pkgs/development/python-modules/jaxlib/bin.nix b/nixpkgs/pkgs/development/python-modules/jaxlib/bin.nix
index f6f8f5e2b1b6..5e27c0f605b8 100644
--- a/nixpkgs/pkgs/development/python-modules/jaxlib/bin.nix
+++ b/nixpkgs/pkgs/development/python-modules/jaxlib/bin.nix
@@ -33,7 +33,7 @@
 }:
 
 let
-  inherit (cudaPackagesGoogle) cudatoolkit cudnn;
+  inherit (cudaPackagesGoogle) cudatoolkit cudnn cudaVersion;
 
   version = "0.4.23";
 
@@ -118,26 +118,44 @@ let
       };
     };
 
-  # Find new releases at https://storage.googleapis.com/jax-releases/jax_releases.html.
+  # Note that the prebuilt jaxlib binary requires specific version of CUDA to
+  # work. The cuda12 jaxlib binaries only works with CUDA 12.2, and cuda11
+  # jaxlib binaries only works with CUDA 11.8. This is why we need to find a
+  # binary that matches the provided cudaVersion.
+  gpuSrcVersionString = "cuda${cudaVersion}-${pythonVersion}";
+
+  # Find new releases at https://storage.googleapis.com/jax-releases
   # When upgrading, you can get these hashes from prefetch.sh. See
   # https://github.com/google/jax/issues/12879 as to why this specific URL is the correct index.
   gpuSrcs = {
-    "3.9" = fetchurl {
+    "cuda12.2-3.9" = fetchurl {
       url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp39-cp39-manylinux2014_x86_64.whl";
       hash = "sha256-our2mSwHPdjVoDAZP+9aNUkJ+vxv1Tq7G5UqA9HvhNI=";
     };
-    "3.10" = fetchurl {
+    "cuda12.2-3.10" = fetchurl {
       url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl";
       hash = "sha256-jkIABnJZnn7A6n9VGs/MldzdDiKwWh0fEvl7Vqn85Kg=";
     };
-    "3.11" = fetchurl {
+    "cuda12.2-3.11" = fetchurl {
       url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp311-cp311-manylinux2014_x86_64.whl";
       hash = "sha256-dMUcRnHjl8NyUeO3P1x7CNgF0iAHFKIzUtHh+/CNkow=";
     };
-    "3.12" = fetchurl {
+    "cuda12.2-3.12" = fetchurl {
       url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp312-cp312-manylinux2014_x86_64.whl";
       hash = "sha256-kXJ6bUwX+QybqYPV9Kpwv+lhdoGEFRr4+1T0vfXoWRo=";
     };
+    "cuda11.8-3.9" = fetchurl {
+      url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp39-cp39-manylinux2014_x86_64.whl";
+      hash = "sha256-m2Y5p12gF3OaADu+aGw5RjcKFrj9RB8xzNWnKNpSz60=";
+    };
+    "cuda11.8-3.10" = fetchurl {
+      url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl";
+      hash = "sha256-aQ7iX3o0kQ4liPexv7dkBVWVTUpaty83L083MybGkf0=";
+    };
+    "cuda11.8-3.11" = fetchurl {
+      url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp311-cp311-manylinux2014_x86_64.whl";
+      hash = "sha256-uIEyjEmv0HBaiYVl5PuICTI9XnH4zAfQ1l9tjALRcP4=";
+    };
   };
 
 in
@@ -154,7 +172,7 @@ buildPythonPackage {
       (
         cpuSrcs."${pythonVersion}-${stdenv.hostPlatform.system}"
           or (throw "jaxlib-bin is not supported on ${stdenv.hostPlatform.system}")
-      ) else gpuSrcs."${pythonVersion}";
+      ) else gpuSrcs."${gpuSrcVersionString}";
 
   # Prebuilt wheels are dynamically linked against things that nix can't find.
   # Run `autoPatchelfHook` to automagically fix them.
@@ -212,6 +230,7 @@ buildPythonPackage {
     broken =
       !(cudaSupport -> (cudaPackagesGoogle ? cudatoolkit) && lib.versionAtLeast cudatoolkit.version "11.1")
       || !(cudaSupport -> (cudaPackagesGoogle ? cudnn) && lib.versionAtLeast cudnn.version "8.2")
-      || !(cudaSupport -> stdenv.isLinux);
+      || !(cudaSupport -> stdenv.isLinux)
+      || !(cudaSupport -> (gpuSrcs ? "cuda${cudaVersion}-${pythonVersion}"));
   };
 }