# -*- test-case-name: tx_xmpp.test.test.test_server -*-
#
# Copyright (c) Ralph Meijer.
# See LICENSE for details.

"""
XMPP Server-to-Server protocol.

This module implements several aspects of XMPP server-to-server communications
as described in XMPP Core (RFC 3920). Refer to that document for the meaning
of the used terminology.
"""

from __future__ import division, absolute_import

import binascii
from hashlib import sha256
import hmac

from zope.interface import implementer

from twisted.internet import defer, reactor
from twisted.names.srvconnect import SRVConnector
from twisted.python import log, randbytes
from twisted.words.protocols.jabber import error, ijabber, jid, xmlstream
from twisted.words.xish import domish

from .generic import DeferredXmlStreamFactory, XmlPipe

NS_DIALBACK = "jabber:server:dialback"


def generateKey(secret, receivingServer, originatingServer, streamID):
    """
    Generate a dialback key for server-to-server XMPP Streams.

    The dialback key is generated using the algorithm described in
    U{XEP-0185<http://xmpp.org/extensions/xep-0185.html>}. The used
    terminology for the parameters is described in RFC-3920.

    @param secret: the shared secret known to the Originating Server and
                   Authoritive Server.
    @type secret: L{str}
    @param receivingServer: the Receiving Server host name.
    @type receivingServer: L{str}
    @param originatingServer: the Originating Server host name.
    @type originatingServer: L{str}
    @param streamID: the Stream ID as generated by the Receiving Server.
    @type streamID: L{str}
    @return: hexadecimal digest of the generated key.
    @type: L{str}
    """

    hashObject = sha256()
    hashObject.update(secret.encode("ascii"))
    hashedSecret = hashObject.hexdigest()
    message = " ".join([receivingServer, originatingServer, streamID])
    hash = hmac.HMAC(
        hashedSecret.encode("ascii"), message.encode("ascii"), digestmod=sha256
    )
    return hash.hexdigest()


def trapStreamError(xs, observer):
    """
    Trap stream errors.

    This wraps an observer to catch exceptions. In case of a
    L{error.StreamError}, it is send over the given XML stream. All other
    exceptions yield a C{'internal-server-error'} stream error, that is
    sent over the stream, while the exception is logged.

    @return: Wrapped observer
    """

    def wrappedObserver(element):
        try:
            observer(element)
        except error.StreamError as exc:
            xs.sendStreamError(exc)
        except:
            log.err()
            exc = error.StreamError("internal-server-error")
            xs.sendStreamError(exc)

    return wrappedObserver


class XMPPServerConnector(SRVConnector):
    def __init__(self, reactor, domain, factory):
        SRVConnector.__init__(self, reactor, "xmpp-server", domain, factory)

    def pickServer(self):
        host, port = SRVConnector.pickServer(self)

        if not self.servers and not self.orderedServers:
            # no SRV record, fall back..
            port = 5269

        return host, port


class DialbackFailed(Exception):
    pass


@implementer(ijabber.IInitiatingInitializer)
class OriginatingDialbackInitializer(object):
    """
    Server Dialback Initializer for the Orginating Server.
    """

    _deferred = None

    def __init__(self, xs, thisHost, otherHost, secret):
        self.xmlstream = xs
        self.thisHost = thisHost
        self.otherHost = otherHost
        self.secret = secret

    def initialize(self):
        self._deferred = defer.Deferred()
        self.xmlstream.addObserver(xmlstream.STREAM_ERROR_EVENT, self.onStreamError)
        self.xmlstream.addObserver("/result[@xmlns='%s']" % NS_DIALBACK, self.onResult)

        key = generateKey(self.secret, self.otherHost, self.thisHost, self.xmlstream.sid)

        result = domish.Element((NS_DIALBACK, "result"))
        result["from"] = self.thisHost
        result["to"] = self.otherHost
        result.addContent(key)

        self.xmlstream.send(result)

        return self._deferred

    def onResult(self, result):
        self.xmlstream.removeObserver(xmlstream.STREAM_ERROR_EVENT, self.onStreamError)
        if result["type"] == "valid":
            self.xmlstream.otherEntity = jid.internJID(self.otherHost)
            self._deferred.callback(None)
        else:
            self._deferred.errback(DialbackFailed())

    def onStreamError(self, failure):
        self.xmlstream.removeObserver("/result[@xmlns='%s']" % NS_DIALBACK, self.onResult)
        self._deferred.errback(failure)


