diff options
Diffstat (limited to 'pkgs/development/python-modules/pytorch/default.nix')
-rw-r--r-- | pkgs/development/python-modules/pytorch/default.nix | 8 |
1 files changed, 3 insertions, 5 deletions
diff --git a/pkgs/development/python-modules/pytorch/default.nix b/pkgs/development/python-modules/pytorch/default.nix index f072972937a9..822586bf190c 100644 --- a/pkgs/development/python-modules/pytorch/default.nix +++ b/pkgs/development/python-modules/pytorch/default.nix @@ -1,7 +1,7 @@ { stdenv, lib, fetchFromGitHub, fetchpatch, buildPythonPackage, python, cudaSupport ? false, cudatoolkit ? null, cudnn ? null, nccl ? null, magma ? null, mklDnnSupport ? true, useSystemNccl ? true, - openMPISupport ? false, openmpi ? null, + MPISupport ? false, mpi, buildDocs ? false, cudaArchList ? null, @@ -29,8 +29,6 @@ isPy3k, pythonOlder }: -assert !openMPISupport || openmpi != null; - # assert that everything needed for cuda is present and that the correct cuda versions are used assert !cudaSupport || cudatoolkit != null; assert cudnn == null || cudatoolkit != null; @@ -38,7 +36,7 @@ assert !cudaSupport || (let majorIs = lib.versions.major cudatoolkit.version; in majorIs == "9" || majorIs == "10" || majorIs == "11"); # confirm that cudatoolkits are sync'd across dependencies -assert !(openMPISupport && cudaSupport) || openmpi.cudatoolkit == cudatoolkit; +assert !(MPISupport && cudaSupport) || mpi.cudatoolkit == cudatoolkit; assert !cudaSupport || magma.cudatoolkit == cudatoolkit; let @@ -224,7 +222,7 @@ in buildPythonPackage rec { typing-extensions # the following are required for tensorboard support pillow six future tensorflow-tensorboard protobuf - ] ++ lib.optionals openMPISupport [ openmpi ] + ] ++ lib.optionals MPISupport [ mpi ] ++ lib.optionals (pythonOlder "3.7") [ dataclasses ]; checkInputs = [ hypothesis ninja psutil ]; |