1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339
|
# -*- coding: utf-8 -*-
from base64 import b64encode
from hashlib import sha1
import os
import socket
import ssl
from ws4py import WS_KEY, WS_VERSION
from ws4py.exc import HandshakeError
from ws4py.websocket import WebSocket
from ws4py.compat import urlsplit
__all__ = ['WebSocketBaseClient']
class WebSocketBaseClient(WebSocket):
def __init__(self, url, protocols=None, extensions=None,
heartbeat_freq=None, ssl_options=None, headers=None):
"""
A websocket client that implements :rfc:`6455` and provides a simple
interface to communicate with a websocket server.
This class works on its own but will block if not run in
its own thread.
When an instance of this class is created, a :py:mod:`socket`
is created. If the connection is a TCP socket,
the nagle's algorithm is disabled.
The address of the server will be extracted from the given
websocket url.
The websocket key is randomly generated, reset the
`key` attribute if you want to provide yours.
For instance to create a TCP client:
.. code-block:: python
>>> from websocket.client import WebSocketBaseClient
>>> ws = WebSocketBaseClient('ws://localhost/ws')
Here is an example for a TCP client over SSL:
.. code-block:: python
>>> from websocket.client import WebSocketBaseClient
>>> ws = WebSocketBaseClient('wss://localhost/ws')
Finally an example of a Unix-domain connection:
.. code-block:: python
>>> from websocket.client import WebSocketBaseClient
>>> ws = WebSocketBaseClient('ws+unix:///tmp/my.sock')
Note that in this case, the initial Upgrade request
will be sent to ``/``. You may need to change this
by setting the resource explicitely before connecting:
.. code-block:: python
>>> from websocket.client import WebSocketBaseClient
>>> ws = WebSocketBaseClient('ws+unix:///tmp/my.sock')
>>> ws.resource = '/ws'
>>> ws.connect()
You may provide extra headers by passing a list of tuples
which must be unicode objects.
"""
self.url = url
self.host = None
self.scheme = None
self.port = None
self.unix_socket_path = None
self.resource = None
self.ssl_options = ssl_options or {}
self.extra_headers = headers or []
if self.scheme == "wss":
# Prevent check_hostname requires server_hostname (ref #187)
if "cert_reqs" not in self.ssl_options:
self.ssl_options["cert_reqs"] = ssl.CERT_NONE
self._parse_url()
if self.unix_socket_path:
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0)
else:
# Let's handle IPv4 and IPv6 addresses
# Simplified from CherryPy's code
try:
family, socktype, proto, canonname, sa = socket.getaddrinfo(self.host, self.port,
socket.AF_UNSPEC,
socket.SOCK_STREAM,
0, socket.AI_PASSIVE)[0]
except socket.gaierror:
family = socket.AF_INET
if self.host.startswith('::'):
family = socket.AF_INET6
socktype = socket.SOCK_STREAM
proto = 0
canonname = ""
sa = (self.host, self.port, 0, 0)
sock = socket.socket(family, socktype, proto)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if hasattr(socket, 'AF_INET6') and family == socket.AF_INET6 and \
self.host.startswith('::'):
try:
sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
except (AttributeError, socket.error):
pass
WebSocket.__init__(self, sock, protocols=protocols,
extensions=extensions,
heartbeat_freq=heartbeat_freq)
self.stream.always_mask = True
self.stream.expect_masking = False
self.key = b64encode(os.urandom(16))
# Adpated from: https://github.com/liris/websocket-client/blob/master/websocket.py#L105
def _parse_url(self):
"""
Parses a URL which must have one of the following forms:
- ws://host[:port][path]
- wss://host[:port][path]
- ws+unix:///path/to/my.socket
In the first two cases, the ``host`` and ``port``
attributes will be set to the parsed values. If no port
is explicitely provided, it will be either 80 or 443
based on the scheme. Also, the ``resource`` attribute is
set to the path segment of the URL (alongside any querystring).
In addition, if the scheme is ``ws+unix``, the
``unix_socket_path`` attribute is set to the path to
the Unix socket while the ``resource`` attribute is
set to ``/``.
"""
# Python 2.6.1 and below don't parse ws or wss urls properly. netloc is empty.
# See: https://github.com/Lawouach/WebSocket-for-Python/issues/59
scheme, url = self.url.split(":", 1)
parsed = urlsplit(url, scheme="http")
if parsed.hostname:
self.host = parsed.hostname
elif '+unix' in scheme:
self.host = 'localhost'
else:
raise ValueError("Invalid hostname from: %s", self.url)
if parsed.port:
self.port = parsed.port
if scheme == "ws":
if not self.port:
self.port = 80
elif scheme == "wss":
if not self.port:
self.port = 443
elif scheme in ('ws+unix', 'wss+unix'):
pass
else:
raise ValueError("Invalid scheme: %s" % scheme)
if parsed.path:
resource = parsed.path
else:
resource = "/"
if '+unix' in scheme:
self.unix_socket_path = resource
resource = '/'
if parsed.query:
resource += "?" + parsed.query
self.scheme = scheme
self.resource = resource
@property
def bind_addr(self):
"""
Returns the Unix socket path if or a tuple
``(host, port)`` depending on the initial
URL's scheme.
"""
return self.unix_socket_path or (self.host, self.port)
def close(self, code=1000, reason=''):
"""
Initiate the closing handshake with the server.
"""
if not self.client_terminated:
self.client_terminated = True
self._write(self.stream.close(code=code, reason=reason).single(mask=True))
def connect(self):
"""
Connects this websocket and starts the upgrade handshake
with the remote endpoint.
"""
if self.scheme == "wss":
# default port is now 443; upgrade self.sender to send ssl
self.sock = ssl.wrap_socket(self.sock, **self.ssl_options)
self._is_secure = True
self.sock.connect(self.bind_addr)
self._write(self.handshake_request)
response = b''
doubleCLRF = b'\r\n\r\n'
while True:
bytes = self.sock.recv(128)
if not bytes:
break
response += bytes
if doubleCLRF in response:
break
if not response:
self.close_connection()
raise HandshakeError("Invalid response")
headers, _, body = response.partition(doubleCLRF)
response_line, _, headers = headers.partition(b'\r\n')
try:
self.process_response_line(response_line)
self.protocols, self.extensions = self.process_handshake_header(headers)
except HandshakeError:
self.close_connection()
raise
self.handshake_ok()
if body:
self.process(body)
@property
def handshake_headers(self):
"""
List of headers appropriate for the upgrade
handshake.
"""
headers = [
('Host', '%s:%s' % (self.host, self.port)),
('Connection', 'Upgrade'),
('Upgrade', 'websocket'),
('Sec-WebSocket-Key', self.key.decode('utf-8')),
('Sec-WebSocket-Version', str(max(WS_VERSION)))
]
if self.protocols:
headers.append(('Sec-WebSocket-Protocol', ','.join(self.protocols)))
if self.extra_headers:
headers.extend(self.extra_headers)
if not any(x for x in headers if x[0].lower() == 'origin'):
scheme, url = self.url.split(":", 1)
parsed = urlsplit(url, scheme="http")
if parsed.hostname:
self.host = parsed.hostname
else:
self.host = 'localhost'
origin = scheme + '://' + self.host
if parsed.port:
origin = origin + ':' + str(parsed.port)
headers.append(('Origin', origin))
return headers
@property
def handshake_request(self):
"""
Prepare the request to be sent for the upgrade handshake.
"""
headers = self.handshake_headers
request = [("GET %s HTTP/1.1" % self.resource).encode('utf-8')]
for header, value in headers:
request.append(("%s: %s" % (header, value)).encode('utf-8'))
request.append(b'\r\n')
return b'\r\n'.join(request)
def process_response_line(self, response_line):
"""
Ensure that we received a HTTP `101` status code in
response to our request and if not raises :exc:`HandshakeError`.
"""
protocol, code, status = response_line.split(b' ', 2)
if code != b'101':
raise HandshakeError("Invalid response status: %s %s" % (code, status))
def process_handshake_header(self, headers):
"""
Read the upgrade handshake's response headers and
validate them against :rfc:`6455`.
"""
protocols = []
extensions = []
headers = headers.strip()
for header_line in headers.split(b'\r\n'):
header, value = header_line.split(b':', 1)
header = header.strip().lower()
value = value.strip().lower()
if header == b'upgrade' and value != b'websocket':
raise HandshakeError("Invalid Upgrade header: %s" % value)
elif header == b'connection' and value != b'upgrade':
raise HandshakeError("Invalid Connection header: %s" % value)
elif header == b'sec-websocket-accept':
match = b64encode(sha1(self.key + WS_KEY).digest())
if value != match.lower():
raise HandshakeError("Invalid challenge response: %s" % value)
elif header == b'sec-websocket-protocol':
protocols = ','.join(value)
elif header == b'sec-websocket-extensions':
extensions = ','.join(value)
return protocols, extensions
def handshake_ok(self):
self.opened()
|