about summary refs log tree commit diff
path: root/nixpkgs/pkgs/development/libraries/science/math/libtorch/test/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'nixpkgs/pkgs/development/libraries/science/math/libtorch/test/default.nix')
-rw-r--r--nixpkgs/pkgs/development/libraries/science/math/libtorch/test/default.nix35
1 files changed, 31 insertions, 4 deletions
diff --git a/nixpkgs/pkgs/development/libraries/science/math/libtorch/test/default.nix b/nixpkgs/pkgs/development/libraries/science/math/libtorch/test/default.nix
index e69807871f46..60f9b5ad8846 100644
--- a/nixpkgs/pkgs/development/libraries/science/math/libtorch/test/default.nix
+++ b/nixpkgs/pkgs/development/libraries/science/math/libtorch/test/default.nix
@@ -1,6 +1,28 @@
-{ stdenv, cmake, libtorch-bin, symlinkJoin }:
+{ lib
+, stdenv
+, cmake
+, libtorch-bin
+, linkFarm
+, symlinkJoin
 
-stdenv.mkDerivation {
+, cudaSupport
+, cudatoolkit
+, cudnn
+}:
+let
+  cudatoolkit_joined = symlinkJoin {
+    name = "${cudatoolkit.name}-unsplit";
+    paths = [ cudatoolkit.out cudatoolkit.lib ];
+  };
+
+  # We do not have access to /run/opengl-driver/lib in the sandbox,
+  # so use a stub instead.
+  cudaStub = linkFarm "cuda-stub" [{
+    name = "libcuda.so.1";
+    path = "${cudatoolkit}/lib/stubs/libcuda.so";
+  }];
+
+in stdenv.mkDerivation {
   pname = "libtorch-test";
   version = libtorch-bin.version;
 
@@ -8,7 +30,11 @@ stdenv.mkDerivation {
 
   nativeBuildInputs = [ cmake ];
 
-  buildInputs = [ libtorch-bin ];
+  buildInputs = [ libtorch-bin ] ++
+    lib.optionals cudaSupport [ cudnn ];
+
+  cmakeFlags = lib.optionals cudaSupport
+    [ "-DCUDA_TOOLKIT_ROOT_DIR=${cudatoolkit_joined}" ];
 
   doCheck = true;
 
@@ -17,6 +43,7 @@ stdenv.mkDerivation {
   '';
 
   checkPhase = ''
-    ./test
+    LD_LIBRARY_PATH=${cudaStub}''${LD_LIBRARY_PATH:+:}$LD_LIBRARY_PATH \
+      ./test
   '';
 }