about summary refs log tree commit diff
path: root/nixpkgs/pkgs/development/cuda-modules/flags.nix
diff options
context:
space:
mode:
authorAlyssa Ross <hi@alyssa.is>2023-12-15 19:32:38 +0100
committerAlyssa Ross <hi@alyssa.is>2023-12-15 19:32:38 +0100
commit6b8e2555ef013b579cda57025b17d662e0f1fe1f (patch)
tree5a83c673af26c9976acd5a5dfa20e09e06898047 /nixpkgs/pkgs/development/cuda-modules/flags.nix
parent66ca7a150b5c051f0728f13134e6265cc46f370c (diff)
parent02357adddd0889782362d999628de9d309d202dc (diff)
downloadnixlib-6b8e2555ef013b579cda57025b17d662e0f1fe1f.tar
nixlib-6b8e2555ef013b579cda57025b17d662e0f1fe1f.tar.gz
nixlib-6b8e2555ef013b579cda57025b17d662e0f1fe1f.tar.bz2
nixlib-6b8e2555ef013b579cda57025b17d662e0f1fe1f.tar.lz
nixlib-6b8e2555ef013b579cda57025b17d662e0f1fe1f.tar.xz
nixlib-6b8e2555ef013b579cda57025b17d662e0f1fe1f.tar.zst
nixlib-6b8e2555ef013b579cda57025b17d662e0f1fe1f.zip
Merge branch 'nixos-unstable-small' of https://github.com/NixOS/nixpkgs
Diffstat (limited to 'nixpkgs/pkgs/development/cuda-modules/flags.nix')
-rw-r--r--nixpkgs/pkgs/development/cuda-modules/flags.nix390
1 files changed, 390 insertions, 0 deletions
diff --git a/nixpkgs/pkgs/development/cuda-modules/flags.nix b/nixpkgs/pkgs/development/cuda-modules/flags.nix
new file mode 100644
index 000000000000..139952bc9dfd
--- /dev/null
+++ b/nixpkgs/pkgs/development/cuda-modules/flags.nix
@@ -0,0 +1,390 @@
+# Type aliases
+# Gpu :: AttrSet
+#   - See the documentation in ./gpus.nix.
+{
+  config,
+  cudaCapabilities ? (config.cudaCapabilities or []),
+  cudaForwardCompat ? (config.cudaForwardCompat or true),
+  lib,
+  cudaVersion,
+  hostPlatform,
+  # gpus :: List Gpu
+  gpus,
+}:
+let
+  inherit (lib)
+    asserts
+    attrsets
+    lists
+    strings
+    trivial
+    ;
+
+  # Flags are determined based on your CUDA toolkit by default.  You may benefit
+  # from improved performance, reduced file size, or greater hardware support by
+  # passing a configuration based on your specific GPU environment.
+  #
+  # cudaCapabilities :: List Capability
+  # List of hardware generations to build.
+  # E.g. [ "8.0" ]
+  # Currently, the last item is considered the optional forward-compatibility arch,
+  # but this may change in the future.
+  #
+  # cudaForwardCompat :: Bool
+  # Whether to include the forward compatibility gencode (+PTX)
+  # to support future GPU generations.
+  # E.g. true
+  #
+  # Please see the accompanying documentation or https://github.com/NixOS/nixpkgs/pull/205351
+
+  # isSupported :: Gpu -> Bool
+  isSupported =
+    gpu:
+    let
+      inherit (gpu) minCudaVersion maxCudaVersion;
+      lowerBoundSatisfied = strings.versionAtLeast cudaVersion minCudaVersion;
+      upperBoundSatisfied =
+        (maxCudaVersion == null) || !(strings.versionOlder maxCudaVersion cudaVersion);
+    in
+    lowerBoundSatisfied && upperBoundSatisfied;
+
+  # NOTE: Jetson is never built by default.
+  # isDefault :: Gpu -> Bool
+  isDefault =
+    gpu:
+    let
+      inherit (gpu) dontDefaultAfter isJetson;
+      newGpu = dontDefaultAfter == null;
+      recentGpu = newGpu || strings.versionAtLeast dontDefaultAfter cudaVersion;
+    in
+    recentGpu && !isJetson;
+
+  # supportedGpus :: List Gpu
+  # GPUs which are supported by the provided CUDA version.
+  supportedGpus = builtins.filter isSupported gpus;
+
+  # defaultGpus :: List Gpu
+  # GPUs which are supported by the provided CUDA version and we want to build for by default.
+  defaultGpus = builtins.filter isDefault supportedGpus;
+
+  # supportedCapabilities :: List Capability
+  supportedCapabilities = lists.map (gpu: gpu.computeCapability) supportedGpus;
+
+  # defaultCapabilities :: List Capability
+  # The default capabilities to target, if not overridden by the user.
+  defaultCapabilities = lists.map (gpu: gpu.computeCapability) defaultGpus;
+
+  # cudaArchNameToVersions :: AttrSet String (List String)
+  # Maps the name of a GPU architecture to different versions of that architecture.
+  # For example, "Ampere" maps to [ "8.0" "8.6" "8.7" ].
+  cudaArchNameToVersions =
+    lists.groupBy' (versions: gpu: versions ++ [gpu.computeCapability]) [] (gpu: gpu.archName)
+      supportedGpus;
+
+  # cudaComputeCapabilityToName :: AttrSet String String
+  # Maps the version of a GPU architecture to the name of that architecture.
+  # For example, "8.0" maps to "Ampere".
+  cudaComputeCapabilityToName = builtins.listToAttrs (
+    lists.map (gpu: attrsets.nameValuePair gpu.computeCapability gpu.archName) supportedGpus
+  );
+
+  # cudaComputeCapabilityToIsJetson :: AttrSet String Boolean
+  cudaComputeCapabilityToIsJetson = builtins.listToAttrs (
+    lists.map (attrs: attrsets.nameValuePair attrs.computeCapability attrs.isJetson) supportedGpus
+  );
+
+  # jetsonComputeCapabilities :: List String
+  jetsonComputeCapabilities = trivial.pipe cudaComputeCapabilityToIsJetson [
+    (attrsets.filterAttrs (_: isJetson: isJetson))
+    builtins.attrNames
+  ];
+
+  # Find the intersection with the user-specified list of cudaCapabilities.
+  # NOTE: Jetson devices are never built by default because they cannot be targeted along with
+  # non-Jetson devices and require an aarch64 host platform. As such, if they're present anywhere,
+  # they must be in the user-specified cudaCapabilities.
+  # NOTE: We don't need to worry about mixes of Jetson and non-Jetson devices here -- there's
+  # sanity-checking for all that in below.
+  jetsonTargets = lists.intersectLists jetsonComputeCapabilities cudaCapabilities;
+
+  # dropDot :: String -> String
+  dropDot = ver: builtins.replaceStrings ["."] [""] ver;
+
+  # archMapper :: String -> List String -> List String
+  # Maps a feature across a list of architecture versions to produce a list of architectures.
+  # For example, "sm" and [ "8.0" "8.6" "8.7" ] produces [ "sm_80" "sm_86" "sm_87" ].
+  archMapper = feat: lists.map (computeCapability: "${feat}_${dropDot computeCapability}");
+
+  # gencodeMapper :: String -> List String -> List String
+  # Maps a feature across a list of architecture versions to produce a list of gencode arguments.
+  # For example, "sm" and [ "8.0" "8.6" "8.7" ] produces [ "-gencode=arch=compute_80,code=sm_80"
+  # "-gencode=arch=compute_86,code=sm_86" "-gencode=arch=compute_87,code=sm_87" ].
+  gencodeMapper =
+    feat:
+    lists.map (
+      computeCapability:
+      "-gencode=arch=compute_${dropDot computeCapability},code=${feat}_${dropDot computeCapability}"
+    );
+
+  # Maps Nix system to NVIDIA redist arch.
+  # NOTE: We swap out the default `linux-sbsa` redist (for server-grade ARM chips) with the
+  # `linux-aarch64` redist (which is for Jetson devices) if we're building any Jetson devices.
+  # Since both are based on aarch64, we can only have one or the other, otherwise there's an
+  # ambiguity as to which should be used.
+  # getRedistArch :: String -> String
+  getRedistArch =
+    nixSystem:
+    if nixSystem == "aarch64-linux" then
+      if jetsonTargets != [] then "linux-aarch64" else "linux-sbsa"
+    else if nixSystem == "x86_64-linux" then
+      "linux-x86_64"
+    else if nixSystem == "ppc64le-linux" then
+      "linux-ppc64le"
+    else if nixSystem == "x86_64-windows" then
+      "windows-x86_64"
+    else
+      builtins.throw "Unsupported Nix system: ${nixSystem}";
+
+  # Maps NVIDIA redist arch to Nix system.
+  # It is imperative that we include the boolean condition based on jetsonTargets to ensure
+  # we don't advertise availability of packages only available on server-grade ARM
+  # as being available for the Jetson, since both `linux-sbsa` and `linux-aarch64` are
+  # mapped to the Nix system `aarch64-linux`.
+  getNixSystem =
+    redistArch:
+    if redistArch == "linux-sbsa" && jetsonTargets == [] then
+      "aarch64-linux"
+    else if redistArch == "linux-aarch64" && jetsonTargets != [] then
+      "aarch64-linux"
+    else if redistArch == "linux-x86_64" then
+      "x86_64-linux"
+    else if redistArch == "linux-ppc64le" then
+      "ppc64le-linux"
+    else if redistArch == "windows-x86_64" then
+      "x86_64-windows"
+    else
+      builtins.throw "Unsupported NVIDIA redist arch: ${redistArch}";
+
+  formatCapabilities =
+    {
+      cudaCapabilities,
+      enableForwardCompat ? true,
+    }:
+    rec {
+      inherit cudaCapabilities enableForwardCompat;
+
+      # archNames :: List String
+      # E.g. [ "Turing" "Ampere" ]
+      archNames = lists.unique (
+        lists.map (cap: cudaComputeCapabilityToName.${cap} or (throw "missing cuda compute capability"))
+          cudaCapabilities
+      );
+
+      # realArches :: List String
+      # The real architectures are physical architectures supported by the CUDA version.
+      # E.g. [ "sm_75" "sm_86" ]
+      realArches = archMapper "sm" cudaCapabilities;
+
+      # virtualArches :: List String
+      # The virtual architectures are typically used for forward compatibility, when trying to support
+      # an architecture newer than the CUDA version allows.
+      # E.g. [ "compute_75" "compute_86" ]
+      virtualArches = archMapper "compute" cudaCapabilities;
+
+      # arches :: List String
+      # By default, build for all supported architectures and forward compatibility via a virtual
+      # architecture for the newest supported architecture.
+      # E.g. [ "sm_75" "sm_86" "compute_86" ]
+      arches = realArches ++ lists.optional enableForwardCompat (lists.last virtualArches);
+
+      # gencode :: List String
+      # A list of CUDA gencode arguments to pass to NVCC.
+      # E.g. [ "-gencode=arch=compute_75,code=sm_75" ... "-gencode=arch=compute_86,code=compute_86" ]
+      gencode =
+        let
+          base = gencodeMapper "sm" cudaCapabilities;
+          forward = gencodeMapper "compute" [(lists.last cudaCapabilities)];
+        in
+        base ++ lib.optionals enableForwardCompat forward;
+
+      # gencodeString :: String
+      # A space-separated string of CUDA gencode arguments to pass to NVCC.
+      # E.g. "-gencode=arch=compute_75,code=sm_75 ... -gencode=arch=compute_86,code=compute_86"
+      gencodeString = strings.concatStringsSep " " gencode;
+
+      # Jetson devices cannot be targeted by the same binaries which target non-Jetson devices. While
+      # NVIDIA provides both `linux-aarch64` and `linux-sbsa` packages, which both target `aarch64`,
+      # they are built with different settings and cannot be mixed.
+      # isJetsonBuild :: Boolean
+      isJetsonBuild =
+        let
+          requestedJetsonDevices =
+            lists.filter (cap: cudaComputeCapabilityToIsJetson.${cap})
+              cudaCapabilities;
+          requestedNonJetsonDevices =
+            lists.filter (cap: !(builtins.elem cap requestedJetsonDevices))
+              cudaCapabilities;
+          jetsonBuildSufficientCondition = requestedJetsonDevices != [];
+          jetsonBuildNecessaryCondition = requestedNonJetsonDevices == [] && hostPlatform.isAarch64;
+        in
+        trivial.throwIf (jetsonBuildSufficientCondition && !jetsonBuildNecessaryCondition)
+          ''
+            Jetson devices cannot be targeted with non-Jetson devices. Additionally, they require hostPlatform to be aarch64.
+            You requested ${builtins.toJSON cudaCapabilities} for host platform ${hostPlatform.system}.
+            Requested Jetson devices: ${builtins.toJSON requestedJetsonDevices}.
+            Requested non-Jetson devices: ${builtins.toJSON requestedNonJetsonDevices}.
+            Exactly one of the following must be true:
+            - All CUDA capabilities belong to Jetson devices and hostPlatform is aarch64.
+            - No CUDA capabilities belong to Jetson devices.
+            See ${./gpus.nix} for a list of architectures supported by this version of Nixpkgs.
+          ''
+          jetsonBuildSufficientCondition
+        && jetsonBuildNecessaryCondition;
+    };
+in
+# When changing names or formats: pause, validate, and update the assert
+assert let
+  expected = {
+    cudaCapabilities = [
+      "7.5"
+      "8.6"
+    ];
+    enableForwardCompat = true;
+
+    archNames = [
+      "Turing"
+      "Ampere"
+    ];
+    realArches = [
+      "sm_75"
+      "sm_86"
+    ];
+    virtualArches = [
+      "compute_75"
+      "compute_86"
+    ];
+    arches = [
+      "sm_75"
+      "sm_86"
+      "compute_86"
+    ];
+
+    gencode = [
+      "-gencode=arch=compute_75,code=sm_75"
+      "-gencode=arch=compute_86,code=sm_86"
+      "-gencode=arch=compute_86,code=compute_86"
+    ];
+    gencodeString = "-gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_86,code=compute_86";
+
+    isJetsonBuild = false;
+  };
+  actual = formatCapabilities {
+    cudaCapabilities = [
+      "7.5"
+      "8.6"
+    ];
+  };
+  actualWrapped = (builtins.tryEval (builtins.deepSeq actual actual)).value;
+in
+asserts.assertMsg ((strings.versionAtLeast cudaVersion "11.2") -> (expected == actualWrapped)) ''
+  This test should only fail when using a version of CUDA older than 11.2, the first to support
+  8.6.
+  Expected: ${builtins.toJSON expected}
+  Actual: ${builtins.toJSON actualWrapped}
+'';
+# Check mixed Jetson and non-Jetson devices
+assert let
+  expected = false;
+  actual = formatCapabilities {
+    cudaCapabilities = [
+      "7.2"
+      "7.5"
+    ];
+  };
+  actualWrapped = (builtins.tryEval (builtins.deepSeq actual actual)).value;
+in
+asserts.assertMsg (expected == actualWrapped) ''
+  Jetson devices capabilities cannot be mixed with non-jetson devices.
+  Capability 7.5 is non-Jetson and should not be allowed with Jetson 7.2.
+  Expected: ${builtins.toJSON expected}
+  Actual: ${builtins.toJSON actualWrapped}
+'';
+# Check Jetson-only
+assert let
+  expected = {
+    cudaCapabilities = [
+      "6.2"
+      "7.2"
+    ];
+    enableForwardCompat = true;
+
+    archNames = [
+      "Pascal"
+      "Volta"
+    ];
+    realArches = [
+      "sm_62"
+      "sm_72"
+    ];
+    virtualArches = [
+      "compute_62"
+      "compute_72"
+    ];
+    arches = [
+      "sm_62"
+      "sm_72"
+      "compute_72"
+    ];
+
+    gencode = [
+      "-gencode=arch=compute_62,code=sm_62"
+      "-gencode=arch=compute_72,code=sm_72"
+      "-gencode=arch=compute_72,code=compute_72"
+    ];
+    gencodeString = "-gencode=arch=compute_62,code=sm_62 -gencode=arch=compute_72,code=sm_72 -gencode=arch=compute_72,code=compute_72";
+
+    isJetsonBuild = true;
+  };
+  actual = formatCapabilities {
+    cudaCapabilities = [
+      "6.2"
+      "7.2"
+    ];
+  };
+  actualWrapped = (builtins.tryEval (builtins.deepSeq actual actual)).value;
+in
+asserts.assertMsg
+  # We can't do this test unless we're targeting aarch64
+  (hostPlatform.isAarch64 -> (expected == actualWrapped))
+  ''
+    Jetson devices can only be built with other Jetson devices.
+    Both 6.2 and 7.2 are Jetson devices.
+    Expected: ${builtins.toJSON expected}
+    Actual: ${builtins.toJSON actualWrapped}
+  '';
+{
+  # formatCapabilities :: { cudaCapabilities: List Capability, enableForwardCompat: Boolean } ->  { ... }
+  inherit formatCapabilities;
+
+  # cudaArchNameToVersions :: String => String
+  inherit cudaArchNameToVersions;
+
+  # cudaComputeCapabilityToName :: String => String
+  inherit cudaComputeCapabilityToName;
+
+  # dropDot :: String -> String
+  inherit dropDot;
+
+  inherit
+    defaultCapabilities
+    supportedCapabilities
+    jetsonComputeCapabilities
+    jetsonTargets
+    getNixSystem
+    getRedistArch
+    ;
+}
+// formatCapabilities {
+  cudaCapabilities = if cudaCapabilities == [] then defaultCapabilities else cudaCapabilities;
+  enableForwardCompat = cudaForwardCompat;
+}