about summary refs log tree commit diff
path: root/nixpkgs/pkgs/development/libraries/onnxruntime
diff options
context:
space:
mode:
authorAlyssa Ross <hi@alyssa.is>2024-02-20 12:16:56 +0100
committerAlyssa Ross <hi@alyssa.is>2024-02-20 12:16:56 +0100
commitb24d64b3b1ef897f07cd072a88a9881cb330aa7f (patch)
treea87bb2eed9af3ef1efd51dd65221d91f0c949041 /nixpkgs/pkgs/development/libraries/onnxruntime
parent73338df7473bb3810e70a16b8b0cba4f0f606f2b (diff)
parentfa15b53dbea5028db38d6e09b4cef6eba42aeebb (diff)
downloadnixlib-b24d64b3b1ef897f07cd072a88a9881cb330aa7f.tar
nixlib-b24d64b3b1ef897f07cd072a88a9881cb330aa7f.tar.gz
nixlib-b24d64b3b1ef897f07cd072a88a9881cb330aa7f.tar.bz2
nixlib-b24d64b3b1ef897f07cd072a88a9881cb330aa7f.tar.lz
nixlib-b24d64b3b1ef897f07cd072a88a9881cb330aa7f.tar.xz
nixlib-b24d64b3b1ef897f07cd072a88a9881cb330aa7f.tar.zst
nixlib-b24d64b3b1ef897f07cd072a88a9881cb330aa7f.zip
Merge branch 'nixos-unstable-small' of https://github.com/NixOS/nixpkgs
Diffstat (limited to 'nixpkgs/pkgs/development/libraries/onnxruntime')
-rw-r--r--nixpkgs/pkgs/development/libraries/onnxruntime/default.nix65
-rw-r--r--nixpkgs/pkgs/development/libraries/onnxruntime/nvcc-gsl.patch32
2 files changed, 85 insertions, 12 deletions
diff --git a/nixpkgs/pkgs/development/libraries/onnxruntime/default.nix b/nixpkgs/pkgs/development/libraries/onnxruntime/default.nix
index 6faa3088fa3c..af4d061d015b 100644
--- a/nixpkgs/pkgs/development/libraries/onnxruntime/default.nix
+++ b/nixpkgs/pkgs/development/libraries/onnxruntime/default.nix
@@ -1,7 +1,7 @@
-{ stdenv
+{ config
+, stdenv
 , lib
 , fetchFromGitHub
-, fetchFromGitLab
 , Foundation
 , abseil-cpp
 , cmake
@@ -18,10 +18,22 @@
 , iconv
 , protobuf_21
 , pythonSupport ? true
-}:
+, cudaSupport ? config.cudaSupport
+, cudaPackages ? {}
+}@inputs:
 
 
 let
+  version = "1.16.3";
+
+  stdenv = throw "Use effectiveStdenv instead";
+  effectiveStdenv = if cudaSupport then cudaPackages.backendStdenv else inputs.stdenv;
+
+  cudaCapabilities = cudaPackages.cudaFlags.cudaCapabilities;
+  # E.g. [ "80" "86" "90" ]
+  cudaArchitectures = (builtins.map cudaPackages.cudaFlags.dropDot cudaCapabilities);
+  cudaArchitecturesString = lib.strings.concatStringsSep ";" cudaArchitectures;
+
   howard-hinnant-date = fetchFromGitHub {
     owner = "HowardHinnant";
     repo = "date";
@@ -74,10 +86,17 @@ let
     rev = "refs/tags/v1.14.1";
     hash = "sha256-ZVSdk6LeAiZpQrrzLxphMbc1b3rNUMpcxcXPP8s/5tE=";
   };
+
+   cutlass = fetchFromGitHub {
+    owner = "NVIDIA";
+    repo = "cutlass";
+    rev = "v3.0.0";
+    sha256 = "sha256-YPD5Sy6SvByjIcGtgeGH80TEKg2BtqJWSg46RvnJChY=";
+   };
 in
