about summary refs log tree commit diff
path: root/nixpkgs/pkgs/development/python-modules/torch/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'nixpkgs/pkgs/development/python-modules/torch/default.nix')
-rw-r--r--nixpkgs/pkgs/development/python-modules/torch/default.nix36
1 files changed, 17 insertions, 19 deletions
diff --git a/nixpkgs/pkgs/development/python-modules/torch/default.nix b/nixpkgs/pkgs/development/python-modules/torch/default.nix
index f9f6e377b139..9efb0facaff3 100644
--- a/nixpkgs/pkgs/development/python-modules/torch/default.nix
+++ b/nixpkgs/pkgs/development/python-modules/torch/default.nix
@@ -43,11 +43,7 @@
 
   # ROCm dependencies
   rocmSupport ? false,
-  gpuTargets ? [ ],
-  openmp, rocm-core, hip, rccl, miopen, miopengemm, rocrand, rocblas,
-  rocfft, rocsparse, hipsparse, rocthrust, rocprim, hipcub, roctracer,
-  rocsolver, hipfft, hipsolver, hipblas, rocminfo, rocm-thunk, rocm-comgr,
-  rocm-device-libs, rocm-runtime, rocm-opencl-runtime, hipify
+  gpuTargets ? [ ], rocmPackages
 }:
 
 let
@@ -89,7 +85,7 @@ let
     else if cudaSupport then
       gpuArchWarner supportedCudaCapabilities unsupportedCudaCapabilities
     else if rocmSupport then
-      hip.gpuTargets
+      rocmPackages.clr.gpuTargets
     else
       throw "No GPU targets specified"
   );
@@ -97,19 +93,25 @@ let
   rocmtoolkit_joined = symlinkJoin {
     name = "rocm-merged";
 
-    paths = [
-      rocm-core hip rccl miopen miopengemm rocrand rocblas
-      rocfft rocsparse hipsparse rocthrust rocprim hipcub
-      roctracer rocfft rocsolver hipfft hipsolver hipblas
+    paths = with rocmPackages; [
+      rocm-core clr rccl miopen miopengemm rocrand rocblas
+      rocsparse hipsparse rocthrust rocprim hipcub
+      roctracer # Unfree at the moment due to hsa-amd-aqlprofile hard dependency in rocprofiler
+      rocfft rocsolver hipfft hipsolver hipblas
       rocminfo rocm-thunk rocm-comgr rocm-device-libs
-      rocm-runtime rocm-opencl-runtime hipify
+      rocm-runtime clr.icd hipify
     ];
+
+    # Fix `setuptools` not being found
+    postBuild = ''
+      rm -rf $out/nix-support
+    '';
   };
 
   brokenConditions = attrsets.filterAttrs (_: cond: cond) {
     "CUDA and ROCm are not mutually exclusive" = cudaSupport && rocmSupport;
     "CUDA is not targeting Linux" = cudaSupport && !stdenv.isLinux;
-    "Unsupported CUDA version" = cudaSupport && (cudaPackages.cudaMajorVersion != "11");
+    "Unsupported CUDA version" = cudaSupport && !(builtins.elem cudaPackages.cudaMajorVersion [ "11" "12" ]);
     "MPI cudatoolkit does not match cudaPackages.cudatoolkit" = MPISupport && cudaSupport && (mpi.cudatoolkit != cudaPackages.cudatoolkit);
     "Magma cudaPackages does not match cudaPackages" = cudaSupport && (magma.cudaPackages != cudaPackages);
   };
@@ -170,7 +172,7 @@ in buildPythonPackage rec {
     # Strangely, this is never set in cmake
     substituteInPlace cmake/public/LoadHIP.cmake \
       --replace "set(ROCM_PATH \$ENV{ROCM_PATH})" \
-        "set(ROCM_PATH \$ENV{ROCM_PATH})''\nset(ROCM_VERSION ${lib.concatStrings (lib.intersperse "0" (lib.splitString "." hip.version))})"
+        "set(ROCM_PATH \$ENV{ROCM_PATH})''\nset(ROCM_VERSION ${lib.concatStrings (lib.intersperse "0" (lib.splitString "." rocmPackages.clr.version))})"
   ''
   # Detection of NCCL version doesn't work particularly well when using the static binary.
   + lib.optionalString cudaSupport ''
@@ -323,7 +325,7 @@ in buildPythonPackage rec {
     ] ++ lists.optionals (strings.versionAtLeast cudaVersion "11.8") [
       cuda_profiler_api.dev # <cuda_profiler_api.h>
     ])
-    ++ lib.optionals rocmSupport [ openmp ]
+    ++ lib.optionals rocmSupport [ rocmPackages.llvm.openmp ]
     ++ lib.optionals (cudaSupport || rocmSupport) [ magma ]
     ++ lib.optionals stdenv.isLinux [ numactl ]
     ++ lib.optionals stdenv.isDarwin [ Accelerate CoreServices libobjc ];
@@ -436,11 +438,7 @@ in buildPythonPackage rec {
     blasProvider = blas.provider;
     # To help debug when a package is broken due to CUDA support
     inherit brokenConditions;
-  } // lib.optionalAttrs cudaSupport {
-    # NOTE: supportedCudaCapabilities isn't computed unless cudaSupport is true, so we can't use
-    #   it in the passthru set above because a downstream package might try to access it even
-    #   when cudaSupport is false. Better to have it missing than null or an empty list by default.
-    cudaCapabilities = supportedCudaCapabilities;
+    cudaCapabilities = if cudaSupport then supportedCudaCapabilities else [ ];
   };
 
   meta = with lib; {