# -*- coding: utf-8 -*-
import base64
from hashlib import sha1
from email.parser import BytesHeaderParser
import io

import asyncio

try:                                                     # pragma: no cover
    asyncio_ensure_future = asyncio.ensure_future        # Python >= 3.5
except AttributeError:                                   # pragma: no cover
    asyncio_ensure_future = getattr(asyncio, 'async')    # Python < 3.5

from ws4py import WS_KEY, WS_VERSION
from ws4py.exc import HandshakeError
from ws4py.websocket import WebSocket

LF = b'\n'
CRLF = b'\r\n'
SPACE = b' '
EMPTY = b''

__all__ = ['WebSocketProtocol']

class WebSocketProtocol(asyncio.StreamReaderProtocol):
    def __init__(self, handler_cls):
        asyncio.StreamReaderProtocol.__init__(self, asyncio.StreamReader(),
                                              self._pseudo_connected)
        self.ws = handler_cls(self)

    def _pseudo_connected(self, reader, writer):
        pass
        
    def connection_made(self, transport):
        """
        A peer is now connected and we receive an instance
        of the underlying :class:`asyncio.Transport`.

        We :class:`asyncio.StreamReader` is created
        and the transport is associated before the
        initial HTTP handshake is undertaken.
        """
        #self.transport = transport
        #self.stream = asyncio.StreamReader()
        #self.stream.set_transport(transport)
        asyncio.StreamReaderProtocol.connection_made(self, transport)
        # Let make it concurrent for others to tag along
        f = asyncio.asyncio_ensure_future(self.handle_initial_handshake())
        f.add_done_callback(self.terminated)

    @property
    def writer(self):
        return self._stream_writer
        
    @property
    def reader(self):
        return self._stream_reader
        
    def terminated(self, f):
        if f.done() and not f.cancelled():
            ex = f.exception()
            if ex:
                response = [b'HTTP/1.0 400 Bad Request']
                response.append(b'Content-Length: 0')
                response.append(b'Connection: close')
                response.append(b'')
                response.append(b'')
                self.writer.write(CRLF.join(response))
                self.ws.close_connection()

    def close(self):
        """
        Initiate the websocket closing handshake
        which will eventuall lead to the underlying
        transport.
        """
        self.ws.close()
        
    def timeout(self):
        self.ws.close_connection()
        if self.ws.started:
            self.ws.closed(1002, "Peer connection timed-out")
        
    def connection_lost(self, exc):
        """
        The peer connection is now, the closing
        handshake won't work so let's not even try.
        However let's make the websocket handler
        be aware of it by calling its `closed`
        method.
        """
        if exc is not None:
            self.ws.close_connection()
            if self.ws.started:
                self.ws.closed(1002, "Peer connection was lost")
            
    @asyncio.coroutine
    def handle_initial_handshake(self):
        """
        Performs the HTTP handshake described in :rfc:`6455`. Note that
        this implementation is really basic and it is strongly advised
        against using it in production. It would probably break for
        most clients. If you want a better support for HTTP, please
        use a more reliable HTTP server implemented using asyncio.
        """
        request_line = yield from self.next_line()
        method, uri, req_protocol = request_line.strip().split(SPACE, 2)
        
        # GET required
        if method.upper() != b'GET':
            raise HandshakeError('HTTP method must be a GET')
        
        headers = yield from self.read_headers()
        if req_protocol == b'HTTP/1.1' and 'Host' not in headers:
            raise ValueError("Missing host header")
        
        for key, expected_value in [('Upgrade', 'websocket'),
                                     ('Connection', 'upgrade')]:
            actual_value = headers.get(key, '').lower()
            if not actual_value:
                raise HandshakeError('Header %s is not defined' % str(key))
            if expected_value not in actual_value:
                raise HandshakeError('Illegal value for header %s: %s' %
                                     (key, actual_value))

        response_headers = {}

        ws_version = WS_VERSION
        version = headers.get('Sec-WebSocket-Version')
        supported_versions = ', '.join([str(v) for v in ws_version])
        version_is_valid = False
        if version:
            try: version = int(version)
            except: pass
            else: version_is_valid = version in ws_version

        if not version_is_valid:
            response_headers['Sec-WebSocket-Version'] = supported_versions
            raise HandshakeError('Unhandled or missing WebSocket version')

        key = headers.get('Sec-WebSocket-Key')
        if key:
            ws_key = base64.b64decode(key.encode('utf-8'))
            if len(ws_key) != 16:
                raise HandshakeError("WebSocket key's length is invalid")

        protocols = []
        ws_protocols = []
        subprotocols = headers.get('Sec-WebSocket-Protocol')
        if subprotocols:
            for s in subprotocols.split(','):
                s = s.strip()
                if s in protocols:
                    ws_protocols.append(s)

        exts = []
        ws_extensions = []
        extensions = headers.get('Sec-WebSocket-Extensions')
        if extensions:
            for ext in extensions.split(','):
                ext = ext.strip()
                if ext in exts:
                    ws_extensions.append(ext)

        self.ws.protocols = ws_protocols
        self.ws.extensions = ws_extensions
        self.ws.headers = headers
                    
        response = [req_protocol + b' 101 Switching Protocols']
        response.append(b'Upgrade: websocket')
        response.append(b'Content-Type: text/plain')
        response.append(b'Content-Length: 0')
        response.append(b'Connection: Upgrade')
        response.append(b'Sec-WebSocket-Version:' + bytes(str(version), 'utf-8'))
        response.append(b'Sec-WebSocket-Accept:' + base64.b64encode(sha1(key.encode('utf-8') + WS_KEY).digest()))
        if ws_protocols:
            response.append(b'Sec-WebSocket-Protocol:' + b', '.join(ws_protocols))
        if ws_extensions:
            response.append(b'Sec-WebSocket-Extensions:' + b','.join(ws_extensions))
        response.append(b'')
        response.append(b'')
        self.writer.write(CRLF.join(response))
        yield from self.handle_websocket()

    @asyncio.coroutine
    def handle_websocket(self):
        """
        Starts the websocket process until the
        exchange is completed and terminated.
        """
        yield from self.ws.run()
        
    @asyncio.coroutine
    def read_headers(self):
        """
        Read all HTTP headers from the HTTP request
        and returns a dictionary of them.
        """
        headers = b''
        while True:
            line = yield from self.next_line()
            headers += line
            if line == CRLF:
                break
        return BytesHeaderParser().parsebytes(headers)
        
    @asyncio.coroutine
    def next_line(self):
        """
        Reads data until \r\n is met and then return all read
        bytes. 
        """
        line = yield from self.reader.readline()
        if not line.endswith(CRLF):
            raise ValueError("Missing mandatory trailing CRLF")
        return line
        
if __name__ == '__main__':
    from ws4py.async_websocket import EchoWebSocket
    
    loop = asyncio.get_event_loop()

    def start_server():
        proto_factory = lambda: WebSocketProtocol(EchoWebSocket)
        return loop.create_server(proto_factory, '', 9007)

    s = loop.run_until_complete(start_server())
    print('serving on', s.sockets[0].getsockname())
    loop.run_forever()
