diff options
author | Someone <sergei.kozlukov@aalto.fi> | 2024-01-25 11:33:59 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-25 11:33:59 +0000 |
commit | b63791c6e5ca0aa4d9637450a7555a1a5273ce1f (patch) | |
tree | ba57d24e40866417b74193550ac5b3092d926dcb /pkgs/development | |
parent | 39b0be0d82e2fa40ec181615c0dbff4e77914e41 (diff) | |
parent | c5ac684ad8eff90e45f1459844604114c121d096 (diff) | |
download | nixlib-b63791c6e5ca0aa4d9637450a7555a1a5273ce1f.tar nixlib-b63791c6e5ca0aa4d9637450a7555a1a5273ce1f.tar.gz nixlib-b63791c6e5ca0aa4d9637450a7555a1a5273ce1f.tar.bz2 nixlib-b63791c6e5ca0aa4d9637450a7555a1a5273ce1f.tar.lz nixlib-b63791c6e5ca0aa4d9637450a7555a1a5273ce1f.tar.xz nixlib-b63791c6e5ca0aa4d9637450a7555a1a5273ce1f.tar.zst nixlib-b63791c6e5ca0aa4d9637450a7555a1a5273ce1f.zip |
Merge pull request #275195 from GaetanLepage/torchrl
python311Packages.torchrl: init at 0.2.1
Diffstat (limited to 'pkgs/development')
4 files changed, 252 insertions, 5 deletions
diff --git a/pkgs/development/python-modules/ale-py/default.nix b/pkgs/development/python-modules/ale-py/default.nix index 77978654e68f..9cc5f6105cf5 100644 --- a/pkgs/development/python-modules/ale-py/default.nix +++ b/pkgs/development/python-modules/ale-py/default.nix @@ -2,7 +2,7 @@ , SDL2 , cmake , fetchFromGitHub -, git +, fetchpatch , gym , importlib-metadata , importlib-resources @@ -11,7 +11,6 @@ , numpy , pybind11 , pytestCheckHook -, python , pythonOlder , setuptools , stdenv @@ -23,10 +22,10 @@ buildPythonPackage rec { pname = "ale-py"; version = "0.8.1"; - format = "pyproject"; + pyproject = true; src = fetchFromGitHub { - owner = "mgbellemare"; + owner = "Farama-Foundation"; repo = "Arcade-Learning-Environment"; rev = "refs/tags/v${version}"; hash = "sha256-B2AxhlzvBy1lJ3JttJjImgTjMtEUyZBv+xHU2IC7BVE="; @@ -35,6 +34,20 @@ buildPythonPackage rec { patches = [ # don't download pybind11, use local pybind11 ./cmake-pybind11.patch + ./patch-sha-check-in-setup.patch + + # The following two patches add the required `include <cstdint>` for compilation to work with GCC 13. + # See https://github.com/Farama-Foundation/Arcade-Learning-Environment/pull/503 + (fetchpatch { + name = "fix-gcc13-compilation-1"; + url = "https://github.com/Farama-Foundation/Arcade-Learning-Environment/commit/ebd64c03cdaa3d8df7da7c62ec3ae5795105e27a.patch"; + hash = "sha256-NMz0hw8USOj88WryHRkMQNWznnP6+5aWovEYNuocQ2c="; + }) + (fetchpatch { + name = "fix-gcc13-compilation-2"; + url = "https://github.com/Farama-Foundation/Arcade-Learning-Environment/commit/4c99c7034f17810f3ff6c27436bfc3b40d08da21.patch"; + hash = "sha256-66/bDCyMr1RsKk63T9GnFZGloLlkdr/bf5WHtWbX6VY="; + }) ]; nativeBuildInputs = [ @@ -67,7 +80,7 @@ buildPythonPackage rec { substituteInPlace pyproject.toml \ --replace 'dynamic = ["version"]' 'version = "${version}"' substituteInPlace setup.py \ - --replace 'subprocess.check_output(["git", "rev-parse", "--short", "HEAD"], cwd=here)' 'b"${src.rev}"' + --replace '@sha@' '"${version}"' ''; dontUseCmakeConfigure = true; diff --git a/pkgs/development/python-modules/ale-py/patch-sha-check-in-setup.patch b/pkgs/development/python-modules/ale-py/patch-sha-check-in-setup.patch new file mode 100644 index 000000000000..f387346ded37 --- /dev/null +++ b/pkgs/development/python-modules/ale-py/patch-sha-check-in-setup.patch @@ -0,0 +1,17 @@ +diff --git a/setup.py b/setup.py +index ff1b1c5..ce40df0 100644 +--- a/setup.py ++++ b/setup.py +@@ -141,11 +141,7 @@ def parse_version(version_file): + + version = ci_version + else: +- sha = ( +- subprocess.check_output(["git", "rev-parse", "--short", "HEAD"], cwd=here) +- .decode("ascii") +- .strip() +- ) ++ sha = @sha@ + version += f"+{sha}" + + return version diff --git a/pkgs/development/python-modules/tensordict/default.nix b/pkgs/development/python-modules/tensordict/default.nix new file mode 100644 index 000000000000..c6a563bf55e5 --- /dev/null +++ b/pkgs/development/python-modules/tensordict/default.nix @@ -0,0 +1,63 @@ +{ lib +, buildPythonPackage +, pythonOlder +, fetchFromGitHub +, setuptools +, torch +, wheel +, which +, cloudpickle +, numpy +, h5py +, pytestCheckHook +}: + +buildPythonPackage rec { + pname = "tensordict"; + version = "0.2.1"; + pyproject = true; + + disabled = pythonOlder "3.8"; + + src = fetchFromGitHub { + owner = "pytorch"; + repo = "tensordict"; + rev = "refs/tags/v${version}"; + hash = "sha256-+Osoz1632F/dEkG/o8RUqCIDok2Qc9Qdak+CCr9m26g="; + }; + + nativeBuildInputs = [ + setuptools + torch + wheel + which + ]; + + propagatedBuildInputs = [ + cloudpickle + numpy + torch + ]; + + pythonImportsCheck = [ + "tensordict" + ]; + + # We have to delete the source because otherwise it is used instead of the installed package. + preCheck = '' + rm -rf tensordict + ''; + + nativeCheckInputs = [ + h5py + pytestCheckHook + ]; + + meta = with lib; { + description = "A pytorch dedicated tensor container"; + changelog = "https://github.com/pytorch/tensordict/releases/tag/v${version}"; + homepage = "https://github.com/pytorch/tensordict"; + license = licenses.mit; + maintainers = with maintainers; [ GaetanLepage ]; + }; +} diff --git a/pkgs/development/python-modules/torchrl/default.nix b/pkgs/development/python-modules/torchrl/default.nix new file mode 100644 index 000000000000..bbf1fccd76ba --- /dev/null +++ b/pkgs/development/python-modules/torchrl/default.nix @@ -0,0 +1,154 @@ +{ lib +, buildPythonPackage +, pythonOlder +, fetchFromGitHub +, fetchpatch +, ninja +, setuptools +, wheel +, which +, cloudpickle +, numpy +, torch +, ale-py +, gym +, pygame +, gymnasium +, mujoco +, moviepy +, git +, hydra-core +, tensorboard +, tqdm +, wandb +, packaging +, tensordict +, imageio +, pytest-rerunfailures +, pytestCheckHook +, pyyaml +, scipy +}: + +buildPythonPackage rec { + pname = "torchrl"; + version = "0.2.1"; + pyproject = true; + + disabled = pythonOlder "3.8"; + + src = fetchFromGitHub { + owner = "pytorch"; + repo = "rl"; + rev = "refs/tags/v${version}"; + hash = "sha256-Y3WbSMGXS6fb4RyXk2SAKHT6RencGTZXM3tc65AQx74="; + }; + + patches = [ + (fetchpatch { # https://github.com/pytorch/rl/pull/1828 + name = "pyproject.toml-remove-unknown-properties"; + url = "https://github.com/pytorch/rl/commit/c390cf602fc79cb37d5f7bda6e44b5e9546ecda0.patch"; + hash = "sha256-cUBBvKJ8vIHprcGzMojkUxcOrrmNPIoIBfLwHXWkjOc="; + }) + ]; + + nativeBuildInputs = [ + ninja + setuptools + wheel + which + ]; + + propagatedBuildInputs = [ + cloudpickle + numpy + packaging + tensordict + torch + ]; + + passthru.optional-dependencies = { + atari = [ + ale-py + gym + pygame + ]; + gym-continuous = [ + gymnasium + mujoco + ]; + rendering = [ + moviepy + ]; + utils = [ + git + hydra-core + tensorboard + tqdm + wandb + ]; + }; + + # torchrl needs to create a folder to store datasets + preBuild = '' + export D4RL_DATASET_DIR=$(mktemp -d) + ''; + + pythonImportsCheck = [ + "torchrl" + ]; + + # We have to delete the source because otherwise it is used instead of the installed package. + preCheck = '' + rm -rf torchrl + + export XDG_RUNTIME_DIR=$(mktemp -d) + '' + # Otherwise, tochrl will try to use unpackaged torchsnapshot. + # TODO: This should be the default from next release so remove when updating from 0.2.1 + + '' + export CKPT_BACKEND="torch" + ''; + + nativeCheckInputs = [ + gymnasium + imageio + pytest-rerunfailures + pytestCheckHook + pyyaml + scipy + ] + ++ passthru.optional-dependencies.atari + ++ passthru.optional-dependencies.gym-continuous + ++ passthru.optional-dependencies.rendering; + + disabledTests = [ + # mujoco.FatalError: an OpenGL platform library has not been loaded into this process, this most likely means that a valid OpenGL context has not been created before mjr_makeContext was called + "test_vecenvs_env" + + # ValueError: Can't write images with one color channel. + "test_log_video" + + # Those tests require the ALE environments (provided by unpackaged shimmy) + "test_collector_env_reset" + "test_gym" + "test_gym_fake_td" + "test_recorder" + "test_recorder_load" + "test_rollout" + "test_parallel_trans_env_check" + "test_serial_trans_env_check" + "test_single_trans_env_check" + "test_td_creation_from_spec" + "test_trans_parallel_env_check" + "test_trans_serial_env_check" + "test_transform_env" + ]; + + meta = with lib; { + description = "A modular, primitive-first, python-first PyTorch library for Reinforcement Learning"; + homepage = "https://github.com/pytorch/rl"; + license = licenses.mit; + maintainers = with maintainers; [ GaetanLepage ]; + }; +} |