#!/usr/bin/env python

# Copyright 2011 OpenStack Foundation
# Copyright 2011 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

# NOTE: XenServer still only supports Python 2.4 in it's dom0 userspace
# which means the Nova xenapi plugins must use only Python 2.4 features

# TODO(sfinucan): Resolve all 'noqa' items once the above is no longer true

#
# XenAPI plugin for host operations
#

try:
    import json
except ImportError:
    import simplejson as json
import logging
import re
import sys
import time

import six
import utils

import dom0_pluginlib as pluginlib
import XenAPI
import XenAPIPlugin

try:
    import xmlrpclib
except ImportError:
    import six.moves.xmlrpc_client as xmlrpclib


pluginlib.configure_logging("xenhost")


host_data_pattern = re.compile(r"\s*(\S+) \([^\)]+\) *: ?(.*)")
config_file_path = "/usr/etc/xenhost.conf"
DEFAULT_TRIES = 23
DEFAULT_SLEEP = 10


def jsonify(fnc):
    def wrapper(*args, **kwargs):
        return json.dumps(fnc(*args, **kwargs))
    return wrapper


class TimeoutError(Exception):
    pass


def _run_command(cmd, cmd_input=None):
    """Wrap utils.run_command to raise PluginError on failure"""
    try:
        return utils.run_command(cmd, cmd_input=cmd_input)
    except utils.SubprocessException as e:  # noqa
        raise pluginlib.PluginError(e.err)


def _resume_compute(session, compute_ref, compute_uuid):
    """Resume compute node on slave host after pool join.

    This has to happen regardless of the success or failure of the join
    operation.
    """
    try:
        # session is valid if the join operation has failed
        session.xenapi.VM.start(compute_ref, False, True)
    except XenAPI.Failure:
        # if session is invalid, e.g. xapi has restarted, then the pool
        # join has been successful, wait for xapi to become alive again
        for c in range(0, DEFAULT_TRIES):
            try:
                _run_command(["xe", "vm-start", "uuid=%s" % compute_uuid])
                return
            except pluginlib.PluginError:
                logging.exception('Waited %d seconds for the slave to '
                                  'become available.' % (c * DEFAULT_SLEEP))
                time.sleep(DEFAULT_SLEEP)
        raise pluginlib.PluginError('Unrecoverable error: the host has '
                                    'not come back for more than %d seconds'
                                    % (DEFAULT_SLEEP * (DEFAULT_TRIES + 1)))


@jsonify
def set_host_enabled(self, arg_dict):
    """Sets this host's ability to accept new instances.

    It will otherwise continue to operate normally.
    """
    enabled = arg_dict.get("enabled")
    if enabled is None:
        raise pluginlib.PluginError(
            "Missing 'enabled' argument to set_host_enabled")

    host_uuid = arg_dict['host_uuid']
    if enabled == "true":
        result = _run_command(["xe", "host-enable", "uuid=%s" % host_uuid])
    elif enabled == "false":
        result = _run_command(["xe", "host-disable", "uuid=%s" % host_uuid])
    else:
        raise pluginlib.PluginError("Illegal enabled status: %s" % enabled)
    # Should be empty string
    if result:
        raise pluginlib.PluginError(result)
    # Return the current enabled status
    cmd = ["xe", "host-param-get", "uuid=%s" % host_uuid, "param-name=enabled"]
    host_enabled = _run_command(cmd)
    if host_enabled == "true":
        status = "enabled"
    else:
        status = "disabled"
    return {"status": status}


def _write_config_dict(dct):
    conf_file = open(config_file_path, 'w')
    json.dump(dct, conf_file)
    conf_file.close()


def _get_config_dict():
    """Returns a dict containing the key/values in the config file.

    If the file doesn't exist, it is created, and an empty dict
    is returned.
    """
    try:
        conf_file = open(config_file_path, 'r')
        config_dct = json.load(conf_file)
        conf_file.close()
    except IOError:
        # File doesn't exist
        config_dct = {}
        # Create the file
        _write_config_dict(config_dct)
    return config_dct


@jsonify
def get_config(self, arg_dict):
    """Return the value stored for the specified key, or None if no match."""
    conf = _get_config_dict()
    params = arg_dict["params"]
    try:
        dct = json.loads(params)
    except Exception:
        dct = params
    key = dct["key"]
    ret = conf.get(key)
    if ret is None:
        # Can't jsonify None
        return "None"
    return ret


