diff options
Diffstat (limited to 'nixpkgs/pkgs/development/python-modules/dm-haiku/default.nix')
-rw-r--r-- | nixpkgs/pkgs/development/python-modules/dm-haiku/default.nix | 101 |
1 files changed, 73 insertions, 28 deletions
diff --git a/nixpkgs/pkgs/development/python-modules/dm-haiku/default.nix b/nixpkgs/pkgs/development/python-modules/dm-haiku/default.nix index 08c1716867a7..cb97e2f837af 100644 --- a/nixpkgs/pkgs/development/python-modules/dm-haiku/default.nix +++ b/nixpkgs/pkgs/development/python-modules/dm-haiku/default.nix @@ -1,32 +1,39 @@ -{ buildPythonPackage +{ lib +, buildPythonPackage , fetchFromGitHub , fetchpatch -, callPackage -, lib +, absl-py +, flax +, jaxlib , jmp +, numpy , tabulate -, jaxlib +, pytest-xdist +, pytestCheckHook +, bsuite +, chex +, cloudpickle +, dill +, dm-env +, dm-tree +, optax +, rlax +, tensorflow }: -buildPythonPackage rec { +let dm-haiku = buildPythonPackage rec { pname = "dm-haiku"; - version = "0.0.10"; + version = "0.0.11"; format = "setuptools"; src = fetchFromGitHub { owner = "deepmind"; - repo = pname; + repo = "dm-haiku"; rev = "refs/tags/v${version}"; - hash = "sha256-EZx3o6PgTeFjTwI9Ko9H39EqPSE0yLWWpsdqX6ALlo4="; + hash = "sha256-xve1vNsVOC6/HVtzmzswM/Sk3uUNaTtqNAKheFb/tmI="; }; patches = [ - # https://github.com/deepmind/dm-haiku/issues/717 - (fetchpatch { - name = "remove-typing-extensions.patch"; - url = "https://github.com/deepmind/dm-haiku/commit/c22867db1a3314a382bd2ce36511e2b756dc32a8.patch"; - hash = "sha256-SxJc8FrImwMqTJ5OuJ1f4T+HfHgW/sGqXeIqlxEatlE="; - }) # https://github.com/deepmind/dm-haiku/pull/672 (fetchpatch { name = "fix-find-namespace-packages.patch"; @@ -35,14 +42,12 @@ buildPythonPackage rec { }) ]; - outputs = [ - "out" - "testsout" - ]; - propagatedBuildInputs = [ + absl-py + flax jaxlib jmp + numpy tabulate ]; @@ -50,17 +55,56 @@ buildPythonPackage rec { "haiku" ]; - postInstall = '' - mkdir $testsout - cp -R examples $testsout/examples - ''; + nativeCheckInputs = [ + bsuite + chex + cloudpickle + dill + dm-env + dm-haiku + dm-tree + jaxlib + optax + pytest-xdist + pytestCheckHook + rlax + tensorflow + ]; + + disabledTests = [ + # See https://github.com/deepmind/dm-haiku/issues/366. + "test_jit_Recurrent" + + # Assertion errors + "testShapeChecking0" + "testShapeChecking1" + + # This test requires a more recent version of tensorflow. The current one (2.13) is not enough. + "test_reshape_convert" + + # This test requires JAX support for double precision (64bit), but enabling this causes several + # other tests to fail. + # https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision + "test_doctest_haiku.experimental" + ]; + + disabledTestPaths = [ + # Those tests requires a more recent version of tensorflow. The current one (2.13) is not enough. + "haiku/_src/integration/jax2tf_test.py" + ]; - # check in passthru.tests.pytest to escape infinite recursion with bsuite doCheck = false; - passthru.tests = { - pytest = callPackage ./tests.nix { }; - }; + # check in passthru.tests.pytest to escape infinite recursion with bsuite + passthru.tests.pytest = dm-haiku.overridePythonAttrs (_: { + pname = "${pname}-tests"; + doCheck = true; + + # We don't have to install because the only purpose + # of this passthru test is to, well, test. + # This fixes having to set `catchConflicts` to false. + dontInstall = true; + }); meta = with lib; { description = "Haiku is a simple neural network library for JAX developed by some of the authors of Sonnet."; @@ -68,4 +112,5 @@ buildPythonPackage rec { license = licenses.asl20; maintainers = with maintainers; [ ndl ]; }; -} +}; +in dm-haiku |