@implementer(ijabber.IInitiatingInitializer)
class ReceivingDialbackInitializer(object):
    """
    Server Dialback Initializer for the Receiving Server.
    """

    _deferred = None

    def __init__(self, xs, thisHost, otherHost, originalStreamID, key):
        self.xmlstream = xs
        self.thisHost = thisHost
        self.otherHost = otherHost
        self.originalStreamID = originalStreamID
        self.key = key

    def initialize(self):
        self._deferred = defer.Deferred()
        self.xmlstream.addObserver(xmlstream.STREAM_ERROR_EVENT, self.onStreamError)
        self.xmlstream.addObserver("/verify[@xmlns='%s']" % NS_DIALBACK, self.onVerify)

        verify = domish.Element((NS_DIALBACK, "verify"))
        verify["from"] = self.thisHost
        verify["to"] = self.otherHost
        verify["id"] = self.originalStreamID
        verify.addContent(self.key)

        self.xmlstream.send(verify)
        return self._deferred

    def onVerify(self, verify):
        self.xmlstream.removeObserver(xmlstream.STREAM_ERROR_EVENT, self.onStreamError)
        if verify["id"] != self.originalStreamID:
            self.xmlstream.sendStreamError(error.StreamError("invalid-id"))
            self._deferred.errback(DialbackFailed())
        elif verify["to"] != self.thisHost:
            self.xmlstream.sendStreamError(error.StreamError("host-unknown"))
            self._deferred.errback(DialbackFailed())
        elif verify["from"] != self.otherHost:
            self.xmlstream.sendStreamError(error.StreamError("invalid-from"))
            self._deferred.errback(DialbackFailed())
        elif verify["type"] == "valid":
            self._deferred.callback(None)
        else:
            self._deferred.errback(DialbackFailed())

    def onStreamError(self, failure):
        self.xmlstream.removeObserver("/verify[@xmlns='%s']" % NS_DIALBACK, self.onVerify)
        self._deferred.errback(failure)


class XMPPServerConnectAuthenticator(xmlstream.ConnectAuthenticator):
    """
    Authenticator for an outgoing XMPP server-to-server connection.

    This authenticator connects to C{otherHost} (the Receiving Server) and then
    initiates dialback as C{thisHost} (the Originating Server) using
    L{OriginatingDialbackInitializer}.

    @ivar thisHost: The domain this server connects from (the Originating
                    Server) .
    @ivar otherHost: The domain of the server this server connects to (the
                     Receiving Server).
    @ivar secret: The shared secret that is used for verifying the validity
                  of this new connection.
    """

    namespace = "jabber:server"

    def __init__(self, thisHost, otherHost, secret):
        self.thisHost = thisHost
        self.otherHost = otherHost
        self.secret = secret
        xmlstream.ConnectAuthenticator.__init__(self, otherHost)

    def connectionMade(self):
        self.xmlstream.thisEntity = jid.internJID(self.thisHost)
        self.xmlstream.prefixes = {xmlstream.NS_STREAMS: "stream", NS_DIALBACK: "db"}
        xmlstream.ConnectAuthenticator.connectionMade(self)

    def associateWithStream(self, xs):
        xmlstream.ConnectAuthenticator.associateWithStream(self, xs)
        init = OriginatingDialbackInitializer(
            xs, self.thisHost, self.otherHost, self.secret
        )
        xs.initializers = [init]


class XMPPServerVerifyAuthenticator(xmlstream.ConnectAuthenticator):
    """
    Authenticator for an outgoing connection to verify an incoming connection.

    This authenticator connects to C{otherHost} (the Authoritative Server) and
    then initiates dialback as C{thisHost} (the Receiving Server) using
    L{ReceivingDialbackInitializer}.

    @ivar thisHost: The domain this server connects from (the Receiving
                    Server) .
    @ivar otherHost: The domain of the server this server connects to (the
                     Authoritative Server).
    @ivar originalStreamID: The stream ID of the incoming connection that is
                            being verified.
    @ivar key: The key provided by the Receving Server to be verified.
    """

    namespace = "jabber:server"

    def __init__(self, thisHost, otherHost, originalStreamID, key):
        self.thisHost = thisHost
        self.otherHost = otherHost
        self.originalStreamID = originalStreamID
        self.key = key
        xmlstream.ConnectAuthenticator.__init__(self, otherHost)

    def connectionMade(self):
        self.xmlstream.thisEntity = jid.internJID(self.thisHost)
        self.xmlstream.prefixes = {xmlstream.NS_STREAMS: "stream", NS_DIALBACK: "db"}
        xmlstream.ConnectAuthenticator.connectionMade(self)

    def associateWithStream(self, xs):
        xmlstream.ConnectAuthenticator.associateWithStream(self, xs)
        init = ReceivingDialbackInitializer(
            xs, self.thisHost, self.otherHost, self.originalStreamID, self.key
        )
        xs.initializers = [init]


