about summary refs log tree commit diff
path: root/nixos/lib/test-driver
diff options
context:
space:
mode:
authorDavid Arnold <dar@xoe.solutions>2021-06-12 17:47:25 -0500
committerDavid Arnold <david.arnold@iohk.io>2021-10-05 14:38:48 -0500
commitb0fc9da879812e47c1ed3438fb0fd51db00a3494 (patch)
treec238d3e8ce9c6ad17c47e8414001a29e137d8e52 /nixos/lib/test-driver
parent3069ba0dd1dec75c5dc4f6a1ee238a4fab9828cd (diff)
downloadnixlib-b0fc9da879812e47c1ed3438fb0fd51db00a3494.tar
nixlib-b0fc9da879812e47c1ed3438fb0fd51db00a3494.tar.gz
nixlib-b0fc9da879812e47c1ed3438fb0fd51db00a3494.tar.bz2
nixlib-b0fc9da879812e47c1ed3438fb0fd51db00a3494.tar.lz
nixlib-b0fc9da879812e47c1ed3438fb0fd51db00a3494.tar.xz
nixlib-b0fc9da879812e47c1ed3438fb0fd51db00a3494.tar.zst
nixlib-b0fc9da879812e47c1ed3438fb0fd51db00a3494.zip
nixos/test/test-driver: Class-ify the test driver
This commit encapsulates the involved domain into classes and
defines explicit and typed arguments where untyped dicts where used.

It preserves backwards compatibility through legacy wrappers.
Diffstat (limited to 'nixos/lib/test-driver')
-rwxr-xr-xnixos/lib/test-driver/test-driver.py804
1 files changed, 511 insertions, 293 deletions
diff --git a/nixos/lib/test-driver/test-driver.py b/nixos/lib/test-driver/test-driver.py
index f8502188bde8..fdc440a896a0 100755
--- a/nixos/lib/test-driver/test-driver.py
+++ b/nixos/lib/test-driver/test-driver.py
@@ -21,7 +21,6 @@ import shutil
 import socket
 import subprocess
 import sys
-import telnetlib
 import tempfile
 import time
 import unicodedata
@@ -89,55 +88,6 @@ CHAR_TO_KEY = {
     ")": "shift-0x0B",
 }
 
-global log, machines, test_script
-
-
-def eprint(*args: object, **kwargs: Any) -> None:
-    print(*args, file=sys.stderr, **kwargs)
-
-
-def make_command(args: list) -> str:
-    return " ".join(map(shlex.quote, (map(str, args))))
-
-
-def create_vlan(vlan_nr: str) -> Tuple[str, str, "subprocess.Popen[bytes]", Any]:
-    log.log("starting VDE switch for network {}".format(vlan_nr))
-    vde_socket = tempfile.mkdtemp(
-        prefix="nixos-test-vde-", suffix="-vde{}.ctl".format(vlan_nr)
-    )
-    pty_master, pty_slave = pty.openpty()
-    vde_process = subprocess.Popen(
-        ["vde_switch", "-s", vde_socket, "--dirmode", "0700"],
-        stdin=pty_slave,
-        stdout=subprocess.PIPE,
-        stderr=subprocess.PIPE,
-        shell=False,
-    )
-    fd = os.fdopen(pty_master, "w")
-    fd.write("version\n")
-    # TODO: perl version checks if this can be read from
-    # an if not, dies. we could hang here forever. Fix it.
-    assert vde_process.stdout is not None
-    vde_process.stdout.readline()
-    if not os.path.exists(os.path.join(vde_socket, "ctl")):
-        raise Exception("cannot start vde_switch")
-
-    return (vlan_nr, vde_socket, vde_process, fd)
-
-
-def retry(fn: Callable, timeout: int = 900) -> None:
-    """Call the given function repeatedly, with 1 second intervals,
-    until it returns True or a timeout is reached.
-    """
-
-    for _ in range(timeout):
-        if fn(False):
-            return
-        time.sleep(1)
-
-    if not fn(True):
-        raise Exception(f"action timed out after {timeout} seconds")
-
 
 class Logger:
     def __init__(self) -> None:
@@ -151,6 +101,10 @@ class Logger:
 
         self._print_serial_logs = True
 
+    @staticmethod
+    def _eprint(*args: object, **kwargs: Any) -> None:
+        print(*args, file=sys.stderr, **kwargs)
+
     def close(self) -> None:
         self.xml.endElement("logfile")
         self.xml.endDocument()
