diff options
Diffstat (limited to 'nixpkgs/pkgs/development/libraries/science/math/libtorch/test/default.nix')
-rw-r--r-- | nixpkgs/pkgs/development/libraries/science/math/libtorch/test/default.nix | 35 |
1 files changed, 31 insertions, 4 deletions
diff --git a/nixpkgs/pkgs/development/libraries/science/math/libtorch/test/default.nix b/nixpkgs/pkgs/development/libraries/science/math/libtorch/test/default.nix index e69807871f46..60f9b5ad8846 100644 --- a/nixpkgs/pkgs/development/libraries/science/math/libtorch/test/default.nix +++ b/nixpkgs/pkgs/development/libraries/science/math/libtorch/test/default.nix @@ -1,6 +1,28 @@ -{ stdenv, cmake, libtorch-bin, symlinkJoin }: +{ lib +, stdenv +, cmake +, libtorch-bin +, linkFarm +, symlinkJoin -stdenv.mkDerivation { +, cudaSupport +, cudatoolkit +, cudnn +}: +let + cudatoolkit_joined = symlinkJoin { + name = "${cudatoolkit.name}-unsplit"; + paths = [ cudatoolkit.out cudatoolkit.lib ]; + }; + + # We do not have access to /run/opengl-driver/lib in the sandbox, + # so use a stub instead. + cudaStub = linkFarm "cuda-stub" [{ + name = "libcuda.so.1"; + path = "${cudatoolkit}/lib/stubs/libcuda.so"; + }]; + +in stdenv.mkDerivation { pname = "libtorch-test"; version = libtorch-bin.version; @@ -8,7 +30,11 @@ stdenv.mkDerivation { nativeBuildInputs = [ cmake ]; - buildInputs = [ libtorch-bin ]; + buildInputs = [ libtorch-bin ] ++ + lib.optionals cudaSupport [ cudnn ]; + + cmakeFlags = lib.optionals cudaSupport + [ "-DCUDA_TOOLKIT_ROOT_DIR=${cudatoolkit_joined}" ]; doCheck = true; @@ -17,6 +43,7 @@ stdenv.mkDerivation { ''; checkPhase = '' - ./test + LD_LIBRARY_PATH=${cudaStub}''${LD_LIBRARY_PATH:+:}$LD_LIBRARY_PATH \ + ./test ''; } |