#! /usr/bin/env python3 from itertools import chain import json import logging from pathlib import Path import os import re import subprocess import sys from typing import Dict, List, Optional, Set, TextIO from urllib.request import urlopen from urllib.error import HTTPError import yaml PKG_SET = "apache-airflow.pythonPackages" # If some requirements are matched by multiple or no Python packages, the # following can be used to choose the correct one PKG_PREFERENCES = { "dnspython": "dnspython", "elasticsearch-dsl": "elasticsearch-dsl", "google-api-python-client": "google-api-python-client", "protobuf": "protobuf", "psycopg2-binary": "psycopg2", "requests_toolbelt": "requests-toolbelt", } # Requirements missing from the airflow provider metadata EXTRA_REQS = { "sftp": ["pysftp"], } def get_version(): with open(os.path.dirname(sys.argv[0]) + "/default.nix") as fh: # A version consists of digits, dots, and possibly a "b" (for beta) m = re.search('version = "([\\d\\.b]+)";', fh.read()) return m.group(1) def get_file_from_github(version: str, path: str): with urlopen( f"https://raw.githubusercontent.com/apache/airflow/{version}/{path}" ) as response: return yaml.safe_load(response) def repository_root() -> Path: return Path(os.path.dirname(sys.argv[0])) / "../../.." def dump_packages() -> Dict[str, Dict[str, str]]: # Store a JSON dump of Nixpkgs' python3Packages output = subprocess.check_output( [ "nix-env", "-f", repository_root(), "-qa", "-A", PKG_SET, "--arg", "config", "{ allowAliases = false; }", "--json", ] ) return json.loads(output) def remove_version_constraint(req: str) -> str: return re.sub(r"[=><~].*$", "", req) def name_to_attr_path(req: str, packages: Dict[str, Dict[str, str]]) -> Optional[str]: if req in PKG_PREFERENCES: return f"{PKG_SET}.{PKG_PREFERENCES[req]}" attr_paths = [] names = [req] # E.g. python-mpd2 is actually called python3.6-mpd2 # instead of python-3.6-python-mpd2 inside Nixpkgs if req.startswith("python-") or req.startswith("python_"): names.append(req[len("python-") :]) for name in names: # treat "-" and "_" equally name = re.sub("[-_]", "[-_]", name) # python(minor).(major)-(pname)-(version or unstable-date) # we need the version qualifier, or we'll have multiple matches # (e.g. pyserial and pyserial-asyncio when looking for pyserial) pattern = re.compile( f"^python\\d+\\.\\d+-{name}-(?:\\d|unstable-.*)", re.I ) for attr_path, package in packages.items(): # logging.debug("Checking match for %s with %s", name, package["name"]) if pattern.match(package["name"]): attr_paths.append(attr_path) # Let's hope there's only one derivation with a matching name assert len(attr_paths) <= 1, f"{req} matches more than one derivation: {attr_paths}" if attr_paths: return attr_paths[0] return None def provider_reqs_to_attr_paths(reqs: List, packages: Dict) -> List: no_version_reqs = map(remove_version_constraint, reqs) filtered_reqs = [ req for req in no_version_reqs if not re.match(r"^apache-airflow", req) ] attr_paths = [] for req in filtered_reqs: attr_path = name_to_attr_path(req, packages) if attr_path is not None: # Add attribute path without "python3Packages." prefix pname = attr_path[len(PKG_SET + ".") :] attr_paths.append(pname) else: # If we can't find it, we just skip and warn the user logging.warning("Could not find package attr for %s", req) return attr_paths def get_cross_provider_reqs( provider: str, provider_reqs: Dict, cross_provider_deps: Dict, seen: List = None ) -> Set: # Unfortunately there are circular cross-provider dependencies, so keep a # list of ones we've seen already seen = seen or [] reqs = set(provider_reqs[provider]) if len(cross_provider_deps[provider]) > 0: reqs.update( chain.from_iterable( get_cross_provider_reqs( d, provider_reqs, cross_provider_deps, seen + [provider] ) if d not in seen else [] for d in cross_provider_deps[provider] ) ) return reqs def get_provider_reqs(version: str, packages: Dict) -> Dict: provider_dependencies = get_file_from_github( version, "generated/provider_dependencies.json" ) provider_reqs = {} cross_provider_deps = {} for provider, provider_data in provider_dependencies.items(): provider_reqs[provider] = list( provider_reqs_to_attr_paths(provider_data["deps"], packages) ) + EXTRA_REQS.get(provider, []) cross_provider_deps[provider] = [ d for d in provider_data["cross-providers-deps"] if d != "common.sql" ] transitive_provider_reqs = {} # Add transitive cross-provider reqs for provider in provider_reqs: transitive_provider_reqs[provider] = get_cross_provider_reqs( provider, provider_reqs, cross_provider_deps ) return transitive_provider_reqs def get_provider_yaml(version: str, provider: str) -> Dict: provider_dir = provider.replace(".", "/") path = f"airflow/providers/{provider_dir}/provider.yaml" try: return get_file_from_github(version, path) except HTTPError: logging.warning("Couldn't get provider yaml for %s", provider) return {} def get_provider_imports(version: str, providers) -> Dict: provider_imports = {} for provider in providers: provider_yaml = get_provider_yaml(version, provider) imports: List[str] = [] if "hooks" in provider_yaml: imports.extend( chain.from_iterable( hook["python-modules"] for hook in provider_yaml["hooks"] ) ) if "operators" in provider_yaml: imports.extend( chain.from_iterable( operator["python-modules"] for operator in provider_yaml["operators"] ) ) provider_imports[provider] = imports return provider_imports def to_nix_expr(provider_reqs: Dict, provider_imports: Dict, fh: TextIO) -> None: fh.write("# Warning: generated by update-providers.py, do not update manually\n") fh.write("{\n") for provider, reqs in provider_reqs.items(): provider_name = provider.replace(".", "_") fh.write(f" {provider_name} = {{\n") fh.write( " deps = [ " + " ".join(sorted(f'"{req}"' for req in reqs)) + " ];\n" ) fh.write( " imports = [ " + " ".join(sorted(f'"{imp}"' for imp in provider_imports[provider])) + " ];\n" ) fh.write(" };\n") fh.write("}\n") def main() -> None: logging.basicConfig(level=logging.INFO) version = get_version() packages = dump_packages() logging.info("Generating providers.nix for version %s", version) provider_reqs = get_provider_reqs(version, packages) provider_imports = get_provider_imports(version, provider_reqs.keys()) with open("providers.nix", "w") as fh: to_nix_expr(provider_reqs, provider_imports, fh) if __name__ == "__main__": main()