diff options
Diffstat (limited to 'nixpkgs/pkgs/development/python-modules/rlax/default.nix')
-rw-r--r-- | nixpkgs/pkgs/development/python-modules/rlax/default.nix | 44 |
1 files changed, 30 insertions, 14 deletions
diff --git a/nixpkgs/pkgs/development/python-modules/rlax/default.nix b/nixpkgs/pkgs/development/python-modules/rlax/default.nix index c73433e146a1..ceb8e9758619 100644 --- a/nixpkgs/pkgs/development/python-modules/rlax/default.nix +++ b/nixpkgs/pkgs/development/python-modules/rlax/default.nix @@ -1,40 +1,56 @@ { lib -, fetchPypi , buildPythonPackage +, fetchFromGitHub +, fetchpatch +, absl-py , chex +, distrax +, dm-env +, jax , jaxlib +, numpy , tensorflow-probability -, optax , dm-haiku -, bsuite -, frozendict +, optax +, pytest-xdist , pytestCheckHook -, dm-env -, distrax }: +}: buildPythonPackage rec { pname = "rlax"; version = "0.1.6"; format = "setuptools"; - src = fetchPypi { - inherit pname version; - hash = "sha256-C3nFOv/zxvAoz6WZ0RAZffzEbxIx/XrGabO4QPxrik8="; + src = fetchFromGitHub { + owner = "google-deepmind"; + repo = "rlax"; + rev = "refs/tags/v${version}"; + hash = "sha256-v2Lbzya+E9d7tlUVlQQa4fuPp2q3E309Qvyt70mcdb0="; }; - buildInputs = [ + patches = [ + (fetchpatch { # Follow chex API change (https://github.com/google-deepmind/chex/pull/52) + name = "replace-deprecated-chex-assertions"; + url = "https://github.com/google-deepmind/rlax/commit/30e7913a1102667137654d6e652a6c4b9e9ba1f4.patch"; + hash = "sha256-OPnuTKEtwZ28hzR1660v3DcktxTYjhR1xYvFbQvOhgs="; + }) + ]; + + propagatedBuildInputs = [ + absl-py chex - jaxlib distrax + dm-env + jax + jaxlib + numpy tensorflow-probability ]; nativeCheckInputs = [ - bsuite - dm-env dm-haiku - frozendict optax + pytest-xdist pytestCheckHook ]; |