@jsonify
def set_config(self, arg_dict):
    """Write the specified key/value pair, overwriting any existing value."""
    conf = _get_config_dict()
    params = arg_dict["params"]
    try:
        dct = json.loads(params)
    except Exception:
        dct = params
    key = dct["key"]
    val = dct["value"]
    if val is None:
        # Delete the key, if present
        conf.pop(key, None)
    else:
        conf.update({key: val})
    _write_config_dict(conf)


def iptables_config(session, args):
    # command should be either save or restore
    logging.debug("iptables_config:enter")
    logging.debug("iptables_config: args=%s", args)
    cmd_args = pluginlib.exists(args, 'cmd_args')
    logging.debug("iptables_config: cmd_args=%s", cmd_args)
    process_input = pluginlib.optional(args, 'process_input')
    logging.debug("iptables_config: process_input=%s", process_input)
    cmd = json.loads(cmd_args)
    cmd = map(str, cmd)

    # either execute iptable-save or iptables-restore
    # command must be only one of these two
    # process_input must be used only with iptables-restore
    if len(list(cmd)) > 0 and cmd[0] in ('iptables-save',
                                         'iptables-restore',
                                         'ip6tables-save',
                                         'ip6tables-restore'):
        result = _run_command(cmd, process_input)
        ret_str = json.dumps(dict(out=result, err=''))
        logging.debug("iptables_config:exit")
        return ret_str
    # else don't do anything and return an error
    else:
        raise pluginlib.PluginError("Invalid iptables command")


def _ovs_add_patch_port(args):
    bridge_name = pluginlib.exists(args, 'bridge_name')
    port_name = pluginlib.exists(args, 'port_name')
    peer_port_name = pluginlib.exists(args, 'peer_port_name')
    cmd_args = ['ovs-vsctl', '--', '--if-exists', 'del-port',
                port_name, '--', 'add-port', bridge_name, port_name,
                '--', 'set', 'interface', port_name,
                'type=patch', 'options:peer=%s' % peer_port_name]
    return _run_command(cmd_args)


def _ovs_del_port(args):
    bridge_name = pluginlib.exists(args, 'bridge_name')
    port_name = pluginlib.exists(args, 'port_name')
    cmd_args = ['ovs-vsctl', '--', '--if-exists', 'del-port',
                bridge_name, port_name]
    return _run_command(cmd_args)


def _ovs_del_br(args):
    bridge_name = pluginlib.exists(args, 'bridge_name')
    cmd_args = ['ovs-vsctl', '--', '--if-exists',
                'del-br', bridge_name]
    return _run_command(cmd_args)


def _ovs_set_if_external_id(args):
    interface = pluginlib.exists(args, 'interface')
    extneral_id = pluginlib.exists(args, 'extneral_id')
    value = pluginlib.exists(args, 'value')
    cmd_args = ['ovs-vsctl', 'set', 'Interface', interface,
                'external-ids:%s=%s' % (extneral_id, value)]
    return _run_command(cmd_args)


def _ovs_add_port(args):
    bridge_name = pluginlib.exists(args, 'bridge_name')
    port_name = pluginlib.exists(args, 'port_name')
    cmd_args = ['ovs-vsctl', '--', '--if-exists', 'del-port', port_name,
                '--', 'add-port', bridge_name, port_name]
    return _run_command(cmd_args)


def _ovs_create_port(args):
    bridge = pluginlib.exists(args, 'bridge')
    port = pluginlib.exists(args, 'port')
    iface_id = pluginlib.exists(args, 'iface-id')
    mac = pluginlib.exists(args, 'mac')
    status = pluginlib.exists(args, 'status')
    cmd_args = ['ovs-vsctl', '--', '--if-exists', 'del-port', port,
                '--', 'add-port', bridge, port,
                '--', 'set', 'Interface', port,
                'external_ids:iface-id=%s' % iface_id,
                'external_ids:iface-status=%s' % status,
                'external_ids:attached-mac=%s' % mac,
                'external_ids:xs-vif-uuid=%s' % iface_id]
    return _run_command(cmd_args)


def _ip_link_get_dev(args):
    device_name = pluginlib.exists(args, 'device_name')
    cmd_args = ['ip', 'link', 'show', device_name]
    return _run_command(cmd_args)


def _ip_link_del_dev(args):
    device_name = pluginlib.exists(args, 'device_name')
    cmd_args = ['ip', 'link', 'delete', device_name]
    return _run_command(cmd_args)


def _ip_link_add_veth_pair(args):
    dev1_name = pluginlib.exists(args, 'dev1_name')
    dev2_name = pluginlib.exists(args, 'dev2_name')
    cmd_args = ['ip', 'link', 'add', dev1_name, 'type', 'veth', 'peer',
                'name', dev2_name]
    return _run_command(cmd_args)


def _ip_link_set_dev(args):
    device_name = pluginlib.exists(args, 'device_name')
    option = pluginlib.exists(args, 'option')
    cmd_args = ['ip', 'link', 'set', device_name, option]
    return _run_command(cmd_args)


def _ip_link_set_promisc(args):
    device_name = pluginlib.exists(args, 'device_name')
    option = pluginlib.exists(args, 'option')
    cmd_args = ['ip', 'link', 'set', device_name, 'promisc', option]
    return _run_command(cmd_args)


def _brctl_add_br(args):
    bridge_name = pluginlib.exists(args, 'bridge_name')
    cmd_args = ['brctl', 'addbr', bridge_name]
    return _run_command(cmd_args)


def _brctl_del_br(args):
    bridge_name = pluginlib.exists(args, 'bridge_name')
    cmd_args = ['brctl', 'delbr', bridge_name]
    return _run_command(cmd_args)


def _brctl_set_fd(args):
    bridge_name = pluginlib.exists(args, 'bridge_name')
    fd = pluginlib.exists(args, 'fd')
    cmd_args = ['brctl', 'setfd', bridge_name, fd]
    return _run_command(cmd_args)


def _brctl_set_stp(args):
    bridge_name = pluginlib.exists(args, 'bridge_name')
    option = pluginlib.exists(args, 'option')
    cmd_args = ['brctl', 'stp', bridge_name, option]
    return _run_command(cmd_args)


def _brctl_add_if(args):
    bridge_name = pluginlib.exists(args, 'bridge_name')
    if_name = pluginlib.exists(args, 'interface_name')
    cmd_args = ['brctl', 'addif', bridge_name, if_name]
    return _run_command(cmd_args)


def _brctl_del_if(args):
    bridge_name = pluginlib.exists(args, 'bridge_name')
    if_name = pluginlib.exists(args, 'interface_name')
    cmd_args = ['brctl', 'delif', bridge_name, if_name]
    return _run_command(cmd_args)


ALLOWED_NETWORK_CMDS = {
    # allowed cmds to config OVS bridge
    'ovs_add_patch_port': _ovs_add_patch_port,
    'ovs_add_port': _ovs_add_port,
    'ovs_create_port': _ovs_create_port,
    'ovs_del_port': _ovs_del_port,
    'ovs_del_br': _ovs_del_br,
    'ovs_set_if_external_id': _ovs_set_if_external_id,
    'ip_link_add_veth_pair': _ip_link_add_veth_pair,
    'ip_link_del_dev': _ip_link_del_dev,
    'ip_link_get_dev': _ip_link_get_dev,
    'ip_link_set_dev': _ip_link_set_dev,
    'ip_link_set_promisc': _ip_link_set_promisc,
    'brctl_add_br': _brctl_add_br,
    'brctl_add_if': _brctl_add_if,
    'brctl_del_br': _brctl_del_br,
    'brctl_del_if': _brctl_del_if,
    'brctl_set_fd': _brctl_set_fd,
    'brctl_set_stp': _brctl_set_stp
    }


def network_config(session, args):
    """network config functions"""
    cmd = pluginlib.exists(args, 'cmd')
    if not isinstance(cmd, six.string_types):
        msg = "invalid command '%s'" % str(cmd)
        raise pluginlib.PluginError(msg)
        return
    if cmd not in ALLOWED_NETWORK_CMDS:
        msg = "Dom0 execution of '%s' is not permitted" % cmd
        raise pluginlib.PluginError(msg)
        return
    cmd_args = pluginlib.exists(args, 'args')
    return ALLOWED_NETWORK_CMDS[cmd](cmd_args)


def _power_action(action, arg_dict):
    # Host must be disabled first
    host_uuid = arg_dict['host_uuid']
    result = _run_command(["xe", "host-disable", "uuid=%s" % host_uuid])
    if result:
        raise pluginlib.PluginError(result)
    # All running VMs must be shutdown
    result = _run_command(["xe", "vm-shutdown", "--multiple",
                          "resident-on=%s" % host_uuid])
    if result:
        raise pluginlib.PluginError(result)
    cmds = {"reboot": "host-reboot",
            "startup": "host-power-on",
            "shutdown": "host-shutdown"}
    result = _run_command(["xe", cmds[action], "uuid=%s" % host_uuid])
    # Should be empty string
    if result:
        raise pluginlib.PluginError(result)
    return {"power_action": action}


@jsonify
def host_reboot(self, arg_dict):
    """Reboots the host."""
    return _power_action("reboot", arg_dict)


@jsonify
def host_shutdown(self, arg_dict):
    """Reboots the host."""
    return _power_action("shutdown", arg_dict)


@jsonify
def host_start(self, arg_dict):
    """Starts the host.

    Currently not feasible, since the host runs on the same machine as
    Xen.
    """
    return _power_action("startup", arg_dict)


@jsonify
def host_join(self, arg_dict):
    """Join a remote host into a pool.

    The pool's master is the host where the plugin is called from. The
    following constraints apply:

    - The host must have no VMs running, except nova-compute, which
      will be shut down (and restarted upon pool-join) automatically,
    - The host must have no shared storage currently set up,
    - The host must have the same license of the master,
    - The host must have the same supplemental packs as the master.
    """
    session = XenAPI.Session(arg_dict.get("url"))
    session.login_with_password(arg_dict.get("user"),
                                arg_dict.get("password"))
    compute_ref = session.xenapi.VM.get_by_uuid(arg_dict.get('compute_uuid'))
    session.xenapi.VM.clean_shutdown(compute_ref)
    try:
        if arg_dict.get("force", "false") == "false":
            session.xenapi.pool.join(arg_dict.get("master_addr"),
                                     arg_dict.get("master_user"),
                                     arg_dict.get("master_pass"))
        else:
            session.xenapi.pool.join_force(arg_dict.get("master_addr"),
                                           arg_dict.get("master_user"),
                                           arg_dict.get("master_pass"))
    finally:
        _resume_compute(session, compute_ref, arg_dict.get("compute_uuid"))


@jsonify
def host_data(self, arg_dict):
    # Runs the commands on the xenstore host to return the current status
    # information.
    host_uuid = arg_dict['host_uuid']
    resp = _run_command(["xe", "host-param-list", "uuid=%s" % host_uuid])
    parsed_data = parse_response(resp)
    # We have the raw dict of values. Extract those that we need,
    # and convert the data types as needed.
    ret_dict = cleanup(parsed_data)
    # Add any config settings
    config = _get_config_dict()
    ret_dict.update(config)
    return ret_dict


def parse_response(resp):
    data = {}
    for ln in resp.splitlines():
        if not ln:
            continue
        mtch = host_data_pattern.match(ln.strip())
        try:
            k, v = mtch.groups()
            data[k] = v
        except AttributeError:
            # Not a valid line; skip it
            continue
    return data


@jsonify
def host_uptime(self, arg_dict):
    """Returns the result of the uptime command on the xenhost."""
    return {"uptime": _run_command(['uptime'])}


def cleanup(dct):
    # Take the raw KV pairs returned and translate them into the
    # appropriate types, discarding any we don't need.
    def safe_int(val):
        # Integer values will either be string versions of numbers,
        # or empty strings. Convert the latter to nulls.
        try:
            return int(val)
        except ValueError:
            return None

    def strip_kv(ln):
        return [val.strip() for val in ln.split(":", 1)]

    out = {}

#    sbs = dct.get("supported-bootloaders", "")
#    out["host_supported-bootloaders"] = sbs.split("; ")
#    out["host_suspend-image-sr-uuid"] = dct.get("suspend-image-sr-uuid", "")
#    out["host_crash-dump-sr-uuid"] = dct.get("crash-dump-sr-uuid", "")
#    out["host_local-cache-sr"] = dct.get("local-cache-sr", "")
    out["enabled"] = dct.get("enabled", "true") == "true"
    omm = {}
    omm["total"] = safe_int(dct.get("memory-total", ""))
    omm["overhead"] = safe_int(dct.get("memory-overhead", ""))
    omm["free"] = safe_int(dct.get("memory-free", ""))
    omm["free-computed"] = safe_int(dct.get("memory-free-computed", ""))
    out["host_memory"] = omm

