about summary refs log tree commit diff
path: root/nixpkgs/pkgs/servers/apache-airflow/update-providers.py
diff options
context:
space:
mode:
Diffstat (limited to 'nixpkgs/pkgs/servers/apache-airflow/update-providers.py')
-rwxr-xr-xnixpkgs/pkgs/servers/apache-airflow/update-providers.py228
1 files changed, 228 insertions, 0 deletions
diff --git a/nixpkgs/pkgs/servers/apache-airflow/update-providers.py b/nixpkgs/pkgs/servers/apache-airflow/update-providers.py
new file mode 100755
index 000000000000..207c38119258
--- /dev/null
+++ b/nixpkgs/pkgs/servers/apache-airflow/update-providers.py
@@ -0,0 +1,228 @@
+#! /usr/bin/env python3
+
+from itertools import chain
+import json
+import logging
+from pathlib import Path
+import os
+import re
+import subprocess
+import sys
+from typing import Dict, List, Optional, Set, TextIO
+from urllib.request import urlopen
+from urllib.error import HTTPError
+import yaml
+
+PKG_SET = "apache-airflow.pythonPackages"
+
+# If some requirements are matched by multiple or no Python packages, the
+# following can be used to choose the correct one
+PKG_PREFERENCES = {
+    "dnspython": "dnspython",
+    "elasticsearch-dsl": "elasticsearch-dsl",
+    "google-api-python-client": "google-api-python-client",
+    "protobuf": "protobuf",
+    "psycopg2-binary": "psycopg2",
+    "requests_toolbelt": "requests-toolbelt",
+}
+
+# Requirements missing from the airflow provider metadata
+EXTRA_REQS = {
+    "sftp": ["pysftp"],
+}
+
+
+def get_version():
+    with open(os.path.dirname(sys.argv[0]) + "/default.nix") as fh:
+        # A version consists of digits, dots, and possibly a "b" (for beta)
+        m = re.search('version = "([\\d\\.b]+)";', fh.read())
+        return m.group(1)
+
+
+def get_file_from_github(version: str, path: str):
+    with urlopen(
+        f"https://raw.githubusercontent.com/apache/airflow/{version}/{path}"
+    ) as response:
+        return yaml.safe_load(response)
+
+
+def repository_root() -> Path:
+    return Path(os.path.dirname(sys.argv[0])) / "../../.."
+
+
+def dump_packages() -> Dict[str, Dict[str, str]]:
+    # Store a JSON dump of Nixpkgs' python3Packages
+    output = subprocess.check_output(
+        [
+            "nix-env",
+            "-f",
+            repository_root(),
+            "-qa",
+            "-A",
+            PKG_SET,
+            "--arg",
+            "config",
+            "{ allowAliases = false; }",
+            "--json",
+        ]
+    )
+    return json.loads(output)
+
+
+def remove_version_constraint(req: str) -> str:
+    return re.sub(r"[=><~].*$", "", req)
+
+
+def name_to_attr_path(req: str, packages: Dict[str, Dict[str, str]]) -> Optional[str]:
+    if req in PKG_PREFERENCES:
+        return f"{PKG_SET}.{PKG_PREFERENCES[req]}"
+    attr_paths = []
+    names = [req]
+    # E.g. python-mpd2 is actually called python3.6-mpd2
+    # instead of python-3.6-python-mpd2 inside Nixpkgs
+    if req.startswith("python-") or req.startswith("python_"):
+        names.append(req[len("python-") :])
+    for name in names:
+        # treat "-" and "_" equally
+        name = re.sub("[-_]", "[-_]", name)
+        # python(minor).(major)-(pname)-(version or unstable-date)
+        # we need the version qualifier, or we'll have multiple matches
+        # (e.g. pyserial and pyserial-asyncio when looking for pyserial)
+        pattern = re.compile(
+            f"^python\\d+\\.\\d+-{name}-(?:\\d|unstable-.*)", re.I
+        )
+        for attr_path, package in packages.items():
+            # logging.debug("Checking match for %s with %s", name, package["name"])
+            if pattern.match(package["name"]):
+                attr_paths.append(attr_path)
+    # Let's hope there's only one derivation with a matching name
+    assert len(attr_paths) <= 1, f"{req} matches more than one derivation: {attr_paths}"
+    if attr_paths:
+        return attr_paths[0]
+    return None
+
+
+def provider_reqs_to_attr_paths(reqs: List, packages: Dict) -> List:
+    no_version_reqs = map(remove_version_constraint, reqs)
+    filtered_reqs = [
+        req for req in no_version_reqs if not re.match(r"^apache-airflow", req)
+    ]
+    attr_paths = []
+    for req in filtered_reqs:
+        attr_path = name_to_attr_path(req, packages)
+        if attr_path is not None:
+            # Add attribute path without "python3Packages." prefix
+            pname = attr_path[len(PKG_SET + ".") :]
+            attr_paths.append(pname)
+        else:
+            # If we can't find it, we just skip and warn the user
+            logging.warning("Could not find package attr for %s", req)
+    return attr_paths
+
+
+def get_cross_provider_reqs(
+    provider: str, provider_reqs: Dict, cross_provider_deps: Dict, seen: List = None
+) -> Set:
+    # Unfortunately there are circular cross-provider dependencies, so keep a
+    # list of ones we've seen already
+    seen = seen or []
+    reqs = set(provider_reqs[provider])
+    if len(cross_provider_deps[provider]) > 0:
+        reqs.update(
+            chain.from_iterable(
+                get_cross_provider_reqs(
+                    d, provider_reqs, cross_provider_deps, seen + [provider]
+                )
+                if d not in seen
+                else []
+                for d in cross_provider_deps[provider]
+            )
+        )
+    return reqs
+
+
+def get_provider_reqs(version: str, packages: Dict) -> Dict:
+    provider_dependencies = get_file_from_github(
+        version, "generated/provider_dependencies.json"
+    )
+    provider_reqs = {}
+    cross_provider_deps = {}
+    for provider, provider_data in provider_dependencies.items():
+        provider_reqs[provider] = list(
+            provider_reqs_to_attr_paths(provider_data["deps"], packages)
+        ) + EXTRA_REQS.get(provider, [])
+        cross_provider_deps[provider] = [
+            d for d in provider_data["cross-providers-deps"] if d != "common.sql"
+        ]
+    transitive_provider_reqs = {}
+    # Add transitive cross-provider reqs
+    for provider in provider_reqs:
+        transitive_provider_reqs[provider] = get_cross_provider_reqs(
+            provider, provider_reqs, cross_provider_deps
+        )
+    return transitive_provider_reqs
+
+
+def get_provider_yaml(version: str, provider: str) -> Dict:
+    provider_dir = provider.replace(".", "/")
+    path = f"airflow/providers/{provider_dir}/provider.yaml"
+    try:
+        return get_file_from_github(version, path)
+    except HTTPError:
+        logging.warning("Couldn't get provider yaml for %s", provider)
+        return {}
+
+
+def get_provider_imports(version: str, providers) -> Dict:
+    provider_imports = {}
+    for provider in providers:
+        provider_yaml = get_provider_yaml(version, provider)
+        imports: List[str] = []
+        if "hooks" in provider_yaml:
+            imports.extend(
+                chain.from_iterable(
+                    hook["python-modules"] for hook in provider_yaml["hooks"]
+                )
+            )
+        if "operators" in provider_yaml:
+            imports.extend(
+                chain.from_iterable(
+                    operator["python-modules"]
+                    for operator in provider_yaml["operators"]
+                )
+            )
+        provider_imports[provider] = imports
+    return provider_imports
+
+
+def to_nix_expr(provider_reqs: Dict, provider_imports: Dict, fh: TextIO) -> None:
+    fh.write("# Warning: generated by update-providers.py, do not update manually\n")
+    fh.write("{\n")
+    for provider, reqs in provider_reqs.items():
+        provider_name = provider.replace(".", "_")
+        fh.write(f"  {provider_name} = {{\n")
+        fh.write(
+            "    deps = [ " + " ".join(sorted(f'"{req}"' for req in reqs)) + " ];\n"
+        )
+        fh.write(
+            "    imports = [ "
+            + " ".join(sorted(f'"{imp}"' for imp in provider_imports[provider]))
+            + " ];\n"
+        )
+        fh.write("  };\n")
+    fh.write("}\n")
+
+
+def main() -> None:
+    logging.basicConfig(level=logging.INFO)
+    version = get_version()
+    packages = dump_packages()
+    logging.info("Generating providers.nix for version %s", version)
+    provider_reqs = get_provider_reqs(version, packages)
+    provider_imports = get_provider_imports(version, provider_reqs.keys())
+    with open("providers.nix", "w") as fh:
+        to_nix_expr(provider_reqs, provider_imports, fh)
+
+
+if __name__ == "__main__":
+    main()