class XMPPServerListenAuthenticator(xmlstream.ListenAuthenticator):
    """
    Authenticator for an incoming XMPP server-to-server connection.

    This authenticator handles two types of incoming connections. Regular
    server-to-server connections are from the Originating Server to the
    Receiving Server, where this server is the Receiving Server. These
    connections start out by receiving a dialback key, verifying the
    key with the Authoritative Server, and then accept normal XMPP stanzas.

    The other type of connections is from a Receiving Server to an
    Authoritative Server, where this server acts as the Authoritative Server.
    These connections are used to verify the validity of an outgoing connection
    from this server. In this case, this server receives a verification
    request, checks the key and then returns the result.

    @ivar service: The service that keeps the list of domains we accept
                   connections for.
    """

    namespace = "jabber:server"

    def __init__(self, service):
        xmlstream.ListenAuthenticator.__init__(self)
        self.service = service

    def streamStarted(self, rootElement):
        xmlstream.ListenAuthenticator.streamStarted(self, rootElement)

        if self.xmlstream.thisEntity:
            targetDomain = self.xmlstream.thisEntity.host
        else:
            targetDomain = self.service.defaultDomain

        def prepareStream(domain):
            self.xmlstream.namespace = self.namespace
            self.xmlstream.prefixes = {xmlstream.NS_STREAMS: "stream", NS_DIALBACK: "db"}
            if domain:
                self.xmlstream.thisEntity = jid.internJID(domain)

        try:
            if (
                xmlstream.NS_STREAMS != rootElement.uri
                or self.namespace != self.xmlstream.namespace
                or ("db", NS_DIALBACK) not in rootElement.localPrefixes.items()
            ):
                raise error.StreamError("invalid-namespace")

            if targetDomain and targetDomain not in self.service.domains:
                raise error.StreamError("host-unknown")
        except error.StreamError as exc:
            prepareStream(self.service.defaultDomain)
            self.xmlstream.sendStreamError(exc)
            return

        self.xmlstream.addObserver(
            "//verify[@xmlns='%s']" % NS_DIALBACK,
            trapStreamError(self.xmlstream, self.onVerify),
        )
        self.xmlstream.addObserver("//result[@xmlns='%s']" % NS_DIALBACK, self.onResult)

        prepareStream(targetDomain)
        self.xmlstream.sendHeader()

        if self.xmlstream.version >= (1, 0):
            features = domish.Element((xmlstream.NS_STREAMS, "features"))
            self.xmlstream.send(features)

    def onVerify(self, verify):
        try:
            receivingServer = jid.JID(verify["from"]).host
            originatingServer = jid.JID(verify["to"]).host
        except (KeyError, jid.InvalidFormat):
            raise error.StreamError("improper-addressing")

        if originatingServer not in self.service.domains:
            raise error.StreamError("host-unknown")

        if (
            self.xmlstream.otherEntity
            and receivingServer != self.xmlstream.otherEntity.host
        ):
            raise error.StreamError("invalid-from")

        streamID = verify.getAttribute("id", "")
        key = str(verify)

        calculatedKey = generateKey(
            self.service.secret, receivingServer, originatingServer, streamID
        )
        validity = (key == calculatedKey) and "valid" or "invalid"

        reply = domish.Element((NS_DIALBACK, "verify"))
        reply["from"] = originatingServer
        reply["to"] = receivingServer
        reply["id"] = streamID
        reply["type"] = validity
        self.xmlstream.send(reply)

    def onResult(self, result):
        def reply(validity):
            reply = domish.Element((NS_DIALBACK, "result"))
            reply["from"] = result["to"]
            reply["to"] = result["from"]
            reply["type"] = validity
            self.xmlstream.send(reply)

        def valid(xs):
            reply("valid")
            if not self.xmlstream.thisEntity:
                self.xmlstream.thisEntity = jid.internJID(receivingServer)
            self.xmlstream.otherEntity = jid.internJID(originatingServer)
            self.xmlstream.dispatch(self.xmlstream, xmlstream.STREAM_AUTHD_EVENT)

        def invalid(failure):
            log.err(failure)
            reply("invalid")

        receivingServer = result["to"]
        originatingServer = result["from"]
        key = str(result)

        d = self.service.validateConnection(
            receivingServer, originatingServer, self.xmlstream.sid, key
        )
        d.addCallbacks(valid, invalid)
        return d


class DeferredS2SClientFactory(DeferredXmlStreamFactory):
    """
    Deferred firing factory for initiating XMPP server-to-server connection.

    The deferred has its callbacks called upon succesful authentication with
    the other server. In case of failed authentication or connection, the
    deferred will have its errbacks called instead.
    """

    logTraffic = False

    def __init__(self, authenticator):
        DeferredXmlStreamFactory.__init__(self, authenticator)

        self.addBootstrap(xmlstream.STREAM_CONNECTED_EVENT, self.onConnectionMade)

        self.serial = 0

    def onConnectionMade(self, xs):
        xs.serial = self.serial
        self.serial += 1

        def logDataIn(buf):
            log.msg("RECV (%d): %r" % (xs.serial, buf))

        def logDataOut(buf):
            log.msg("SEND (%d): %r" % (xs.serial, buf))

        if self.logTraffic:
            xs.rawDataInFn = logDataIn
            xs.rawDataOutFn = logDataOut


def initiateS2S(factory):
    domain = factory.authenticator.otherHost.encode("idna")
    c = XMPPServerConnector(reactor, domain, factory)
    c.connect()
    return factory.deferred


class XMPPS2SServerFactory(xmlstream.XmlStreamServerFactory):
    """
    XMPP Server-to-Server Server factory.

    This factory accepts XMPP server-to-server connections.
    """

    logTraffic = False

    def __init__(self, service):
        self.service = service

        def authenticatorFactory():
            return XMPPServerListenAuthenticator(service)

        xmlstream.XmlStreamServerFactory.__init__(self, authenticatorFactory)
        self.addBootstrap(xmlstream.STREAM_CONNECTED_EVENT, self.onConnectionMade)
        self.addBootstrap(xmlstream.STREAM_AUTHD_EVENT, self.onAuthenticated)

        self.serial = 0

    def onConnectionMade(self, xs):
        """
        Called when a server-to-server connection was made.

        This enables traffic debugging on incoming streams.
        """
        xs.serial = self.serial
        self.serial += 1

        def logDataIn(buf):
            log.msg("RECV (%d): %r" % (xs.serial, buf))

        def logDataOut(buf):
            log.msg("SEND (%d): %r" % (xs.serial, buf))

        if self.logTraffic:
            xs.rawDataInFn = logDataIn
            xs.rawDataOutFn = logDataOut

        xs.addObserver(xmlstream.STREAM_ERROR_EVENT, self.onError)

    def onAuthenticated(self, xs):
        thisHost = xs.thisEntity.host
        otherHost = xs.otherEntity.host

        log.msg(
            "Incoming connection %d from %r to %r established"
            % (xs.serial, otherHost, thisHost)
        )

        xs.addObserver(xmlstream.STREAM_END_EVENT, self.onConnectionLost, 0, xs)
        xs.addObserver("/*", self.onElement, 0, xs)

    def onConnectionLost(self, xs, reason):
        thisHost = xs.thisEntity.host
        otherHost = xs.otherEntity.host

        log.msg(
            "Incoming connection %d from %r to %r disconnected"
            % (xs.serial, otherHost, thisHost)
        )

    def onError(self, reason):
        log.err(reason, "Stream Error")

    def onElement(self, xs, element):
        """
        Called when an element was received from one of the connected streams.

        """
        if element.handled:
            return
        else:
            self.service.dispatch(xs, element)


