about summary refs log tree commit diff
path: root/nixpkgs/pkgs/development/cuda-modules/flags.nix
blob: d5e01be01fd51b49d804c31ee7c251079abf6061 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
# Type aliases
# Gpu :: AttrSet
#   - See the documentation in ./gpus.nix.
{
  config,
  cudaCapabilities ? (config.cudaCapabilities or []),
  cudaForwardCompat ? (config.cudaForwardCompat or true),
  lib,
  cudaVersion,
  hostPlatform,
  # gpus :: List Gpu
  gpus,
}:
let
  inherit (lib)
    asserts
    attrsets
    lists
    strings
    trivial
    ;

  # Flags are determined based on your CUDA toolkit by default.  You may benefit
  # from improved performance, reduced file size, or greater hardware support by
  # passing a configuration based on your specific GPU environment.
  #
  # cudaCapabilities :: List Capability
  # List of hardware generations to build.
  # E.g. [ "8.0" ]
  # Currently, the last item is considered the optional forward-compatibility arch,
  # but this may change in the future.
  #
  # cudaForwardCompat :: Bool
  # Whether to include the forward compatibility gencode (+PTX)
  # to support future GPU generations.
  # E.g. true
  #
  # Please see the accompanying documentation or https://github.com/NixOS/nixpkgs/pull/205351

  # isSupported :: Gpu -> Bool
  isSupported =
    gpu:
    let
      inherit (gpu) minCudaVersion maxCudaVersion;
      lowerBoundSatisfied = strings.versionAtLeast cudaVersion minCudaVersion;
      upperBoundSatisfied =
        (maxCudaVersion == null) || !(strings.versionOlder maxCudaVersion cudaVersion);
    in
    lowerBoundSatisfied && upperBoundSatisfied;

  # NOTE: Jetson is never built by default.
  # isDefault :: Gpu -> Bool
  isDefault =
    gpu:
    let
      inherit (gpu) dontDefaultAfter isJetson;
      newGpu = dontDefaultAfter == null;
      recentGpu = newGpu || strings.versionAtLeast dontDefaultAfter cudaVersion;
    in
    recentGpu && !isJetson;

  # supportedGpus :: List Gpu
  # GPUs which are supported by the provided CUDA version.
  supportedGpus = builtins.filter isSupported gpus;

  # defaultGpus :: List Gpu
  # GPUs which are supported by the provided CUDA version and we want to build for by default.
  defaultGpus = builtins.filter isDefault supportedGpus;

  # supportedCapabilities :: List Capability
  supportedCapabilities = lists.map (gpu: gpu.computeCapability) supportedGpus;

  # defaultCapabilities :: List Capability
  # The default capabilities to target, if not overridden by the user.
  defaultCapabilities = lists.map (gpu: gpu.computeCapability) defaultGpus;

  # cudaArchNameToVersions :: AttrSet String (List String)
  # Maps the name of a GPU architecture to different versions of that architecture.
  # For example, "Ampere" maps to [ "8.0" "8.6" "8.7" ].
  cudaArchNameToVersions =
    lists.groupBy' (versions: gpu: versions ++ [gpu.computeCapability]) [] (gpu: gpu.archName)
      supportedGpus;

  # cudaComputeCapabilityToName :: AttrSet String String
  # Maps the version of a GPU architecture to the name of that architecture.
  # For example, "8.0" maps to "Ampere".
  cudaComputeCapabilityToName = builtins.listToAttrs (
    lists.map (gpu: attrsets.nameValuePair gpu.computeCapability gpu.archName) supportedGpus
  );

  # cudaComputeCapabilityToIsJetson :: AttrSet String Boolean
  cudaComputeCapabilityToIsJetson = builtins.listToAttrs (
    lists.map (attrs: attrsets.nameValuePair attrs.computeCapability attrs.isJetson) supportedGpus
  );

  # jetsonComputeCapabilities :: List String
  jetsonComputeCapabilities = trivial.pipe cudaComputeCapabilityToIsJetson [
    (attrsets.filterAttrs (_: isJetson: isJetson))
    builtins.attrNames
  ];

  # Find the intersection with the user-specified list of cudaCapabilities.
  # NOTE: Jetson devices are never built by default because they cannot be targeted along with
  # non-Jetson devices and require an aarch64 host platform. As such, if they're present anywhere,
  # they must be in the user-specified cudaCapabilities.
  # NOTE: We don't need to worry about mixes of Jetson and non-Jetson devices here -- there's
  # sanity-checking for all that in below.
  jetsonTargets = lists.intersectLists jetsonComputeCapabilities cudaCapabilities;

  # dropDot :: String -> String
  dropDot = ver: builtins.replaceStrings ["."] [""] ver;

  # archMapper :: String -> List String -> List String
  # Maps a feature across a list of architecture versions to produce a list of architectures.
  # For example, "sm" and [ "8.0" "8.6" "8.7" ] produces [ "sm_80" "sm_86" "sm_87" ].
  archMapper = feat: lists.map (computeCapability: "${feat}_${dropDot computeCapability}");

  # gencodeMapper :: String -> List String -> List String
  # Maps a feature across a list of architecture versions to produce a list of gencode arguments.
  # For example, "sm" and [ "8.0" "8.6" "8.7" ] produces [ "-gencode=arch=compute_80,code=sm_80"
  # "-gencode=arch=compute_86,code=sm_86" "-gencode=arch=compute_87,code=sm_87" ].
  gencodeMapper =
    feat:
    lists.map (
      computeCapability:
      "-gencode=arch=compute_${dropDot computeCapability},code=${feat}_${dropDot computeCapability}"
    );

  # Maps Nix system to NVIDIA redist arch.
  # NOTE: We swap out the default `linux-sbsa` redist (for server-grade ARM chips) with the
  # `linux-aarch64` redist (which is for Jetson devices) if we're building any Jetson devices.
  # Since both are based on aarch64, we can only have one or the other, otherwise there's an
  # ambiguity as to which should be used.
  # NOTE: This function *will* be called by unsupported systems because `cudaPackages` is part of
  # `all-packages.nix`, which is evaluated on all systems. As such, we need to handle unsupported
  # systems gracefully.
  # getRedistArch :: String -> String
  getRedistArch = nixSystem: attrsets.attrByPath [ nixSystem ] "unsupported" {
    aarch64-linux = if jetsonTargets != [] then "linux-aarch64" else "linux-sbsa";
    x86_64-linux = "linux-x86_64";
    ppc64le-linux = "linux-ppc64le";
    x86_64-windows = "windows-x86_64";
  };

  # Maps NVIDIA redist arch to Nix system.
  # NOTE: This function *will* be called by unsupported systems because `cudaPackages` is part of
  # `all-packages.nix`, which is evaluated on all systems. As such, we need to handle unsupported
  # systems gracefully.
  # getNixSystem :: String -> String
  getNixSystem = redistArch: attrsets.attrByPath [ redistArch ] "unsupported-${redistArch}" {
    linux-sbsa = "aarch64-linux";
    linux-aarch64 = "aarch64-linux";
    linux-x86_64 = "x86_64-linux";
    linux-ppc64le = "ppc64le-linux";
    windows-x86_64 = "x86_64-windows";
  };

  formatCapabilities =
    {
      cudaCapabilities,
      enableForwardCompat ? true,
    }:
    rec {
      inherit cudaCapabilities enableForwardCompat;

      # archNames :: List String
      # E.g. [ "Turing" "Ampere" ]
      #
      # Unknown architectures are rendered as sm_XX gencode flags.
      archNames = lists.unique (
        lists.map (cap: cudaComputeCapabilityToName.${cap} or "sm_${dropDot cap}") cudaCapabilities
      );

      # realArches :: List String
      # The real architectures are physical architectures supported by the CUDA version.
      # E.g. [ "sm_75" "sm_86" ]
      realArches = archMapper "sm" cudaCapabilities;

      # virtualArches :: List String
      # The virtual architectures are typically used for forward compatibility, when trying to support
      # an architecture newer than the CUDA version allows.
      # E.g. [ "compute_75" "compute_86" ]
      virtualArches = archMapper "compute" cudaCapabilities;

      # arches :: List String
      # By default, build for all supported architectures and forward compatibility via a virtual
      # architecture for the newest supported architecture.
      # E.g. [ "sm_75" "sm_86" "compute_86" ]
      arches = realArches ++ lists.optional enableForwardCompat (lists.last virtualArches);

      # gencode :: List String
      # A list of CUDA gencode arguments to pass to NVCC.
      # E.g. [ "-gencode=arch=compute_75,code=sm_75" ... "-gencode=arch=compute_86,code=compute_86" ]
      gencode =
        let
          base = gencodeMapper "sm" cudaCapabilities;
          forward = gencodeMapper "compute" [(lists.last cudaCapabilities)];
        in
        base ++ lib.optionals enableForwardCompat forward;

      # gencodeString :: String
      # A space-separated string of CUDA gencode arguments to pass to NVCC.
      # E.g. "-gencode=arch=compute_75,code=sm_75 ... -gencode=arch=compute_86,code=compute_86"
      gencodeString = strings.concatStringsSep " " gencode;

      # Jetson devices cannot be targeted by the same binaries which target non-Jetson devices. While
      # NVIDIA provides both `linux-aarch64` and `linux-sbsa` packages, which both target `aarch64`,
      # they are built with different settings and cannot be mixed.
      # isJetsonBuild :: Boolean
      isJetsonBuild =
        let
          requestedJetsonDevices =
            lists.filter (cap: cudaComputeCapabilityToIsJetson.${cap} or false)
              cudaCapabilities;
          requestedNonJetsonDevices =
            lists.filter (cap: !(builtins.elem cap requestedJetsonDevices))
              cudaCapabilities;
          jetsonBuildSufficientCondition = requestedJetsonDevices != [];
          jetsonBuildNecessaryCondition = requestedNonJetsonDevices == [] && hostPlatform.isAarch64;
        in
        trivial.throwIf (jetsonBuildSufficientCondition && !jetsonBuildNecessaryCondition)
          ''
            Jetson devices cannot be targeted with non-Jetson devices. Additionally, they require hostPlatform to be aarch64.
            You requested ${builtins.toJSON cudaCapabilities} for host platform ${hostPlatform.system}.
            Requested Jetson devices: ${builtins.toJSON requestedJetsonDevices}.
            Requested non-Jetson devices: ${builtins.toJSON requestedNonJetsonDevices}.
            Exactly one of the following must be true:
            - All CUDA capabilities belong to Jetson devices and hostPlatform is aarch64.
            - No CUDA capabilities belong to Jetson devices.
            See ${./gpus.nix} for a list of architectures supported by this version of Nixpkgs.
          ''
          jetsonBuildSufficientCondition
        && jetsonBuildNecessaryCondition;
    };
