diff options
Diffstat (limited to 'nixpkgs/pkgs/development/python-modules/torch/default.nix')
-rw-r--r-- | nixpkgs/pkgs/development/python-modules/torch/default.nix | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/nixpkgs/pkgs/development/python-modules/torch/default.nix b/nixpkgs/pkgs/development/python-modules/torch/default.nix index 1d9fd2a469f6..0dcc2fdba2d9 100644 --- a/nixpkgs/pkgs/development/python-modules/torch/default.nix +++ b/nixpkgs/pkgs/development/python-modules/torch/default.nix @@ -196,7 +196,8 @@ in buildPythonPackage rec { export TORCH_CUDA_ARCH_LIST="${gpuTargetString}" export CC=${cudatoolkit.cc}/bin/gcc CXX=${cudatoolkit.cc}/bin/g++ '' + lib.optionalString (cudaSupport && cudnn != null) '' - export CUDNN_INCLUDE_DIR=${cudnn}/include + export CUDNN_INCLUDE_DIR=${cudnn.dev}/include + export CUDNN_LIB_DIR=${cudnn.lib}/lib '' + lib.optionalString rocmSupport '' export ROCM_PATH=${rocmtoolkit_joined} export ROCM_SOURCE_DIR=${rocmtoolkit_joined} @@ -290,7 +291,7 @@ in buildPythonPackage rec { buildInputs = [ blas blas.provider pybind11 ] ++ lib.optionals stdenv.isLinux [ linuxHeaders_5_19 ] # TMP: avoid "flexible array member" errors for now - ++ lib.optionals cudaSupport [ cudnn nccl ] + ++ lib.optionals cudaSupport [ cudnn.dev cudnn.lib nccl ] ++ lib.optionals rocmSupport [ openmp ] ++ lib.optionals (cudaSupport || rocmSupport) [ magma ] ++ lib.optionals stdenv.isLinux [ numactl ] |