class ServerService(object):
    """
    Service for managing XMPP server to server connections.
    """

    logTraffic = False

    def __init__(self, router, domain=None, secret=None):
        self.router = router

        self.defaultDomain = domain
        self.domains = set()
        if self.defaultDomain:
            self.domains.add(self.defaultDomain)

        if secret is not None:
            self.secret = secret
        else:
            self.secret = binascii.hexlify(randbytes.secureRandom(16))

        self._outgoingStreams = {}
        self._outgoingQueues = {}
        self._outgoingConnecting = set()
        self.serial = 0

        pipe = XmlPipe()
        self.xmlstream = pipe.source
        self.router.addRoute(None, pipe.sink)
        self.xmlstream.addObserver("/*", self.send)

    def outgoingInitialized(self, xs):
        thisHost = xs.thisEntity.host
        otherHost = xs.otherEntity.host

        log.msg(
            "Outgoing connection %d from %r to %r established"
            % (xs.serial, thisHost, otherHost)
        )

        self._outgoingStreams[thisHost, otherHost] = xs
        xs.addObserver(
            xmlstream.STREAM_END_EVENT, lambda _: self.outgoingDisconnected(xs)
        )

        if (thisHost, otherHost) in self._outgoingQueues:
            for element in self._outgoingQueues[thisHost, otherHost]:
                xs.send(element)
            del self._outgoingQueues[thisHost, otherHost]

    def outgoingDisconnected(self, xs):
        thisHost = xs.thisEntity.host
        otherHost = xs.otherEntity.host

        log.msg(
            "Outgoing connection %d from %r to %r disconnected"
            % (xs.serial, thisHost, otherHost)
        )

        del self._outgoingStreams[thisHost, otherHost]

    def initiateOutgoingStream(self, thisHost, otherHost):
        """
        Initiate an outgoing XMPP server-to-server connection.
        """

        def resetConnecting(_):
            self._outgoingConnecting.remove((thisHost, otherHost))

        if (thisHost, otherHost) in self._outgoingConnecting:
            return

        authenticator = XMPPServerConnectAuthenticator(thisHost, otherHost, self.secret)
        factory = DeferredS2SClientFactory(authenticator)
        factory.addBootstrap(xmlstream.STREAM_AUTHD_EVENT, self.outgoingInitialized)
        factory.logTraffic = self.logTraffic

        self._outgoingConnecting.add((thisHost, otherHost))

        d = initiateS2S(factory)
        d.addBoth(resetConnecting)
        return d

    def validateConnection(self, thisHost, otherHost, sid, key):
        """
        Validate an incoming XMPP server-to-server connection.
        """

        def connected(xs):
            # Set up stream for immediate disconnection.
            def disconnect(_):
                xs.transport.loseConnection()

            xs.addObserver(xmlstream.STREAM_AUTHD_EVENT, disconnect)
            xs.addObserver(xmlstream.INIT_FAILED_EVENT, disconnect)

        authenticator = XMPPServerVerifyAuthenticator(thisHost, otherHost, sid, key)
        factory = DeferredS2SClientFactory(authenticator)
        factory.addBootstrap(xmlstream.STREAM_CONNECTED_EVENT, connected)
        factory.logTraffic = self.logTraffic

        d = initiateS2S(factory)
        return d

    def send(self, stanza):
        """
        Send stanza to the proper XML Stream.

        This uses addressing embedded in the stanza to find the correct stream
        to forward the stanza to.
        """

        otherHost = jid.internJID(stanza["to"]).host
        thisHost = jid.internJID(stanza["from"]).host

        if (thisHost, otherHost) not in self._outgoingStreams:
            # There is no connection with the destination (yet). Cache the
            # outgoing stanza until the connection has been established.
            # XXX: If the connection cannot be established, the queue should
            #      be emptied at some point.
            if (thisHost, otherHost) not in self._outgoingQueues:
                self._outgoingQueues[(thisHost, otherHost)] = []
            self._outgoingQueues[(thisHost, otherHost)].append(stanza)
            self.initiateOutgoingStream(thisHost, otherHost)
        else:
            self._outgoingStreams[(thisHost, otherHost)].send(stanza)

    def dispatch(self, xs, stanza):
        """
        Send on element to be routed within the server.
        """
        stanzaFrom = stanza.getAttribute("from")
        stanzaTo = stanza.getAttribute("to")

        if not stanzaFrom or not stanzaTo:
            xs.sendStreamError(error.StreamError("improper-addressing"))
        else:
            try:
                sender = jid.internJID(stanzaFrom)
                jid.internJID(stanzaTo)
            except jid.InvalidFormat:
                log.msg("Dropping error stanza with malformed JID")

            if sender.host != xs.otherEntity.host:
                xs.sendStreamError(error.StreamError("invalid-from"))
            else:
                self.xmlstream.send(stanza)
