about summary refs log tree commit diff
path: root/nixpkgs/pkgs/development/python-modules/xformers/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'nixpkgs/pkgs/development/python-modules/xformers/default.nix')
-rw-r--r--nixpkgs/pkgs/development/python-modules/xformers/default.nix22
1 files changed, 20 insertions, 2 deletions
diff --git a/nixpkgs/pkgs/development/python-modules/xformers/default.nix b/nixpkgs/pkgs/development/python-modules/xformers/default.nix
index 164513da94c8..c909559ca59a 100644
--- a/nixpkgs/pkgs/development/python-modules/xformers/default.nix
+++ b/nixpkgs/pkgs/development/python-modules/xformers/default.nix
@@ -25,7 +25,8 @@
 #, flash-attn
 }:
 let
-  version = "0.03";
+  inherit (torch) cudaCapabilities cudaPackages cudaSupport;
+  version = "0.0.23.post1";
 in
 buildPythonPackage {
   pname = "xformers";
@@ -38,17 +39,34 @@ buildPythonPackage {
     owner = "facebookresearch";
     repo = "xformers";
     rev = "refs/tags/v${version}";
-    hash = "sha256-G8f7tny5B8SAQ6+2uOjhY7nD0uOT4sskIwtTdwivQXo=";
+    hash = "sha256-AJXow8MmX4GxtEE2jJJ/ZIBr+3i+uS4cA6vofb390rY=";
     fetchSubmodules = true;
   };
 
+  patches = [
+    ./0001-fix-allow-building-without-git.patch
+  ];
+
   preBuild = ''
     cat << EOF > ./xformers/version.py
     # noqa: C801
     __version__ = "${version}"
     EOF
+  '' + lib.optionalString cudaSupport ''
+    export CUDA_HOME=${cudaPackages.cuda_nvcc}
+    export TORCH_CUDA_ARCH_LIST="${lib.concatStringsSep ";" cudaCapabilities}"
   '';
 
+  buildInputs = lib.optionals cudaSupport (with cudaPackages; [
+    # flash-attn build
+    cuda_cudart # cuda_runtime_api.h
+    libcusparse.dev # cusparse.h
+    cuda_cccl.dev # nv/target
+    libcublas.dev # cublas_v2.h
+    libcusolver.dev # cusolverDn.h
+    libcurand.dev # curand_kernel.h
+  ]);
+
   nativeBuildInputs = [
     which
   ];