about summary refs log tree commit diff
path: root/nixpkgs/pkgs/development/libraries/science/math/nccl/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'nixpkgs/pkgs/development/libraries/science/math/nccl/default.nix')
-rw-r--r--nixpkgs/pkgs/development/libraries/science/math/nccl/default.nix40
1 files changed, 19 insertions, 21 deletions
diff --git a/nixpkgs/pkgs/development/libraries/science/math/nccl/default.nix b/nixpkgs/pkgs/development/libraries/science/math/nccl/default.nix
index 155e863bf21e..2eb391dda46b 100644
--- a/nixpkgs/pkgs/development/libraries/science/math/nccl/default.nix
+++ b/nixpkgs/pkgs/development/libraries/science/math/nccl/default.nix
@@ -2,24 +2,25 @@
 , backendStdenv
 , fetchFromGitHub
 , which
-, cudaPackages ? { }
-, addOpenGLRunpath
+, autoAddOpenGLRunpathHook
+, cuda_cccl
+, cuda_cudart
+, cuda_nvcc
+, cudaFlags
+, cudaVersion
 }:
-
-with cudaPackages;
-
 let
   # Output looks like "-gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_86,code=compute_86"
   gencode = lib.concatStringsSep " " cudaFlags.gencode;
 in
-backendStdenv.mkDerivation rec {
-  name = "nccl-${version}-cuda-${cudaPackages.cudaMajorVersion}";
+backendStdenv.mkDerivation (finalAttrs: {
+  name = "nccl-${finalAttrs.version}-cuda-${cudaVersion}";
   version = "2.16.5-1";
 
   src = fetchFromGitHub {
     owner = "NVIDIA";
     repo = "nccl";
-    rev = "v${version}";
+    rev = "v${finalAttrs.version}";
     hash = "sha256-JyhhYKSVIqUKIbC1rCJozPT1IrIyRLGrTjdPjJqsYaU=";
   };
 
@@ -27,13 +28,18 @@ backendStdenv.mkDerivation rec {
 
   nativeBuildInputs = [
     which
-    addOpenGLRunpath
+    autoAddOpenGLRunpathHook
     cuda_nvcc
   ];
 
   buildInputs = [
     cuda_cudart
-  ] ++ lib.optionals (lib.versionAtLeast cudaVersion "12.0.0") [
+  ]
+  # NOTE: CUDA versions in Nixpkgs only use a major and minor version. When we do comparisons
+  # against other version, like below, it's important that we use the same format. Otherwise,
+  # we'll get incorrect results.
+  # For example, lib.versionAtLeast "12.0" "12.0.0" == false.
+  ++ lib.optionals (lib.versionAtLeast cudaVersion "12.0") [
     cuda_cccl
   ];
 
@@ -46,27 +52,19 @@ backendStdenv.mkDerivation rec {
 
   makeFlags = [
     "CUDA_HOME=${cuda_nvcc}"
-    "CUDA_LIB=${cuda_cudart}/lib64"
-    "CUDA_INC=${cuda_cudart}/include"
+    "CUDA_LIB=${lib.getLib cuda_cudart}/lib"
+    "CUDA_INC=${lib.getDev cuda_cudart}/include"
     "PREFIX=$(out)"
   ];
 
   postFixup = ''
     moveToOutput lib/libnccl_static.a $dev
-
-    # Set RUNPATH so that libnvidia-ml in /run/opengl-driver(-32)/lib can be found.
-    # See the explanation in addOpenGLRunpath.
-    addOpenGLRunpath $out/lib/lib*.so
   '';
 
   env.NIX_CFLAGS_COMPILE = toString [ "-Wno-unused-function" ];
 
   enableParallelBuilding = true;
 
-  passthru = {
-    inherit cudaPackages;
-  };
-
   meta = with lib; {
     description = "Multi-GPU and multi-node collective communication primitives for NVIDIA GPUs";
     homepage = "https://developer.nvidia.com/nccl";
@@ -74,4 +72,4 @@ backendStdenv.mkDerivation rec {
     platforms = [ "x86_64-linux" ];
     maintainers = with maintainers; [ mdaiter orivej ];
   };
-}
+})