diff options
author | Alyssa Ross <hi@alyssa.is> | 2023-10-20 22:09:03 +0000 |
---|---|---|
committer | Alyssa Ross <hi@alyssa.is> | 2023-10-20 22:09:03 +0000 |
commit | 50c21d167f7114fa1dbd95e5c4fb30eeb1a2d02e (patch) | |
tree | f2556b911180125ccbb7ed0e78a54e92da89adce /nixpkgs/pkgs/development/python-modules/jax/default.nix | |
parent | 4c16d4548a98563c9d9ad76f4e5b2202864ccd54 (diff) | |
parent | cfc75eec4603c06503ae750f88cf397e00796ea8 (diff) | |
download | nixlib-50c21d167f7114fa1dbd95e5c4fb30eeb1a2d02e.tar nixlib-50c21d167f7114fa1dbd95e5c4fb30eeb1a2d02e.tar.gz nixlib-50c21d167f7114fa1dbd95e5c4fb30eeb1a2d02e.tar.bz2 nixlib-50c21d167f7114fa1dbd95e5c4fb30eeb1a2d02e.tar.lz nixlib-50c21d167f7114fa1dbd95e5c4fb30eeb1a2d02e.tar.xz nixlib-50c21d167f7114fa1dbd95e5c4fb30eeb1a2d02e.tar.zst nixlib-50c21d167f7114fa1dbd95e5c4fb30eeb1a2d02e.zip |
Merge commit 'cfc75eec4603c06503ae750f88cf397e00796ea8'
Conflicts: nixpkgs/pkgs/build-support/rust/build-rust-package/default.nix
Diffstat (limited to 'nixpkgs/pkgs/development/python-modules/jax/default.nix')
-rw-r--r-- | nixpkgs/pkgs/development/python-modules/jax/default.nix | 22 |
1 files changed, 18 insertions, 4 deletions
diff --git a/nixpkgs/pkgs/development/python-modules/jax/default.nix b/nixpkgs/pkgs/development/python-modules/jax/default.nix index a92148ada6bb..d9293e073480 100644 --- a/nixpkgs/pkgs/development/python-modules/jax/default.nix +++ b/nixpkgs/pkgs/development/python-modules/jax/default.nix @@ -12,6 +12,7 @@ , numpy , opt-einsum , pytestCheckHook +, pytest-xdist , pythonOlder , scipy , stdenv @@ -26,17 +27,17 @@ let in buildPythonPackage rec { pname = "jax"; - version = "0.4.16"; - format = "pyproject"; + version = "0.4.19"; + pyproject = true; disabled = pythonOlder "3.9"; src = fetchFromGitHub { owner = "google"; - repo = pname; + repo = "jax"; # google/jax contains tags for jax and jaxlib. Only use jax tags! rev = "refs/tags/${pname}-v${version}"; - hash = "sha256-q+8CXGxK8JX0bUMK4KJB3qV/EaLHg68D1B5UrtRz0Eg="; + hash = "sha256-l5uLPqhg/hqtO9oJSaioow5cH/0jKHDVziGezkfnVcc="; }; nativeBuildInputs = [ @@ -61,13 +62,18 @@ buildPythonPackage rec { jaxlib' matplotlib pytestCheckHook + pytest-xdist ]; + # high parallelism will result in the tests getting stuck + dontUsePytestXdist = true; + # NOTE: Don't run the tests in the expiremental directory as they require flax # which creates a circular dependency. See https://discourse.nixos.org/t/how-to-nix-ify-python-packages-with-circular-dependencies/14648/2. # Not a big deal, this is how the JAX docs suggest running the test suite # anyhow. pytestFlagsArray = [ + "--numprocesses=4" "-W ignore::DeprecationWarning" "tests/" ]; @@ -94,6 +100,14 @@ buildPythonPackage rec { "test_for_loop_fixpoint_correctly_identifies_loop_varying_residuals_unrolled_for_loop" "testQdwhWithRandomMatrix3" "testScanGrad_jit_scan" + + # See https://github.com/google/jax/issues/17867. + "test_array" + "test_async" + "test_copy0" + "test_device_put" + "test_make_array_from_callback" + "test_make_array_from_single_device_arrays" ]; disabledTestPaths = lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [ |