diff options
author | Alyssa Ross <hi@alyssa.is> | 2021-09-08 17:57:14 +0000 |
---|---|---|
committer | Alyssa Ross <hi@alyssa.is> | 2021-09-13 11:31:47 +0000 |
commit | ee7984efa14902a2ddd820c937457667a4f40c6a (patch) | |
tree | c9c1d046733cefe5e21fdd8a52104175d47b2443 /nixpkgs/pkgs/development/python-modules/torchvision | |
parent | ffc9d4ba381da62fd08b361bacd1e71e2a3d934d (diff) | |
parent | b3c692172e5b5241b028a98e1977f9fb12eeaf42 (diff) | |
download | nixlib-ee7984efa14902a2ddd820c937457667a4f40c6a.tar nixlib-ee7984efa14902a2ddd820c937457667a4f40c6a.tar.gz nixlib-ee7984efa14902a2ddd820c937457667a4f40c6a.tar.bz2 nixlib-ee7984efa14902a2ddd820c937457667a4f40c6a.tar.lz nixlib-ee7984efa14902a2ddd820c937457667a4f40c6a.tar.xz nixlib-ee7984efa14902a2ddd820c937457667a4f40c6a.tar.zst nixlib-ee7984efa14902a2ddd820c937457667a4f40c6a.zip |
Merge commit 'b3c692172e5b5241b028a98e1977f9fb12eeaf42'
Diffstat (limited to 'nixpkgs/pkgs/development/python-modules/torchvision')
-rw-r--r-- | nixpkgs/pkgs/development/python-modules/torchvision/default.nix | 24 |
1 files changed, 21 insertions, 3 deletions
diff --git a/nixpkgs/pkgs/development/python-modules/torchvision/default.nix b/nixpkgs/pkgs/development/python-modules/torchvision/default.nix index a42c517ede96..fc9905881cb6 100644 --- a/nixpkgs/pkgs/development/python-modules/torchvision/default.nix +++ b/nixpkgs/pkgs/development/python-modules/torchvision/default.nix @@ -1,4 +1,5 @@ { lib +, symlinkJoin , buildPythonPackage , fetchFromGitHub , ninja @@ -10,9 +11,18 @@ , pillow , pytorch , pytest +, cudatoolkit +, cudnn +, cudaSupport ? pytorch.cudaSupport or false # by default uses the value from pytorch }: -buildPythonPackage rec { +let + cudatoolkit_joined = symlinkJoin { + name = "${cudatoolkit.name}-unsplit"; + paths = [ cudatoolkit.out cudatoolkit.lib ]; + }; + cudaArchStr = lib.optionalString cudaSupport lib.strings.concatStringsSep ";" pytorch.cudaArchList; +in buildPythonPackage rec { pname = "torchvision"; version = "0.10.0"; @@ -23,15 +33,22 @@ buildPythonPackage rec { sha256 = "13j04ij0jmi58nhav1p69xrm8dg7jisg23268i3n6lnms37n02kc"; }; - nativeBuildInputs = [ libpng ninja which ]; + nativeBuildInputs = [ libpng ninja which ] + ++ lib.optionals cudaSupport [ cudatoolkit_joined ]; TORCHVISION_INCLUDE = "${libjpeg_turbo.dev}/include/"; TORCHVISION_LIBRARY = "${libjpeg_turbo}/lib/"; - buildInputs = [ libjpeg_turbo libpng ]; + buildInputs = [ libjpeg_turbo libpng ] + ++ lib.optionals cudaSupport [ cudnn ]; propagatedBuildInputs = [ numpy pillow pytorch scipy ]; + preBuild = lib.optionalString cudaSupport '' + export TORCH_CUDA_ARCH_LIST="${cudaArchStr}" + export FORCE_CUDA=1 + ''; + # tries to download many datasets for tests doCheck = false; @@ -45,6 +62,7 @@ buildPythonPackage rec { description = "PyTorch vision library"; homepage = "https://pytorch.org/"; license = licenses.bsd3; + platforms = with platforms; linux ++ lib.optionals (!cudaSupport) darwin; maintainers = with maintainers; [ ericsagnes ]; }; } |