@@ -169,15 +123,27 @@ class Logger:
         self.xml.characters(message)
         self.xml.endElement("line")
 
+    def info(self, *args, **kwargs) -> None:  # type: ignore
+        self.log(*args, **kwargs)
+
+    def warning(self, *args, **kwargs) -> None:  # type: ignore
+        self.log(*args, **kwargs)
+
+    def error(self, *args, **kwargs) -> None:  # type: ignore
+        self.log(*args, **kwargs)
+        sys.exit(1)
+
     def log(self, message: str, attributes: Dict[str, str] = {}) -> None:
-        eprint(self.maybe_prefix(message, attributes))
+        self._eprint(self.maybe_prefix(message, attributes))
         self.drain_log_queue()
         self.log_line(message, attributes)
 
     def log_serial(self, message: str, machine: str) -> None:
         self.enqueue({"msg": message, "machine": machine, "type": "serial"})
         if self._print_serial_logs:
-            eprint(Style.DIM + "{} # {}".format(machine, message) + Style.RESET_ALL)
+            self._eprint(
+                Style.DIM + "{} # {}".format(machine, message) + Style.RESET_ALL
+            )
 
     def enqueue(self, item: Dict[str, str]) -> None:
         self.queue.put(item)
@@ -194,7 +160,7 @@ class Logger:
 
     @contextmanager
     def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
-        eprint(self.maybe_prefix(message, attributes))
+        self._eprint(self.maybe_prefix(message, attributes))
 
         self.xml.startElement("nest", attrs={})
         self.xml.startElement("head", attributes)
@@ -211,6 +177,27 @@ class Logger:
         self.xml.endElement("nest")
 
 
+rootlog = Logger()
+
+
+def make_command(args: list) -> str:
+    return " ".join(map(shlex.quote, (map(str, args))))
+
+
+def retry(fn: Callable, timeout: int = 900) -> None:
+    """Call the given function repeatedly, with 1 second intervals,
+    until it returns True or a timeout is reached.
+    """
+
+    for _ in range(timeout):
+        if fn(False):
+            return
+        time.sleep(1)
+
+    if not fn(True):
+        raise Exception(f"action timed out after {timeout} seconds")
+
+
 def _perform_ocr_on_screenshot(
     screenshot_path: str, model_ids: Iterable[int]
 ) -> List[str]:
