about summary refs log tree commit diff
path: root/pkgs/development
diff options
context:
space:
mode:
authorSomeone <sergei.kozlukov@aalto.fi>2024-01-25 11:33:59 +0000
committerGitHub <noreply@github.com>2024-01-25 11:33:59 +0000
commitb63791c6e5ca0aa4d9637450a7555a1a5273ce1f (patch)
treeba57d24e40866417b74193550ac5b3092d926dcb /pkgs/development
parent39b0be0d82e2fa40ec181615c0dbff4e77914e41 (diff)
parentc5ac684ad8eff90e45f1459844604114c121d096 (diff)
downloadnixlib-b63791c6e5ca0aa4d9637450a7555a1a5273ce1f.tar
nixlib-b63791c6e5ca0aa4d9637450a7555a1a5273ce1f.tar.gz
nixlib-b63791c6e5ca0aa4d9637450a7555a1a5273ce1f.tar.bz2
nixlib-b63791c6e5ca0aa4d9637450a7555a1a5273ce1f.tar.lz
nixlib-b63791c6e5ca0aa4d9637450a7555a1a5273ce1f.tar.xz
nixlib-b63791c6e5ca0aa4d9637450a7555a1a5273ce1f.tar.zst
nixlib-b63791c6e5ca0aa4d9637450a7555a1a5273ce1f.zip
Merge pull request #275195 from GaetanLepage/torchrl
python311Packages.torchrl: init at 0.2.1
Diffstat (limited to 'pkgs/development')
-rw-r--r--pkgs/development/python-modules/ale-py/default.nix23
-rw-r--r--pkgs/development/python-modules/ale-py/patch-sha-check-in-setup.patch17
-rw-r--r--pkgs/development/python-modules/tensordict/default.nix63
-rw-r--r--pkgs/development/python-modules/torchrl/default.nix154
4 files changed, 252 insertions, 5 deletions
diff --git a/pkgs/development/python-modules/ale-py/default.nix b/pkgs/development/python-modules/ale-py/default.nix
index 77978654e68f..9cc5f6105cf5 100644
--- a/pkgs/development/python-modules/ale-py/default.nix
+++ b/pkgs/development/python-modules/ale-py/default.nix
@@ -2,7 +2,7 @@
 , SDL2
 , cmake
 , fetchFromGitHub
-, git
+, fetchpatch
 , gym
 , importlib-metadata
 , importlib-resources
@@ -11,7 +11,6 @@
 , numpy
 , pybind11
 , pytestCheckHook
-, python
 , pythonOlder
 , setuptools
 , stdenv