in
# When changing names or formats: pause, validate, and update the assert
assert let
  expected = {
    cudaCapabilities = [
      "7.5"
      "8.6"
    ];
    enableForwardCompat = true;

    archNames = [
      "Turing"
      "Ampere"
    ];
    realArches = [
      "sm_75"
      "sm_86"
    ];
    virtualArches = [
      "compute_75"
      "compute_86"
    ];
    arches = [
      "sm_75"
      "sm_86"
      "compute_86"
    ];

    gencode = [
      "-gencode=arch=compute_75,code=sm_75"
      "-gencode=arch=compute_86,code=sm_86"
      "-gencode=arch=compute_86,code=compute_86"
    ];
    gencodeString = "-gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_86,code=compute_86";

    isJetsonBuild = false;
  };
  actual = formatCapabilities {
    cudaCapabilities = [
      "7.5"
      "8.6"
    ];
  };
  actualWrapped = (builtins.tryEval (builtins.deepSeq actual actual)).value;
in
asserts.assertMsg ((strings.versionAtLeast cudaVersion "11.2") -> (expected == actualWrapped)) ''
  This test should only fail when using a version of CUDA older than 11.2, the first to support
  8.6.
  Expected: ${builtins.toJSON expected}
  Actual: ${builtins.toJSON actualWrapped}
'';
# Check mixed Jetson and non-Jetson devices
assert let
  expected = false;
  actual = formatCapabilities {
    cudaCapabilities = [
      "7.2"
      "7.5"
    ];
  };
  actualWrapped = (builtins.tryEval (builtins.deepSeq actual actual)).value;
in
asserts.assertMsg (expected == actualWrapped) ''
  Jetson devices capabilities cannot be mixed with non-jetson devices.
  Capability 7.5 is non-Jetson and should not be allowed with Jetson 7.2.
  Expected: ${builtins.toJSON expected}
  Actual: ${builtins.toJSON actualWrapped}
'';
# Check Jetson-only
assert let
  expected = {
    cudaCapabilities = [
      "6.2"
      "7.2"
    ];
    enableForwardCompat = true;

    archNames = [
      "Pascal"
      "Volta"
    ];
    realArches = [
      "sm_62"
      "sm_72"
    ];
    virtualArches = [
      "compute_62"
      "compute_72"
    ];
    arches = [
      "sm_62"
      "sm_72"
      "compute_72"
    ];

    gencode = [
      "-gencode=arch=compute_62,code=sm_62"
      "-gencode=arch=compute_72,code=sm_72"
      "-gencode=arch=compute_72,code=compute_72"
    ];
    gencodeString = "-gencode=arch=compute_62,code=sm_62 -gencode=arch=compute_72,code=sm_72 -gencode=arch=compute_72,code=compute_72";

    isJetsonBuild = true;
  };
  actual = formatCapabilities {
    cudaCapabilities = [
      "6.2"
      "7.2"
    ];
  };
  actualWrapped = (builtins.tryEval (builtins.deepSeq actual actual)).value;
in
asserts.assertMsg
  # We can't do this test unless we're targeting aarch64
  (hostPlatform.isAarch64 -> (expected == actualWrapped))
  ''
    Jetson devices can only be built with other Jetson devices.
    Both 6.2 and 7.2 are Jetson devices.
    Expected: ${builtins.toJSON expected}
    Actual: ${builtins.toJSON actualWrapped}
  '';
{
  # formatCapabilities :: { cudaCapabilities: List Capability, enableForwardCompat: Boolean } ->  { ... }
  inherit formatCapabilities;

  # cudaArchNameToVersions :: String => String
  inherit cudaArchNameToVersions;

  # cudaComputeCapabilityToName :: String => String
  inherit cudaComputeCapabilityToName;

  # dropDot :: String -> String
  inherit dropDot;

  inherit
    defaultCapabilities
    supportedCapabilities
    jetsonComputeCapabilities
    jetsonTargets
    getNixSystem
    getRedistArch
    ;
}
// formatCapabilities {
  cudaCapabilities = if cudaCapabilities == [] then defaultCapabilities else cudaCapabilities;
  enableForwardCompat = cudaForwardCompat;
}