about summary refs log tree commit diff
diff options
context:
space:
mode:
authorConnor Baker <connor.baker@tweag.io>2023-12-04 13:50:31 -0500
committerGitHub <noreply@github.com>2023-12-04 13:50:31 -0500
commitf0124cff665c80ff983b62f1d3454e9be604fd74 (patch)
treee867fde0e3d3402e24d08795c7979d5824694429
parent92df577d8797e5b1f033654cc2109e56179707f2 (diff)
parent5bf016e1e98accfd62c5bbd5a6dea3049af72595 (diff)
downloadnixlib-f0124cff665c80ff983b62f1d3454e9be604fd74.tar
nixlib-f0124cff665c80ff983b62f1d3454e9be604fd74.tar.gz
nixlib-f0124cff665c80ff983b62f1d3454e9be604fd74.tar.bz2
nixlib-f0124cff665c80ff983b62f1d3454e9be604fd74.tar.lz
nixlib-f0124cff665c80ff983b62f1d3454e9be604fd74.tar.xz
nixlib-f0124cff665c80ff983b62f1d3454e9be604fd74.tar.zst
nixlib-f0124cff665c80ff983b62f1d3454e9be604fd74.zip
Merge pull request #272082 from ConnorBaker/fix/torch-optional-cuda-deps
python3Packages.torch: enable cuDNN & NCCL only if available
-rw-r--r--pkgs/development/python-modules/torch/default.nix19
1 files changed, 9 insertions, 10 deletions
diff --git a/pkgs/development/python-modules/torch/default.nix b/pkgs/development/python-modules/torch/default.nix
index 49ed033bd6a1..d18dd97df5b4 100644
--- a/pkgs/development/python-modules/torch/default.nix
+++ b/pkgs/development/python-modules/torch/default.nix
@@ -56,10 +56,7 @@
 
 let
   inherit (lib) attrsets lists strings trivial;
-  inherit (cudaPackages) cudaFlags cudnn;
-
-  # Some packages are not available on all platforms
-  nccl = cudaPackages.nccl or null;
+  inherit (cudaPackages) cudaFlags cudnn nccl;
 
   setBool = v: if v then "1" else "0";
 
@@ -212,10 +209,11 @@ in buildPythonPackage rec {
   # For more, see https://github.com/open-mpi/ompi/issues/7733#issuecomment-629806195.
   preConfigure = lib.optionalString cudaSupport ''
     export TORCH_CUDA_ARCH_LIST="${gpuTargetString}"
-    export CUDNN_INCLUDE_DIR=${cudnn.dev}/include
-    export CUDNN_LIB_DIR=${cudnn.lib}/lib
     export CUPTI_INCLUDE_DIR=${cudaPackages.cuda_cupti.dev}/include
     export CUPTI_LIBRARY_DIR=${cudaPackages.cuda_cupti.lib}/lib
+  '' + lib.optionalString (cudaSupport && cudaPackages ? cudnn) ''
+    export CUDNN_INCLUDE_DIR=${cudnn.dev}/include
+    export CUDNN_LIB_DIR=${cudnn.lib}/lib
   '' + lib.optionalString rocmSupport ''
     export ROCM_PATH=${rocmtoolkit_joined}
     export ROCM_SOURCE_DIR=${rocmtoolkit_joined}
@@ -273,7 +271,7 @@ in buildPythonPackage rec {
   PYTORCH_BUILD_VERSION = version;
   PYTORCH_BUILD_NUMBER = 0;
 
-  USE_NCCL = setBool (nccl != null);
+  USE_NCCL = setBool (cudaPackages ? nccl);
   USE_SYSTEM_NCCL = setBool useSystemNccl;                  # don't build pytorch's third_party NCCL
   USE_STATIC_NCCL = setBool useSystemNccl;
 
@@ -348,8 +346,6 @@ in buildPythonPackage rec {
       cuda_nvrtc.lib
       cuda_nvtx.dev
       cuda_nvtx.lib # -llibNVToolsExt
-      cudnn.dev
-      cudnn.lib
       libcublas.dev
       libcublas.lib
       libcufft.dev
@@ -360,7 +356,10 @@ in buildPythonPackage rec {
       libcusolver.lib
       libcusparse.dev
       libcusparse.lib
-    ] ++ lists.optionals (nccl != null) [
+    ] ++ lists.optionals (cudaPackages ? cudnn) [
+      cudnn.dev
+      cudnn.lib
+    ] ++ lists.optionals (useSystemNccl && cudaPackages ? nccl) [
       # Some platforms do not support NCCL (i.e., Jetson)
       nccl.dev # Provides nccl.h AND a static copy of NCCL!
     ] ++ lists.optionals (strings.versionOlder cudaVersion "11.8") [