
|
# -*- coding: utf-8 -*-
import base64
from hashlib import sha1
from email.parser import BytesHeaderParser
import io
import asyncio
try: # pragma: no cover
asyncio_ensure_future = asyncio.ensure_future # Python >= 3.5
except AttributeError: # pragma: no cover
asyncio_ensure_future = getattr(asyncio, 'async') # Python < 3.5
from ws4py import WS_KEY, WS_VERSION
from ws4py.exc import HandshakeError
from ws4py.websocket import WebSocket
LF = b'\n'
CRLF = b'\r\n'
SPACE = b' '
EMPTY = b''
__all__ = ['WebSocketProtocol']
class WebSocketProtocol(asyncio.StreamReaderProtocol):
def __init__(self, handler_cls):
asyncio.StreamReaderProtocol.__init__(self, asyncio.StreamReader(),
self._pseudo_connected)
self.ws = handler_cls(self)
def _pseudo_connected(self, reader, writer):
pass
def connection_made(self, transport):
"""
A peer is now connected and we receive an instance
of the underlying :class:`asyncio.Transport`.
We :class:`asyncio.StreamReader` is created
and the transport is associated before the
initial HTTP handshake is undertaken.
"""
#self.transport = transport
#self.stream = asyncio.StreamReader()
#self.stream.set_transport(transport)
asyncio.StreamReaderProtocol.connection_made(self, transport)
# Let make it concurrent for others to tag along
f = asyncio.asyncio_ensure_future(self.handle_initial_handshake())
f.add_done_callback(self.terminated)
@property
def writer(self):
return self._stream_writer
@property
def reader(self):
return self._stream_reader
def terminated(self, f):
if f.done() and not f.cancelled():
ex = f.exception()
if ex:
response = [b'HTTP/1.0 400 Bad Request']
response.append(b'Content-Length: 0')
response.append(b'Connection: close')
response.append(b'')
response.append(b'')
self.writer.write(CRLF.join(response))
self.ws.close_connection()
def close(self):
"""
Initiate the websocket closing handshake
which will eventuall lead to the underlying
transport.
"""
self.ws.close()
def timeout(self):
self.ws.close_connection()
if self.ws.started:
self.ws.closed(1002, "Peer connection timed-out")
def connection_lost(self, exc):
"""
The peer connection is now, the closing
handshake won't work so let's not even try.
However let's make the websocket handler
be aware of it by calling its `closed`
method.
"""
if exc is not None:
self.ws.close_connection()
if self.ws.started:
self.ws.closed(1002, "Peer connection was lost")
@asyncio.coroutine
def handle_initial_handshake(self):
"""
Performs the HTTP handshake described in :rfc:`6455`. Note that
this implementation is really basic and it is strongly advised
against using it in production. It would probably break for
most clients. If you want a better support for HTTP, please
use a more reliable HTTP server implemented using asyncio.
"""
request_line = yield from self.next_line()
method, uri, req_protocol = request_line.strip().split(SPACE, 2)
# GET required
if method.upper() != b'GET':
raise HandshakeError('HTTP method must be a GET')
headers = yield from self.read_headers()
if req_protocol == b'HTTP/1.1' and 'Host' not in headers:
raise ValueError("Missing host header")
for key, expected_value in [('Upgrade', 'websocket'),
('Connection', 'upgrade')]:
actual_value = headers.get(key, '').lower()
if not actual_value:
raise HandshakeError('Header %s is not defined' % str(key))
if expected_value not in actual_value:
raise HandshakeError('Illegal value for header %s: %s' %
(key, actual_value))
response_headers = {}
ws_version = WS_VERSION
version = headers.get('Sec-WebSocket-Version')
supported_versions = ', '.join([str(v) for v in ws_version])
version_is_valid = False
if version:
try: version = int(version)
except: pass
else: version_is_valid = version in ws_version
if not version_is_valid:
response_headers['Sec-WebSocket-Version'] = supported_versions
raise HandshakeError('Unhandled or missing WebSocket version')
key = headers.get('Sec-WebSocket-Key')
if key:
ws_key = base64.b64decode(key.encode('utf-8'))
if len(ws_key) != 16:
raise HandshakeError("WebSocket key's length is invalid")
protocols = []
ws_protocols = []
subprotocols = headers.get('Sec-WebSocket-Protocol')
if subprotocols:
for s in subprotocols.split(','):
s = s.strip()
if s in protocols:
ws_protocols.append(s)
exts = []
ws_extensions = []
extensions = headers.get('Sec-WebSocket-Extensions')
if extensions:
for ext in extensions.split(','):
ext = ext.strip()
if ext in exts:
ws_extensions.append(ext)
self.ws.protocols = ws_protocols
self.ws.extensions = ws_extensions
self.ws.headers = headers
response = [req_protocol + b' 101 Switching Protocols']
response.append(b'Upgrade: websocket')
response.append(b'Content-Type: text/plain')
response.append(b'Content-Length: 0')
response.append(b'Connection: Upgrade')
response.append(b'Sec-WebSocket-Version:' + bytes(str(version), 'utf-8'))
response.append(b'Sec-WebSocket-Accept:' + base64.b64encode(sha1(key.encode('utf-8') + WS_KEY).digest()))
if ws_protocols:
response.append(b'Sec-WebSocket-Protocol:' + b', '.join(ws_protocols))
if ws_extensions:
response.append(b'Sec-WebSocket-Extensions:' + b','.join(ws_extensions))
response.append(b'')
response.append(b'')
self.writer.write(CRLF.join(response))
yield from self.handle_websocket()
@asyncio.coroutine
def handle_websocket(self):
"""
Starts the websocket process until the
exchange is completed and terminated.
"""
yield from self.ws.run()
@asyncio.coroutine
def read_headers(self):
"""
Read all HTTP headers from the HTTP request
and returns a dictionary of them.
"""
headers = b''
while True:
line = yield from self.next_line()
headers += line
if line == CRLF:
break
return BytesHeaderParser().parsebytes(headers)
@asyncio.coroutine
def next_line(self):
"""
Reads data until \r\n is met and then return all read
bytes.
"""
line = yield from self.reader.readline()
if not line.endswith(CRLF):
raise ValueError("Missing mandatory trailing CRLF")
return line
if __name__ == '__main__':
from ws4py.async_websocket import EchoWebSocket
loop = asyncio.get_event_loop()
def start_server():
proto_factory = lambda: WebSocketProtocol(EchoWebSocket)
return loop.create_server(proto_factory, '', 9007)
s = loop.run_until_complete(start_server())
print('serving on', s.sockets[0].getsockname())
loop.run_forever()
|