#    out["host_API-version"] = avv = {}
#    avv["vendor"] = dct.get("API-version-vendor", "")
#    avv["major"] = safe_int(dct.get("API-version-major", ""))
#    avv["minor"] = safe_int(dct.get("API-version-minor", ""))

    out["enabled"] = dct.get("enabled", True)
    out["host_uuid"] = dct.get("uuid", None)
    out["host_name-label"] = dct.get("name-label", "")
    out["host_name-description"] = dct.get("name-description", "")
#    out["host_host-metrics-live"] = dct.get(
#            "host-metrics-live", "false") == "true"
    out["host_hostname"] = dct.get("hostname", "")
    out["host_ip_address"] = dct.get("address", "")
    oc = dct.get("other-config", "")
    ocd = {}
    if oc:
        for oc_fld in oc.split("; "):
            ock, ocv = strip_kv(oc_fld)
            ocd[ock] = ocv
    out["host_other-config"] = ocd

    capabilities = dct.get("capabilities", "")
    out["host_capabilities"] = capabilities.replace(";", "").split()
#    out["host_allowed-operations"] = dct.get(
#            "allowed-operations", "").split("; ")
#    lsrv = dct.get("license-server", "")
#    out["host_license-server"] = ols = {}
#    if lsrv:
#        for lspart in lsrv.split("; "):
#            lsk, lsv = lspart.split(": ")
#            if lsk == "port":
#                ols[lsk] = safe_int(lsv)
#            else:
#                ols[lsk] = lsv
#    sv = dct.get("software-version", "")
#    out["host_software-version"] = osv = {}
#    if sv:
#        for svln in sv.split("; "):
#            svk, svv = strip_kv(svln)
#            osv[svk] = svv
    cpuinf = dct.get("cpu_info", "")
    ocp = {}
    if cpuinf:
        for cpln in cpuinf.split("; "):
            cpk, cpv = strip_kv(cpln)
            if cpk in ("cpu_count", "family", "model", "stepping"):
                ocp[cpk] = safe_int(cpv)
            else:
                ocp[cpk] = cpv
    out["host_cpu_info"] = ocp
#    out["host_edition"] = dct.get("edition", "")
#    out["host_external-auth-service-name"] = dct.get(
#            "external-auth-service-name", "")
    return out


def query_gc(session, sr_uuid, vdi_uuid):
    result = _run_command(["/opt/xensource/sm/cleanup.py",
                           "-q", "-u", sr_uuid])
    # Example output: "Currently running: True"
    return result[19:].strip() == "True"


def get_pci_device_details(session):
    """Returns a string that is a list of pci devices with details.

    This string is obtained by running the command lspci. With -vmm option,
    it dumps PCI device data in machine readable form. This verbose format
    display a sequence of records separated by a blank line. We will also
    use option "-n" to get vendor_id and device_id as numeric values and
    the "-k" option to get the kernel driver used if any.
    """
    return _run_command(["lspci", "-vmmnk"])


def get_pci_type(session, pci_device):
    """Returns the type of the PCI device (type-PCI, type-VF or type-PF).

    pci-device -- The address of the pci device
    """
    # We need to add the domain if it is missing
    if pci_device.count(':') == 1:
        pci_device = "0000:" + pci_device
    output = _run_command(["ls", "/sys/bus/pci/devices/" + pci_device + "/"])

    if "physfn" in output:
        return "type-VF"
    if "virtfn" in output:
        return "type-PF"
    return "type-PCI"


if __name__ == "__main__":
    # Support both serialized and non-serialized plugin approaches
    _, methodname = xmlrpclib.loads(sys.argv[1])
    if methodname in ['query_gc', 'get_pci_device_details', 'get_pci_type',
                      'network_config']:
        utils.register_plugin_calls(query_gc,
                                    get_pci_device_details,
                                    get_pci_type,
                                    network_config)

    XenAPIPlugin.dispatch(
        {"host_data": host_data,
         "set_host_enabled": set_host_enabled,
         "host_shutdown": host_shutdown,
         "host_reboot": host_reboot,
         "host_start": host_start,
         "host_join": host_join,
         "get_config": get_config,
         "set_config": set_config,
         "iptables_config": iptables_config,
         "host_uptime": host_uptime})
