#!/usr/bin/env python3
import argparse
import codecs
import functools
import logging
import multiprocessing
import os.path
import pdb
import queue
import re
import shlex
import shutil
import signal
import subprocess
import sys
import threading
import tempfile
import textwrap
import time
import traceback
import yaml

import prettytable


MODE_NVRAM = 'nvram'
MODE_UEFI_SHELL = 'efi-shell'

TEST_EXECUTED = "executed"

TMP_PREFIX = "stubbytest."
STARTUP_NSH_UEFI_SHELL = r"""setvar SecureBoot
fs0:
cd fs0:\efi\boot
launch.nsh
reset -s "exited with %lasterror%"
"""

STARTUP_NSH_NVRAM = r"""setvar SecureBoot
fs0:
cd fs0:\efi\boot
bcfg boot rm 00
bcfg boot add 00 {efi} harness-efi
bcfg boot -opt 0x0 boot.opts
#type boot.opts
#bcfg boot dump -v
mv startup.nsh startup.nsh.orig
cp shutdown.nsh startup.nsh
echo -off
echo "rebooting into: {efi} {args}"
reset -c "reboot-into-harness-efi: {efi} {args}"
"""

SHUTDOWN_NSH_CONTENT = r"""reset -s shutdown-from-shutdown-nsh"""

ANSI_ESCAPE = re.compile(b'(?:\x1b[@-_]|[\x80-\x9f])[0-?]*[ -/]*[@-~]')


def booted_runtime_cli(_name, tdata, result):
    """Validate that linux's /proc/cmdline should be identical to tdata[runtime]"""
    kcmdline = result.kernel_cmdline()
    if kcmdline != tdata["runtime"]:
        raise TestError("\n".join([
            "kernel command line differed from runtime",
            " expected: '" + tdata["runtime"] + "'",
            " found   : '" + kcmdline + "'",
            ]) + "\n")


def booted_expected_cli(_name, tdata, result):
    """Validate that linux's /proc/cmdline should be identical to tdata[expected]"""
    kcmdline = result.kernel_cmdline()
    if kcmdline != tdata["expected"]:
        raise TestError("\n".join([
            "kernel command line differed from expected",
            " expected: '" + tdata["expected"] + "'",
            " found   : '" + kcmdline + "'",
            ]) + "\n")


def warned_cmdline(_name, _tdata, result):
    """Validate that stubby wrote a warning before loading kernel."""
    if result.is_warned() is True:
        return
    raise TestError("boot did not warn")


def denied_boot(_name, _tdata, result):
    """Validate that stubby refused to boot"""
    if result.is_rejected() is True:
        return
    raise TestError("boot was not denied")


def load_test_data(testfile):
    testdata = yaml.safe_load(read_file(testfile))

    if not isinstance(testdata, dict):
        raise ValueError(f"top level entry in test data from {testfile} is not a dict")

    fields = {
        "builtin": (str, lambda n: ""),
        "sb": (bool, lambda n: name.startswith("sb-")),
        "shim": (bool, lambda n: n.startswith("sb-shim") or n.startswith("ib-shim")),
        "runtime": (str, None),
        "assert": (list, None),
        "name": (str, None)
    }

    errors = []
    tests = []
    if 'tests' in testdata:
        tests = testdata['tests']
    else:
        errors.append(f"No 'tests' key found in {testfile}")

    names = set()
    for tnum, td in enumerate(tests):
        name = td.get("name")
        if name == "" or name is None:
            errors.append(f"test number {tnum} has no name")
            continue
        if name in names:
            errors.append(f"name {name} appears multiple times (num {tnum})")
        names.add(name)
        for fname, (ftype, namebased) in fields.items():
            if fname not in td:
                # allow defaults based on name for shorter yaml
                if namebased is None:
                    errors.append(f"{name}: missing field {fname}")
                else:
                    td[fname] = namebased(name)
            elif not isinstance(td.get(fname), ftype):
                found = td.get(fname).__class__.__name__
                errors.append(f"{name}: {fname}: is type '{found}', expected {ftype}")

        for num, assertion in enumerate(td.get("assert", [])):
            if assertion not in CHECKS:
                errors.append(f"{name}: assertion {num} ({assertion}) is invalid")

    if len(errors) != 0:
        raise ValueError(
            ("Errors validating test data in " + testfile + ":\n  " +
              "\n  ".join(errors) + "\n"))

    return testdata


CHECKS = {
    "booted-runtime-cli": booted_runtime_cli,
    "booted-expected-cli": booted_expected_cli,
    "warned-cmdline": warned_cmdline,
    "denied-boot": denied_boot,
}


