diff options
Diffstat (limited to 'nixpkgs/pkgs/development/python-modules/xformers/default.nix')
-rw-r--r-- | nixpkgs/pkgs/development/python-modules/xformers/default.nix | 22 |
1 files changed, 20 insertions, 2 deletions
diff --git a/nixpkgs/pkgs/development/python-modules/xformers/default.nix b/nixpkgs/pkgs/development/python-modules/xformers/default.nix index 164513da94c8..c909559ca59a 100644 --- a/nixpkgs/pkgs/development/python-modules/xformers/default.nix +++ b/nixpkgs/pkgs/development/python-modules/xformers/default.nix @@ -25,7 +25,8 @@ #, flash-attn }: let - version = "0.03"; + inherit (torch) cudaCapabilities cudaPackages cudaSupport; + version = "0.0.23.post1"; in buildPythonPackage { pname = "xformers"; @@ -38,17 +39,34 @@ buildPythonPackage { owner = "facebookresearch"; repo = "xformers"; rev = "refs/tags/v${version}"; - hash = "sha256-G8f7tny5B8SAQ6+2uOjhY7nD0uOT4sskIwtTdwivQXo="; + hash = "sha256-AJXow8MmX4GxtEE2jJJ/ZIBr+3i+uS4cA6vofb390rY="; fetchSubmodules = true; }; + patches = [ + ./0001-fix-allow-building-without-git.patch + ]; + preBuild = '' cat << EOF > ./xformers/version.py # noqa: C801 __version__ = "${version}" EOF + '' + lib.optionalString cudaSupport '' + export CUDA_HOME=${cudaPackages.cuda_nvcc} + export TORCH_CUDA_ARCH_LIST="${lib.concatStringsSep ";" cudaCapabilities}" ''; + buildInputs = lib.optionals cudaSupport (with cudaPackages; [ + # flash-attn build + cuda_cudart # cuda_runtime_api.h + libcusparse.dev # cusparse.h + cuda_cccl.dev # nv/target + libcublas.dev # cublas_v2.h + libcusolver.dev # cusolverDn.h + libcurand.dev # curand_kernel.h + ]); + nativeBuildInputs = [ which ]; |