about summary refs log tree commit diff
path: root/nixpkgs/pkgs/development/cuda-modules/generic-builders/multiplex.nix
blob: b8053094bcc82579e313b044d7f0ef8108eaf321 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
{
  # callPackage-provided arguments
  lib,
  cudaVersion,
  flags,
  hostPlatform,
  # Expected to be passed by the caller
  mkVersionedPackageName,
  # pname :: String
  pname,
  # releasesModule :: Path
  # A path to a module which provides a `releases` attribute
  releasesModule,
  # shims :: Path
  # A path to a module which provides a `shims` attribute
  # The redistribRelease is only used in ./manifest.nix for the package version
  # and the package description (which NVIDIA's manifest calls the "name").
  # It's also used for fetching the source, but we override that since we can't
  # re-use that portion of the functionality (different URLs, etc.).
  # The featureRelease is used to populate meta.platforms (by way of looking at the attribute names)
  # and to determine the outputs of the package.
  # shimFn :: {package, redistArch} -> AttrSet
  shimsFn ? ({package, redistArch}: throw "shimsFn must be provided"),
  # fixupFn :: Path
  # A path (or nix expression) to be evaluated with callPackage and then
  # provided to the package's overrideAttrs function.
  # It must accept at least the following arguments:
  # - final
  # - cudaVersion
  # - mkVersionedPackageName
  # - package
  fixupFn ? (
    {
      final,
      cudaVersion,
      mkVersionedPackageName,
      package,
      ...
    }:
    throw "fixupFn must be provided"
  ),
}:
let
  inherit (lib)
    attrsets
    lists
    modules
    strings
    ;

  evaluatedModules = modules.evalModules {
    modules = [
      ../modules
      releasesModule
    ];
  };

  # NOTE: Important types:
  # - Releases: ../modules/${pname}/releases/releases.nix
  # - Package: ../modules/${pname}/releases/package.nix

  # All releases across all platforms
  # See ../modules/${pname}/releases/releases.nix
  allReleases = evaluatedModules.config.${pname}.releases;

  # Compute versioned attribute name to be used in this package set
  # Patch version changes should not break the build, so we only use major and minor
  # computeName :: Package -> String
  computeName = {version, ...}: mkVersionedPackageName pname version;

  # Check whether a package supports our CUDA version
  # isSupported :: Package -> Bool
  isSupported =
    package:
    strings.versionAtLeast cudaVersion package.minCudaVersion
    && strings.versionAtLeast package.maxCudaVersion cudaVersion;

  # Get all of the packages for our given platform.
  redistArch = flags.getRedistArch hostPlatform.system;

  # All the supported packages we can build for our platform.
  # supportedPackages :: List (AttrSet Packages)
  supportedPackages = builtins.filter isSupported (allReleases.${redistArch} or []);

  # newestToOldestSupportedPackage :: List (AttrSet Packages)
  newestToOldestSupportedPackage = lists.reverseList supportedPackages;

  nameOfNewest = computeName (builtins.head newestToOldestSupportedPackage);

  # A function which takes the `final` overlay and the `package` being built and returns
  # a function to be consumed via `overrideAttrs`.
  overrideAttrsFixupFn =
    final: package:
    final.callPackage fixupFn {
      inherit
        final
        cudaVersion
        mkVersionedPackageName
        package
        ;
    };

  extension =
    final: _:
    let
      # Builds our package into derivation and wraps it in a nameValuePair, where the name is the versioned name
      # of the package.
      buildPackage =
        package:
        let
          shims = final.callPackage shimsFn {inherit package redistArch;};
          name = computeName package;
          drv = final.callPackage ./manifest.nix {
            inherit pname;
            redistName = pname;
            inherit (shims) redistribRelease featureRelease;
          };
          fixedDrv = drv.overrideAttrs (overrideAttrsFixupFn final package);
        in
        attrsets.nameValuePair name fixedDrv;

      # versionedDerivations :: AttrSet Derivation
      versionedDerivations = builtins.listToAttrs (lists.map buildPackage newestToOldestSupportedPackage);

      defaultDerivation = attrsets.optionalAttrs (versionedDerivations != {}) {
        ${pname} = versionedDerivations.${nameOfNewest};
      };
    in
    versionedDerivations // defaultDerivation;
in
extension