def escape_ansi(data):
    replaces = (
        (b'\x1b\x5b\x30\x39\x3b\x30\x31\x48', b'\n'),  # before 'Press any key'
        (b'\x0d', b''),  # ctrl-M / carriage return
    )
    for src, dest in replaces:
        data = data.replace(src, dest)
    return ANSI_ESCAPE.sub(b'', data)


def first_file(*files):
    for fname in files:
        if os.path.exists(fname):
            return fname
    return None


class SubpResult():
    cmd, data, duration, exception = (None, None, None, None)
    out, err = (b'', b'')
    rc = -1

    def __init__(self, **kwargs):
        bad = ','.join([k for k in kwargs if not hasattr(self, k)])
        if bad:
            raise AttributeError(
                f"{self.__class__.__name__} object has no attributes: {bad}")
        for k, v in kwargs.items():
            setattr(self, k, v)

    def __str__(self):
        def indent(name, data):
            if data is None:
                return f"{name}:"

            if hasattr(data, 'decode'):
                data = data.decode('utf-8', errors='ignore')

            return "%s: %s" % (
                name,
                ''.join(["  " + line + "\n" for line in data.splitlines()]))

        return '\n'.join([
            "cmd: %s" % '' if not self.cmd else shell_quote(self.cmd),
            f"rc: {self.rc}",
            "duration: %.2fs" % self.duration,
            "exc: %s" % self.exception,
            indent("stdout", self.out),
            indent("stderr", self.err),
            ""])


class SubpError(Exception):
    def __init__(self, result, desc=None):
        super(SubpError, self).__init__(desc)
        self.desc = desc
        self.result = result

    def __str__(self):
        ret = "" if not self.desc else "desc: %s\n" % self.desc
        ret += "" if not self.result.exception else "exception: %s\n" % self.result.exception
        return ret + str(self.result)


def shell_quote(cmd):
    return ' '.join(shlex.quote(a) for a in cmd)


def subp(cmd, capture=True, data=None, rcs=(0,), timeout=None, ksignal=signal.SIGTERM,
         cwd=None, stdout=None, stderr=None):
    """Execute a subprocess.

    cmd: string or list to execute.  string is not split.
    capture: capture and return stdin and stdout in the SubpResult
    data: stdin to the process.
    rcs: list of exit codes that should not raise exception.
    timeout: maximum time in seconds for process to finish.
    ksignal: the signal to send to the process. Standard Popen interface
            always sends SIGKILL, which does not allow process to clean up.

    Return is a SubpResult.
    If rcs is provided and exit code is not in the list of rcs, then
    a SubpError is raised.
    """
    if isinstance(rcs, int):
        rcs = (rcs,)

    # allow user to pass in a string as data
    if hasattr(data, 'encode'):
        data = data.encode('utf-8')

    devnull_fp = None
    stdin = subprocess.PIPE
    if capture:
        stdout, stderr = (subprocess.PIPE, subprocess.PIPE)

    result = SubpResult(cmd=cmd, rc=-1)
    start = time.time()
    try:
        if data is None:
            # devnull ensures any read gets null and wont ever block.
            devnull_fp = open(os.devnull, "rb")
            stdin = devnull_fp

        logging.debug("Executing: %s", shell_quote(cmd))
        sp = subprocess.Popen(
            cmd, stdout=stdout, stderr=stderr, stdin=stdin, cwd=cwd)

        result.out, result.err, result.exception = do_communicate(
            proc=sp, data=data, timeout=timeout, sig=ksignal, delay=5)

        if result.exception is None:
            result.rc = sp.returncode

        if not result.out:
            result.out = b''
        if not result.err:
            result.err = b''

    finally:
        result.duration = time.time() - start
        if devnull_fp:
            devnull_fp.close()

    logging.debug("returned %d took %.2fs", result.rc, result.duration)
    if rcs is None or result.rc in rcs:
        return result
    else:
        raise SubpError(result)


