# -*- coding: utf-8 -*-

# This file is part of Cockpit.
#
# Copyright (C) 2013 Red Hat, Inc.
#
# Cockpit is free software; you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation; either version 2.1 of the License, or
# (at your option) any later version.
#
# Cockpit is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with Cockpit; If not, see <http://www.gnu.org/licenses/>.

import contextlib
import errno
import fcntl
import libvirt
import libvirt_qemu
import os
import random
import re
import select
import signal
import string
import socket
import subprocess
import tempfile
import sys
import threading
import time

import xml.etree.ElementTree as etree

DEFAULT_IMAGE = os.environ.get("TEST_OS", "fedora-26")

MEMORY_MB = 1024

# Images which are Atomic based
ATOMIC_IMAGES = ["rhel-atomic", "fedora-atomic", "continuous-atomic"]

LOCAL_DIR = os.path.dirname(__file__)

# based on http://stackoverflow.com/a/17753573
# we use this to quieten down calls
@contextlib.contextmanager
def stdchannel_redirected(stdchannel, dest_filename):
    """
    A context manager to temporarily redirect stdout or stderr
    e.g.:
    with stdchannel_redirected(sys.stderr, os.devnull):
        noisy_function()
    """
    try:
        stdchannel.flush()
        oldstdchannel = os.dup(stdchannel.fileno())
        dest_file = open(dest_filename, 'w')
        os.dup2(dest_file.fileno(), stdchannel.fileno())
        yield
    finally:
        if oldstdchannel is not None:
            os.dup2(oldstdchannel, stdchannel.fileno())
        if dest_file is not None:
            dest_file.close()

class Timeout:
    """ Add a timeout to an operation
        Specify machine to ensure that a machine's ssh operations are canceled when the timer expires.
    """
    def __init__(self, seconds=1, error_message='Timeout', machine=None):
        self.seconds = seconds
        self.error_message = error_message
        self.machine = machine
    def handle_timeout(self, signum, frame):
        if self.machine:
            if self.machine.ssh_process:
                self.machine.ssh_process.terminate()
            self.machine.disconnect()

        raise Exception(self.error_message)
    def __enter__(self):
        signal.signal(signal.SIGALRM, self.handle_timeout)
        signal.alarm(self.seconds)
    def __exit__(self, type, value, traceback):
        signal.alarm(0)

class Failure(Exception):
    def __init__(self, msg):
        self.msg = msg
    def __str__(self):
        return self.msg

class RepeatableFailure(Failure):
    pass

class Machine:
    def __init__(self, address=None, image=None, verbose=False, label=None, fetch=True):
        self.verbose = verbose

        # Currently all images are x86_64. When that changes we will have
        # an override file for those images that are not
        self.arch = "x86_64"

        self.image = image or "unknown"
        self.atomic_image = self.image in ATOMIC_IMAGES
        self.fetch = fetch
        self.vm_username = "root"
        self.address = address
        self.label = label or "UNKNOWN"
        self.ssh_master = None
        self.ssh_process = None
        self.ssh_port = 22

    def disconnect(self):
        self._kill_ssh_master()

    def message(self, *args):
        """Prints args if in verbose mode"""
        if not self.verbose:
            return
        print " ".join(args)

    def start(self, maintain=False, macaddr=None, memory_mb=None, cpus=None, wait_for_ip=True):
        """Overridden by machine classes to start the machine"""
        self.message("Assuming machine is already running")

    def stop(self):
        """Overridden by machine classes to stop the machine"""
        self.message("Not shutting down already running machine")

    # wait until we can execute something on the machine. ie: wait for ssh
    # get_new_address is an optional function to acquire a new ip address for each try
    #   it is expected to raise an exception on failure and return a valid address otherwise
    def wait_execute(self, timeout_sec=120, get_new_address=None):
        """Try to connect to self.address on ssh port"""

        # If connected to machine, kill master connection
        self._kill_ssh_master()

        start_time = time.time()
        while (time.time() - start_time) < timeout_sec:
            if get_new_address:
                try:
                    self.address = get_new_address()
                except:
                    continue
            addrinfo = socket.getaddrinfo(self.address, self.ssh_port, 0, socket.SOCK_STREAM)
            (family, socktype, proto, canonname, sockaddr) = addrinfo[0]
            sock = socket.socket(family, socktype, proto)
            sock.settimeout(1)
            try:
                sock.connect(sockaddr)
                return True
            except:
                pass
            finally:
                sock.close()
            time.sleep(0.5)
        return False

    def wait_user_login(self):
        """Wait until logging in as non-root works.

           Most tests run as the "admin" user, so we make sure that
           user sessions are allowed (and cockit-ws will let "admin"
           in) before declaring a test machine as "booted".
        """
        tries_left = 60
        while (tries_left > 0):
            try:
                self.execute("! test -f /run/nologin")
                return
            except subprocess.CalledProcessError:
                pass
            tries_left = tries_left - 1
            time.sleep(1)
        raise Failure("Timed out waiting for /run/nologin to disappear")

    def wait_boot(self):
        """Wait for a machine to boot"""
        assert False, "Cannot wait for a machine we didn't start"

    def wait_poweroff(self):
        """Overridden by machine classes to wait for a machine to stop"""
        assert False, "Cannot wait for a machine we didn't start"

    def kill(self):
        """Overridden by machine classes to unconditionally kill the running machine"""
        assert False, "Cannot kill a machine we didn't start"

    def shutdown(self):
        """Overridden by machine classes to gracefully shutdown the running machine"""
        assert False, "Cannot shutdown a machine we didn't start"

    def _start_ssh_master(self):
        self._kill_ssh_master()

        control = os.path.join(tempfile.gettempdir(), "ssh-%h-%p-%r-" + str(os.getpid()))

        cmd = [
            "ssh",
            "-p", str(self.ssh_port),
            "-i", self._calc_identity(),
            "-o", "StrictHostKeyChecking=no",
            "-o", "UserKnownHostsFile=/dev/null",
            "-o", "BatchMode=yes",
            "-M", # ControlMaster, no stdin
            "-o", "ControlPath=" + control,
            "-o", "LogLevel=ERROR",
            "-l", self.vm_username,
            self.address,
            "/bin/bash -c 'echo READY; read a'"
        ]

        # Connection might be refused, so try this 10 times
        tries_left = 10;
        while tries_left > 0:
            tries_left = tries_left - 1
            proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
            stdout_fd = proc.stdout.fileno()
            output = ""
            while stdout_fd > -1 and "READY" not in output:
                ret = select.select([stdout_fd], [], [], 10)
                for fd in ret[0]:
                    if fd == stdout_fd:
                        data = os.read(fd, 1024)
                        if data == "":
                            stdout_fd = -1
                            proc.stdout.close()
                        output += data

            if stdout_fd > -1:
                break

            # try again if the connection was refused, unless we've used up our tries
            proc.wait()
            if proc.returncode == 255 and tries_left > 0:
                self.message("ssh: connection refused, trying again")
                time.sleep(1)
                continue
            else:
                raise Failure("SSH master process exited with code: {0}".format(proc.returncode))

        self.ssh_master = control
        self.ssh_process = proc

        if not self._check_ssh_master():
            raise Failure("Couldn't launch an SSH master process")

    def _kill_ssh_master(self):
        if self.ssh_master:
            try:
                os.unlink(self.ssh_master)
            except OSError as e:
                if e.errno != errno.ENOENT:
                    raise
            self.ssh_master = None
        if self.ssh_process:
            self.ssh_process.stdin.close()
            with Timeout(seconds=90, error_message="Timeout while waiting for ssh master to shut down"):
                self.ssh_process.wait()
            self.ssh_process = None

    def _check_ssh_master(self):
        if not self.ssh_master:
            return False
        cmd = [
            "ssh",
            "-q",
            "-p", str(self.ssh_port),
            "-o", "StrictHostKeyChecking=no",
            "-o", "UserKnownHostsFile=/dev/null",
            "-o", "BatchMode=yes",
            "-S", self.ssh_master,
            "-O", "check",
            "-l", self.vm_username,
            self.address
        ]
        with open(os.devnull, 'w') as devnull:
            code = subprocess.call(cmd, stdin=devnull, stdout=devnull, stderr=devnull)
            if code == 0:
                return True
        return False

    def _ensure_ssh_master(self):
        if not self._check_ssh_master():
            self._start_ssh_master()

    def debug_shell(self):
        """Run an interactive shell"""
        cmd = [
            "ssh",
            "-p", str(self.ssh_port),
            "-o", "StrictHostKeyChecking=no",
            "-o", "UserKnownHostsFile=/dev/null",
            "-i", self._calc_identity(),
            "-l", self.vm_username,
            self.address
        ]
        subprocess.call(cmd)

    def execute(self, command=None, script=None, input=None, environment={}, stdout=None, quiet=False, direct=False):
        """Execute a shell command in the test machine and return its output.

        Either specify @command or @script

        Arguments:
            command: The string to execute by /bin/sh.
            script: A multi-line script to execute in /bin/sh
            input: Input to send to the command
            environment: Additional environmetn variables
        Returns:
            The command/script output as a string.
        """
        assert command or script
        assert self.address

        if not direct:
            self._ensure_ssh_master()

        # default to no translations; can be overridden in environment
        cmd = [
            "env", "-u", "LANGUAGE", "LC_ALL=C",
            "ssh",
            "-p", str(self.ssh_port),
            "-o", "StrictHostKeyChecking=no",
            "-o", "UserKnownHostsFile=/dev/null",
            "-o", "BatchMode=yes"
        ]

        if direct:
            cmd += [ "-i", self._calc_identity() ]
        else:
            cmd += [ "-o", "ControlPath=" + self.ssh_master ]

        cmd += [
            "-l", self.vm_username,
            self.address
        ]

        if command:
            assert not environment, "Not yet supported"
            if isinstance(command, basestring):
                cmd += [command]
                if not quiet:
                    self.message("+", command)
            else:
                cmd += command
                if not quiet:
                    self.message("+", *command)
        else:
            assert not input, "input not supported to script"
            cmd += ["sh", "-s"]
            if self.verbose:
                cmd += ["-x"]
            input = ""
            for name, value in environment.items():
                input += "%s='%s'\n" % (name, value)
                input += "export %s\n" % name
            input += script
            command = "<script>"

        if stdout:
            subprocess.call(cmd, stdout=stdout)
            return

        output = ""
        proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        stdin_fd = proc.stdin.fileno()
        stdout_fd = proc.stdout.fileno()
        stderr_fd = proc.stderr.fileno()
        rset = [stdout_fd, stderr_fd]
        wset = [stdin_fd]
        while len(rset) > 0 or len(wset) > 0:
            ret = select.select(rset, wset, [], 10)
            for fd in ret[0]:
                if fd == stdout_fd:
                    data = os.read(fd, 1024)
                    if data == "":
                        rset.remove(stdout_fd)
                        proc.stdout.close()
                    else:
                        if self.verbose:
                            sys.stdout.write(data)
                        output += data
                elif fd == stderr_fd:
                    data = os.read(fd, 1024)
                    if data == "":
                        rset.remove(stderr_fd)
                        proc.stderr.close()
                    elif not quiet or self.verbose:
                        sys.stderr.write(data)
            for fd in ret[1]:
                if fd == stdin_fd:
                    if input:
                        num = os.write(fd, input)
                        input = input[num:]
                    if not input:
                        wset.remove(stdin_fd)
                        proc.stdin.close()
        proc.wait()

        if proc.returncode != 0:
            raise subprocess.CalledProcessError(proc.returncode, command, output=output)
        return output

    def upload(self, sources, dest):
        """Upload a file into the test machine

        Arguments:
            sources: the array of paths of the file to upload
            dest: the file path in the machine to upload to
        """
        assert sources and dest
        assert self.address

        self._ensure_ssh_master()

        cmd = [
            "scp", "-B",
            "-r", "-p",
            "-o", "StrictHostKeyChecking=no",
            "-o", "UserKnownHostsFile=/dev/null",
            "-o", "ControlPath=" + self.ssh_master,
            "-o", "BatchMode=yes",
          ]
        if not self.verbose:
            cmd += [ "-q" ]

        def relative_to_test_dir(path):
            return os.path.join(LOCAL_DIR, "..", path)
        cmd += map(relative_to_test_dir, sources)

        cmd += [ "%s@[%s]:%s" % (self.vm_username, self.address, dest) ]

        self.message("Uploading", ", ".join(sources))
        self.message(" ".join(cmd))
        subprocess.check_call(cmd)

    def download(self, source, dest):
        """Download a file from the test machine.
        """
        assert source and dest
        assert self.address

        self._ensure_ssh_master()
        dest = os.path.join(LOCAL_DIR, "..", dest)

        cmd = [
            "scp", "-B",
            "-P", str(self.ssh_port),
            "-o", "StrictHostKeyChecking=no",
            "-o", "UserKnownHostsFile=/dev/null",
            "-o", "ControlPath=" + self.ssh_master,
            "-o", "BatchMode=yes",
            ]
        if not self.verbose:
            cmd += ["-q"]
        cmd += [ "%s@[%s]:%s" % (self.vm_username, self.address, source), dest ]

        self.message("Downloading", source)
        self.message(" ".join(cmd))
        subprocess.check_call(cmd)

    def download_dir(self, source, dest):
        """Download a directory from the test machine, recursively.
        """
        assert source and dest
        assert self.address

        self._ensure_ssh_master()
        dest = os.path.join(LOCAL_DIR, "..", dest)

        cmd = [
            "scp", "-B",
            "-P", str(self.ssh_port),
            "-o", "StrictHostKeyChecking=no",
            "-o", "UserKnownHostsFile=/dev/null",
            "-o", "ControlPath=" + self.ssh_master,
            "-o", "BatchMode=yes",
            "-r",
          ]
        if not self.verbose:
            cmd += ["-q"]
        cmd += [ "%s@[%s]:%s" % (self.vm_username, self.address, source), dest ]

        self.message("Downloading", source)
        self.message(" ".join(cmd))
        try:
            subprocess.check_call(cmd)
            target = os.path.join(dest, os.path.basename(source))
            subprocess.check_call([ "find", target, "-type", "f", "-exec", "chmod", "0644", "{}", ";" ])
        except:
            self.message("Error while downloading directory '{0}'".format(source))

    def write(self, dest, content):
        """Write a file into the test machine

        Arguments:
            content: Raw data to write to file
            dest: The file name in the machine to write to
        """
        assert dest
        assert self.address

        cmd = "cat > '%s'" % dest
        self.execute(command=cmd, input=content)

    def spawn(self, shell_cmd, log_id):
        """Spawn a process in the test machine.

        Arguments:
           shell_cmd: The string to execute by /bin/sh.
           log_id: The name of the file, realtive to /var/log on the test
              machine, that will receive stdout and stderr of the command.
        Returns:
            The pid of the /bin/sh process that executes the command.
        """
        return int(self.execute("{ (%s) >/var/log/%s 2>&1 & }; echo $!" % (shell_cmd, log_id)))

    def _calc_identity(self):
        identity = os.path.join(LOCAL_DIR, "identity")
        os.chmod(identity, 0600)
        return identity

    def journal_messages(self, syslog_ids, log_level):
        """Return interesting journal messages"""

        # Journald does not always set trusted fields like
        # _SYSTEMD_UNIT or _EXE correctly for the last few messages of
        # a dying process, so we filter by the untrusted but reliable
        # SYSLOG_IDENTIFIER instead

        matches = " ".join(map(lambda id: "SYSLOG_IDENTIFIER=" + id, syslog_ids))

        # Some versions of journalctl terminate unsuccessfully when
        # the output is empty.  We work around this by ignoring the
        # exit status and including error messages from journalctl
        # itself in the returned messages.

        cmd = "journalctl 2>&1 -o cat -p %d %s || true" % (log_level, matches)
        messages = self.execute(cmd).splitlines()
        if len(messages) == 1 and ("Cannot assign requested address" in messages[0]
                                   or "-- No entries --" in messages[0]):
            # No messages
            return [ ]
        else:
            return messages

    def audit_messages(self, type_pref):
        cmd = "journalctl -o cat SYSLOG_IDENTIFIER=kernel 2>&1 | grep 'type=%s.*audit' || true" % (type_pref, )
        messages = self.execute(cmd).splitlines()
        if len(messages) == 1 and "Cannot assign requested address" in messages[0]:
            messages = [ ]
        return messages

    def get_admin_group(self):
        if "debian" in self.image or "ubuntu" in self.image:
            return "sudo"
        else:
            return "wheel"

    def start_cockpit(self, atomic_wait_for_host="localhost", tls=False):
        """Start Cockpit.

        Cockpit is not running when the test virtual machine starts up, to
        allow you to make modifications before it starts.
        """

        if self.atomic_image:
            # HACK: https://bugzilla.redhat.com/show_bug.cgi?id=1228776
            # we want to run:
            # self.execute("atomic run cockpit/ws --no-tls")
            # but atomic doesn't forward the parameter, so we use the resulting command
            # also we need to wait for cockpit to be up and running
            cmd = """#!/bin/sh
            systemctl start docker &&
            """
            if tls:
                cmd += "/usr/bin/docker run -d --privileged --pid=host -v /:/host cockpit/ws /container/atomic-run --local-ssh\n"
            else:
                cmd += "/usr/bin/docker run -d --privileged --pid=host -v /:/host cockpit/ws /container/atomic-run --local-ssh --no-tls\n"
            with Timeout(seconds=90, error_message="Timeout while waiting for cockpit/ws to start"):
                self.execute(script=cmd)
            if atomic_wait_for_host:
                self.wait_for_cockpit_running(atomic_wait_for_host)
        elif tls:
            self.execute(script="""#!/bin/sh
            rm -f /etc/systemd/system/cockpit.service.d/notls.conf &&
            systemctl daemon-reload &&
            systemctl start cockpit.socket
            """)
        else:
            self.execute(script="""#!/bin/sh
            mkdir -p /etc/systemd/system/cockpit.service.d/ &&
            rm -f /etc/systemd/system/cockpit.service.d/notls.conf &&
            systemctl daemon-reload &&
            printf \"[Service]\nExecStartPre=-/bin/sh -c 'echo 0 > /proc/sys/kernel/yama/ptrace_scope'\nExecStart=\n%s --no-tls\n\" `systemctl cat cockpit.service | grep ExecStart=` > /etc/systemd/system/cockpit.service.d/notls.conf &&
            systemctl daemon-reload &&
            systemctl start cockpit.socket
            """)

    def restart_cockpit(self):
        """Restart Cockpit.
        """
        if self.atomic_image:
            with Timeout(seconds=90, error_message="Timeout while waiting for cockpit/ws to restart"):
                self.execute("docker restart `docker ps | grep cockpit/ws | awk '{print $1;}'`")
            self.wait_for_cockpit_running()
        else:
            self.execute("systemctl restart cockpit")

    def stop_cockpit(self):
        """Stop Cockpit.
        """
        if self.atomic_image:
            with Timeout(seconds=60, error_message="Timeout while waiting for cockpit/ws to stop"):
                self.execute("docker kill `docker ps | grep cockpit/ws | awk '{print $1;}'`")
        else:
            self.execute("systemctl stop cockpit.socket")

class VirtEventHandler():
    """ VirtEventHandler registers event handlers (currently: boot, resume, reboot) for libvirt domain instances
        It requires an existing libvirt connection handle, because libvirt requires the domain
        references to be from the same connection instance!
        A thread in the background will run the libvirt event loop. Convenience functions wait_for_reboot,
        wait_for_running and wait_for_stopped exist (with timeouts).
        Access to the datastructures is mutex-protected, with an additional threading.Condition object
        for signaling new events (to avoid polling in the wait* convenience functions).
        It is expected for the caller to register new domains and if possible deregister them for the callbacks.
    """
    def __init__(self, libvirt_connection, verbose = False):
        self.eventLoopThread = None
        self.domain_status = { }
        self.domain_has_rebooted = { }
        self.verbose = verbose
        self.connection = libvirt_connection

        # register event handlers
        self.registered_callbacks = { }

        self.data_lock = threading.RLock()
        self.signal_condition = threading.Condition(self.data_lock)

        # only show debug messages for specific domains, since
        # we might have multiple event handlers at any given time
        self.debug_domains = []

        self.virEventLoopNativeStart()

    def allow_domain_debug_output(self, dom_name):
        with self.data_lock:
            if not dom_name in self.debug_domains:
                self.debug_domains.append(dom_name)

    def forbid_domain_debug_output(self, dom_name):
        with self.data_lock:
            if dom_name in self.debug_domains:
                self.debug_domains.remove(dom_name)

    # 'reboot' and domain lifecycle events are treated differently by libvirt
    # a regular reboot doesn't affect the started/stopped state of the domain
    @staticmethod
    def domain_event_reboot_callback(conn, dom, event_handler):
        key = (dom.name(), dom.ID())
        with event_handler.data_lock:
            if not key in event_handler.domain_has_rebooted or event_handler.domain_has_rebooted[key] != True:
                if event_handler.verbose and dom.name() in event_handler.debug_domains:
                    sys.stderr.write("[%s] REBOOT: Domain '%s' (ID %s)\n" % (str(time.time()), dom.name(), dom.ID()))
                event_handler.domain_has_rebooted[key] = True
                event_handler.signal_condition.notifyAll()

    @staticmethod
    def domain_event_callback(conn, dom, event, detail, event_handler):
        key = (dom.name(), dom.ID())
        value = { 'status': event_handler.dom_event_to_string(event),
                  'detail': event_handler.dom_detail_to_string(event, detail)
                }
        with event_handler.data_lock:
            if not key in event_handler.domain_status or event_handler.domain_status[key] != value:
                event_handler.domain_status[key] = value
                event_handler.signal_condition.notifyAll()
                if event_handler.verbose and dom.name() in event_handler.debug_domains:
                    sys.stderr.write("[%s] EVENT: Domain '%s' (ID %s) %s %s\n" % (
                            str(time.time()),
                            dom.name(),
                            dom.ID(),
                            event_handler.dom_event_to_string(event),
                            event_handler.dom_detail_to_string(event, detail))
                        )

    def register_handlers_for_domain(self):
        self.deregister_handlers_for_domain()
        self.registered_callbacks = [
                self.connection.domainEventRegisterAny(None,
                                                       libvirt.VIR_DOMAIN_EVENT_ID_LIFECYCLE,
                                                       VirtEventHandler.domain_event_callback,
                                                       self),
                self.connection.domainEventRegisterAny(None,
                                                       libvirt.VIR_DOMAIN_EVENT_ID_REBOOT,
                                                       VirtEventHandler.domain_event_reboot_callback,
                                                       self)
            ]

    def deregister_handlers_for_domain(self):
        for cb in self.registered_callbacks:
            self.connection.domainEventDeregisterAny(cb)
        self.registered_callbacks = [ ]

    # mapping of event and detail ids to strings from
    # http://libvirt.org/git/?p=libvirt-python.git;a=blob_plain;f=examples/event-test.py;hb=HEAD
    def dom_event_to_string(self, event):
        domEventStrings = ( "Defined",
                            "Undefined",
                            "Started",
                            "Suspended",
                            "Resumed",
                            "Stopped",
                            "Shutdown",
                            "PMSuspended",
                            "Crashed"
                          )
        return domEventStrings[event]

    def dom_detail_to_string(self, event, detail):
        domEventStrings = (
            ( "Added", "Updated" ),
            ( "Removed", ),
            ( "Booted", "Migrated", "Restored", "Snapshot", "Wakeup" ),
            ( "Paused", "Migrated", "IOError", "Watchdog", "Restored", "Snapshot", "API error" ),
            ( "Unpaused", "Migrated", "Snapshot" ),
            ( "Shutdown", "Destroyed", "Crashed", "Migrated", "Saved", "Failed", "Snapshot"),
            ( "Finished", ),
            ( "Memory", "Disk" ),
            ( "Panicked", ),
            )
        return domEventStrings[event][detail]

    def reset_domain_status(self, domain):
        with self.data_lock:
            key = (domain.name(), domain.ID())
            if key in self.domain_status:
                del self.domain_status[key]

    def reset_domain_reboot_status(self, domain):
        with self.data_lock:
            key = (domain.name(), domain.ID())
            if key in self.domain_has_rebooted:
                del self.domain_has_rebooted[key]

    # reboot flag should have probably been reset before this
    # returns whether domain has rebooted
    def wait_for_reboot(self, domain, timeout_sec=120):
        start_time = time.time()
        end_time = start_time + timeout_sec
        key = (domain.name(), domain.ID())
        with self.data_lock:
            if key in self.domain_has_rebooted:
                return True
        remaining_time = end_time - time.time()
        with self.signal_condition:
            while remaining_time > 0:
                # wait for a domain event or our timeout
                self.signal_condition.wait(remaining_time)
                if key in self.domain_has_rebooted:
                    return True
                remaining_time = end_time - time.time()
        return False

    def has_rebooted(self, domain):
        key = (domain.name(), domain.ID())
        with self.data_lock:
            return key in self.domain_has_rebooted

    def domain_is_running(self, domain):
        key = (domain.name(), domain.ID())
        with self.data_lock:
            return key in self.domain_status and self.domain_status[key]['status'] in ["Started", "Resumed"]

    def domain_is_stopped(self, domain):
        key = (domain.name(), domain.ID())
        with self.data_lock:
            return key in self.domain_status and self.domain_status[key] in [{'status': 'Shutdown', 'detail': 'Finished'},
                                                                             {'status': 'Stopped', 'detail': 'Shutdown'}
                                                                            ]

    def wait_for_running(self, domain, timeout_sec=120):
        start_time = time.time()
        end_time = start_time + timeout_sec
        if self.domain_is_running(domain):
            return True
        remaining_time = end_time - time.time()
        with self.signal_condition:
            while remaining_time > 0:
                # wait for a domain event or our timeout
                self.signal_condition.wait(remaining_time)
                if self.domain_is_running(domain):
                    return True
                remaining_time = end_time - time.time()
        return False

    def _domain_is_valid(self, uuid):
        try:
            with stdchannel_redirected(sys.stderr, os.devnull):
                return self.connection.lookupByUUID(uuid)
        except:
            return False

    def wait_for_stopped(self, domain, timeout_sec=120):
        start_time = time.time()
        end_time = start_time + timeout_sec
        uuid = domain.UUID()
        if self.domain_is_stopped(domain) or not self._domain_is_valid(uuid):
            return True
        remaining_time = end_time - time.time()
        with self.signal_condition:
            while remaining_time > 0:
                # wait for a domain event or our timeout
                self.signal_condition.wait(remaining_time)
                if self.domain_is_stopped(domain) or not self._domain_is_valid(uuid):
                    return True
                remaining_time = end_time - time.time()
        return False

    def virEventLoopNativeStart(self):
        def virEventLoopNativeRun():
            self.register_handlers_for_domain()
            try:
                while libvirt:
                    if libvirt.virEventRunDefaultImpl() < 0:
                        raise Failure("Error in libvirt event handler")
            except:
                raise Failure("error in libvirt event loop")

        self.eventLoopThread = threading.Thread(target=virEventLoopNativeRun,
                                           name="libvirtEventLoop")
        self.eventLoopThread.setDaemon(True)
        self.eventLoopThread.start()

TEST_DOMAIN_XML="""
<domain type='%(type)s'>
  <name>%(name)s</name>
  %(cpu)s
  <os>
    <type arch='%(arch)s'>hvm</type>
    <boot dev='hd'/>
  </os>
  <memory unit='MiB'>%(memory_in_mib)d</memory>
  <currentMemory unit='MiB'>%(memory_in_mib)d</currentMemory>
  <features>
    <acpi/>
  </features>
  <devices>
    <disk type='file' snapshot='external'>
      <driver name='qemu' type='qcow2'/>
      <source file='%(drive)s'/>
      <target dev='vda' bus='virtio'/>
      <serial>ROOT</serial>
    </disk>
    <controller type='scsi' model='virtio-scsi' index='0' id='hot'/>
    <interface type='bridge'>
      <source bridge='cockpit1'/>
      <model type='virtio'/>
      %(mac)s
    </interface>
    <console type='pty'>
      <target type='serial' port='0'/>
    </console>
    <disk type='file' device='cdrom'>
      <source file='%(iso)s'/>
      <target dev='hdb' bus='ide'/>
      <readonly/>
    </disk>
  </devices>
</domain>
"""

TEST_DISK_XML="""
<disk type='file'>
  <driver name='qemu' type='raw'/>
  <source file='%(file)s'/>
  <serial>%(serial)s</serial>
  <address type='drive' controller='0' bus='0' target='2' unit='%(unit)d'/>
  <target dev='%(dev)s' bus='scsi'/>
</disk>
"""

class VirtMachine(Machine):
    memory_mb = None
    cpus = None

    def __init__(self, image, **args):

        # The path to the image file to load, and parse an image name
        if "/" in image:
            self.image_file = image = os.path.abspath(image)
        else:
            self.image_file = os.path.join(LOCAL_DIR, "..", "images", image)
            if not os.path.lexists(self.image_file):
                self.image_file = os.path.join(LOCAL_DIR, "..", "..", "bots", "images", image)
        (image, extension) = os.path.splitext(os.path.basename(image))

        Machine.__init__(self, image=image, **args)

        base_dir = os.path.join(LOCAL_DIR, "..", "..")
        self.run_dir = os.path.join(os.environ.get("TEST_DATA", base_dir), "tmp", "run")

        self._network_description = etree.parse(open(os.path.join(LOCAL_DIR, "network-cockpit.xml")))

        self.test_disk_desc_original = None

        # it is ESSENTIAL to register the default implementation of the event loop before opening a connection
        # otherwise messages may be delayed or lost
        libvirt.virEventRegisterDefaultImpl()
        self.virt_connection = self._libvirt_connection(hypervisor = "qemu:///session")
        self.event_handler = VirtEventHandler(libvirt_connection=self.virt_connection, verbose=self.verbose)

        # network names are currently hardcoded into network-cockpit.xml
        self.network_name = "cockpit1"

        # we can't see the network itself as non-root, create it using vm-prep as root

        # Unique identifiers for hostnet config
        self._hostnet = 8

        self._domain = None

        # init variables needed for running a vm
        self._cleanup()

    def _libvirt_connection(self, hypervisor, read_only = False):
        tries_left = 5
        connection = None
        if read_only:
            open_function = libvirt.openReadOnly
        else:
            open_function = libvirt.open
        while not connection and (tries_left > 0):
            try:
                connection = open_function(hypervisor)
            except:
                # wait a bit
                time.sleep(1)
                pass
            tries_left -= 1
        if not connection:
            # try again, but if an error occurs, don't catch it
            connection = open_function(hypervisor)
        return connection

    def _resource_lockfile_path(self, resource):
        resources = os.path.join(tempfile.gettempdir(), ".cockpit-test-resources")
        resource = resource.replace("/", "_")
        return os.path.join(resources, resource)

    # The lock is open until the process calling this function exits
    def _lock_resource(self, resource, exclusive=True):
        resources = os.path.join(tempfile.gettempdir(), ".cockpit-test-resources")
        if not os.path.exists(resources):
            os.mkdir(resources, 0755)
        lockpath = self._resource_lockfile_path(resource)
        fd = os.open(lockpath, os.O_WRONLY | os.O_CREAT)
        try:
            flags = fcntl.LOCK_NB
            if exclusive:
                flags |= fcntl.LOCK_EX
            else:
                flags |= fcntl.LOCK_SH
            fcntl.flock(fd, flags)
        except IOError:
            os.close(fd)
            return False
        else:
            return True

    def reserve_macaddr(self):
        pid = os.getpid()
        for seed in range(1, 64 * 1024):
            mac = "9E%02x%02x%02x%04x" % ((pid >> 16) & 0xff, (pid >> 8) & 0xff, pid & 0xff, seed)
            mac = ":".join([mac[i:i+2] for i in range(0, len(mac), 2)])
            if self._lock_resource(mac):
                return mac

        raise Failure("Couldn't find unused mac address for '%s'" % (self.image))

    def _choose_macaddr(self):
        # Check if this has a forced mac address
        for h in self._network_description.find(".//dhcp"):
            image = h.get("{urn:cockpit-project.org:cockpit}image")
            if image == self.image:
                mac = h.get("mac")
                if mac:
                    return mac

        # If not, get a random one
        return self.reserve_macaddr()

    def _start_qemu(self, maintain=False, macaddr=None, wait_for_ip=True, memory_mb=None, cpus=None):
        self._cleanup()

        try:
            os.makedirs(self.run_dir, 0750)
        except OSError, ex:
            if ex.errno != errno.EEXIST:
                raise

        image_to_use = self.image_file
        if not maintain:
            (unused, self._transient_image) = tempfile.mkstemp(suffix='.qcow2', prefix="", dir=self.run_dir)
            subprocess.check_call([ "qemu-img", "create", "-q",
                                    "-f", "qcow2",
                                    "-o", "backing_file=%s" % self.image_file,
                                    self._transient_image ])
            image_to_use = self._transient_image

        if not macaddr:
            macaddr = self._choose_macaddr()

        # if we have a static ip, this implies a singleton instance
        # make sure we don't randomize the libvirt domain name in those cases
        static_domain_name = None
        if macaddr:
            lease = self._static_lease_from_mac(macaddr)
            if lease:
                static_domain_name = self.image + "_" + lease['name']
        mac_desc = ""
        cpu_desc = ""
        domain_type = "qemu"
        if os.path.exists("/dev/kvm"):
            domain_type = "kvm"
            cpu_desc += "<cpu mode='host-passthrough'/>\n"
            cpu_desc += "<vcpu>%(cpus)d</vcpu>\n" % { 'cpus': cpus or VirtMachine.cpus or 1 }
        else:
            print >> sys.stderr, "WARNING: Starting virtual machine with emulation due to missing KVM"
            print >> sys.stderr, "WARNING: Machine will run about 10-20 times slower"
        if macaddr:
            mac_desc = "<mac address='%(mac)s'/>" % {'mac': macaddr}
        if static_domain_name:
            domain_name = static_domain_name
        else:
            rand_extension = '-' + ''.join(random.choice(string.digits + string.ascii_lowercase) for i in range(4))
            domain_name = self.image + rand_extension

        test_domain_desc = TEST_DOMAIN_XML % {
                                        "name": domain_name,
                                        "type": domain_type,
                                        "arch": self.arch,
                                        "cpu": cpu_desc,
                                        "memory_in_mib": memory_mb or VirtMachine.memory_mb or MEMORY_MB,
                                        "drive": image_to_use,
                                        "mac": mac_desc,
                                        "iso": os.path.join(LOCAL_DIR, "cloud-init.iso")
                                      }

        # add the virtual machine
        try:
            # allow debug output for this domain
            self.event_handler.allow_domain_debug_output(domain_name)
            self._domain = self.virt_connection.createXML(test_domain_desc, libvirt.VIR_DOMAIN_START_AUTODESTROY)
        except libvirt.libvirtError, le:
            # remove the debug output
            self.event_handler.forbid_domain_debug_output(domain_name)
            if 'already exists with uuid' in le.message:
                raise RepeatableFailure("libvirt domain already exists: " + le.message)
            else:
                raise

        macs = self._qemu_network_macs()
        if not macs:
            raise Failure("no mac addresses found for created machine")
        self.message("available mac addresses: %s" % (", ".join(macs)))
        self.macaddr = macs[0]
        if wait_for_ip:
            self.address = self._ip_from_mac(self.macaddr)

    # start virsh console
    def qemu_console(self, maintain=True, macaddr=None, memory_mb=None, cpus=None):
        try:
            self._start_qemu(maintain=maintain, macaddr=macaddr, wait_for_ip=False, memory_mb=memory_mb, cpus=cpus)
            self.message("started machine %s with address %s" % (self._domain.name(), self.address))
            if maintain:
                self.message("Changes are written to the image file.")
                self.message("WARNING: Uncontrolled shutdown can lead to a corrupted image.")
            else:
                self.message("All changes are discarded, the image file won't be changed.")
            print "console, to quit use Ctrl+], Ctrl+5 (depending on locale)"
            proc = subprocess.Popen("virsh -c qemu:///session console %s" % self._domain.ID(), shell=True)
            proc.wait()
            try:
                if maintain:
                    self.shutdown()
                else:
                    self.kill()
            except libvirt.libvirtError, le:
                # the domain may have already been freed (shutdown) while the console was running
                self.message("libvirt error during shutdown: %s" % (le.get_error_message()))

        except:
            raise
        finally:
            self._cleanup()

    def start(self, maintain=False, macaddr=None, memory_mb=None, cpus=None, wait_for_ip=True):
        if self.fetch and not os.path.exists(self.image_file):
            try:
                subprocess.check_call([ "image-download", self.image_file ])
            except OSError, ex:
                if ex.errno != errno.ENOENT:
                    raise

        tries = 0
        while True:
            try:
                self._start_qemu(maintain=maintain, macaddr=macaddr, wait_for_ip=wait_for_ip, memory_mb=memory_mb, cpus=cpus)
                if not self._domain.isActive():
                    self._domain.start()
                self._maintaining = maintain
            except RepeatableFailure:
                self.kill()
                if tries < 10:
                    tries += 1
                    time.sleep(tries)
                    continue
                else:
                    raise
            except:
                self.kill()
                raise

            # Normally only one pass
            break

    def _static_lease_from_mac(self, mac):
        for h in self._network_description.find(".//dhcp"):
            netmac = h.get("mac") or ""
            if netmac.lower() == mac.lower():
                return { "ip":   h.get("ip"), "name": h.get("name") }
        return None

    def _diagnose_no_address(self):
        SCRIPT = """
            spawn virsh -c qemu:///session console $argv
            set timeout 300
            expect "Escape character"
            send "\r"
            expect " login: "
            send "root\r"
            expect "Password: "
            send "foobar\r"
            expect " ~]# "
            send "ip addr\r\n"
            expect " ~]# "
            exit 0
        """
        expect = subprocess.Popen(["expect", "--", "-", str(self._domain.ID())], stdin=subprocess.PIPE)
        expect.communicate(SCRIPT)

    def _ip_from_mac(self, mac, timeout_sec=300):
        # Get address from the arp arp output looks like this.
        #
        # IP address     HW type  Flags  HW address         Mask  Device
        # 10.111.118.45  0x1      0x0    9e:00:03:72:00:04  *     cockpit1
        # ...

        output = ""
        start_time = time.time()
        while (time.time() - start_time) < timeout_sec:
            with open("/proc/net/arp", "r") as fp:
                output = fp.read()
            for line in output.split("\n"):
                parts = re.split(' +', line)
                if len(parts) > 5 and parts[3].lower() == mac.lower() and parts[5] == self.network_name:
                    return parts[0]
            time.sleep(1)

        message = "{0}: [{1}]\n{2}\n".format(mac, ", ".join(self._qemu_network_macs()), output)
        sys.stderr.write(message)
        self._diagnose_no_address()
        raise RepeatableFailure("Can't resolve IP of " + mac)

    def reset_reboot_flag(self):
        self.event_handler.reset_domain_reboot_status(self._domain)

    def wait_reboot(self, wait_for_running_timeout=120):
        self.disconnect()
        if not self.event_handler.wait_for_reboot(self._domain):
            raise Failure("system didn't notify us about a reboot")
        # we may have to check for a new dhcp lease, but the old one can be active for a bit
        if not self.wait_execute(timeout_sec=wait_for_running_timeout, get_new_address=lambda: self._ip_from_mac(self.macaddr, timeout_sec=5)):
            raise Failure("system didn't reboot properly")
        self.wait_user_login()

    def wait_boot(self, wait_for_running_timeout = 120, allow_one_reboot=False):
        # we should check for selinux relabeling in progress here
        if not self.event_handler.wait_for_running(self._domain, timeout_sec=wait_for_running_timeout ):
            raise Failure("Machine %s didn't start." % (self.address))

        if not self.address:
            self.address = self._ip_from_mac(self.macaddr)

        # if we allow a reboot, the connection to test for a finished boot may be interrupted
        # by the reboot, causing an exception
        try:
            start_time = time.time()
            connected = False
            while (time.time() - start_time) < wait_for_running_timeout:
                if self.wait_execute(timeout_sec=15, get_new_address=lambda: self._ip_from_mac(self.macaddr, timeout_sec=3)):
                    connected = True
                    break
                if allow_one_reboot and self.event_handler.has_rebooted(self._domain):
                    self.reset_reboot_flag()
                    self.wait_boot(wait_for_running_timeout, allow_one_reboot=False)
            if not connected:
                self._diagnose_no_address()
                raise Failure("Unable to reach machine %s via ssh." % (self.address))
            self.wait_user_login()
        except:
            if allow_one_reboot:
                self.wait_reboot()
                self.reset_reboot_flag()
                self.wait_boot(wait_for_running_timeout, allow_one_reboot=False)
            else:
                raise

    def stop(self, timeout_sec=120):
        if self._maintaining:
            self.shutdown(timeout_sec=timeout_sec)
        else:
            self.kill()

    def _cleanup(self, quick=False):
        self.disconnect()
        try:
            if hasattr(self, '_disks'):
                for index in dict(self._disks):
                    self.rem_disk(index, quick)

            self._disks = { }

            if self._domain:
                # remove the debug output
                self.event_handler.forbid_domain_debug_output(self._domain.name())

            self._domain = None
            self.address = None
            self.macaddr = None
            if hasattr(self, '_transient_image') and self._transient_image and os.path.exists(self._transient_image):
                os.unlink(self._transient_image)
        except:
            (type, value, traceback) = sys.exc_info()
            print >> sys.stderr, "WARNING: Cleanup failed:", str(value)

    def kill(self):
        # stop system immediately, with potential data loss
        # to shutdown gracefully, use shutdown()
        self.disconnect()
        if self._domain:
            try:
                # not graceful
                with stdchannel_redirected(sys.stderr, os.devnull):
                    self._domain.destroyFlags(libvirt.VIR_DOMAIN_DESTROY_DEFAULT)
            except:
                pass
        self._cleanup(quick=True)

    def wait_poweroff(self, timeout_sec=120):
        # shutdown must have already been triggered
        if self._domain:
            if not self.event_handler.wait_for_stopped(self._domain, timeout_sec=timeout_sec):
                self.message("waiting for machine poweroff timed out")

        self._cleanup(quick=True)

    def shutdown(self, timeout_sec=120):
        # shutdown the system gracefully
        # to stop it immediately, use kill()
        self.disconnect()
        try:
            if self._domain:
                self._domain.shutdown()
            self.wait_poweroff(timeout_sec=timeout_sec)
        finally:
            self._cleanup()

    def add_disk(self, size, serial=None):
        index = 1
        while index in self._disks:
            index += 1

        if not serial:
            serial = "DISK%d" % index

        try:
            os.makedirs(self.run_dir, 0750)
        except OSError, ex:
            if ex.errno != errno.EEXIST:
                raise

        path = os.path.join(self.run_dir, "disk-%s-%d" % (self._domain.name(), index))
        if os.path.exists(path):
            os.unlink(path)

        subprocess.check_call(["qemu-img", "create", "-q", "-f", "raw", path, str(size)])

        dev = 'sd' + string.ascii_lowercase[index]
        disk_desc = TEST_DISK_XML % {
                          'file': path,
                          'serial': serial,
                          'unit': index,
                          'dev': dev
                        }

        if self._domain.attachDeviceFlags(disk_desc, libvirt.VIR_DOMAIN_AFFECT_LIVE) != 0:
            raise Failure("Unable to add disk to vm")

        self._disks[index] = {
            "path": path,
            "serial": serial,
            "filename": path,
            "dev": dev
        }

        return index

    def set_disk_io_speed(self, disk_index, speed_in_bytes=0):
        subprocess.check_call([
              "virsh", "-c", "qemu:///session", "blkdeviotune", "--current", str(self._domain.ID()), self._disks[disk_index]["dev"],
              "--total-bytes-sec", str(speed_in_bytes)
            ])

    def add_disk_path(self, main_index):
        index = 1
        while index in self._disks:
            index += 1

        filename = self._disks[main_index]["path"]
        serial = self._disks[main_index]["serial"]

        dev = 'sd' + string.ascii_lowercase[index]
        disk_desc = TEST_DISK_XML % {'file': filename, 'serial': serial, 'unit': index, 'dev': dev}

        if self._domain.attachDeviceFlags(disk_desc, libvirt.VIR_DOMAIN_AFFECT_LIVE ) != 0:
            raise Failure("Unable to add disk to vm")

        self._disks[index] = {
            "filename": filename,
            "serial": serial,
            "dev": dev
        }

        return index

    def rem_disk(self, index, quick=False):
        assert index in self._disks
        disk = self._disks.pop(index)

        if not quick:
            disk_desc = TEST_DISK_XML % {
                'file': disk["filename"],
                'serial': disk["serial"],
                'unit': index,
                'dev': disk["dev"]
            }

            if self._domain.detachDeviceFlags(disk_desc, libvirt.VIR_DOMAIN_AFFECT_LIVE ) != 0:
                raise Failure("Unable to remove disk from vm")

        # if this isn't just an additional path, clean up
        if "path" in disk and disk["path"] and os.path.exists(disk["path"]):
            os.unlink(disk["path"])

    def _qemu_monitor(self, command):
        self.message("& " + command)
        # you can run commands manually using virsh:
        # virsh -c qemu:///session qemu-monitor-command [domain name/id] --hmp [command]
        output = libvirt_qemu.qemuMonitorCommand(self._domain, command, libvirt_qemu.VIR_DOMAIN_QEMU_MONITOR_COMMAND_HMP)
        self.message(output.strip())
        return output

    def _qemu_network_macs(self):
        macs = []
        for line in self._qemu_monitor("info network").split("\n"):
            x, y, mac = line.partition("macaddr=")
            mac = mac.strip()
            if mac:
                macs.append(mac)
        return macs

    def add_netiface(self, mac=None, vlan=0):
        if not mac:
            mac = self.reserve_macaddr()
        cmd = "device_add e1000,mac=%s" % mac
        if vlan == 0:
            # selinux can prevent the creation of the bridge
            # https://bugzilla.redhat.com/show_bug.cgi?id=1267217
            output = self._qemu_monitor("netdev_add bridge,id=hostnet%d,br=cockpit1" % self._hostnet)
            if "Device 'bridge' could not be initialized" in output:
                raise Failure("Unable to add bridge for virtual machine, possibly related to an selinux-denial")
            cmd += ",netdev=hostnet%d" % self._hostnet
            self._hostnet += 1
        else:
            cmd += ",vlan=%d" % vlan
        self._qemu_monitor(cmd)
        return mac

    def needs_writable_usr(self):
        # On atomic systems, we need a hack to change files in /usr/lib/systemd
        if self.atomic_image:
            self.execute(command="mount -o remount,rw /usr")

    def wait_for_cockpit_running(self, address="localhost", port=9090, seconds=30):
        WAIT_COCKPIT_RUNNING = """#!/bin/sh
until curl -s --connect-timeout 2 --max-time 3 http://%s:%s >/dev/null; do
    sleep 0.5;
done;
""" % (address, port)
        with Timeout(seconds=seconds, error_message="Timeout while waiting for cockpit to start"):
            self.execute(script=WAIT_COCKPIT_RUNNING)
