about summary refs log tree commit diff
path: root/nixos/lib/test-driver/test-driver.py
diff options
context:
space:
mode:
Diffstat (limited to 'nixos/lib/test-driver/test-driver.py')
-rw-r--r--nixos/lib/test-driver/test-driver.py73
1 files changed, 59 insertions, 14 deletions
diff --git a/nixos/lib/test-driver/test-driver.py b/nixos/lib/test-driver/test-driver.py
index 7e575189209a..c27947bc610d 100644
--- a/nixos/lib/test-driver/test-driver.py
+++ b/nixos/lib/test-driver/test-driver.py
@@ -1,13 +1,17 @@
 #! /somewhere/python3
 from contextlib import contextmanager, _GeneratorContextManager
+from queue import Queue, Empty
+from typing import Tuple, Any, Callable, Dict, Iterator, Optional, List
 from xml.sax.saxutils import XMLGenerator
 import _thread
 import atexit
+import base64
 import os
+import pathlib
 import ptpython.repl
 import pty
-from queue import Queue, Empty
 import re
+import shlex
 import shutil
 import socket
 import subprocess
@@ -15,9 +19,6 @@ import sys
 import tempfile
 import time
 import unicodedata
-from typing import Tuple, Any, Callable, Dict, Iterator, Optional, List
-import shlex
-import pathlib
 
 CHAR_TO_KEY = {
     "A": "shift-a",
@@ -84,7 +85,7 @@ CHAR_TO_KEY = {
 
 # Forward references
 nr_tests: int
-nr_succeeded: int
+failed_tests: list
 log: "Logger"
 machines: "List[Machine]"
 
@@ -221,7 +222,7 @@ class Machine:
             return path
 
         self.state_dir = create_dir("vm-state-{}".format(self.name))
-        self.shared_dir = create_dir("{}/xchg".format(self.state_dir))
+        self.shared_dir = create_dir("shared-xchg")
 
         self.booted = False
         self.connected = False
@@ -395,7 +396,7 @@ class Machine:
         status_code_pattern = re.compile(r"(.*)\|\!EOF\s+(\d+)")
 
         while True:
-            chunk = self.shell.recv(4096).decode()
+            chunk = self.shell.recv(4096).decode(errors="ignore")
             match = status_code_pattern.match(chunk)
             if match:
                 output += match[1]
@@ -566,6 +567,41 @@ class Machine:
             if ret.returncode != 0:
                 raise Exception("Cannot convert screenshot")
 
+    def copy_from_host_via_shell(self, source: str, target: str) -> None:
+        """Copy a file from the host into the guest by piping it over the
+        shell into the destination file. Works without host-guest shared folder.
+        Prefer copy_from_host for whenever possible.
+        """
+        with open(source, "rb") as fh:
+            content_b64 = base64.b64encode(fh.read()).decode()
+            self.succeed(
+                f"mkdir -p $(dirname {target})",
+                f"echo -n {content_b64} | base64 -d > {target}",
+            )
+
+    def copy_from_host(self, source: str, target: str) -> None:
+        """Copy a file from the host into the guest via the `shared_dir` shared
+        among all the VMs (using a temporary directory).
+        """
+        host_src = pathlib.Path(source)
+        vm_target = pathlib.Path(target)
+        with tempfile.TemporaryDirectory(dir=self.shared_dir) as shared_td:
+            shared_temp = pathlib.Path(shared_td)
+            host_intermediate = shared_temp / host_src.name
+            vm_shared_temp = pathlib.Path("/tmp/shared") / shared_temp.name
+            vm_intermediate = vm_shared_temp / host_src.name
+
+            self.succeed(make_command(["mkdir", "-p", vm_shared_temp]))
+            if host_src.is_dir():
+                shutil.copytree(host_src, host_intermediate)
+            else:
+                shutil.copy(host_src, host_intermediate)
+            self.succeed("sync")
+            self.succeed(make_command(["mkdir", "-p", vm_target.parent]))
+            self.succeed(make_command(["cp", "-r", vm_intermediate, vm_target]))
+        # Make sure the cleanup is synced into VM
+        self.succeed("sync")
+
     def copy_from_vm(self, source: str, target_dir: str = "") -> None:
         """Copy a file from the VM (specified by an in-VM source path) to a path
         relative to `$out`. The file is copied via the `shared_dir` shared among
@@ -576,7 +612,7 @@ class Machine:
         vm_src = pathlib.Path(source)
         with tempfile.TemporaryDirectory(dir=self.shared_dir) as shared_td:
             shared_temp = pathlib.Path(shared_td)
-            vm_shared_temp = pathlib.Path("/tmp/xchg") / shared_temp.name
+            vm_shared_temp = pathlib.Path("/tmp/shared") / shared_temp.name
             vm_intermediate = vm_shared_temp / vm_src.name
             intermediate = shared_temp / vm_src.name
             # Copy the file to the shared directory inside VM
@@ -704,7 +740,8 @@ class Machine:
 
         def process_serial_output() -> None:
             for _line in self.process.stdout:
-                line = _line.decode("unicode_escape").replace("\r", "").rstrip()
+                # Ignore undecodable bytes that may occur in boot menus
+                line = _line.decode(errors="ignore").replace("\r", "").rstrip()
                 eprint("{} # {}".format(self.name, line))
                 self.logger.enqueue({"msg": line, "machine": self.name})
 
@@ -841,23 +878,31 @@ def run_tests() -> None:
             machine.execute("sync")
 
     if nr_tests != 0:
+        nr_succeeded = nr_tests - len(failed_tests)
         eprint("{} out of {} tests succeeded".format(nr_succeeded, nr_tests))
-        if nr_tests > nr_succeeded:
+        if len(failed_tests) > 0:
+            eprint(
+                "The following tests have failed:\n - {}".format(
+                    "\n - ".join(failed_tests)
+                )
+            )
             sys.exit(1)
 
 
 @contextmanager
 def subtest(name: str) -> Iterator[None]:
     global nr_tests
-    global nr_succeeded
+    global failed_tests
 
     with log.nested(name):
         nr_tests += 1
         try:
             yield
-            nr_succeeded += 1
             return True
         except Exception as e:
+            failed_tests.append(
+                'Test "{}" failed with error: "{}"'.format(name, str(e))
+            )
             log.log("error: {}".format(str(e)))
 
     return False
@@ -866,7 +911,7 @@ def subtest(name: str) -> Iterator[None]:
 if __name__ == "__main__":
     log = Logger()
 
-    vlan_nrs = list(dict.fromkeys(os.environ["VLANS"].split()))
+    vlan_nrs = list(dict.fromkeys(os.environ.get("VLANS", "").split()))
     vde_sockets = [create_vlan(v) for v in vlan_nrs]
     for nr, vde_socket, _, _ in vde_sockets:
         os.environ["QEMU_VDE_SOCKET_{}".format(nr)] = vde_socket
@@ -879,7 +924,7 @@ if __name__ == "__main__":
     exec("\n".join(machine_eval))
 
     nr_tests = 0
-    nr_succeeded = 0
+    failed_tests = []
 
     @atexit.register
     def clean_up() -> None: