about summary refs log tree commit diff
path: root/nixpkgs/pkgs/development/python-modules/jax/default.nix
diff options
context:
space:
mode:
authorAlyssa Ross <hi@alyssa.is>2023-10-20 22:09:03 +0000
committerAlyssa Ross <hi@alyssa.is>2023-10-20 22:09:03 +0000
commit50c21d167f7114fa1dbd95e5c4fb30eeb1a2d02e (patch)
treef2556b911180125ccbb7ed0e78a54e92da89adce /nixpkgs/pkgs/development/python-modules/jax/default.nix
parent4c16d4548a98563c9d9ad76f4e5b2202864ccd54 (diff)
parentcfc75eec4603c06503ae750f88cf397e00796ea8 (diff)
downloadnixlib-50c21d167f7114fa1dbd95e5c4fb30eeb1a2d02e.tar
nixlib-50c21d167f7114fa1dbd95e5c4fb30eeb1a2d02e.tar.gz
nixlib-50c21d167f7114fa1dbd95e5c4fb30eeb1a2d02e.tar.bz2
nixlib-50c21d167f7114fa1dbd95e5c4fb30eeb1a2d02e.tar.lz
nixlib-50c21d167f7114fa1dbd95e5c4fb30eeb1a2d02e.tar.xz
nixlib-50c21d167f7114fa1dbd95e5c4fb30eeb1a2d02e.tar.zst
nixlib-50c21d167f7114fa1dbd95e5c4fb30eeb1a2d02e.zip
Merge commit 'cfc75eec4603c06503ae750f88cf397e00796ea8'
Conflicts:
	nixpkgs/pkgs/build-support/rust/build-rust-package/default.nix
Diffstat (limited to 'nixpkgs/pkgs/development/python-modules/jax/default.nix')
-rw-r--r--nixpkgs/pkgs/development/python-modules/jax/default.nix22
1 files changed, 18 insertions, 4 deletions
diff --git a/nixpkgs/pkgs/development/python-modules/jax/default.nix b/nixpkgs/pkgs/development/python-modules/jax/default.nix
index a92148ada6bb..d9293e073480 100644
--- a/nixpkgs/pkgs/development/python-modules/jax/default.nix
+++ b/nixpkgs/pkgs/development/python-modules/jax/default.nix
@@ -12,6 +12,7 @@
 , numpy
 , opt-einsum
 , pytestCheckHook
+, pytest-xdist
 , pythonOlder
 , scipy
 , stdenv
@@ -26,17 +27,17 @@ let
 in
 buildPythonPackage rec {
   pname = "jax";
-  version = "0.4.16";
-  format = "pyproject";
+  version = "0.4.19";
+  pyproject = true;
 
   disabled = pythonOlder "3.9";
 
   src = fetchFromGitHub {
     owner = "google";
-    repo = pname;
+    repo = "jax";
     # google/jax contains tags for jax and jaxlib. Only use jax tags!
     rev = "refs/tags/${pname}-v${version}";
-    hash = "sha256-q+8CXGxK8JX0bUMK4KJB3qV/EaLHg68D1B5UrtRz0Eg=";
+    hash = "sha256-l5uLPqhg/hqtO9oJSaioow5cH/0jKHDVziGezkfnVcc=";
   };
 
   nativeBuildInputs = [
@@ -61,13 +62,18 @@ buildPythonPackage rec {
     jaxlib'
     matplotlib
     pytestCheckHook
+    pytest-xdist
   ];
 
+  # high parallelism will result in the tests getting stuck
+  dontUsePytestXdist = true;
+
   # NOTE: Don't run the tests in the expiremental directory as they require flax
   # which creates a circular dependency. See https://discourse.nixos.org/t/how-to-nix-ify-python-packages-with-circular-dependencies/14648/2.
   # Not a big deal, this is how the JAX docs suggest running the test suite
   # anyhow.
   pytestFlagsArray = [
+    "--numprocesses=4"
     "-W ignore::DeprecationWarning"
     "tests/"
   ];
@@ -94,6 +100,14 @@ buildPythonPackage rec {
     "test_for_loop_fixpoint_correctly_identifies_loop_varying_residuals_unrolled_for_loop"
     "testQdwhWithRandomMatrix3"
     "testScanGrad_jit_scan"
+
+    # See https://github.com/google/jax/issues/17867.
+    "test_array"
+    "test_async"
+    "test_copy0"
+    "test_device_put"
+    "test_make_array_from_callback"
+    "test_make_array_from_single_device_arrays"
   ];
 
   disabledTestPaths = lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [