diff options
Diffstat (limited to 'nixpkgs/pkgs/development/python-modules/torch/default.nix')
-rw-r--r-- | nixpkgs/pkgs/development/python-modules/torch/default.nix | 36 |
1 files changed, 17 insertions, 19 deletions
diff --git a/nixpkgs/pkgs/development/python-modules/torch/default.nix b/nixpkgs/pkgs/development/python-modules/torch/default.nix index f9f6e377b139..9efb0facaff3 100644 --- a/nixpkgs/pkgs/development/python-modules/torch/default.nix +++ b/nixpkgs/pkgs/development/python-modules/torch/default.nix @@ -43,11 +43,7 @@ # ROCm dependencies rocmSupport ? false, - gpuTargets ? [ ], - openmp, rocm-core, hip, rccl, miopen, miopengemm, rocrand, rocblas, - rocfft, rocsparse, hipsparse, rocthrust, rocprim, hipcub, roctracer, - rocsolver, hipfft, hipsolver, hipblas, rocminfo, rocm-thunk, rocm-comgr, - rocm-device-libs, rocm-runtime, rocm-opencl-runtime, hipify + gpuTargets ? [ ], rocmPackages }: let @@ -89,7 +85,7 @@ let else if cudaSupport then gpuArchWarner supportedCudaCapabilities unsupportedCudaCapabilities else if rocmSupport then - hip.gpuTargets + rocmPackages.clr.gpuTargets else throw "No GPU targets specified" ); @@ -97,19 +93,25 @@ let rocmtoolkit_joined = symlinkJoin { name = "rocm-merged"; - paths = [ - rocm-core hip rccl miopen miopengemm rocrand rocblas - rocfft rocsparse hipsparse rocthrust rocprim hipcub - roctracer rocfft rocsolver hipfft hipsolver hipblas + paths = with rocmPackages; [ + rocm-core clr rccl miopen miopengemm rocrand rocblas + rocsparse hipsparse rocthrust rocprim hipcub + roctracer # Unfree at the moment due to hsa-amd-aqlprofile hard dependency in rocprofiler + rocfft rocsolver hipfft hipsolver hipblas rocminfo rocm-thunk rocm-comgr rocm-device-libs - rocm-runtime rocm-opencl-runtime hipify + rocm-runtime clr.icd hipify ]; + + # Fix `setuptools` not being found + postBuild = '' + rm -rf $out/nix-support + ''; }; brokenConditions = attrsets.filterAttrs (_: cond: cond) { "CUDA and ROCm are not mutually exclusive" = cudaSupport && rocmSupport; "CUDA is not targeting Linux" = cudaSupport && !stdenv.isLinux; - "Unsupported CUDA version" = cudaSupport && (cudaPackages.cudaMajorVersion != "11"); + "Unsupported CUDA version" = cudaSupport && !(builtins.elem cudaPackages.cudaMajorVersion [ "11" "12" ]); "MPI cudatoolkit does not match cudaPackages.cudatoolkit" = MPISupport && cudaSupport && (mpi.cudatoolkit != cudaPackages.cudatoolkit); "Magma cudaPackages does not match cudaPackages" = cudaSupport && (magma.cudaPackages != cudaPackages); }; @@ -170,7 +172,7 @@ in buildPythonPackage rec { # Strangely, this is never set in cmake substituteInPlace cmake/public/LoadHIP.cmake \ --replace "set(ROCM_PATH \$ENV{ROCM_PATH})" \ - "set(ROCM_PATH \$ENV{ROCM_PATH})''\nset(ROCM_VERSION ${lib.concatStrings (lib.intersperse "0" (lib.splitString "." hip.version))})" + "set(ROCM_PATH \$ENV{ROCM_PATH})''\nset(ROCM_VERSION ${lib.concatStrings (lib.intersperse "0" (lib.splitString "." rocmPackages.clr.version))})" '' # Detection of NCCL version doesn't work particularly well when using the static binary. + lib.optionalString cudaSupport '' @@ -323,7 +325,7 @@ in buildPythonPackage rec { ] ++ lists.optionals (strings.versionAtLeast cudaVersion "11.8") [ cuda_profiler_api.dev # <cuda_profiler_api.h> ]) - ++ lib.optionals rocmSupport [ openmp ] + ++ lib.optionals rocmSupport [ rocmPackages.llvm.openmp ] ++ lib.optionals (cudaSupport || rocmSupport) [ magma ] ++ lib.optionals stdenv.isLinux [ numactl ] ++ lib.optionals stdenv.isDarwin [ Accelerate CoreServices libobjc ]; @@ -436,11 +438,7 @@ in buildPythonPackage rec { blasProvider = blas.provider; # To help debug when a package is broken due to CUDA support inherit brokenConditions; - } // lib.optionalAttrs cudaSupport { - # NOTE: supportedCudaCapabilities isn't computed unless cudaSupport is true, so we can't use - # it in the passthru set above because a downstream package might try to access it even - # when cudaSupport is false. Better to have it missing than null or an empty list by default. - cudaCapabilities = supportedCudaCapabilities; + cudaCapabilities = if cudaSupport then supportedCudaCapabilities else [ ]; }; meta = with lib; { |