about summary refs log tree commit diff
path: root/nixos/lib/test-driver
diff options
context:
space:
mode:
authorDavid Arnold <dar@xoe.solutions>2021-06-06 13:50:02 -0500
committerDavid Arnold <dgx.arnold@gmail.com>2021-08-19 23:55:26 -0500
commitdb614e11d672cf8e3c1268d34e74e0c9981ab5be (patch)
treecf12e553455dfa05225da8eb5ec66ce082ff8222 /nixos/lib/test-driver
parent5edf5b60c3d8f82b5fc5e73e822b6f7460584945 (diff)
downloadnixlib-db614e11d672cf8e3c1268d34e74e0c9981ab5be.tar
nixlib-db614e11d672cf8e3c1268d34e74e0c9981ab5be.tar.gz
nixlib-db614e11d672cf8e3c1268d34e74e0c9981ab5be.tar.bz2
nixlib-db614e11d672cf8e3c1268d34e74e0c9981ab5be.tar.lz
nixlib-db614e11d672cf8e3c1268d34e74e0c9981ab5be.tar.xz
nixlib-db614e11d672cf8e3c1268d34e74e0c9981ab5be.tar.zst
nixlib-db614e11d672cf8e3c1268d34e74e0c9981ab5be.zip
nixos/tests/test-driver: better control test env symbols
Previous to this commit, the entire test driver environment was shared
with the actual python test environment.

This is a hefty api surface. This commit selectively exposes only those
symbols to the test environment that are actually meant to be used by
tests.
Diffstat (limited to 'nixos/lib/test-driver')
-rwxr-xr-xnixos/lib/test-driver/test-driver.py57
1 files changed, 42 insertions, 15 deletions
diff --git a/nixos/lib/test-driver/test-driver.py b/nixos/lib/test-driver/test-driver.py
index 0372148cb33c..488789e119d0 100755
--- a/nixos/lib/test-driver/test-driver.py
+++ b/nixos/lib/test-driver/test-driver.py
@@ -89,9 +89,7 @@ CHAR_TO_KEY = {
     ")": "shift-0x0B",
 }
 
-# Forward references
-log: "Logger"
-machines: "List[Machine]"
+global log, machines, test_script
 
 
 def eprint(*args: object, **kwargs: Any) -> None:
@@ -103,7 +101,6 @@ def make_command(args: list) -> str:
 
 
 def create_vlan(vlan_nr: str) -> Tuple[str, str, "subprocess.Popen[bytes]", Any]:
-    global log
     log.log("starting VDE switch for network {}".format(vlan_nr))
     vde_socket = tempfile.mkdtemp(
         prefix="nixos-test-vde-", suffix="-vde{}.ctl".format(vlan_nr)
@@ -246,6 +243,9 @@ def _perform_ocr_on_screenshot(
 
 
 class Machine:
+    def __repr__(self) -> str:
+        return f"<Machine '{self.name}'>"
+
     def __init__(self, args: Dict[str, Any]) -> None:
         if "name" in args:
             self.name = args["name"]
@@ -910,29 +910,25 @@ class Machine:
 
 
 def create_machine(args: Dict[str, Any]) -> Machine:
-    global log
     args["log"] = log
     return Machine(args)
 
 
 def start_all() -> None:
-    global machines
     with log.nested("starting all VMs"):
         for machine in machines:
             machine.start()
 
 
 def join_all() -> None:
-    global machines
     with log.nested("waiting for all VMs to finish"):
         for machine in machines:
             machine.wait_for_shutdown()
 
 
 def run_tests(interactive: bool = False) -> None:
-    global machines
     if interactive:
-        ptpython.repl.embed(globals(), locals())
+        ptpython.repl.embed(test_symbols(), {})
     else:
         test_script()
         # TODO: Collect coverage data
@@ -942,12 +938,10 @@ def run_tests(interactive: bool = False) -> None:
 
 
 def serial_stdout_on() -> None:
-    global log
     log._print_serial_logs = True
 
 
 def serial_stdout_off() -> None:
-    global log
     log._print_serial_logs = False
 
 
@@ -989,6 +983,37 @@ def subtest(name: str) -> Iterator[None]:
     return False
 
 
+def _test_symbols() -> Dict[str, Any]:
+    general_symbols = dict(
+        start_all=start_all,
+        test_script=globals().get("test_script"),  # same
+        machines=globals().get("machines"),  # without being initialized
+        log=globals().get("log"),  # extracting those symbol keys
+        os=os,
+        create_machine=create_machine,
+        subtest=subtest,
+        run_tests=run_tests,
+        join_all=join_all,
+        serial_stdout_off=serial_stdout_off,
+        serial_stdout_on=serial_stdout_on,
+    )
+    return general_symbols
+
+
+def test_symbols() -> Dict[str, Any]:
+
+    general_symbols = _test_symbols()
+
+    machine_symbols = {m.name: machines[idx] for idx, m in enumerate(machines)}
+    print(
+        "additionally exposed symbols:\n    "
+        + ", ".join(map(lambda m: m.name, machines))
+        + ",\n    "
+        + ", ".join(list(general_symbols.keys()))
+    )
+    return {**general_symbols, **machine_symbols}
+
+
 if __name__ == "__main__":
     arg_parser = argparse.ArgumentParser(prog="nixos-test-driver")
     arg_parser.add_argument(
@@ -1028,12 +1053,9 @@ if __name__ == "__main__":
     )
 
     args = arg_parser.parse_args()
-    global test_script
     testscript = pathlib.Path(args.testscript).read_text()
 
-    def test_script() -> None:
-        with log.nested("running the VM test script"):
-            exec(testscript, globals())
+    global log, machines, test_script
 
     log = Logger()
 
@@ -1062,6 +1084,11 @@ if __name__ == "__main__":
                 process.terminate()
         log.close()
 
+    def test_script() -> None:
+        with log.nested("running the VM test script"):
+            symbols = test_symbols()  # call eagerly
+            exec(testscript, symbols, None)
+
     interactive = args.interactive or (not bool(testscript))
     tic = time.time()
     run_tests(interactive)