def _check_run_args(cliargs):
    checks = (
        ("kernel", "kernel", None),
        ("initrd", "initrd", None),
        ("stubby", "stubby.efi", None),
        ("sbat", "sbat.csv", None),
        ("shim", "shim.efi",
            lambda: first_file("/usr/lib/shim/shimx64.efi")),
        ("signing_key", "signing.key",
            lambda: first_file("/usr/share/ovmf/PkKek-1-snakeoil.key")),
        ("signing_cert", "signing.pem",
            lambda: first_file("/usr/share/ovmf/PkKek-1-snakeoil.pem")),
        ("ovmf_secure_code", "ovmf-secure-code.fd",
            lambda: first_file(
                "/usr/share/OVMF/OVMF_CODE_4M.snakeoil.fd",
                "/usr/share/OVMF/OVMF_CODE.fd",
                )),
        ("ovmf_secure_vars", "ovmf-secure-vars.fd",
            lambda: first_file(
                "/usr/share/OVMF/OVMF_VARS_4M.snakeoil.fd",
                "/usr/share/OVMF/OVMF_VARS.secboot.fd",
                )),
        ("ovmf_insecure_code", "ovmf-insecure-code.fd",
            lambda: first_file(
                "/usr/share/OVMF/OVMF_CODE_4M.fd",
                "/usr/share/OVMF/OVMF_CODE.fd",
                )),
        ("ovmf_insecure_vars", "ovmf-insecure-vars.fd",
            lambda: first_file(
                "/usr/share/OVMF/OVMF_VARS_4M.fd",
                "/usr/share/OVMF/OVMF_VARS.fd",
                )),
    )

    errors = []
    idir = cliargs.inputs_dir
    if idir is not None:
        idir = os.path.normpath(cliargs.inputs_dir)

    for (argname, fname, searcher) in checks:
        argval = getattr(cliargs, argname)
        if argval is not None:
            continue
        if idir is not None:
            fpath = path_join(idir, fname)
            if os.path.exists(fpath):
                setattr(cliargs, argname, fpath)
                continue
        if searcher is not None:
            fpath = searcher()
            if fpath is not None:
                setattr(cliargs, argname, fpath)
                continue
        errors.append("did not find value for " + fname)

    # for these paths, we know passwords.
    known_passwords = {
        "/usr/share/ovmf/PkKek-1-snakeoil.key": "snakeoil",
    }
    # signing_pass is either a file in <idir>/signing.password or password as a string.
    if cliargs.signing_pass is None:
        if idir is not None:
            fpath = path_join(idir, "signing.password")
            if os.path.exists(fpath):
                cliargs.signing_pass = read_file(fpath).strip()
        if cliargs.signing_pass is None:
            cliargs.signing_pass = known_passwords.get(cliargs.signing_key)
        if cliargs.signing_pass is None:
            cliargs.signing_pass = ""

    return errors


def _add_run_args(s):
    runargs = (
        (("-I", "--inputs-dir"),
         {"action": "store", "help": "test inputs directory"}),
        (("-M", "--boot-mode",),
         {"action": "store", "default": MODE_NVRAM,
          "help": "The boot mode: %s or %s" % (MODE_UEFI_SHELL, MODE_NVRAM)}),
        (("-k", "--kernel"),
         {"action": "store", "help": "linux kernel [<inputs>/kernel]"}),
        (("-i", "--initrd"),
         {"action": "store", "help": "linux initramfs [<inputs>/initrd]"}),
        (("-B", "--sbat"),
         {"action": "store", "help": "sbat [<inputs>/sbat.csv]"}),
        (("-S", "--shim"),
         {"action": "store", "help": "shim.efi [<inputs>/shim.efi]"}),
        (("-s", "--stubby"),
         {"action": "store", "help": "stubby [<inputs>/stubby.efi]"}),
        (("--signing-key",),
         {"action": "store", "help": "signing key [<inputs>/signing.key]"}),
        (("--signing-cert",),
         {"action": "store",
          "help": "signing certificate [<inputs>/signing.pem]"}),
        (("--signing-pass",),
         {"action": "store",
          "help": "password for the signing key [<inputs>/signing.password]"}),
        (("--ovmf-secure-code",),
         {"action": "store", "help": "ovmf-secure-code [<inputs>/ovmf-secure-code.fd]"}),
        (("--ovmf-insecure-code",),
         {"action": "store", "help": "ovmf-insecure-code [<inputs>/ovmf-insecure-code.fd]"}),
        (("--ovmf-secure-vars",),
         {"action": "store",
          "help": (
              "ovmf-vars for secure boot. Must allow execution of code signed"
              "by provided signing-key [<inputs>/ovmf-secure-vars.fd]")}),
        (("--ovmf-insecure-vars",),
         {"action": "store",
          "help": ("ovmf-vars for insecure boot. [<inputs>/ovmf-insecure-vars.fd]")}),
        (("--timeout",),
          {"action": "store", "type": float, "default": 120.0,
           "help": "time in seconds to allow for boot operation to complete"}),
        (("--num-threads",),
          {"action": "store", "type":  int, "default": multiprocessing.cpu_count(),
           "help": "number of parallel qemu processes to run"}),
        (("--work-dir",),
          {"action": "store", "default": None,
           "help": "Use provided dir for working directory"}),
    )

    for args, kwargs in runargs:
        s.add_argument(*args, **kwargs)


