diff options
author | Connor Baker <connor.baker@tweag.io> | 2023-12-04 13:50:31 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-04 13:50:31 -0500 |
commit | f0124cff665c80ff983b62f1d3454e9be604fd74 (patch) | |
tree | e867fde0e3d3402e24d08795c7979d5824694429 | |
parent | 92df577d8797e5b1f033654cc2109e56179707f2 (diff) | |
parent | 5bf016e1e98accfd62c5bbd5a6dea3049af72595 (diff) | |
download | nixlib-f0124cff665c80ff983b62f1d3454e9be604fd74.tar nixlib-f0124cff665c80ff983b62f1d3454e9be604fd74.tar.gz nixlib-f0124cff665c80ff983b62f1d3454e9be604fd74.tar.bz2 nixlib-f0124cff665c80ff983b62f1d3454e9be604fd74.tar.lz nixlib-f0124cff665c80ff983b62f1d3454e9be604fd74.tar.xz nixlib-f0124cff665c80ff983b62f1d3454e9be604fd74.tar.zst nixlib-f0124cff665c80ff983b62f1d3454e9be604fd74.zip |
Merge pull request #272082 from ConnorBaker/fix/torch-optional-cuda-deps
python3Packages.torch: enable cuDNN & NCCL only if available
-rw-r--r-- | pkgs/development/python-modules/torch/default.nix | 19 |
1 files changed, 9 insertions, 10 deletions
diff --git a/pkgs/development/python-modules/torch/default.nix b/pkgs/development/python-modules/torch/default.nix index 49ed033bd6a1..d18dd97df5b4 100644 --- a/pkgs/development/python-modules/torch/default.nix +++ b/pkgs/development/python-modules/torch/default.nix @@ -56,10 +56,7 @@ let inherit (lib) attrsets lists strings trivial; - inherit (cudaPackages) cudaFlags cudnn; - - # Some packages are not available on all platforms - nccl = cudaPackages.nccl or null; + inherit (cudaPackages) cudaFlags cudnn nccl; setBool = v: if v then "1" else "0"; @@ -212,10 +209,11 @@ in buildPythonPackage rec { # For more, see https://github.com/open-mpi/ompi/issues/7733#issuecomment-629806195. preConfigure = lib.optionalString cudaSupport '' export TORCH_CUDA_ARCH_LIST="${gpuTargetString}" - export CUDNN_INCLUDE_DIR=${cudnn.dev}/include - export CUDNN_LIB_DIR=${cudnn.lib}/lib export CUPTI_INCLUDE_DIR=${cudaPackages.cuda_cupti.dev}/include export CUPTI_LIBRARY_DIR=${cudaPackages.cuda_cupti.lib}/lib + '' + lib.optionalString (cudaSupport && cudaPackages ? cudnn) '' + export CUDNN_INCLUDE_DIR=${cudnn.dev}/include + export CUDNN_LIB_DIR=${cudnn.lib}/lib '' + lib.optionalString rocmSupport '' export ROCM_PATH=${rocmtoolkit_joined} export ROCM_SOURCE_DIR=${rocmtoolkit_joined} @@ -273,7 +271,7 @@ in buildPythonPackage rec { PYTORCH_BUILD_VERSION = version; PYTORCH_BUILD_NUMBER = 0; - USE_NCCL = setBool (nccl != null); + USE_NCCL = setBool (cudaPackages ? nccl); USE_SYSTEM_NCCL = setBool useSystemNccl; # don't build pytorch's third_party NCCL USE_STATIC_NCCL = setBool useSystemNccl; @@ -348,8 +346,6 @@ in buildPythonPackage rec { cuda_nvrtc.lib cuda_nvtx.dev cuda_nvtx.lib # -llibNVToolsExt - cudnn.dev - cudnn.lib libcublas.dev libcublas.lib libcufft.dev @@ -360,7 +356,10 @@ in buildPythonPackage rec { libcusolver.lib libcusparse.dev libcusparse.lib - ] ++ lists.optionals (nccl != null) [ + ] ++ lists.optionals (cudaPackages ? cudnn) [ + cudnn.dev + cudnn.lib + ] ++ lists.optionals (useSystemNccl && cudaPackages ? nccl) [ # Some platforms do not support NCCL (i.e., Jetson) nccl.dev # Provides nccl.h AND a static copy of NCCL! ] ++ lists.optionals (strings.versionOlder cudaVersion "11.8") [ |