1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229
|
# -*- coding: utf-8 -*-
import base64
from hashlib import sha1
from email.parser import BytesHeaderParser
import io
import asyncio
try: # pragma: no cover
asyncio_ensure_future = asyncio.ensure_future # Python >= 3.5
except AttributeError: # pragma: no cover
asyncio_ensure_future = getattr(asyncio, 'async') # Python < 3.5
from ws4py import WS_KEY, WS_VERSION
from ws4py.exc import HandshakeError
from ws4py.websocket import WebSocket
LF = b'\n'
CRLF = b'\r\n'
SPACE = b' '
EMPTY = b''
__all__ = ['WebSocketProtocol']
class WebSocketProtocol(asyncio.StreamReaderProtocol):
def __init__(self, handler_cls):
asyncio.StreamReaderProtocol.__init__(self, asyncio.StreamReader(),
self._pseudo_connected)
self.ws = handler_cls(self)
def _pseudo_connected(self, reader, writer):
pass
def connection_made(self, transport):
"""
A peer is now connected and we receive an instance
of the underlying :class:`asyncio.Transport`.
We :class:`asyncio.StreamReader` is created
and the transport is associated before the
initial HTTP handshake is undertaken.
"""
#self.transport = transport
#self.stream = asyncio.StreamReader()
#self.stream.set_transport(transport)
asyncio.StreamReaderProtocol.connection_made(self, transport)
# Let make it concurrent for others to tag along
f = asyncio.asyncio_ensure_future(self.handle_initial_handshake())
f.add_done_callback(self.terminated)
@property
def writer(self):
return self._stream_writer
@property
def reader(self):
return self._stream_reader
def terminated(self, f):
if f.done() and not f.cancelled():
ex = f.exception()
if ex:
response = [b'HTTP/1.0 400 Bad Request']
response.append(b'Content-Length: 0')
response.append(b'Connection: close')
response.append(b'')
response.append(b'')
self.writer.write(CRLF.join(response))
self.ws.close_connection()
def close(self):
"""
Initiate the websocket closing handshake
which will eventuall lead to the underlying
transport.
"""
self.ws.close()
def timeout(self):
self.ws.close_connection()
if self.ws.started:
self.ws.closed(1002, "Peer connection timed-out")
def connection_lost(self, exc):
"""
The peer connection is now, the closing
handshake won't work so let's not even try.
However let's make the websocket handler
be aware of it by calling its `closed`
method.
"""
if exc is not None:
self.ws.close_connection()
if self.ws.started:
self.ws.closed(1002, "Peer connection was lost")
@asyncio.coroutine
def handle_initial_handshake(self):
"""
Performs the HTTP handshake described in :rfc:`6455`. Note that
this implementation is really basic and it is strongly advised
against using it in production. It would probably break for
most clients. If you want a better support for HTTP, please
use a more reliable HTTP server implemented using asyncio.
"""
request_line = yield from self.next_line()
method, uri, req_protocol = request_line.strip().split(SPACE, 2)
# GET required
if method.upper() != b'GET':
raise HandshakeError('HTTP method must be a GET')
headers = yield from self.read_headers()
if req_protocol == b'HTTP/1.1' and 'Host' not in headers:
raise ValueError("Missing host header")
for key, expected_value in [('Upgrade', 'websocket'),
('Connection', 'upgrade')]:
actual_value = headers.get(key, '').lower()
if not actual_value:
raise HandshakeError('Header %s is not defined' % str(key))
if expected_value not in actual_value:
raise HandshakeError('Illegal value for header %s: %s' %
(key, actual_value))
response_headers = {}
ws_version = WS_VERSION
version = headers.get('Sec-WebSocket-Version')
supported_versions = ', '.join([str(v) for v in ws_version])
version_is_valid = False
if version:
try: version = int(version)
except: pass
else: version_is_valid = version in ws_version
if not version_is_valid:
response_headers['Sec-WebSocket-Version'] = supported_versions
raise HandshakeError('Unhandled or missing WebSocket version')
key = headers.get('Sec-WebSocket-Key')
if key:
ws_key = base64.b64decode(key.encode('utf-8'))
if len(ws_key) != 16:
raise HandshakeError("WebSocket key's length is invalid")
protocols = []
ws_protocols = []
subprotocols = headers.get('Sec-WebSocket-Protocol')
if subprotocols:
for s in subprotocols.split(','):
s = s.strip()
if s in protocols:
ws_protocols.append(s)
exts = []
ws_extensions = []
extensions = headers.get('Sec-WebSocket-Extensions')
if extensions:
for ext in extensions.split(','):
ext = ext.strip()
if ext in exts:
ws_extensions.append(ext)
self.ws.protocols = ws_protocols
self.ws.extensions = ws_extensions
self.ws.headers = headers
response = [req_protocol + b' 101 Switching Protocols']
response.append(b'Upgrade: websocket')
response.append(b'Content-Type: text/plain')
response.append(b'Content-Length: 0')
response.append(b'Connection: Upgrade')
response.append(b'Sec-WebSocket-Version:' + bytes(str(version), 'utf-8'))
response.append(b'Sec-WebSocket-Accept:' + base64.b64encode(sha1(key.encode('utf-8') + WS_KEY).digest()))
if ws_protocols:
response.append(b'Sec-WebSocket-Protocol:' + b', '.join(ws_protocols))
if ws_extensions:
response.append(b'Sec-WebSocket-Extensions:' + b','.join(ws_extensions))
response.append(b'')
response.append(b'')
self.writer.write(CRLF.join(response))
yield from self.handle_websocket()
@asyncio.coroutine
def handle_websocket(self):
"""
Starts the websocket process until the
exchange is completed and terminated.
"""
yield from self.ws.run()
@asyncio.coroutine
def read_headers(self):
"""
Read all HTTP headers from the HTTP request
and returns a dictionary of them.
"""
headers = b''
while True:
line = yield from self.next_line()
headers += line
if line == CRLF:
break
return BytesHeaderParser().parsebytes(headers)
@asyncio.coroutine
def next_line(self):
"""
Reads data until \r\n is met and then return all read
bytes.
"""
line = yield from self.reader.readline()
if not line.endswith(CRLF):
raise ValueError("Missing mandatory trailing CRLF")
return line
if __name__ == '__main__':
from ws4py.async_websocket import EchoWebSocket
loop = asyncio.get_event_loop()
def start_server():
proto_factory = lambda: WebSocketProtocol(EchoWebSocket)
return loop.create_server(proto_factory, '', 9007)
s = loop.run_until_complete(start_server())
print('serving on', s.sockets[0].getsockname())
loop.run_forever()
|