about summary refs log tree commit diff
path: root/nixpkgs/pkgs/development/cuda-modules/nccl/default.nix
blob: 25296c21365d1caa984b9d481b02aff5e1ea4c10 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# NOTE: Though NCCL is called within the cudaPackages package set, we avoid passing in
# the names of dependencies from that package set directly to avoid evaluation errors
# in the case redistributable packages are not available.
{
  lib,
  fetchFromGitHub,
  python3,
  which,
  cudaPackages,
  # passthru.updateScript
  gitUpdater,
}:
let
  inherit (cudaPackages)
    autoAddOpenGLRunpathHook
    backendStdenv
    cuda_cccl
    cuda_cudart
    cuda_nvcc
    cudaFlags
    cudatoolkit
    cudaVersion
    ;
in
backendStdenv.mkDerivation (
  finalAttrs: {
    pname = "nccl";
    version = "2.20.3-1";

    src = fetchFromGitHub {
      owner = "NVIDIA";
      repo = finalAttrs.pname;
      rev = "v${finalAttrs.version}";
      hash = "sha256-7gI1q6uN3saz/twwLjWl7XmMucYjvClDPDdbVpVM0vU=";
    };

    strictDeps = true;

    outputs = [
      "out"
      "dev"
    ];

    nativeBuildInputs =
      [
        which
        autoAddOpenGLRunpathHook
        python3
      ]
      ++ lib.optionals (lib.versionOlder cudaVersion "11.4") [cudatoolkit]
      ++ lib.optionals (lib.versionAtLeast cudaVersion "11.4") [cuda_nvcc];

    buildInputs =
      lib.optionals (lib.versionOlder cudaVersion "11.4") [cudatoolkit]
      ++ lib.optionals (lib.versionAtLeast cudaVersion "11.4") [
        cuda_nvcc.dev # crt/host_config.h
        cuda_cudart
      ]
      # NOTE: CUDA versions in Nixpkgs only use a major and minor version. When we do comparisons
      # against other version, like below, it's important that we use the same format. Otherwise,
      # we'll get incorrect results.
      # For example, lib.versionAtLeast "12.0" "12.0.0" == false.
      ++ lib.optionals (lib.versionAtLeast cudaVersion "12.0") [cuda_cccl];

    env.NIX_CFLAGS_COMPILE = toString ["-Wno-unused-function"];

    preConfigure = ''
      patchShebangs ./src/device/generate.py
      makeFlagsArray+=(
        "NVCC_GENCODE=${lib.concatStringsSep " " cudaFlags.gencode}"
      )
    '';

    makeFlags =
      ["PREFIX=$(out)"]
      ++ lib.optionals (lib.versionOlder cudaVersion "11.4") [
        "CUDA_HOME=${cudatoolkit}"
        "CUDA_LIB=${lib.getLib cudatoolkit}/lib"
        "CUDA_INC=${lib.getDev cudatoolkit}/include"
      ]
      ++ lib.optionals (lib.versionAtLeast cudaVersion "11.4") [
        "CUDA_HOME=${cuda_nvcc}"
        "CUDA_LIB=${lib.getLib cuda_cudart}/lib"
        "CUDA_INC=${lib.getDev cuda_cudart}/include"
      ];

    enableParallelBuilding = true;

    postFixup = ''
      moveToOutput lib/libnccl_static.a $dev
    '';

    passthru.updateScript = gitUpdater {
      inherit (finalAttrs) pname version;
      rev-prefix = "v";
    };

    meta = with lib; {
      description = "Multi-GPU and multi-node collective communication primitives for NVIDIA GPUs";
      homepage = "https://developer.nvidia.com/nccl";
      license = licenses.bsd3;
      platforms = platforms.linux;
      # NCCL is not supported on Jetson, because it does not use NVLink or PCI-e for inter-GPU communication.
      # https://forums.developer.nvidia.com/t/can-jetson-orin-support-nccl/232845/9
      badPlatforms = lib.optionals cudaFlags.isJetsonBuild [ "aarch64-linux" ];
      maintainers =
        with maintainers;
        [
          mdaiter
          orivej
        ]
        ++ teams.cuda.members;
    };
  }
)