diff options
Diffstat (limited to 'nixpkgs/pkgs/development/libraries/onnxruntime/default.nix')
-rw-r--r-- | nixpkgs/pkgs/development/libraries/onnxruntime/default.nix | 65 |
1 files changed, 53 insertions, 12 deletions
diff --git a/nixpkgs/pkgs/development/libraries/onnxruntime/default.nix b/nixpkgs/pkgs/development/libraries/onnxruntime/default.nix index 6faa3088fa3c..af4d061d015b 100644 --- a/nixpkgs/pkgs/development/libraries/onnxruntime/default.nix +++ b/nixpkgs/pkgs/development/libraries/onnxruntime/default.nix @@ -1,7 +1,7 @@ -{ stdenv +{ config +, stdenv , lib , fetchFromGitHub -, fetchFromGitLab , Foundation , abseil-cpp , cmake @@ -18,10 +18,22 @@ , iconv , protobuf_21 , pythonSupport ? true -}: +, cudaSupport ? config.cudaSupport +, cudaPackages ? {} +}@inputs: let + version = "1.16.3"; + + stdenv = throw "Use effectiveStdenv instead"; + effectiveStdenv = if cudaSupport then cudaPackages.backendStdenv else inputs.stdenv; + + cudaCapabilities = cudaPackages.cudaFlags.cudaCapabilities; + # E.g. [ "80" "86" "90" ] + cudaArchitectures = (builtins.map cudaPackages.cudaFlags.dropDot cudaCapabilities); + cudaArchitecturesString = lib.strings.concatStringsSep ";" cudaArchitectures; + howard-hinnant-date = fetchFromGitHub { owner = "HowardHinnant"; repo = "date"; @@ -74,10 +86,17 @@ let rev = "refs/tags/v1.14.1"; hash = "sha256-ZVSdk6LeAiZpQrrzLxphMbc1b3rNUMpcxcXPP8s/5tE="; }; + + cutlass = fetchFromGitHub { + owner = "NVIDIA"; + repo = "cutlass"; + rev = "v3.0.0"; + sha256 = "sha256-YPD5Sy6SvByjIcGtgeGH80TEKg2BtqJWSg46RvnJChY="; + }; in -stdenv.mkDerivation rec { +effectiveStdenv.mkDerivation rec { pname = "onnxruntime"; - version = "1.16.3"; + inherit version; src = fetchFromGitHub { owner = "microsoft"; @@ -96,6 +115,10 @@ stdenv.mkDerivation rec { # - use MakeAvailable instead of the low-level Populate, # - use Eigen3::Eigen as the target name (as declared by libeigen/eigen). ./0001-eigen-allow-dependency-injection.patch + ] ++ lib.optionals cudaSupport [ + # We apply the referenced 1064.patch ourselves to our nix dependency. + # FIND_PACKAGE_ARGS for CUDA was added in https://github.com/microsoft/onnxruntime/commit/87744e5 so it might be possible to delete this patch after upgrading to 1.17.0 + ./nvcc-gsl.patch ]; nativeBuildInputs = [ @@ -109,7 +132,9 @@ stdenv.mkDerivation rec { pythonOutputDistHook setuptools wheel - ]); + ]) ++ lib.optionals cudaSupport [ + cudaPackages.cuda_nvcc + ]; buildInputs = [ eigen @@ -118,16 +143,24 @@ stdenv.mkDerivation rec { nlohmann_json microsoft-gsl ] ++ lib.optionals pythonSupport (with python3Packages; [ + gtest' numpy pybind11 packaging - ]) ++ lib.optionals stdenv.isDarwin [ + ]) ++ lib.optionals effectiveStdenv.isDarwin [ Foundation iconv - ]; + ] ++ lib.optionals cudaSupport (with cudaPackages; [ + cuda_cccl # cub/cub.cuh + libcublas # cublas_v2.h + libcurand # curand.h + libcusparse # cusparse.h + libcufft # cufft.h + cudnn # cudnn.h + cuda_cudart + ]); nativeCheckInputs = lib.optionals pythonSupport (with python3Packages; [ - gtest' pytest sympy onnx @@ -159,23 +192,31 @@ stdenv.mkDerivation rec { "-Donnxruntime_BUILD_UNIT_TESTS=ON" "-Donnxruntime_ENABLE_LTO=ON" "-Donnxruntime_USE_FULL_PROTOBUF=OFF" + (lib.cmakeBool "onnxruntime_USE_CUDA" cudaSupport) + (lib.cmakeBool "onnxruntime_USE_NCCL" cudaSupport) ] ++ lib.optionals pythonSupport [ "-Donnxruntime_ENABLE_PYTHON=ON" + ] ++ lib.optionals cudaSupport [ + (lib.cmakeFeature "FETCHCONTENT_SOURCE_DIR_CUTLASS" cutlass) + (lib.cmakeFeature "onnxruntime_CUDNN_HOME" cudaPackages.cudnn) + (lib.cmakeFeature "CMAKE_CUDA_ARCHITECTURES" cudaArchitecturesString) ]; - env = lib.optionalAttrs stdenv.cc.isClang { + env = lib.optionalAttrs effectiveStdenv.cc.isClang { NIX_CFLAGS_COMPILE = toString [ "-Wno-error=deprecated-declarations" "-Wno-error=unused-but-set-variable" ]; }; - doCheck = true; + doCheck = !cudaSupport; + + requiredSystemFeatures = lib.optionals cudaSupport [ "big-parallel" ]; postPatch = '' substituteInPlace cmake/libonnxruntime.pc.cmake.in \ --replace-fail '$'{prefix}/@CMAKE_INSTALL_ @CMAKE_INSTALL_ - '' + lib.optionalString (stdenv.hostPlatform.system == "aarch64-linux") '' + '' + lib.optionalString (effectiveStdenv.hostPlatform.system == "aarch64-linux") '' # https://github.com/NixOS/nixpkgs/pull/226734#issuecomment-1663028691 rm -v onnxruntime/test/optimizer/nhwc_transformer_test.cc ''; |