about summary refs log tree commit diff
path: root/nixpkgs/pkgs/development/libraries/science/math/cudnn/generic.nix
blob: cdfa924b2242af90ce5c34c5c3d240e85a718bdf (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
{ stdenv,
  backendStdenv,
  lib,
  zlib,
  useCudatoolkitRunfile ? false,
  cudaVersion,
  cudaMajorVersion,
  cudatoolkit, # For cuda < 11
  libcublas ? null, # cuda <11 doesn't ship redist packages
  autoPatchelfHook,
  autoAddOpenGLRunpathHook,
  fetchurl,
  # The distributed version of CUDNN includes both dynamically liked .so files,
  # as well as statically linked .a files.  However, CUDNN is quite large
  # (multiple gigabytes), so you can save some space in your nix store by
  # removing the statically linked libraries if you are not using them.
  #
  # Setting this to true removes the statically linked .a files.
  # Setting this to false keeps these statically linked .a files.
  removeStatic ? false,
}: {
  version,
  url,
  hash,
  minCudaVersion,
  maxCudaVersion,
}:
assert useCudatoolkitRunfile || (libcublas != null); let
  inherit (lib) lists strings trivial versions;

  # majorMinorPatch :: String -> String
  majorMinorPatch = (trivial.flip trivial.pipe) [
    (versions.splitVersion)
    (lists.take 3)
    (strings.concatStringsSep ".")
  ];

  # versionTriple :: String
  # Version with three components: major.minor.patch
  versionTriple = majorMinorPatch version;

  # cudatoolkit_root :: Derivation
  cudatoolkit_root =
    if useCudatoolkitRunfile
    then cudatoolkit
    else libcublas;
in
  backendStdenv.mkDerivation {
    pname = "cudatoolkit-${cudaMajorVersion}-cudnn";
    version = versionTriple;

    src = fetchurl {
      inherit url hash;
    };

    # Check and normalize Runpath against DT_NEEDED using autoPatchelf.
    # Prepend /run/opengl-driver/lib using addOpenGLRunpath for dlopen("libcudacuda.so")
    nativeBuildInputs = [
      autoPatchelfHook
      autoAddOpenGLRunpathHook
    ];

    # Used by autoPatchelfHook
    buildInputs = [
      # Note this libstdc++ isn't from the (possibly older) nvcc-compatible
      # stdenv, but from the (newer) stdenv that the rest of nixpkgs uses
      stdenv.cc.cc.lib

      zlib
      cudatoolkit_root
    ];

    # We used to patch Runpath here, but now we use autoPatchelfHook
    #
    # Note also that version <=8.3.0 contained a subdirectory "lib64/" but in
    # version 8.3.2 it seems to have been renamed to simply "lib/".
    installPhase =
      ''
        runHook preInstall

        mkdir -p $out
        cp -a include $out/include
        [ -d "lib/" ] && cp -a lib $out/lib
        [ -d "lib64/" ] && cp -a lib64 $out/lib64
      ''
      + strings.optionalString removeStatic ''
        rm -f $out/lib/*.a
        rm -f $out/lib64/*.a
      ''
      + ''
        runHook postInstall
      '';

    # Without --add-needed autoPatchelf forgets $ORIGIN on cuda>=8.0.5.
    postFixup = strings.optionalString (strings.versionAtLeast versionTriple "8.0.5") ''
      patchelf $out/lib/libcudnn.so --add-needed libcudnn_cnn_infer.so
    '';

    passthru = {
      inherit useCudatoolkitRunfile;

      cudatoolkit =
        trivial.warn
        ''
          cudnn.cudatoolkit passthru attribute is deprecated;
          if your derivation uses cudnn directly, it should probably consume cudaPackages instead
        ''
        cudatoolkit;

      majorVersion = versions.major versionTriple;
    };

    meta = with lib; {
      # Check that the cudatoolkit version satisfies our min/max constraints (both
      # inclusive). We mark the package as broken if it fails to satisfies the
      # official version constraints (as recorded in default.nix). In some cases
      # you _may_ be able to smudge version constraints, just know that you're
      # embarking into unknown and unsupported territory when doing so.
      broken =
        strings.versionOlder cudaVersion minCudaVersion
        || strings.versionOlder maxCudaVersion cudaVersion;
      description = "NVIDIA CUDA Deep Neural Network library (cuDNN)";
      homepage = "https://developer.nvidia.com/cudnn";
      sourceProvenance = with sourceTypes; [binaryNativeCode];
      # TODO: consider marking unfreRedistributable when not using runfile
      license = licenses.unfree;
      platforms = ["x86_64-linux"];
      maintainers = with maintainers; [mdaiter samuela];
    };
  }