about summary refs log tree commit diff
path: root/nixpkgs/pkgs/development/python-modules/torch/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'nixpkgs/pkgs/development/python-modules/torch/default.nix')
-rw-r--r--nixpkgs/pkgs/development/python-modules/torch/default.nix5
1 files changed, 3 insertions, 2 deletions
diff --git a/nixpkgs/pkgs/development/python-modules/torch/default.nix b/nixpkgs/pkgs/development/python-modules/torch/default.nix
index 1d9fd2a469f6..0dcc2fdba2d9 100644
--- a/nixpkgs/pkgs/development/python-modules/torch/default.nix
+++ b/nixpkgs/pkgs/development/python-modules/torch/default.nix
@@ -196,7 +196,8 @@ in buildPythonPackage rec {
     export TORCH_CUDA_ARCH_LIST="${gpuTargetString}"
     export CC=${cudatoolkit.cc}/bin/gcc CXX=${cudatoolkit.cc}/bin/g++
   '' + lib.optionalString (cudaSupport && cudnn != null) ''
-    export CUDNN_INCLUDE_DIR=${cudnn}/include
+    export CUDNN_INCLUDE_DIR=${cudnn.dev}/include
+    export CUDNN_LIB_DIR=${cudnn.lib}/lib
   '' + lib.optionalString rocmSupport ''
     export ROCM_PATH=${rocmtoolkit_joined}
     export ROCM_SOURCE_DIR=${rocmtoolkit_joined}
@@ -290,7 +291,7 @@ in buildPythonPackage rec {
 
   buildInputs = [ blas blas.provider pybind11 ]
     ++ lib.optionals stdenv.isLinux [ linuxHeaders_5_19 ] # TMP: avoid "flexible array member" errors for now
-    ++ lib.optionals cudaSupport [ cudnn nccl ]
+    ++ lib.optionals cudaSupport [ cudnn.dev cudnn.lib nccl ]
     ++ lib.optionals rocmSupport [ openmp ]
     ++ lib.optionals (cudaSupport || rocmSupport) [ magma ]
     ++ lib.optionals stdenv.isLinux [ numactl ]