about summary refs log tree commit diff
path: root/nixpkgs/pkgs/development/python-modules/torch/default.nix
diff options
context:
space:
mode:
authorAlyssa Ross <hi@alyssa.is>2023-09-01 11:51:02 +0000
committerAlyssa Ross <hi@alyssa.is>2023-09-01 11:51:02 +0000
commitaa4353b499e6950b7333578f936455a628145c31 (patch)
treec6332cedece2327a18d08794755b3fc0f9f1905b /nixpkgs/pkgs/development/python-modules/torch/default.nix
parentac456d475f4e50818499b804359355c0f3b4bbf7 (diff)
parent52185f4d76c18d8348f963795dfed1de018e8dfe (diff)
downloadnixlib-aa4353b499e6950b7333578f936455a628145c31.tar
nixlib-aa4353b499e6950b7333578f936455a628145c31.tar.gz
nixlib-aa4353b499e6950b7333578f936455a628145c31.tar.bz2
nixlib-aa4353b499e6950b7333578f936455a628145c31.tar.lz
nixlib-aa4353b499e6950b7333578f936455a628145c31.tar.xz
nixlib-aa4353b499e6950b7333578f936455a628145c31.tar.zst
nixlib-aa4353b499e6950b7333578f936455a628145c31.zip
Merge https://github.com/NixOS/nixpkgs
Diffstat (limited to 'nixpkgs/pkgs/development/python-modules/torch/default.nix')
-rw-r--r--nixpkgs/pkgs/development/python-modules/torch/default.nix5
1 files changed, 3 insertions, 2 deletions
diff --git a/nixpkgs/pkgs/development/python-modules/torch/default.nix b/nixpkgs/pkgs/development/python-modules/torch/default.nix
index 1d9fd2a469f6..0dcc2fdba2d9 100644
--- a/nixpkgs/pkgs/development/python-modules/torch/default.nix
+++ b/nixpkgs/pkgs/development/python-modules/torch/default.nix
@@ -196,7 +196,8 @@ in buildPythonPackage rec {
     export TORCH_CUDA_ARCH_LIST="${gpuTargetString}"
     export CC=${cudatoolkit.cc}/bin/gcc CXX=${cudatoolkit.cc}/bin/g++
   '' + lib.optionalString (cudaSupport && cudnn != null) ''
-    export CUDNN_INCLUDE_DIR=${cudnn}/include
+    export CUDNN_INCLUDE_DIR=${cudnn.dev}/include
+    export CUDNN_LIB_DIR=${cudnn.lib}/lib
   '' + lib.optionalString rocmSupport ''
     export ROCM_PATH=${rocmtoolkit_joined}
     export ROCM_SOURCE_DIR=${rocmtoolkit_joined}
@@ -290,7 +291,7 @@ in buildPythonPackage rec {
 
   buildInputs = [ blas blas.provider pybind11 ]
     ++ lib.optionals stdenv.isLinux [ linuxHeaders_5_19 ] # TMP: avoid "flexible array member" errors for now
-    ++ lib.optionals cudaSupport [ cudnn nccl ]
+    ++ lib.optionals cudaSupport [ cudnn.dev cudnn.lib nccl ]
     ++ lib.optionals rocmSupport [ openmp ]
     ++ lib.optionals (cudaSupport || rocmSupport) [ magma ]
     ++ lib.optionals stdenv.isLinux [ numactl ]