about summary refs log tree commit diff
path: root/nixpkgs/pkgs/development/cuda-modules/tensorrt/fixup.nix
blob: d713189328ed794eb4bb36394d23468fa8007eb1 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
{
  cudaVersion,
  final,
  hostPlatform,
  lib,
  mkVersionedPackageName,
  package,
  patchelf,
  requireFile,
  ...
}:
let
  inherit (lib)
    maintainers
    meta
    strings
    versions
    ;
in
finalAttrs: prevAttrs: {
  # Useful for inspecting why something went wrong.
  brokenConditions =
    let
      cudaTooOld = strings.versionOlder cudaVersion package.minCudaVersion;
      cudaTooNew =
        (package.maxCudaVersion != null) && strings.versionOlder package.maxCudaVersion cudaVersion;
      cudnnVersionIsSpecified = package.cudnnVersion != null;
      cudnnVersionSpecified = versions.majorMinor package.cudnnVersion;
      cudnnVersionProvided = versions.majorMinor finalAttrs.passthru.cudnn.version;
      cudnnTooOld =
        cudnnVersionIsSpecified && (strings.versionOlder cudnnVersionProvided cudnnVersionSpecified);
      cudnnTooNew =
        cudnnVersionIsSpecified && (strings.versionOlder cudnnVersionSpecified cudnnVersionProvided);
    in
    prevAttrs.brokenConditions
    // {
      "CUDA version is too old" = cudaTooOld;
      "CUDA version is too new" = cudaTooNew;
      "CUDNN version is too old" = cudnnTooOld;
      "CUDNN version is too new" = cudnnTooNew;
    };

  src = requireFile {
    name = package.filename;
    inherit (package) hash;
    message = ''
      To use the TensorRT derivation, you must join the NVIDIA Developer Program and
      download the ${package.version} TAR package for CUDA ${cudaVersion} from
      ${finalAttrs.meta.homepage}.

      Once you have downloaded the file, add it to the store with the following
      command, and try building this derivation again.

      $ nix-store --add-fixed sha256 ${package.filename}
    '';
  };

  # We need to look inside the extracted output to get the files we need.
  sourceRoot = "TensorRT-${finalAttrs.version}";

  buildInputs = prevAttrs.buildInputs ++ [finalAttrs.passthru.cudnn.lib];

  preInstall =
    let
      targetArch =
        if hostPlatform.isx86_64 then
          "x86_64-linux-gnu"
        else if hostPlatform.isAarch64 then
          "aarch64-linux-gnu"
        else
          throw "Unsupported architecture";
    in
    (prevAttrs.preInstall or "")
    + ''
      # Replace symlinks to bin and lib with the actual directories from targets.
      for dir in bin lib; do
        rm "$dir"
        mv "targets/${targetArch}/$dir" "$dir"
      done
    '';

  # Tell autoPatchelf about runtime dependencies.
  postFixup =
    let
      versionTriple = "${versions.majorMinor finalAttrs.version}.${versions.patch finalAttrs.version}";
    in
    (prevAttrs.postFixup or "")
    + ''
      ${meta.getExe' patchelf "patchelf"} --add-needed libnvinfer.so \
        "$lib/lib/libnvinfer.so.${versionTriple}" \
        "$lib/lib/libnvinfer_plugin.so.${versionTriple}" \
        "$lib/lib/libnvinfer_builder_resource.so.${versionTriple}"
    '';

  passthru = {
    useCudatoolkitRunfile = strings.versionOlder cudaVersion "11.3.999";
    # The CUDNN used with TensorRT.
    # If null, the default cudnn derivation will be used.
    # If a version is specified, the cudnn derivation with that version will be used,
    # unless it is not available, in which case the default cudnn derivation will be used.
    cudnn =
      let
        desiredName = mkVersionedPackageName "cudnn" package.cudnnVersion;
        desiredIsAvailable = final ? desiredName;
      in
      if package.cudnnVersion == null || !desiredIsAvailable then final.cudnn else final.${desiredName};
  };

  meta = prevAttrs.meta // {
    homepage = "https://developer.nvidia.com/tensorrt";
    maintainers = prevAttrs.meta.maintainers ++ [maintainers.aidalgol];
  };
}