about summary refs log tree commit diff
path: root/nixpkgs/pkgs/development/python-modules/rlax/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'nixpkgs/pkgs/development/python-modules/rlax/default.nix')
-rw-r--r--nixpkgs/pkgs/development/python-modules/rlax/default.nix44
1 files changed, 30 insertions, 14 deletions
diff --git a/nixpkgs/pkgs/development/python-modules/rlax/default.nix b/nixpkgs/pkgs/development/python-modules/rlax/default.nix
index c73433e146a1..ceb8e9758619 100644
--- a/nixpkgs/pkgs/development/python-modules/rlax/default.nix
+++ b/nixpkgs/pkgs/development/python-modules/rlax/default.nix
@@ -1,40 +1,56 @@
 { lib
-, fetchPypi
 , buildPythonPackage
+, fetchFromGitHub
+, fetchpatch
+, absl-py
 , chex
+, distrax
+, dm-env
+, jax
 , jaxlib
+, numpy
 , tensorflow-probability
-, optax
 , dm-haiku
-, bsuite
-, frozendict
+, optax
+, pytest-xdist
 , pytestCheckHook
-, dm-env
-, distrax }:
+}:
 
 buildPythonPackage rec {
   pname = "rlax";
   version = "0.1.6";
   format = "setuptools";
 
-  src = fetchPypi {
-    inherit pname version;
-    hash = "sha256-C3nFOv/zxvAoz6WZ0RAZffzEbxIx/XrGabO4QPxrik8=";
+  src = fetchFromGitHub {
+    owner = "google-deepmind";
+    repo = "rlax";
+    rev = "refs/tags/v${version}";
+    hash = "sha256-v2Lbzya+E9d7tlUVlQQa4fuPp2q3E309Qvyt70mcdb0=";
   };
 
-  buildInputs = [
+  patches = [
+    (fetchpatch {  # Follow chex API change (https://github.com/google-deepmind/chex/pull/52)
+      name = "replace-deprecated-chex-assertions";
+      url = "https://github.com/google-deepmind/rlax/commit/30e7913a1102667137654d6e652a6c4b9e9ba1f4.patch";
+      hash = "sha256-OPnuTKEtwZ28hzR1660v3DcktxTYjhR1xYvFbQvOhgs=";
+    })
+  ];
+
+  propagatedBuildInputs = [
+    absl-py
     chex
-    jaxlib
     distrax
+    dm-env
+    jax
+    jaxlib
+    numpy
     tensorflow-probability
   ];
 
   nativeCheckInputs = [
-    bsuite
-    dm-env
     dm-haiku
-    frozendict
     optax
+    pytest-xdist
     pytestCheckHook
   ];