about summary refs log tree commit diff
path: root/nixpkgs/pkgs/development/python-modules/torchvision/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'nixpkgs/pkgs/development/python-modules/torchvision/default.nix')
-rw-r--r--nixpkgs/pkgs/development/python-modules/torchvision/default.nix24
1 files changed, 21 insertions, 3 deletions
diff --git a/nixpkgs/pkgs/development/python-modules/torchvision/default.nix b/nixpkgs/pkgs/development/python-modules/torchvision/default.nix
index a42c517ede96..fc9905881cb6 100644
--- a/nixpkgs/pkgs/development/python-modules/torchvision/default.nix
+++ b/nixpkgs/pkgs/development/python-modules/torchvision/default.nix
@@ -1,4 +1,5 @@
 { lib
+, symlinkJoin
 , buildPythonPackage
 , fetchFromGitHub
 , ninja
@@ -10,9 +11,18 @@
 , pillow
 , pytorch
 , pytest
+, cudatoolkit
+, cudnn
+, cudaSupport ? pytorch.cudaSupport or false # by default uses the value from pytorch
 }:
 
-buildPythonPackage rec {
+let
+  cudatoolkit_joined = symlinkJoin {
+    name = "${cudatoolkit.name}-unsplit";
+    paths = [ cudatoolkit.out cudatoolkit.lib ];
+  };
+  cudaArchStr = lib.optionalString cudaSupport lib.strings.concatStringsSep ";" pytorch.cudaArchList;
+in buildPythonPackage rec {
   pname = "torchvision";
   version = "0.10.0";
 
@@ -23,15 +33,22 @@ buildPythonPackage rec {
     sha256 = "13j04ij0jmi58nhav1p69xrm8dg7jisg23268i3n6lnms37n02kc";
   };
 
-  nativeBuildInputs = [ libpng ninja which ];
+  nativeBuildInputs = [ libpng ninja which ]
+    ++ lib.optionals cudaSupport [ cudatoolkit_joined ];
 
   TORCHVISION_INCLUDE = "${libjpeg_turbo.dev}/include/";
   TORCHVISION_LIBRARY = "${libjpeg_turbo}/lib/";
 
-  buildInputs = [ libjpeg_turbo libpng ];
+  buildInputs = [ libjpeg_turbo libpng ]
+    ++ lib.optionals cudaSupport [ cudnn ];
 
   propagatedBuildInputs = [ numpy pillow pytorch scipy ];
 
+  preBuild = lib.optionalString cudaSupport ''
+    export TORCH_CUDA_ARCH_LIST="${cudaArchStr}"
+    export FORCE_CUDA=1
+  '';
+
   # tries to download many datasets for tests
   doCheck = false;
 
@@ -45,6 +62,7 @@ buildPythonPackage rec {
     description = "PyTorch vision library";
     homepage = "https://pytorch.org/";
     license = licenses.bsd3;
+    platforms = with platforms; linux ++ lib.optionals (!cudaSupport) darwin;
     maintainers = with maintainers; [ ericsagnes ];
   };
 }