def decrypt_key(encrypted, secret, output):
    subp(["openssl", "rsa", "-in", encrypted, "-out",
          output, "-passin", "stdin"], data=secret)


def sign_efi(key, cert, unsigned, signed):
    subp(["sbsign", "--key=" + key, "--cert=" + cert,
          "--output=" + signed, unsigned])


def smash(output, stubefi, kernel, initrd, sbat, cmdline):
    cfp, cmdline_f = tempfile.mkstemp(prefix=TMP_PREFIX)
    os.write(cfp, cmdline.encode("utf-8"))
    os.close(cfp)

    cmd = [
        "objcopy",
        f"--add-section=.cmdline={cmdline_f}",
        "--change-section-vma=.cmdline=0x30000",
        f"--add-section=.sbat={sbat}",
        "--change-section-vma=.sbat=0x50000",
        "--set-section-alignment=.sbat=512",
        f"--add-section=.linux={kernel}",
        "--change-section-vma=.linux=0x1000000",
        f"--add-section=.initrd={initrd}",
        "--change-section-vma=.initrd=0x3000000",
        stubefi, output]

    try:
        subp(cmd)
    finally:
        os.unlink(cmdline_f)


def ensure_dirs(dlist):
    for d in dlist:
        ensure_dir(d)


def ensure_dir(mdir):
    if os.path.isdir(mdir):
        return
    os.makedirs(mdir)


def read_file(path):
    with open(path, mode="r", encoding="utf-8") as fp:
        return fp.read()


def write_file(path, content):
    with open(path, mode="w", encoding="utf-8") as fp:
        fp.write(content)


def gen_esp(esp, run_d, mode, kernel, initrd, shim, stubefi, sbat, signing_key, signing_cert,
            cmdline_builtin, runtime_cli):

    unsigned_kernel = path_join(run_d, "kernel-unsigned.efi")
    signed_kernel = path_join(run_d, "kernel.efi")
    startup_nsh = path_join(run_d, "startup.nsh")

    files = [
        f"{signed_kernel}:kernel.efi",
        f"{signed_kernel}:grubx64.efi",
        f"{startup_nsh}:startup.nsh",
    ]
    if shim:
        files.append(f"{shim}:shim.efi")

    modes = [MODE_NVRAM, MODE_UEFI_SHELL]
    if mode not in modes:
        raise ValueError("Bad mode '%s': expected one of %s" % (mode, ",".join(modes)))
    if mode == MODE_NVRAM:
        shutdown_nsh = path_join(run_d, "shutdown.nsh")
        bootopts = path_join(run_d, "boot.opts")
        files.extend([f"{shutdown_nsh}:shutdown.nsh", f"{bootopts}:boot.opts"])

        kefi = "kernel.efi"
        efi = kefi
        if shim:
            efi = "shim.efi"
            runtime_cli = kefi + " " + runtime_cli

        write_file(startup_nsh, STARTUP_NSH_NVRAM.format(efi=efi, args=runtime_cli))
        write_file(shutdown_nsh, SHUTDOWN_NSH_CONTENT)
        with codecs.open(bootopts, mode="w", encoding="utf-16le") as fp:
            fp.write(runtime_cli)
    elif mode == MODE_UEFI_SHELL:
        launch_nsh = path_join(run_d, "launch.nsh")
        files.extend([f"{launch_nsh}:launch.nsh"])
        write_file(startup_nsh, STARTUP_NSH_UEFI_SHELL)
        lcontent = "kernel.efi" + f" {runtime_cli}" if runtime_cli else ""
        if shim:
            lcontent = "shim.efi " + lcontent
        write_file(launch_nsh, lcontent + "\n")

    smash(unsigned_kernel, stubefi, kernel, initrd, sbat, cmdline_builtin)
    sign_efi(signing_key, signing_cert, unsigned_kernel, signed_kernel)

    subp(["gen-esp", "create", esp] + files)


def path_join(*args):
    return os.path.sep.join(args)


def get_env_boolean(name, default):
    if default not in (True, False):
        raise TypeError(f"Default value for env var {name} was '{default}', not boolean")
    val = os.environ.get(name)
    if val is None:
        return default
    if val == "true":
        return True
    if val == "false":
        return False
    raise ValueError(f"environment var {name} had value '{val}'. Expected 'true' or 'false'")


def do_communicate(proc, data=None, timeout=None, sig=signal.SIGTERM, delay=5):
    exc, stdout, stderr = (None, None, None)
    try:
        stdout, stderr = proc.communicate(input=data, timeout=timeout)
    except subprocess.TimeoutExpired as e:
        exc = e
        proc.send_signal(sig)
        try:
            stdout, stderr = proc.communicate(timeout=delay)
        except subprocess.TimeoutExpired:
            if sig != signal.SIGKILL:
                proc.send_signal(signal.SIGKILL)
                stdout, stderr = proc.communicate()
    except Exception as e:
        exc = e

    return stdout, stderr, exc


class swtpm2:
    socket_timeout = 10
    sigterm_timeout = 5

    result, sp, runthread, start_time = (None, None, None, None)

    def __init__(self, state_d, socket=None, log=None, timeout=None):
        self.state_d = state_d
        self.socket = socket if socket else path_join(self.state_d, "socket")
        self.log = log if log else path_join(self.state_d, "log")
        self.timeout = timeout

    def cmd(self):
        return ["swtpm", "socket", "--tpmstate=dir=" + self.state_d,
                "--ctrl=type=unixio,path=" + self.socket,
                "--log=level=20,file=" + self.log,
                "--tpm2"]

    def _communicate(self):
        res = self.result
        res.out, res.err, res.exc = do_communicate(
            self.sp, timeout=self.timeout, delay=self.sigterm_timeout)
        res.duration = time.time() - self.start_time
        res.rc = self.sp.returncode
        self.sp = None

    def start(self):
        ensure_dirs([
            os.path.dirname(self.socket),
            os.path.dirname(self.log),
            self.state_d])

        try:
            os.remove(self.socket)
        except FileNotFoundError:
            pass

        self.start_time = time.time()
        self.result = SubpResult(cmd=self.cmd(), rc=-1)
        try:
            self.sp = subprocess.Popen(
                self.result.cmd,
                stdout=subprocess.PIPE, stderr = subprocess.PIPE)
        except FileNotFoundError as e:
            self.result.duration = time.time() - self.start_time
            self.result.exception = e
            raise SubpError(self.result)

        self.runthread = threading.Thread(target=self._communicate)
        self.runthread.start()

        # wait for the socket to appear to signal system ready
        while True:
            if os.path.exists(self.socket):
                break
            if not self.runthread.is_alive():
                raise SubpError(self.result)
            if time.time() - self.start_time > self.socket_timeout:
                self.stop()
                logging.warn(
                    "swtpm process did not create socket %s within %ds "
                    "after starting", self.socket, self.socket_timeout)
                raise TimeoutError("Timeout waiting for swtpm to create socket")

            time.sleep(.1)


    def stop(self, sig=signal.SIGTERM, delay=None):
        if delay is None:
            delay = self.sigterm_timeout

        if self.runthread is None:
            logging.debug("Called stop, but no runthread")
            return
        if self.runthread.is_alive():
            if self.sp:
                self.sp.send_signal(sig)
            self.runthread.join(delay)
            if self.runthread.is_alive():
                logging.debug(
                    "swtpm process did not exit within %d seconds after %s"
                    ", sending SIGKILL", delay, sig.NAME)
                if self.sp:
                    self.sp.send_signal(signal.SIGKILL)
                self.runthread.join()

        self.runthread = None

    def __repr__(self):
        return " ".join([
            "swtpm(",
            "pid = %d" % (-1 if self.sp is None else self.sp.pid),
            ")",
        ])


class Runner:
    def __init__(self, cliargs, workdir=None):
        self.cleanup_workdir = True
        if workdir is None:
            self.workdir = tempfile.mkdtemp(prefix=TMP_PREFIX)
        else:
            self.cleanup_workdir = False
            self.workdir = os.path.abspath(workdir)
            ensure_dir(self.workdir)

        self.verbose = cliargs.verbose
        self.timeout = cliargs.timeout
        self.testinfo = load_test_data(cliargs.testdata)
        errors = _check_run_args(cliargs)
        if len(errors) != 0:
            raise RuntimeError("Errors in cli input:\n  " + "\n  ".join(errors))

        self.results_d = os.path.abspath(cliargs.results)
        self.boot_mode = cliargs.boot_mode
        self.kernel = self._to_workd(cliargs.kernel, "kernel")
        self.initrd = self._to_workd(cliargs.initrd, "initrd")
        self.stubby = self._to_workd(cliargs.stubby, "stubby.efi")
        self.shim = self._to_workd(cliargs.shim, "shim.efi")
        self.sbat = self._to_workd(cliargs.sbat, "sbat.csv")
        self.ovmf_secure_vars = self._to_workd(
                cliargs.ovmf_secure_vars, "ovmf-secure-vars.fd")
        self.ovmf_insecure_vars = self._to_workd(
                cliargs.ovmf_insecure_vars, "ovmf-insecure-vars.fd")
        self.ovmf_secure_code = self._to_workd(
                cliargs.ovmf_secure_code, "ovmf-secure-code.fd")
        self.ovmf_insecure_code = self._to_workd(
                cliargs.ovmf_insecure_code, "ovmf-insecure-code.fd")

        self.signing_key_in = cliargs.signing_key
        self.signing_cert = self._to_workd(cliargs.signing_cert, "signing.pem")
        self.signing_pass = cliargs.signing_pass

        if self.signing_pass is None:
            self.signing_key = self._to_workd(self.signing_key_in, "signing.key")
        else:
            self.signing_key = path_join(self.workdir, "signing.key")
            decrypt_key(self.signing_key_in, self.signing_pass, self.signing_key)

        self.signed_shim = path_join(self.workdir, "shim-signed.efi")
        sign_efi(self.signing_key, self.signing_cert, self.shim, self.signed_shim)

        self.tests = self.testinfo["tests"]

        ensure_dir(self.results_d)
        test_contents = read_file(cliargs.testdata)
        write_file(path_join(self.results_d, "tests.yaml"), test_contents)

        # github actions do not support kvm.
        # https://github.com/actions/runner-images/issues/183
        self.kvm = get_env_boolean("KVM", True)
        self.num_threads = min(cliargs.num_threads, len(self.tests))

    def _to_workd(self, fpath, name):
        newpath = path_join(self.workdir, name)
        shutil.copyfile(fpath, newpath)
        return newpath

    def run_all(self):
        jobq = queue.Queue()
        def worker():
            while True:
                try:
                    self.run(jobq.get())
                finally:
                    jobq.task_done()

        # Turn-on the worker thread.
        for _ in range(self.num_threads):
            threading.Thread(target=worker, daemon=True).start()

        for testdata in self.tests:
            jobq.put(testdata)

        jobq.join()

    def run_name(self, name):
        td = None
        for testdata in self.tests:
            if testdata.get("name") == name:
                td = testdata
                break
        if td is None:
            raise KeyError(f"No test named {name}")

        return self.run(td)

    def run(self, testdata):
        name = testdata["name"]
        results_d = path_join(self.results_d, name)
        ensure_dir(results_d)
        result = TEST_EXECUTED
        run_d = path_join(self.workdir, name)
        ensure_dir(run_d)
        results_d = path_join(self.results_d, name)
        ensure_dir(results_d)

        rlogfp = open(path_join(results_d, "run.log"), "w")
        start = time.time()
        def runlog(msg):
            rlogfp.write("[%7.3f] " % (time.time() - start) + msg + "\n")
            rlogfp.flush()

        qemu_log = path_join(results_d, "qemu.log")
        serial_log = path_join(results_d, "serial.log")
        print(f"Starting {name} in {run_d}\n qemu_log: {qemu_log}\n serial: {serial_log}.raw")

        try:
            self._run(testdata, run_d, results_d, qemu_log, serial_log, runlog)
        except SubpError as e:
            if isinstance(e.result.exception, subprocess.TimeoutExpired):
                result = "TIMEOUT: " + str(e)
            else:
                result = "SUBPROCESS: " + str(e)
                raise
        except Exception as e:
            result = "EXCEPTION: " + str(e)
            # re-raise the unexpected exception
            raise
        finally:
            write_file(path_join(results_d, "execution-result"), result + "\n")
            rlogfp.close()
            res = result.split(":")[0]
            print("Finished %s [%.2fs]: %s" % (name, time.time() - start, res))
            if self.verbose > 0:
                print("\n --- TEST [%s] serial.log ---\n" % name)
                with open(serial_log) as fh:
                    print("%s\n" % "".join(fh.readlines()))
                print(" --- TEST [%s] end log ---\n" % name)

    def _run(self, testdata, run_d, results_d, qemu_log, serial_log, log):
        esp = path_join(run_d, "esp.raw")

        log("generating esp")
        gen_esp(
            esp, run_d=run_d, mode=self.boot_mode,
            kernel=self.kernel, initrd=self.initrd,
            shim=self.signed_shim if testdata["shim"] else None,
            stubefi=self.stubby, sbat=self.sbat,
            signing_key=self.signing_key, signing_cert=self.signing_cert,
            cmdline_builtin=testdata["builtin"], runtime_cli=testdata["runtime"])

        ocode_src = self.ovmf_secure_code
        ovars_src = self.ovmf_secure_vars
        ovars = path_join(run_d, "ovmf-vars.fd")
        if not testdata["sb"]:
            ocode_src = self.ovmf_insecure_code
            ovars_src = self.ovmf_insecure_vars

        shutil.copyfile(ovars_src, ovars)
        rel_ovars = os.path.basename(ovars)
        rel_ocode_src = path_join("..", os.path.basename(ocode_src))
        rel_esp = os.path.basename(esp)

        tpmd = "./tpm"
        cmd_base = [
            "qemu-system-x86_64",
            "-M", "q35,smm=on" + (",accel=kvm" if self.kvm else ""),
            "-m", "256",
            # need to disable s3 for ubuntu ovmf older than 22.04.
            "-global", "ICH9-LPC.disable_s3=1",
            "-global", "driver=cfi.pflash01,property=secure,value=on",
            "-nic", "none",
            "-drive", f"if=pflash,format=raw,file={rel_ocode_src},readonly=on",
            # snapshot=on {rel_ovars} so debug with 'boot' will take the full path
            # rather than shortcutting out the setting of nvram in MODE_NVRAM
            "-drive", f"if=pflash,format=raw,file={rel_ovars},snapshot=on",
            "-drive", f"file={rel_esp},id=disk00,if=none,format=raw,index=0,snapshot=on",
            "-device", "virtio-blk,drive=disk00,serial=esp-image",
            "-chardev", "socket,id=chrtpm,path=" + path_join(tpmd, "socket"),
            "-tpmdev", "emulator,id=tpm0,chardev=chrtpm",
            "-device", "tpm-tis,tpmdev=tpm0"]

        cmd = cmd_base + [
             "-serial", f"file:{serial_log}.raw",
             "-vnc", "none",
        ]

        bootscript = textwrap.dedent("""\
            #!/bin/sh
            mkdir -p {tpmd}
            rm -f {tpmd}/socket
            {tpmcmd} &
            while ! [ -e "{tpmd}/socket" ]; do sleep .1; done
            {qemucmd}
            r=$?
            wait
            exit $r
            """)

        write_file(path_join(run_d, "boot"),
            bootscript.format(
                tpmd=tpmd,
                tpmcmd=shell_quote(swtpm2(tpmd).cmd()),
                qemucmd=shell_quote(cmd_base + ["-serial", "mon:stdio"])))

        timeout = self.timeout

        tpm = swtpm2(
            state_d=path_join(run_d, "tpm"),
            log=path_join(results_d, "tpm.log"),
            timeout=self.timeout)

        log("starting tpm")
        tpm.start()

        log("tpm ready, starting qemu")
        with open(qemu_log, "wb") as qfp:
            qfp.write(b"# " + b" ".join([s.encode("utf-8") for s in cmd]) + b"\n")
            qfp.flush()
            ret = subp(cmd, capture=False, rcs=None, cwd=run_d,
                       stderr=subprocess.STDOUT, stdout=qfp, timeout=timeout)
            qfp.write(b"\nrc=%d\n" % ret.rc)
        log("qemu exited %d" % ret.rc)

        if os.path.exists(serial_log + ".raw"):
            with open(serial_log + ".raw", "rb") as rfp:
                with open(serial_log, "wb") as wfp:
                    rawlog = rfp.read()
                    wfp.write(escape_ansi(rawlog))
                    if not rawlog.endswith(b"\n"):
                        wfp.write(b"\n")

        tpm.stop()
        log("tpm returned %d" % tpm.result.rc)
        if tpm.result.rc != 0:
            logging.warn("TPM errored: %s", tpm.result)
            log("%s" % tpm.result)

        log("finished")
        if ret.rc != 0:
            raise SubpError(ret)

    def cleanup(self):
        if self.cleanup_workdir:
            shutil.rmtree(self.workdir)


class TestResult:
    _unset = object()
    re_cmdline = re.compile(r"KERNEL COMMAND LINE: (?P<value>[^\n]*)\n")
    re_rejected = re.compile(r"Custom kernel command line rejected\n")
    re_warned = re.compile(r"Custom kernel would be rejected in secure mode\n")

    def __init__(self, rdir):
        self.rdir = rdir
        self.serial = read_file(path_join(rdir, "serial.log"))
        self._kcmdline = self._unset
        self._is_rejected = self._unset
        self.execution_result = read_file(path_join(rdir, "execution-result")).rstrip()

    def serial_raw(self):
        with open(path_join(self.rdir, "serial.log.raw"), "rb") as fp:
            return fp.read()

    @functools.lru_cache
    def kernel_cmdline(self):
        m = self.re_cmdline.search(self.serial)
        if m is None:
            return ""
        return m.group("value")

    @functools.lru_cache
    def is_rejected(self):
        return self.re_rejected.search(self.serial) is not None

    @functools.lru_cache
    def is_warned(self):
        return self.re_warned.search(self.serial) is not None


class TestError(Exception):
    pass


class Validator:
    def __init__(self, args):
        self.args = args
        self.results_d = args.results
        self.testinfo = load_test_data(path_join(self.results_d, "tests.yaml"))
        self._tests = None
        self._table = prettytable.PrettyTable()
        columns = ["Name", "Passed", "Errors"]
        if self.args.show_config_column:
            columns.append("Config")
        self._table.field_names = columns
        self._table.align = "l"

    def tests(self):
        tests = {}
        for testdata in self.testinfo["tests"]:
            tests[testdata["name"]] = testdata
        return tests

    def _check(self, chkname, tdata, result):
        CHECKS[chkname](tdata["name"], tdata, result)

    def validate_test(self, tdata):
        name = tdata["name"]
        result = TestResult(path_join(self.results_d, name))
        errors = {}
        if result.execution_result != TEST_EXECUTED:
            etype = result.execution_result.split()[0].rstrip(":")
            if etype == "":
                etype = "NORESULT"
            errors["execution"] = etype

        for ast in tdata["assert"]:
            try:
                self._check(ast, tdata, result)
            except TestError as e:
                errors[ast] = e

        return len(errors) == 0, errors

    def validate_all(self):
        errors = {}
        for name, tdata in self.tests().items():
            trdir = path_join(self.results_d, name)
            if not os.path.isdir(trdir):
                print(f"no results for test {name}")
                continue
            result, errs = self.validate_test(tdata)
            if errs:
                errors[name] = errs

            row = [name, result, ",".join(errs.keys())]
            if self.args.show_config_column:
                row.append(tdata)
            self._table.add_row(row)

        return errors

    def print_results(self):
        print("== Results ==")
        print(self._table)


