diff options
Diffstat (limited to 'nixpkgs/pkgs/development/python-modules/jaxlib/bin.nix')
-rw-r--r-- | nixpkgs/pkgs/development/python-modules/jaxlib/bin.nix | 35 |
1 files changed, 27 insertions, 8 deletions
diff --git a/nixpkgs/pkgs/development/python-modules/jaxlib/bin.nix b/nixpkgs/pkgs/development/python-modules/jaxlib/bin.nix index f6f8f5e2b1b6..5e27c0f605b8 100644 --- a/nixpkgs/pkgs/development/python-modules/jaxlib/bin.nix +++ b/nixpkgs/pkgs/development/python-modules/jaxlib/bin.nix @@ -33,7 +33,7 @@ }: let - inherit (cudaPackagesGoogle) cudatoolkit cudnn; + inherit (cudaPackagesGoogle) cudatoolkit cudnn cudaVersion; version = "0.4.23"; @@ -118,26 +118,44 @@ let }; }; - # Find new releases at https://storage.googleapis.com/jax-releases/jax_releases.html. + # Note that the prebuilt jaxlib binary requires specific version of CUDA to + # work. The cuda12 jaxlib binaries only works with CUDA 12.2, and cuda11 + # jaxlib binaries only works with CUDA 11.8. This is why we need to find a + # binary that matches the provided cudaVersion. + gpuSrcVersionString = "cuda${cudaVersion}-${pythonVersion}"; + + # Find new releases at https://storage.googleapis.com/jax-releases # When upgrading, you can get these hashes from prefetch.sh. See # https://github.com/google/jax/issues/12879 as to why this specific URL is the correct index. gpuSrcs = { - "3.9" = fetchurl { + "cuda12.2-3.9" = fetchurl { url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp39-cp39-manylinux2014_x86_64.whl"; hash = "sha256-our2mSwHPdjVoDAZP+9aNUkJ+vxv1Tq7G5UqA9HvhNI="; }; - "3.10" = fetchurl { + "cuda12.2-3.10" = fetchurl { url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl"; hash = "sha256-jkIABnJZnn7A6n9VGs/MldzdDiKwWh0fEvl7Vqn85Kg="; }; - "3.11" = fetchurl { + "cuda12.2-3.11" = fetchurl { url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp311-cp311-manylinux2014_x86_64.whl"; hash = "sha256-dMUcRnHjl8NyUeO3P1x7CNgF0iAHFKIzUtHh+/CNkow="; }; - "3.12" = fetchurl { + "cuda12.2-3.12" = fetchurl { url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp312-cp312-manylinux2014_x86_64.whl"; hash = "sha256-kXJ6bUwX+QybqYPV9Kpwv+lhdoGEFRr4+1T0vfXoWRo="; }; + "cuda11.8-3.9" = fetchurl { + url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp39-cp39-manylinux2014_x86_64.whl"; + hash = "sha256-m2Y5p12gF3OaADu+aGw5RjcKFrj9RB8xzNWnKNpSz60="; + }; + "cuda11.8-3.10" = fetchurl { + url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl"; + hash = "sha256-aQ7iX3o0kQ4liPexv7dkBVWVTUpaty83L083MybGkf0="; + }; + "cuda11.8-3.11" = fetchurl { + url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp311-cp311-manylinux2014_x86_64.whl"; + hash = "sha256-uIEyjEmv0HBaiYVl5PuICTI9XnH4zAfQ1l9tjALRcP4="; + }; }; in @@ -154,7 +172,7 @@ buildPythonPackage { ( cpuSrcs."${pythonVersion}-${stdenv.hostPlatform.system}" or (throw "jaxlib-bin is not supported on ${stdenv.hostPlatform.system}") - ) else gpuSrcs."${pythonVersion}"; + ) else gpuSrcs."${gpuSrcVersionString}"; # Prebuilt wheels are dynamically linked against things that nix can't find. # Run `autoPatchelfHook` to automagically fix them. @@ -212,6 +230,7 @@ buildPythonPackage { broken = !(cudaSupport -> (cudaPackagesGoogle ? cudatoolkit) && lib.versionAtLeast cudatoolkit.version "11.1") || !(cudaSupport -> (cudaPackagesGoogle ? cudnn) && lib.versionAtLeast cudnn.version "8.2") - || !(cudaSupport -> stdenv.isLinux); + || !(cudaSupport -> stdenv.isLinux) + || !(cudaSupport -> (gpuSrcs ? "cuda${cudaVersion}-${pythonVersion}")); }; } |