diff options
Diffstat (limited to 'pkgs/development/python-modules/diffusers/default.nix')
-rw-r--r-- | pkgs/development/python-modules/diffusers/default.nix | 153 |
1 files changed, 153 insertions, 0 deletions
diff --git a/pkgs/development/python-modules/diffusers/default.nix b/pkgs/development/python-modules/diffusers/default.nix new file mode 100644 index 000000000000..3485f9e3351d --- /dev/null +++ b/pkgs/development/python-modules/diffusers/default.nix @@ -0,0 +1,153 @@ +{ lib +, stdenv +, buildPythonPackage +, fetchFromGitHub +, pythonOlder +, writeText +, setuptools +, wheel +, filelock +, huggingface-hub +, importlib-metadata +, numpy +, pillow +, regex +, requests +, safetensors +# optional dependencies +, accelerate +, datasets +, flax +, jax +, jaxlib +, jinja2 +, protobuf +, tensorboard +, torch +# test dependencies +, parameterized +, pytest-timeout +, pytest-xdist +, pytestCheckHook +, requests-mock +, ruff +, scipy +, sentencepiece +, torchsde +, transformers +}: + +buildPythonPackage rec { + pname = "diffusers"; + version = "0.24.0"; + pyproject = true; + + disabled = pythonOlder "3.8"; + + src = fetchFromGitHub { + owner = "huggingface"; + repo = "diffusers"; + rev = "refs/tags/v${version}"; + hash = "sha256-ccWF8hQzPhFY/kqRum2tbanI+cQiT25MmvPZN+hGadc="; + }; + + nativeBuildInputs = [ + setuptools + wheel + ]; + + propagatedBuildInputs = [ + filelock + huggingface-hub + importlib-metadata + numpy + pillow + regex + requests + safetensors + ]; + + passthru.optional-dependencies = { + flax = [ + flax + jax + jaxlib + ]; + torch = [ + accelerate + torch + ]; + training = [ + accelerate + datasets + jinja2 + protobuf + tensorboard + ]; + }; + + pythonImportsCheck = [ + "diffusers" + ]; + + # tests crash due to torch segmentation fault + doCheck = !(stdenv.isLinux && stdenv.isAarch64); + + nativeCheckInputs = [ + parameterized + pytest-timeout + pytest-xdist + pytestCheckHook + requests-mock + ruff + scipy + sentencepiece + torchsde + transformers + ] ++ passthru.optional-dependencies.torch; + + preCheck = let + # This pytest hook mocks and catches attempts at accessing the network + # tests that try to access the network will raise, get caught, be marked as skipped and tagged as xfailed. + # cf. python3Packages.shap + conftestSkipNetworkErrors = writeText "conftest.py" '' + from _pytest.runner import pytest_runtest_makereport as orig_pytest_runtest_makereport + import urllib3 + + class NetworkAccessDeniedError(RuntimeError): pass + def deny_network_access(*a, **kw): + raise NetworkAccessDeniedError + + urllib3.connection.HTTPSConnection._new_conn = deny_network_access + + def pytest_runtest_makereport(item, call): + tr = orig_pytest_runtest_makereport(item, call) + if call.excinfo is not None and call.excinfo.type is NetworkAccessDeniedError: + tr.outcome = 'skipped' + tr.wasxfail = "reason: Requires network access." + return tr + ''; + in '' + export HOME=$TMPDIR + cat ${conftestSkipNetworkErrors} >> tests/conftest.py + ''; + + pytestFlagsArray = [ + "tests/" + ]; + + disabledTests = [ + # depends on current working directory + "test_deprecate_stacklevel" + # fails due to precision of floating point numbers + "test_model_cpu_offload_forward_pass" + ]; + + meta = with lib; { + description = "State-of-the-art diffusion models for image and audio generation in PyTorch"; + homepage = "https://github.com/huggingface/diffusers"; + changelog = "https://github.com/huggingface/diffusers/releases/tag/${src.rev}"; + license = licenses.asl20; + maintainers = with maintainers; [ natsukium ]; + }; +} |