-stdenv.mkDerivation rec {
+effectiveStdenv.mkDerivation rec {
   pname = "onnxruntime";
-  version = "1.16.3";
+  inherit version;
 
   src = fetchFromGitHub {
     owner = "microsoft";
@@ -96,6 +115,10 @@ stdenv.mkDerivation rec {
     # - use MakeAvailable instead of the low-level Populate,
     # - use Eigen3::Eigen as the target name (as declared by libeigen/eigen).
     ./0001-eigen-allow-dependency-injection.patch
+  ] ++ lib.optionals cudaSupport [
+    # We apply the referenced 1064.patch ourselves to our nix dependency.
+    #  FIND_PACKAGE_ARGS for CUDA was added in https://github.com/microsoft/onnxruntime/commit/87744e5 so it might be possible to delete this patch after upgrading to 1.17.0
+    ./nvcc-gsl.patch
   ];
 
   nativeBuildInputs = [
@@ -109,7 +132,9 @@ stdenv.mkDerivation rec {
     pythonOutputDistHook
     setuptools
     wheel
-  ]);
+  ]) ++ lib.optionals cudaSupport [
+    cudaPackages.cuda_nvcc
+  ];
 
   buildInputs = [
     eigen
@@ -118,16 +143,24 @@ stdenv.mkDerivation rec {
     nlohmann_json
     microsoft-gsl
   ] ++ lib.optionals pythonSupport (with python3Packages; [
+    gtest'
     numpy
     pybind11
     packaging
-  ]) ++ lib.optionals stdenv.isDarwin [
+  ]) ++ lib.optionals effectiveStdenv.isDarwin [
     Foundation
     iconv
-  ];
+  ] ++ lib.optionals cudaSupport (with cudaPackages; [
+    cuda_cccl # cub/cub.cuh
+    libcublas # cublas_v2.h
+    libcurand # curand.h
+    libcusparse # cusparse.h
+    libcufft # cufft.h
+    cudnn # cudnn.h
+    cuda_cudart
+  ]);
 
   nativeCheckInputs = lib.optionals pythonSupport (with python3Packages; [
-    gtest'
     pytest
     sympy
     onnx
@@ -159,23 +192,31 @@ stdenv.mkDerivation rec {
     "-Donnxruntime_BUILD_UNIT_TESTS=ON"
     "-Donnxruntime_ENABLE_LTO=ON"
     "-Donnxruntime_USE_FULL_PROTOBUF=OFF"
+    (lib.cmakeBool "onnxruntime_USE_CUDA" cudaSupport)
+    (lib.cmakeBool "onnxruntime_USE_NCCL" cudaSupport)
   ] ++ lib.optionals pythonSupport [
     "-Donnxruntime_ENABLE_PYTHON=ON"
+  ] ++ lib.optionals cudaSupport [
+    (lib.cmakeFeature "FETCHCONTENT_SOURCE_DIR_CUTLASS" cutlass)
+    (lib.cmakeFeature "onnxruntime_CUDNN_HOME" cudaPackages.cudnn)
+    (lib.cmakeFeature "CMAKE_CUDA_ARCHITECTURES" cudaArchitecturesString)
   ];
 
-  env = lib.optionalAttrs stdenv.cc.isClang {
+  env = lib.optionalAttrs effectiveStdenv.cc.isClang {
     NIX_CFLAGS_COMPILE = toString [
       "-Wno-error=deprecated-declarations"
       "-Wno-error=unused-but-set-variable"
     ];
   };
 
-  doCheck = true;
+  doCheck = !cudaSupport;
+
+  requiredSystemFeatures = lib.optionals cudaSupport [ "big-parallel" ];
 
   postPatch = ''
     substituteInPlace cmake/libonnxruntime.pc.cmake.in \
       --replace-fail '$'{prefix}/@CMAKE_INSTALL_ @CMAKE_INSTALL_
-  '' + lib.optionalString (stdenv.hostPlatform.system == "aarch64-linux") ''
+  '' + lib.optionalString (effectiveStdenv.hostPlatform.system == "aarch64-linux") ''
     # https://github.com/NixOS/nixpkgs/pull/226734#issuecomment-1663028691
     rm -v onnxruntime/test/optimizer/nhwc_transformer_test.cc
   '';
diff --git a/nixpkgs/pkgs/development/libraries/onnxruntime/nvcc-gsl.patch b/nixpkgs/pkgs/development/libraries/onnxruntime/nvcc-gsl.patch
new file mode 100644
index 000000000000..948de62e7e75
--- /dev/null
+++ b/nixpkgs/pkgs/development/libraries/onnxruntime/nvcc-gsl.patch
@@ -0,0 +1,32 @@
+diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake
+index 9effd1a2db..faff5e8de7 100644
+--- a/cmake/external/onnxruntime_external_deps.cmake
++++ b/cmake/external/onnxruntime_external_deps.cmake
+@@ -280,21 +280,12 @@ if (NOT WIN32)
+   endif()
+ endif()
+ 
+-if(onnxruntime_USE_CUDA)
+-  FetchContent_Declare(
+-    GSL
+-    URL ${DEP_URL_microsoft_gsl}
+-    URL_HASH SHA1=${DEP_SHA1_microsoft_gsl}
+-    PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/gsl/1064.patch
+-  )
+-else()
+-  FetchContent_Declare(
+-    GSL
+-    URL ${DEP_URL_microsoft_gsl}
+-    URL_HASH SHA1=${DEP_SHA1_microsoft_gsl}
+-    FIND_PACKAGE_ARGS 4.0 NAMES Microsoft.GSL
+-  )
+-endif()
++FetchContent_Declare(
++  GSL
++  URL ${DEP_URL_microsoft_gsl}
++  URL_HASH SHA1=${DEP_SHA1_microsoft_gsl}
++  FIND_PACKAGE_ARGS 4.0 NAMES Microsoft.GSL
++)
+ 
+ FetchContent_Declare(
+     safeint