def validate_results(cliargs):
    vdator = Validator(cliargs)
    errors = vdator.validate_all()
    vdator.print_results()
    if not errors:
        return 0

    print("== Errors ==")
    indent = "  "
    for testname, errdict in errors.items():
        for assertname, err in errdict.items():
            print("%s/%s\n %s%s\n" %
                  (testname, assertname, indent,
                   ("\n" + indent).join(str(err).splitlines())))
    return 1


def main_run(cliargs):
    runner = Runner(cliargs, workdir=cliargs.work_dir)
    print("There are %d tests to run in %d threads" %
          (len(runner.tests), cliargs.num_threads))
    try:
        runner.run_all()
    finally:
        runner.cleanup()

    errors = validate_results(cliargs)
    if errors:
        return 1
    return 0


def main_validate(args):
    validate_results(args)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--verbose', '-v', action='count', default=0)

    subcmds = parser.add_subparsers(dest="subcmd")
    s = subcmds.add_parser("run", help="run the tests")
    _add_run_args(s)
    s.add_argument(
        "--no-validate", action="store_true", default=False,
        help="Only run tests and collect output. do not validate.")
    s.add_argument(
        "-r", "--results", action="store", default="./results",
        help="Write results output to dir")
    s.add_argument(
        "-C", "--show-config-column", action="store_true", default=False,
        help="Display the 'Config' column in validate results")
    s.add_argument("testdata", help="The yaml formated test definition file")
    s.set_defaults(handler=main_run)

    s = subcmds.add_parser("validate", help="validate test results")
    s.add_argument("results", help="Results dir from 'run'")
    s.set_defaults(handler=main_validate)

    args = parser.parse_args()

    mydir = os.path.dirname(__file__)
    os.environ["PATH"] = mydir + ":" + os.environ["PATH"]

    level = (logging.ERROR, logging.INFO, logging.DEBUG)[min(args.verbose, 2)]
    logging.basicConfig(level=level)
    if hasattr(args, 'handler'):
        return args.handler(args)
    parser.print_help()
    return 1


if __name__ == "__main__":
    ret = 0
    try:
        ret = main()
    except Exception:
        traceback.print_exc()
        pdb.post_mortem()
        ret = 1
    sys.exit(ret)