@@ -23,10 +22,10 @@
 buildPythonPackage rec {
   pname = "ale-py";
   version = "0.8.1";
-  format = "pyproject";
+  pyproject = true;
 
   src = fetchFromGitHub {
-    owner = "mgbellemare";
+    owner = "Farama-Foundation";
     repo = "Arcade-Learning-Environment";
     rev = "refs/tags/v${version}";
     hash = "sha256-B2AxhlzvBy1lJ3JttJjImgTjMtEUyZBv+xHU2IC7BVE=";
@@ -35,6 +34,20 @@ buildPythonPackage rec {
   patches = [
     # don't download pybind11, use local pybind11
     ./cmake-pybind11.patch
+    ./patch-sha-check-in-setup.patch
+
+    # The following two patches add the required `include <cstdint>` for compilation to work with GCC 13.
+    # See https://github.com/Farama-Foundation/Arcade-Learning-Environment/pull/503
+    (fetchpatch {
+      name = "fix-gcc13-compilation-1";
+      url = "https://github.com/Farama-Foundation/Arcade-Learning-Environment/commit/ebd64c03cdaa3d8df7da7c62ec3ae5795105e27a.patch";
+      hash = "sha256-NMz0hw8USOj88WryHRkMQNWznnP6+5aWovEYNuocQ2c=";
+    })
+    (fetchpatch {
+      name = "fix-gcc13-compilation-2";
+      url = "https://github.com/Farama-Foundation/Arcade-Learning-Environment/commit/4c99c7034f17810f3ff6c27436bfc3b40d08da21.patch";
+      hash = "sha256-66/bDCyMr1RsKk63T9GnFZGloLlkdr/bf5WHtWbX6VY=";
+    })
   ];
 
   nativeBuildInputs = [
@@ -67,7 +80,7 @@ buildPythonPackage rec {
     substituteInPlace pyproject.toml \
       --replace 'dynamic = ["version"]' 'version = "${version}"'
     substituteInPlace setup.py \
-      --replace 'subprocess.check_output(["git", "rev-parse", "--short", "HEAD"], cwd=here)' 'b"${src.rev}"'
+      --replace '@sha@' '"${version}"'
   '';
 
   dontUseCmakeConfigure = true;
diff --git a/pkgs/development/python-modules/ale-py/patch-sha-check-in-setup.patch b/pkgs/development/python-modules/ale-py/patch-sha-check-in-setup.patch
new file mode 100644
index 000000000000..f387346ded37
--- /dev/null
+++ b/pkgs/development/python-modules/ale-py/patch-sha-check-in-setup.patch
@@ -0,0 +1,17 @@
+diff --git a/setup.py b/setup.py
+index ff1b1c5..ce40df0 100644
+--- a/setup.py
++++ b/setup.py
+@@ -141,11 +141,7 @@ def parse_version(version_file):
+ 
+         version = ci_version
+     else:
+-        sha = (
+-            subprocess.check_output(["git", "rev-parse", "--short", "HEAD"], cwd=here)
+-            .decode("ascii")
+-            .strip()
+-        )
++        sha = @sha@
+         version += f"+{sha}"
+ 
+     return version
diff --git a/pkgs/development/python-modules/tensordict/default.nix b/pkgs/development/python-modules/tensordict/default.nix
new file mode 100644
index 000000000000..c6a563bf55e5
--- /dev/null
+++ b/pkgs/development/python-modules/tensordict/default.nix
@@ -0,0 +1,63 @@
+{ lib
+, buildPythonPackage
+, pythonOlder
+, fetchFromGitHub
+, setuptools
+, torch
+, wheel
+, which
+, cloudpickle
+, numpy
+, h5py
+, pytestCheckHook
+}:
+
+buildPythonPackage rec {
+  pname = "tensordict";
+  version = "0.2.1";
+  pyproject = true;
+
+  disabled = pythonOlder "3.8";
+
+  src = fetchFromGitHub {
+    owner = "pytorch";
+    repo = "tensordict";
+    rev = "refs/tags/v${version}";
+    hash = "sha256-+Osoz1632F/dEkG/o8RUqCIDok2Qc9Qdak+CCr9m26g=";
+  };
+
+  nativeBuildInputs = [
+    setuptools
+    torch
+    wheel
+    which
+  ];
+
+  propagatedBuildInputs = [
+    cloudpickle
+    numpy
+    torch
+  ];
+
+  pythonImportsCheck = [
+    "tensordict"
+  ];
+
+  # We have to delete the source because otherwise it is used instead of the installed package.
+  preCheck = ''
+    rm -rf tensordict
+  '';
+
+  nativeCheckInputs = [
+    h5py
+    pytestCheckHook
+  ];
+
+  meta = with lib; {
+    description = "A pytorch dedicated tensor container";
+    changelog = "https://github.com/pytorch/tensordict/releases/tag/v${version}";
+    homepage = "https://github.com/pytorch/tensordict";
+    license = licenses.mit;
+    maintainers = with maintainers; [ GaetanLepage ];
+  };
+}
diff --git a/pkgs/development/python-modules/torchrl/default.nix b/pkgs/development/python-modules/torchrl/default.nix
new file mode 100644
index 000000000000..bbf1fccd76ba
--- /dev/null
+++ b/pkgs/development/python-modules/torchrl/default.nix
@@ -0,0 +1,154 @@
+{ lib
+, buildPythonPackage
+, pythonOlder
+, fetchFromGitHub
+, fetchpatch
+, ninja
+, setuptools
+, wheel
+, which
+, cloudpickle
+, numpy
+, torch
+, ale-py
+, gym
+, pygame
+, gymnasium
+, mujoco
+, moviepy
+, git
+, hydra-core
+, tensorboard
+, tqdm
+, wandb
+, packaging
+, tensordict
+, imageio
+, pytest-rerunfailures
+, pytestCheckHook
+, pyyaml
+, scipy
+}:
+
+buildPythonPackage rec {
+  pname = "torchrl";
+  version = "0.2.1";
+  pyproject = true;
+
+  disabled = pythonOlder "3.8";
+
+  src = fetchFromGitHub {
+    owner = "pytorch";
+    repo = "rl";
+    rev = "refs/tags/v${version}";
+    hash = "sha256-Y3WbSMGXS6fb4RyXk2SAKHT6RencGTZXM3tc65AQx74=";
+  };
+
+  patches = [
+    (fetchpatch {  # https://github.com/pytorch/rl/pull/1828
+      name = "pyproject.toml-remove-unknown-properties";
+      url = "https://github.com/pytorch/rl/commit/c390cf602fc79cb37d5f7bda6e44b5e9546ecda0.patch";
+      hash = "sha256-cUBBvKJ8vIHprcGzMojkUxcOrrmNPIoIBfLwHXWkjOc=";
+    })
+  ];
+
+  nativeBuildInputs = [
+    ninja
+    setuptools
+    wheel
+    which
+  ];
+
+  propagatedBuildInputs = [
+    cloudpickle
+    numpy
+    packaging
+    tensordict
+    torch
+  ];
+
+  passthru.optional-dependencies = {
+    atari = [
+      ale-py
+      gym
+      pygame
+    ];
+    gym-continuous = [
+      gymnasium
+      mujoco
+    ];
+    rendering = [
+      moviepy
+    ];
+    utils = [
+      git
+      hydra-core
+      tensorboard
+      tqdm
+      wandb
+    ];
+  };
+
+  # torchrl needs to create a folder to store datasets
+  preBuild = ''
+    export D4RL_DATASET_DIR=$(mktemp -d)
+  '';
+
+  pythonImportsCheck = [
+    "torchrl"
+  ];
+
+  # We have to delete the source because otherwise it is used instead of the installed package.
+  preCheck = ''
+    rm -rf torchrl
+
+    export XDG_RUNTIME_DIR=$(mktemp -d)
+  ''
+  # Otherwise, tochrl will try to use unpackaged torchsnapshot.
+  # TODO: This should be the default from next release so remove when updating from 0.2.1
+  + ''
+    export CKPT_BACKEND="torch"
+  '';
+
+  nativeCheckInputs = [
+    gymnasium
+    imageio
+    pytest-rerunfailures
+    pytestCheckHook
+    pyyaml
+    scipy
+  ]
+  ++ passthru.optional-dependencies.atari
+  ++ passthru.optional-dependencies.gym-continuous
+  ++ passthru.optional-dependencies.rendering;
+
+  disabledTests = [
+    # mujoco.FatalError: an OpenGL platform library has not been loaded into this process, this most likely means that a valid OpenGL context has not been created before mjr_makeContext was called
+    "test_vecenvs_env"
+
+    # ValueError: Can't write images with one color channel.
+    "test_log_video"
+
+    # Those tests require the ALE environments (provided by unpackaged shimmy)
+    "test_collector_env_reset"
+    "test_gym"
+    "test_gym_fake_td"
+    "test_recorder"
+    "test_recorder_load"
+    "test_rollout"
+    "test_parallel_trans_env_check"
+    "test_serial_trans_env_check"
+    "test_single_trans_env_check"
+    "test_td_creation_from_spec"
+    "test_trans_parallel_env_check"
+    "test_trans_serial_env_check"
+    "test_transform_env"
+  ];
+
+  meta = with lib; {
+    description = "A modular, primitive-first, python-first PyTorch library for Reinforcement Learning";
+    homepage = "https://github.com/pytorch/rl";
+    license = licenses.mit;
+    maintainers = with maintainers; [ GaetanLepage ];
+  };
+}