1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
|
from __future__ import absolute_import, division, print_function, with_statement
import traceback
from tornado.concurrent import Future
from tornado.httpclient import HTTPError, HTTPRequest
from tornado.log import gen_log
from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog
from tornado.test.util import unittest, skipOnTravis
from tornado.web import Application, RequestHandler
try:
import tornado.websocket
from tornado.util import _websocket_mask_python
except ImportError:
# The unittest module presents misleading errors on ImportError
# (it acts as if websocket_test could not be found, hiding the underlying
# error). If we get an ImportError here (which could happen due to
# TORNADO_EXTENSION=1), print some extra information before failing.
traceback.print_exc()
raise
from tornado.websocket import WebSocketHandler, websocket_connect, WebSocketError
try:
from tornado import speedups
except ImportError:
speedups = None
class TestWebSocketHandler(WebSocketHandler):
"""Base class for testing handlers that exposes the on_close event.
This allows for deterministic cleanup of the associated socket.
"""
def initialize(self, close_future):
self.close_future = close_future
def on_close(self):
self.close_future.set_result(None)
class EchoHandler(TestWebSocketHandler):
def on_message(self, message):
self.write_message(message, isinstance(message, bytes))
class HeaderHandler(TestWebSocketHandler):
def open(self):
self.write_message(self.request.headers.get('X-Test', ''))
class NonWebSocketHandler(RequestHandler):
def get(self):
self.write('ok')
class WebSocketTest(AsyncHTTPTestCase):
def get_app(self):
self.close_future = Future()
return Application([
('/echo', EchoHandler, dict(close_future=self.close_future)),
('/non_ws', NonWebSocketHandler),
('/header', HeaderHandler, dict(close_future=self.close_future)),
])
@gen_test
def test_websocket_gen(self):
ws = yield websocket_connect(
'ws://localhost:%d/echo' % self.get_http_port(),
io_loop=self.io_loop)
ws.write_message('hello')
response = yield ws.read_message()
self.assertEqual(response, 'hello')
ws.close()
yield self.close_future
def test_websocket_callbacks(self):
websocket_connect(
'ws://localhost:%d/echo' % self.get_http_port(),
io_loop=self.io_loop, callback=self.stop)
ws = self.wait().result()
ws.write_message('hello')
ws.read_message(self.stop)
response = self.wait().result()
self.assertEqual(response, 'hello')
ws.close()
yield self.close_future
@gen_test
def test_websocket_http_fail(self):
with self.assertRaises(HTTPError) as cm:
yield websocket_connect(
'ws://localhost:%d/notfound' % self.get_http_port(),
io_loop=self.io_loop)
self.assertEqual(cm.exception.code, 404)
@gen_test
def test_websocket_http_success(self):
with self.assertRaises(WebSocketError):
yield websocket_connect(
'ws://localhost:%d/non_ws' % self.get_http_port(),
io_loop=self.io_loop)
@skipOnTravis
@gen_test
def test_websocket_network_timeout(self):
sock, port = bind_unused_port()
sock.close()
with self.assertRaises(HTTPError) as cm:
with ExpectLog(gen_log, ".*"):
yield websocket_connect(
'ws://localhost:%d/' % port,
io_loop=self.io_loop,
connect_timeout=0.01)
self.assertEqual(cm.exception.code, 599)
@gen_test
def test_websocket_network_fail(self):
sock, port = bind_unused_port()
sock.close()
with self.assertRaises(HTTPError) as cm:
with ExpectLog(gen_log, ".*"):
yield websocket_connect(
'ws://localhost:%d/' % port,
io_loop=self.io_loop,
connect_timeout=3600)
self.assertEqual(cm.exception.code, 599)
@gen_test
def test_websocket_close_buffered_data(self):
ws = yield websocket_connect(
'ws://localhost:%d/echo' % self.get_http_port())
ws.write_message('hello')
ws.write_message('world')
ws.stream.close()
yield self.close_future
@gen_test
def test_websocket_headers(self):
# Ensure that arbitrary headers can be passed through websocket_connect.
ws = yield websocket_connect(
HTTPRequest('ws://localhost:%d/header' % self.get_http_port(),
headers={'X-Test': 'hello'}))
response = yield ws.read_message()
self.assertEqual(response, 'hello')
ws.close()
yield self.close_future
class MaskFunctionMixin(object):
# Subclasses should define self.mask(mask, data)
def test_mask(self):
self.assertEqual(self.mask(b'abcd', b''), b'')
self.assertEqual(self.mask(b'abcd', b'b'), b'\x03')
self.assertEqual(self.mask(b'abcd', b'54321'), b'TVPVP')
self.assertEqual(self.mask(b'ZXCV', b'98765432'), b'c`t`olpd')
# Include test cases with \x00 bytes (to ensure that the C
# extension isn't depending on null-terminated strings) and
# bytes with the high bit set (to smoke out signedness issues).
self.assertEqual(self.mask(b'\x00\x01\x02\x03',
b'\xff\xfb\xfd\xfc\xfe\xfa'),
b'\xff\xfa\xff\xff\xfe\xfb')
self.assertEqual(self.mask(b'\xff\xfb\xfd\xfc',
b'\x00\x01\x02\x03\x04\x05'),
b'\xff\xfa\xff\xff\xfb\xfe')
class PythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase):
def mask(self, mask, data):
return _websocket_mask_python(mask, data)
@unittest.skipIf(speedups is None, "tornado.speedups module not present")
class CythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase):
def mask(self, mask, data):
return speedups.websocket_mask(mask, data)
|