@@ -242,113 +229,256 @@ def _perform_ocr_on_screenshot(
     return model_results
 
 
-class Machine:
-    def __repr__(self) -> str:
-        return f"<Machine '{self.name}'>"
-
-    def __init__(self, args: Dict[str, Any]) -> None:
-        if "name" in args:
-            self.name = args["name"]
-        else:
-            self.name = "machine"
-            cmd = args.get("startCommand", None)
-            if cmd:
-                match = re.search("run-(.+)-vm$", cmd)
-                if match:
-                    self.name = match.group(1)
-        self.logger = args["log"]
-        self.script = args.get("startCommand", self.create_startcommand(args))
-
-        tmp_dir = os.environ.get("TMPDIR", tempfile.gettempdir())
-
-        def create_dir(name: str) -> str:
-            path = os.path.join(tmp_dir, name)
-            os.makedirs(path, mode=0o700, exist_ok=True)
-            return path
+class StartCommand:
+    """The Base Start Command knows how to append the necesary
+    runtime qemu options as determined by a particular test driver
+    run. Any such start command is expected to happily receive and
+    append additional qemu args.
+    """
 
-        self.state_dir = os.path.join(tmp_dir, f"vm-state-{self.name}")
-        if not args.get("keepVmState", False):
-            self.cleanup_statedir()
-        os.makedirs(self.state_dir, mode=0o700, exist_ok=True)
-        self.shared_dir = create_dir("shared-xchg")
+    _cmd: str
 
-        self.booted = False
-        self.connected = False
-        self.pid: Optional[int] = None
-        self.socket = None
-        self.monitor: Optional[socket.socket] = None
-        self.allow_reboot = args.get("allowReboot", False)
+    def cmd(
+        self,
+        monitor_socket_path: pathlib.Path,
+        shell_socket_path: pathlib.Path,
+        allow_reboot: bool = False,  # TODO: unused, legacy?
+    ) -> str:
+        display_opts = ""
+        display_available = any(x in os.environ for x in ["DISPLAY", "WAYLAND_DISPLAY"])
+        if display_available:
+            display_opts += " -nographic"
+
+        # qemu options
+        qemu_opts = ""
+        qemu_opts += (
+            ""
+            if allow_reboot
+            else " -no-reboot"
+            " -device virtio-serial"
+            " -device virtconsole,chardev=shell"
+            " -device virtio-rng-pci"
+            " -serial stdio"
+        )
+        # TODO: qemu script already catpures this env variable, legacy?
+        qemu_opts += " " + os.environ.get("QEMU_OPTS", "")
+
+        return (
+            f"{self._cmd}"
+            f" -monitor unix:{monitor_socket_path}"
+            f" -chardev socket,id=shell,path={shell_socket_path}"
+            f"{qemu_opts}"
+            f"{display_opts}"
+        )
 
     @staticmethod
-    def create_startcommand(args: Dict[str, str]) -> str:
-        net_backend = "-netdev user,id=net0"
-        net_frontend = "-device virtio-net-pci,netdev=net0"
+    def build_environment(
+        state_dir: pathlib.Path,
+        shared_dir: pathlib.Path,
+    ) -> dict:
+        # We make a copy to not update the current environment
+        env = dict(os.environ)
+        env.update(
+            {
+                "TMPDIR": str(state_dir),
+                "SHARED_DIR": str(shared_dir),
+                "USE_TMPDIR": "1",
+            }
+        )
+        return env
+
+    def run(
+        self,
+        state_dir: pathlib.Path,
+        shared_dir: pathlib.Path,
+        monitor_socket_path: pathlib.Path,
+        shell_socket_path: pathlib.Path,
+    ) -> subprocess.Popen:
+        return subprocess.Popen(
+            self.cmd(monitor_socket_path, shell_socket_path),
+            stdin=subprocess.DEVNULL,
+            stdout=subprocess.PIPE,
+            stderr=subprocess.STDOUT,
+            shell=True,
+            cwd=state_dir,
+            env=self.build_environment(state_dir, shared_dir),
+        )
+
 
-        if "netBackendArgs" in args:
-            net_backend += "," + args["netBackendArgs"]
+class NixStartScript(StartCommand):
+    """A start script from nixos/modules/virtualiation/qemu-vm.nix
+    that also satisfies the requirement of the BaseStartCommand.
+    These Nix commands have the particular charactersitic that the
+    machine name can be extracted out of them via a regex match.
+    (Admittedly a _very_ implicit contract, evtl. TODO fix)
+    """
 
-        if "netFrontendArgs" in args:
-            net_frontend += "," + args["netFrontendArgs"]
+    def __init__(self, script: str):
+        self._cmd = script
 
-        start_command = (
-            args.get("qemuBinary", "qemu-kvm")
-            + " -m 384 "
-            + net_backend
-            + " "
-            + net_frontend
-            + " $QEMU_OPTS "
-        )
+    @property
+    def machine_name(self) -> str:
+        match = re.search("run-(.+)-vm$", self._cmd)
+        name = "machine"
+        if match:
+            name = match.group(1)
+        return name
 
-        if "hda" in args:
-            hda_path = os.path.abspath(args["hda"])
-            if args.get("hdaInterface", "") == "scsi":
-                start_command += (
-                    "-drive id=hda,file="
-                    + hda_path
-                    + ",werror=report,if=none "
-                    + "-device scsi-hd,drive=hda "
+
+class LegacyStartCommand(StartCommand):
+    """Used in some places to create an ad-hoc machine instead of
+    using nix test instrumentation + module system for that purpose.
+    Legacy.
+    """
+
+    def __init__(
+        self,
+        netBackendArgs: Optional[str] = None,
+        netFrontendArgs: Optional[str] = None,
+        hda: Optional[Tuple[pathlib.Path, str]] = None,
+        cdrom: Optional[str] = None,
+        usb: Optional[str] = None,
+        bios: Optional[str] = None,
+        qemuFlags: Optional[str] = None,
+    ):
+        self._cmd = "qemu-kvm -m 384"
+
+        # networking
+        net_backend = "-netdev user,id=net0"
+        net_frontend = "-device virtio-net-pci,netdev=net0"
+        if netBackendArgs is not None:
+            net_backend += "," + netBackendArgs
+        if netFrontendArgs is not None:
+            net_frontend += "," + netFrontendArgs
+        self._cmd += f" {net_backend} {net_frontend}"
+
+        # hda
+        hda_cmd = ""
+        if hda is not None:
+            hda_path = hda[0].resolve()
+            hda_interface = hda[1]
+            if hda_interface == "scsi":
+                hda_cmd += (
+                    f" -drive id=hda,file={hda_path},werror=report,if=none"
+                    " -device scsi-hd,drive=hda"
                 )
             else:
-                start_command += (
-                    "-drive file="
-                    + hda_path
-                    + ",if="
-                    + args["hdaInterface"]
-                    + ",werror=report "
-                )
+                hda_cmd += f" -drive file={hda_path},if={hda_interface},werror=report"
+        self._cmd += hda_cmd
 
-        if "cdrom" in args:
-            start_command += "-cdrom " + args["cdrom"] + " "
+        # cdrom
+        if cdrom is not None:
+            self._cmd += f" -cdrom {cdrom}"
 
-        if "usb" in args:
+        # usb
+        usb_cmd = ""
+        if usb is not None:
             # https://github.com/qemu/qemu/blob/master/docs/usb2.txt
-            start_command += (
-                "-device usb-ehci -drive "
-                + "id=usbdisk,file="
-                + args["usb"]
-                + ",if=none,readonly "
-                + "-device usb-storage,drive=usbdisk "
+            usb_cmd += (
+                " -device usb-ehci"
+                f" -drive id=usbdisk,file={usb},if=none,readonly"
+                " -device usb-storage,drive=usbdisk "
             )
-        if "bios" in args:
-            start_command += "-bios " + args["bios"] + " "
+        self._cmd += usb_cmd
+
+        # bios
+        if bios is not None:
+            self._cmd += f" -bios {bios}"
+
+        # qemu flags
+        if qemuFlags is not None:
+            self._cmd += f" {qemuFlags}"
+
+
+class Machine:
+    """A handle to the machine with this name, that also knows how to manage
+    the machine lifecycle with the help of a start script / command."""
+
+    name: str
+    tmp_dir: pathlib.Path
+    shared_dir: pathlib.Path
+    state_dir: pathlib.Path
+    monitor_path: pathlib.Path
+    shell_path: pathlib.Path
+
+    start_command: StartCommand
+    keep_vm_state: bool
+    allow_reboot: bool
+
+    process: Optional[subprocess.Popen] = None
+    pid: Optional[int] = None
+    monitor: Optional[socket.socket] = None
+    shell: Optional[socket.socket] = None
+
+    booted: bool = False
+    connected: bool = False
+    # Store last serial console lines for use
+    # of wait_for_console_text
+    last_lines: Queue = Queue()
 
-        start_command += args.get("qemuFlags", "")
+    def __repr__(self) -> str:
+        return f"<Machine '{self.name}'>"
+
+    def __init__(
+        self,
+        tmp_dir: pathlib.Path,
+        start_command: StartCommand,
+        name: str = "machine",
+        keep_vm_state: bool = False,
+        allow_reboot: bool = False,
+    ) -> None:
+        self.tmp_dir = tmp_dir
+        self.keep_vm_state = keep_vm_state
+        self.allow_reboot = allow_reboot
+        self.name = name
+        self.start_command = start_command
+
+        # set up directories
+        self.shared_dir = self.tmp_dir / "shared-xchg"
+        self.shared_dir.mkdir(mode=0o700, exist_ok=True)
+
+        self.state_dir = self.tmp_dir / f"vm-state-{self.name}"
+        self.monitor_path = self.state_dir / "monitor"
+        self.shell_path = self.state_dir / "shell"
+        if (not self.keep_vm_state) and self.state_dir.exists():
+            self.cleanup_statedir()
+        self.state_dir.mkdir(mode=0o700, exist_ok=True)
 
-        return start_command
+    @staticmethod
+    def create_startcommand(args: Dict[str, str]) -> StartCommand:
+        rootlog.warning(
+            "Using legacy create_startcommand(),"
+            "please use proper nix test vm instrumentation, instead"
+            "to generate the appropriate nixos test vm qemu startup script"
+        )
+        hda = None
+        if args.get("hda"):
+            hda_arg: str = args.get("hda", "")
+            hda_arg_path: pathlib.Path = pathlib.Path(hda_arg)
+            hda = (hda_arg_path, args.get("hdaInterface", ""))
+        return LegacyStartCommand(
+            netBackendArgs=args.get("netBackendArgs"),
+            netFrontendArgs=args.get("netFrontendArgs"),
+            hda=hda,
+            cdrom=args.get("cdrom"),
+            usb=args.get("usb"),
+            bios=args.get("bios"),
+            qemuFlags=args.get("qemuFlags"),
+        )
 
     def is_up(self) -> bool:
         return self.booted and self.connected
 
     def log(self, msg: str) -> None:
-        self.logger.log(msg, {"machine": self.name})
+        rootlog.log(msg, {"machine": self.name})
 
     def log_serial(self, msg: str) -> None:
-        self.logger.log_serial(msg, self.name)
+        rootlog.log_serial(msg, self.name)
 
     def nested(self, msg: str, attrs: Dict[str, str] = {}) -> _GeneratorContextManager:
         my_attrs = {"machine": self.name}
         my_attrs.update(attrs)
-        return self.logger.nested(msg, my_attrs)
+        return rootlog.nested(msg, my_attrs)
 
     def wait_for_monitor_prompt(self) -> str:
         assert self.monitor is not None
@@ -446,6 +576,7 @@ class Machine:
         self.connect()
 
         out_command = "( set -euo pipefail; {} ); echo '|!=EOF' $?\n".format(command)
+        assert self.shell
         self.shell.send(out_command.encode())
 
         output = ""
@@ -466,6 +597,8 @@ class Machine:
         Should only be used during test development, not in the production test."""
         self.connect()
         self.log("Terminal is ready (there is no prompt):")
+
+        assert self.shell
         subprocess.run(
             ["socat", "READLINE", f"FD:{self.shell.fileno()}"],
             pass_fds=[self.shell.fileno()],
@@ -534,6 +667,7 @@ class Machine:
 
         with self.nested("waiting for the VM to power off"):
             sys.stdout.flush()
+            assert self.process
             self.process.wait()
 
             self.pid = None
@@ -611,6 +745,8 @@ class Machine:
         with self.nested("waiting for the VM to finish booting"):
             self.start()
 
+            assert self.shell
+
             tic = time.time()
             self.shell.recv(1024)
             # TODO: Timeout
@@ -750,65 +886,35 @@ class Machine:
 
         self.log("starting vm")
 
-        def create_socket(path: str) -> socket.socket:
-            if os.path.exists(path):
-                os.unlink(path)
+        def clear(path: pathlib.Path) -> pathlib.Path:
+            if path.exists():
+                path.unlink()
+            return path
+
+        def create_socket(path: pathlib.Path) -> socket.socket:
             s = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM)
-            s.bind(path)
+            s.bind(str(path))
             s.listen(1)
             return s
 
-        monitor_path = os.path.join(self.state_dir, "monitor")
-        self.monitor_socket = create_socket(monitor_path)
-
-        shell_path = os.path.join(self.state_dir, "shell")
-        self.shell_socket = create_socket(shell_path)
-
-        display_available = any(x in os.environ for x in ["DISPLAY", "WAYLAND_DISPLAY"])
-        qemu_options = (
-            " ".join(
-                [
-                    "" if self.allow_reboot else "-no-reboot",
-                    "-monitor unix:{}".format(monitor_path),
-                    "-chardev socket,id=shell,path={}".format(shell_path),
-                    "-device virtio-serial",
-                    "-device virtconsole,chardev=shell",
-                    "-device virtio-rng-pci",
-                    "-serial stdio" if display_available else "-nographic",
-                ]
-            )
-            + " "
-            + os.environ.get("QEMU_OPTS", "")
+        monitor_socket = create_socket(clear(self.monitor_path))
+        shell_socket = create_socket(clear(self.shell_path))
+        self.process = self.start_command.run(
+            self.state_dir,
+            self.shared_dir,
+            self.monitor_path,
+            self.shell_path,
         )
-
-        environment = dict(os.environ)
-        environment.update(
-            {
-                "TMPDIR": self.state_dir,
-                "SHARED_DIR": self.shared_dir,
-                "USE_TMPDIR": "1",
-                "QEMU_OPTS": qemu_options,
-            }
-        )
-
-        self.process = subprocess.Popen(
-            self.script,
-            stdin=subprocess.DEVNULL,
-            stdout=subprocess.PIPE,
-            stderr=subprocess.STDOUT,
-            shell=True,
-            cwd=self.state_dir,
-            env=environment,
-        )
-        self.monitor, _ = self.monitor_socket.accept()
-        self.shell, _ = self.shell_socket.accept()
+        self.monitor, _ = monitor_socket.accept()
+        self.shell, _ = shell_socket.accept()
 
         # Store last serial console lines for use
         # of wait_for_console_text
         self.last_lines: Queue = Queue()
 
         def process_serial_output() -> None:
-            assert self.process.stdout is not None
+            assert self.process
+            assert self.process.stdout
             for _line in self.process.stdout:
                 # Ignore undecodable bytes that may occur in boot menus
                 line = _line.decode(errors="ignore").replace("\r", "").rstrip()
@@ -825,15 +931,15 @@ class Machine:
         self.log("QEMU running (pid {})".format(self.pid))
 
     def cleanup_statedir(self) -> None:
-        if os.path.isdir(self.state_dir):
-            shutil.rmtree(self.state_dir)
-            self.logger.log(f"deleting VM state directory {self.state_dir}")
-            self.logger.log("if you want to keep the VM state, pass --keep-vm-state")
+        shutil.rmtree(self.state_dir)
+        rootlog.log(f"deleting VM state directory {self.state_dir}")
+        rootlog.log("if you want to keep the VM state, pass --keep-vm-state")
 
     def shutdown(self) -> None:
         if not self.booted:
             return
 
+        assert self.shell
         self.shell.send("poweroff\n".encode())
         self.wait_for_shutdown()
 
@@ -908,41 +1014,225 @@ class Machine:
         """Make the machine reachable."""
         self.send_monitor_command("set_link virtio-net-pci.1 on")
 
+    def release(self) -> None:
+        if self.pid is None:
+            return
+        rootlog.info(f"kill machine (pid {self.pid})")
+        assert self.process
+        assert self.shell
+        assert self.monitor
+        self.process.terminate()
+        self.shell.close()
+        self.monitor.close()
+
+
+class VLan:
+    """A handle to the vlan with this number, that also knows how to manage
+    it's lifecycle.
+    """
 
-def create_machine(args: Dict[str, Any]) -> Machine:
-    args["log"] = log
-    return Machine(args)
+    nr: int
+    socket_dir: pathlib.Path
 
+    process: Optional[subprocess.Popen]
+    pid: Optional[int]
+    fd: Optional[io.TextIOBase]
 
-def start_all() -> None:
-    with log.nested("starting all VMs"):
-        for machine in machines:
-            machine.start()
+    def __repr__(self) -> str:
+        return f"<Vlan Nr. {self.nr}>"
 
+    def __init__(self, nr: int, tmp_dir: pathlib.Path):
+        self.nr = nr
+        self.socket_dir = tmp_dir / f"vde{self.nr}.ctl"
 
-def join_all() -> None:
-    with log.nested("waiting for all VMs to finish"):
-        for machine in machines:
-            machine.wait_for_shutdown()
+        # TODO: don't side-effect environment here
+        os.environ[f"QEMU_VDE_SOCKET_{self.nr}"] = str(self.socket_dir)
 
+    def start(self) -> None:
 
-def run_tests(interactive: bool = False) -> None:
-    if interactive:
-        ptpython.repl.embed(test_symbols(), {})
-    else:
-        test_script()
+        rootlog.info("start vlan")
+        pty_master, pty_slave = pty.openpty()
+
+        self.process = subprocess.Popen(
+            ["vde_switch", "-s", self.socket_dir, "--dirmode", "0700"],
+            stdin=pty_slave,
+            stdout=subprocess.PIPE,
+            stderr=subprocess.PIPE,
+            shell=False,
+        )
+        self.pid = self.process.pid
+        self.fd = os.fdopen(pty_master, "w")
+        self.fd.write("version\n")
+
+        # TODO: perl version checks if this can be read from
+        # an if not, dies. we could hang here forever. Fix it.
+        assert self.process.stdout is not None
+        self.process.stdout.readline()
+        if not (self.socket_dir / "ctl").exists():
+            rootlog.error("cannot start vde_switch")
+
+        rootlog.info(f"running vlan (pid {self.pid})")
+
+    def release(self) -> None:
+        if self.pid is None:
+            return
+        rootlog.info(f"kill vlan (pid {self.pid})")
+        assert self.fd
+        assert self.process
+        self.fd.close()
+        self.process.terminate()
+
+
+class Driver:
+    """A handle to the driver that sets up the environment
+    and runs the tests"""
+
+    tests: str
+    vlans: List[VLan]
+    machines: List[Machine]
+
+    def __init__(
+        self,
+        start_scripts: List[str],
+        vlans: List[int],
+        tests: str,
+        keep_vm_state: bool = False,
+    ):
+        self.tests = tests
+
+        tmp_dir = pathlib.Path(os.environ.get("TMPDIR", tempfile.gettempdir()))
+        tmp_dir.mkdir(mode=0o700, exist_ok=True)
+
+        self.vlans = [VLan(nr, tmp_dir) for nr in vlans]
+        with rootlog.nested("start all VLans"):
+            for vlan in self.vlans:
+                vlan.start()
+
+        def cmd(scripts: List[str]) -> Iterator[NixStartScript]:
+            for s in scripts:
+                yield NixStartScript(s)
+
+        self.machines = [
+            Machine(
+                start_command=cmd,
+                keep_vm_state=keep_vm_state,
+                name=cmd.machine_name,
+                tmp_dir=tmp_dir,
+            )
+            for cmd in cmd(start_scripts)
+        ]
+
+        @atexit.register
+        def clean_up() -> None:
+            with rootlog.nested("clean up"):
+                for machine in self.machines:
+                    machine.release()
+                for vlan in self.vlans:
+                    vlan.release()
+
+    def subtest(self, name: str) -> Iterator[None]:
+        """Group logs under a given test name"""
+        with rootlog.nested(name):
+            try:
+                yield
+                return True
+            except:
+                rootlog.error(f'Test "{name}" failed with error:')
+                raise
+
+    def test_symbols(self) -> Dict[str, Any]:
+        @contextmanager
+        def subtest(name: str) -> Iterator[None]:
+            return self.subtest(name)
+
+        general_symbols = dict(
+            start_all=self.start_all,
+            test_script=self.test_script,
+            machines=self.machines,
+            vlans=self.vlans,
+            driver=self,
+            log=rootlog,
+            os=os,
+            create_machine=self.create_machine,
+            subtest=subtest,
+            run_tests=self.run_tests,
+            join_all=self.join_all,
+            retry=retry,
+            serial_stdout_off=self.serial_stdout_off,
+            serial_stdout_on=self.serial_stdout_on,
+            Machine=Machine,  # for typing
+        )
+        machine_symbols = {
+            m.name: self.machines[idx] for idx, m in enumerate(self.machines)
+        }
+        vlan_symbols = {
+            f"vlan{v.nr}": self.vlans[idx] for idx, v in enumerate(self.vlans)
+        }
+        print(
+            "additionally exposed symbols:\n    "
+            + ", ".join(map(lambda m: m.name, self.machines))
+            + ",\n    "
+            + ", ".join(map(lambda v: f"vlan{v.nr}", self.vlans))
+            + ",\n    "
+            + ", ".join(list(general_symbols.keys()))
+        )
+        return {**general_symbols, **machine_symbols, **vlan_symbols}
+
+    def test_script(self) -> None:
+        """Run the test script"""
+        with rootlog.nested("run the VM test script"):
+            symbols = self.test_symbols()  # call eagerly
+            exec(self.tests, symbols, None)
+
+    def run_tests(self) -> None:
+        """Run the test script (for non-interactive test runs)"""
+        self.test_script()
         # TODO: Collect coverage data
-        for machine in machines:
+        for machine in self.machines:
             if machine.is_up():
                 machine.execute("sync")
 
+    def start_all(self) -> None:
+        """Start all machines"""
+        with rootlog.nested("start all VMs"):
+            for machine in self.machines:
+                machine.start()
+
+    def join_all(self) -> None:
+        """Wait for all machines to shut down"""
+        with rootlog.nested("wait for all VMs to finish"):
+            for machine in self.machines:
+                machine.wait_for_shutdown()
+
+    def create_machine(self, args: Dict[str, Any]) -> Machine:
+        rootlog.warning(
+            "Using legacy create_machine(), please instantiate the"
+            "Machine class directly, instead"
+        )
+        tmp_dir = pathlib.Path(os.environ.get("TMPDIR", tempfile.gettempdir()))
+        tmp_dir.mkdir(mode=0o700, exist_ok=True)
 
-def serial_stdout_on() -> None:
-    log._print_serial_logs = True
+        if args.get("startCommand"):
+            start_command: str = args.get("startCommand", "")
+            cmd = NixStartScript(start_command)
+            name = args.get("name", cmd.machine_name)
+        else:
+            cmd = Machine.create_startcommand(args)  # type: ignore
+            name = args.get("name", "machine")
+
+        return Machine(
+            tmp_dir=tmp_dir,
+            start_command=cmd,
+            name=name,
+            keep_vm_state=args.get("keep_vm_state", False),
+            allow_reboot=args.get("allow_reboot", False),
+        )
 
+    def serial_stdout_on(self) -> None:
+        rootlog._print_serial_logs = True
 
-def serial_stdout_off() -> None:
-    log._print_serial_logs = False
+    def serial_stdout_off(self) -> None:
+        rootlog._print_serial_logs = False
 
 
 class EnvDefault(argparse.Action):
@@ -970,52 +1260,6 @@ class EnvDefault(argparse.Action):
         setattr(namespace, self.dest, values)
 
 
-@contextmanager
-def subtest(name: str) -> Iterator[None]:
-    with log.nested(name):
-        try:
-            yield
-            return True
-        except Exception as e:
-            log.log(f'Test "{name}" failed with error: "{e}"')
-            raise e
-
-    return False
-
-
-def _test_symbols() -> Dict[str, Any]:
-    general_symbols = dict(
-        start_all=start_all,
-        test_script=globals().get("test_script"),  # same
-        machines=globals().get("machines"),  # without being initialized
-        log=globals().get("log"),  # extracting those symbol keys
-        os=os,
-        create_machine=create_machine,
-        subtest=subtest,
-        run_tests=run_tests,
-        join_all=join_all,
-        retry=retry,
-        serial_stdout_off=serial_stdout_off,
-        serial_stdout_on=serial_stdout_on,
-        Machine=Machine,  # for typing
-    )
-    return general_symbols
-
-
-def test_symbols() -> Dict[str, Any]:
-
-    general_symbols = _test_symbols()
-
-    machine_symbols = {m.name: machines[idx] for idx, m in enumerate(machines)}
-    print(
-        "additionally exposed symbols:\n    "
-        + ", ".join(map(lambda m: m.name, machines))
-        + ",\n    "
-        + ", ".join(list(general_symbols.keys()))
-    )
-    return {**general_symbols, **machine_symbols}
-
-
 if __name__ == "__main__":
     arg_parser = argparse.ArgumentParser(prog="nixos-test-driver")
     arg_parser.add_argument(
@@ -1055,44 +1299,18 @@ if __name__ == "__main__":
     )
 
     args = arg_parser.parse_args()
-    testscript = pathlib.Path(args.testscript).read_text()
-
-    global log, machines, test_script
-
-    log = Logger()
-
-    vde_sockets = [create_vlan(v) for v in args.vlans]
-    for nr, vde_socket, _, _ in vde_sockets:
-        os.environ["QEMU_VDE_SOCKET_{}".format(nr)] = vde_socket
-
-    machines = [
-        create_machine({"startCommand": s, "keepVmState": args.keep_vm_state})
-        for s in args.start_scripts
-    ]
-    machine_eval = [
-        "{0} = machines[{1}]".format(m.name, idx) for idx, m in enumerate(machines)
-    ]
-    exec("\n".join(machine_eval))
-
-    @atexit.register
-    def clean_up() -> None:
-        with log.nested("cleaning up"):
-            for machine in machines:
-                if machine.pid is None:
-                    continue
-                log.log("killing {} (pid {})".format(machine.name, machine.pid))
-                machine.process.kill()
-            for _, _, process, _ in vde_sockets:
-                process.terminate()
-        log.close()
-
-    def test_script() -> None:
-        with log.nested("running the VM test script"):
-            symbols = test_symbols()  # call eagerly
-            exec(testscript, symbols, None)
-
-    interactive = args.interactive or (not bool(testscript))
-    tic = time.time()
-    run_tests(interactive)
-    toc = time.time()
-    print("test script finished in {:.2f}s".format(toc - tic))
+
+    if not args.keep_vm_state:
+        rootlog.info("Machine state will be reset. To keep it, pass --keep-vm-state")
+
+    driver = Driver(
+        args.start_scripts, args.vlans, args.testscript.read_text(), args.keep_vm_state
+    )
+
+    if args.interactive:
+        ptpython.repl.embed(driver.test_symbols(), {})
+    else:
+        tic = time.time()
+        driver.run_tests()
+        toc = time.time()
+        rootlog.info(f"test script finished in {(toc-tic):.2f}s")