diff options
Diffstat (limited to 'nixpkgs/pkgs/servers/apache-airflow/update-providers.py')
-rwxr-xr-x | nixpkgs/pkgs/servers/apache-airflow/update-providers.py | 228 |
1 files changed, 228 insertions, 0 deletions
diff --git a/nixpkgs/pkgs/servers/apache-airflow/update-providers.py b/nixpkgs/pkgs/servers/apache-airflow/update-providers.py new file mode 100755 index 000000000000..207c38119258 --- /dev/null +++ b/nixpkgs/pkgs/servers/apache-airflow/update-providers.py @@ -0,0 +1,228 @@ +#! /usr/bin/env python3 + +from itertools import chain +import json +import logging +from pathlib import Path +import os +import re +import subprocess +import sys +from typing import Dict, List, Optional, Set, TextIO +from urllib.request import urlopen +from urllib.error import HTTPError +import yaml + +PKG_SET = "apache-airflow.pythonPackages" + +# If some requirements are matched by multiple or no Python packages, the +# following can be used to choose the correct one +PKG_PREFERENCES = { + "dnspython": "dnspython", + "elasticsearch-dsl": "elasticsearch-dsl", + "google-api-python-client": "google-api-python-client", + "protobuf": "protobuf", + "psycopg2-binary": "psycopg2", + "requests_toolbelt": "requests-toolbelt", +} + +# Requirements missing from the airflow provider metadata +EXTRA_REQS = { + "sftp": ["pysftp"], +} + + +def get_version(): + with open(os.path.dirname(sys.argv[0]) + "/default.nix") as fh: + # A version consists of digits, dots, and possibly a "b" (for beta) + m = re.search('version = "([\\d\\.b]+)";', fh.read()) + return m.group(1) + + +def get_file_from_github(version: str, path: str): + with urlopen( + f"https://raw.githubusercontent.com/apache/airflow/{version}/{path}" + ) as response: + return yaml.safe_load(response) + + +def repository_root() -> Path: + return Path(os.path.dirname(sys.argv[0])) / "../../.." + + +def dump_packages() -> Dict[str, Dict[str, str]]: + # Store a JSON dump of Nixpkgs' python3Packages + output = subprocess.check_output( + [ + "nix-env", + "-f", + repository_root(), + "-qa", + "-A", + PKG_SET, + "--arg", + "config", + "{ allowAliases = false; }", + "--json", + ] + ) + return json.loads(output) + + +def remove_version_constraint(req: str) -> str: + return re.sub(r"[=><~].*$", "", req) + + +def name_to_attr_path(req: str, packages: Dict[str, Dict[str, str]]) -> Optional[str]: + if req in PKG_PREFERENCES: + return f"{PKG_SET}.{PKG_PREFERENCES[req]}" + attr_paths = [] + names = [req] + # E.g. python-mpd2 is actually called python3.6-mpd2 + # instead of python-3.6-python-mpd2 inside Nixpkgs + if req.startswith("python-") or req.startswith("python_"): + names.append(req[len("python-") :]) + for name in names: + # treat "-" and "_" equally + name = re.sub("[-_]", "[-_]", name) + # python(minor).(major)-(pname)-(version or unstable-date) + # we need the version qualifier, or we'll have multiple matches + # (e.g. pyserial and pyserial-asyncio when looking for pyserial) + pattern = re.compile( + f"^python\\d+\\.\\d+-{name}-(?:\\d|unstable-.*)", re.I + ) + for attr_path, package in packages.items(): + # logging.debug("Checking match for %s with %s", name, package["name"]) + if pattern.match(package["name"]): + attr_paths.append(attr_path) + # Let's hope there's only one derivation with a matching name + assert len(attr_paths) <= 1, f"{req} matches more than one derivation: {attr_paths}" + if attr_paths: + return attr_paths[0] + return None + + +def provider_reqs_to_attr_paths(reqs: List, packages: Dict) -> List: + no_version_reqs = map(remove_version_constraint, reqs) + filtered_reqs = [ + req for req in no_version_reqs if not re.match(r"^apache-airflow", req) + ] + attr_paths = [] + for req in filtered_reqs: + attr_path = name_to_attr_path(req, packages) + if attr_path is not None: + # Add attribute path without "python3Packages." prefix + pname = attr_path[len(PKG_SET + ".") :] + attr_paths.append(pname) + else: + # If we can't find it, we just skip and warn the user + logging.warning("Could not find package attr for %s", req) + return attr_paths + + +def get_cross_provider_reqs( + provider: str, provider_reqs: Dict, cross_provider_deps: Dict, seen: List = None +) -> Set: + # Unfortunately there are circular cross-provider dependencies, so keep a + # list of ones we've seen already + seen = seen or [] + reqs = set(provider_reqs[provider]) + if len(cross_provider_deps[provider]) > 0: + reqs.update( + chain.from_iterable( + get_cross_provider_reqs( + d, provider_reqs, cross_provider_deps, seen + [provider] + ) + if d not in seen + else [] + for d in cross_provider_deps[provider] + ) + ) + return reqs + + +def get_provider_reqs(version: str, packages: Dict) -> Dict: + provider_dependencies = get_file_from_github( + version, "generated/provider_dependencies.json" + ) + provider_reqs = {} + cross_provider_deps = {} + for provider, provider_data in provider_dependencies.items(): + provider_reqs[provider] = list( + provider_reqs_to_attr_paths(provider_data["deps"], packages) + ) + EXTRA_REQS.get(provider, []) + cross_provider_deps[provider] = [ + d for d in provider_data["cross-providers-deps"] if d != "common.sql" + ] + transitive_provider_reqs = {} + # Add transitive cross-provider reqs + for provider in provider_reqs: + transitive_provider_reqs[provider] = get_cross_provider_reqs( + provider, provider_reqs, cross_provider_deps + ) + return transitive_provider_reqs + + +def get_provider_yaml(version: str, provider: str) -> Dict: + provider_dir = provider.replace(".", "/") + path = f"airflow/providers/{provider_dir}/provider.yaml" + try: + return get_file_from_github(version, path) + except HTTPError: + logging.warning("Couldn't get provider yaml for %s", provider) + return {} + + +def get_provider_imports(version: str, providers) -> Dict: + provider_imports = {} + for provider in providers: + provider_yaml = get_provider_yaml(version, provider) + imports: List[str] = [] + if "hooks" in provider_yaml: + imports.extend( + chain.from_iterable( + hook["python-modules"] for hook in provider_yaml["hooks"] + ) + ) + if "operators" in provider_yaml: + imports.extend( + chain.from_iterable( + operator["python-modules"] + for operator in provider_yaml["operators"] + ) + ) + provider_imports[provider] = imports + return provider_imports + + +def to_nix_expr(provider_reqs: Dict, provider_imports: Dict, fh: TextIO) -> None: + fh.write("# Warning: generated by update-providers.py, do not update manually\n") + fh.write("{\n") + for provider, reqs in provider_reqs.items(): + provider_name = provider.replace(".", "_") + fh.write(f" {provider_name} = {{\n") + fh.write( + " deps = [ " + " ".join(sorted(f'"{req}"' for req in reqs)) + " ];\n" + ) + fh.write( + " imports = [ " + + " ".join(sorted(f'"{imp}"' for imp in provider_imports[provider])) + + " ];\n" + ) + fh.write(" };\n") + fh.write("}\n") + + +def main() -> None: + logging.basicConfig(level=logging.INFO) + version = get_version() + packages = dump_packages() + logging.info("Generating providers.nix for version %s", version) + provider_reqs = get_provider_reqs(version, packages) + provider_imports = get_provider_imports(version, provider_reqs.keys()) + with open("providers.nix", "w") as fh: + to_nix_expr(provider_reqs, provider_imports, fh) + + +if __name__ == "__main__": + main() |