1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410
|
#!/usr/bin/env python
from __future__ import absolute_import, division, with_statement
from tornado import httpclient, simple_httpclient, netutil
from tornado.escape import json_decode, utf8, _unicode, recursive_unicode, native_str
from tornado.httpserver import HTTPServer
from tornado.httputil import HTTPHeaders
from tornado.iostream import IOStream
from tornado.simple_httpclient import SimpleAsyncHTTPClient
from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase, AsyncTestCase
from tornado.util import b, bytes_type
from tornado.web import Application, RequestHandler
import os
import shutil
import socket
import sys
import tempfile
try:
import ssl
except ImportError:
ssl = None
class HandlerBaseTestCase(AsyncHTTPTestCase, LogTrapTestCase):
def get_app(self):
return Application([('/', self.__class__.Handler)])
def fetch_json(self, *args, **kwargs):
response = self.fetch(*args, **kwargs)
response.rethrow()
return json_decode(response.body)
class HelloWorldRequestHandler(RequestHandler):
def initialize(self, protocol="http"):
self.expected_protocol = protocol
def get(self):
assert self.request.protocol == self.expected_protocol
self.finish("Hello world")
def post(self):
self.finish("Got %d bytes in POST" % len(self.request.body))
class BaseSSLTest(AsyncHTTPTestCase, LogTrapTestCase):
def get_ssl_version(self):
raise NotImplementedError()
def setUp(self):
super(BaseSSLTest, self).setUp()
# Replace the client defined in the parent class.
# Some versions of libcurl have deadlock bugs with ssl,
# so always run these tests with SimpleAsyncHTTPClient.
self.http_client = SimpleAsyncHTTPClient(io_loop=self.io_loop,
force_instance=True)
def get_app(self):
return Application([('/', HelloWorldRequestHandler,
dict(protocol="https"))])
def get_httpserver_options(self):
# Testing keys were generated with:
# openssl req -new -keyout tornado/test/test.key -out tornado/test/test.crt -nodes -days 3650 -x509
test_dir = os.path.dirname(__file__)
return dict(ssl_options=dict(
certfile=os.path.join(test_dir, 'test.crt'),
keyfile=os.path.join(test_dir, 'test.key'),
ssl_version=self.get_ssl_version()))
def fetch(self, path, **kwargs):
self.http_client.fetch(self.get_url(path).replace('http', 'https'),
self.stop,
validate_cert=False,
**kwargs)
return self.wait()
class SSLTestMixin(object):
def test_ssl(self):
response = self.fetch('/')
self.assertEqual(response.body, b("Hello world"))
def test_large_post(self):
response = self.fetch('/',
method='POST',
body='A' * 5000)
self.assertEqual(response.body, b("Got 5000 bytes in POST"))
def test_non_ssl_request(self):
# Make sure the server closes the connection when it gets a non-ssl
# connection, rather than waiting for a timeout or otherwise
# misbehaving.
self.http_client.fetch(self.get_url("/"), self.stop,
request_timeout=3600,
connect_timeout=3600)
response = self.wait()
self.assertEqual(response.code, 599)
# Python's SSL implementation differs significantly between versions.
# For example, SSLv3 and TLSv1 throw an exception if you try to read
# from the socket before the handshake is complete, but the default
# of SSLv23 allows it.
class SSLv23Test(BaseSSLTest, SSLTestMixin):
def get_ssl_version(self):
return ssl.PROTOCOL_SSLv23
class SSLv3Test(BaseSSLTest, SSLTestMixin):
def get_ssl_version(self):
return ssl.PROTOCOL_SSLv3
class TLSv1Test(BaseSSLTest, SSLTestMixin):
def get_ssl_version(self):
return ssl.PROTOCOL_TLSv1
if hasattr(ssl, 'PROTOCOL_SSLv2'):
class SSLv2Test(BaseSSLTest):
def get_ssl_version(self):
return ssl.PROTOCOL_SSLv2
def test_sslv2_fail(self):
# This is really more of a client test, but run it here since
# we've got all the other ssl version tests here.
# Clients should have SSLv2 disabled by default.
try:
# The server simply closes the connection when it gets
# an SSLv2 ClientHello packet.
# request_timeout is needed here because on some platforms
# (cygwin, but not native windows python), the close is not
# detected promptly.
response = self.fetch('/', request_timeout=1)
except ssl.SSLError:
# In some python/ssl builds the PROTOCOL_SSLv2 constant
# exists but SSLv2 support is still compiled out, which
# would result in an SSLError here (details vary depending
# on python version). The important thing is that
# SSLv2 request's don't succeed, so we can just ignore
# the errors here.
return
self.assertEqual(response.code, 599)
if ssl is None:
del BaseSSLTest
del SSLv23Test
del SSLv3Test
del TLSv1Test
elif getattr(ssl, 'OPENSSL_VERSION_INFO', (0, 0)) < (1, 0):
# In pre-1.0 versions of openssl, SSLv23 clients always send SSLv2
# ClientHello messages, which are rejected by SSLv3 and TLSv1
# servers. Note that while the OPENSSL_VERSION_INFO was formally
# introduced in python3.2, it was present but undocumented in
# python 2.7
del SSLv3Test
del TLSv1Test
class MultipartTestHandler(RequestHandler):
def post(self):
self.finish({"header": self.request.headers["X-Header-Encoding-Test"],
"argument": self.get_argument("argument"),
"filename": self.request.files["files"][0].filename,
"filebody": _unicode(self.request.files["files"][0]["body"]),
})
class RawRequestHTTPConnection(simple_httpclient._HTTPConnection):
def set_request(self, request):
self.__next_request = request
def _on_connect(self, parsed, parsed_hostname):
self.stream.write(self.__next_request)
self.__next_request = None
self.stream.read_until(b("\r\n\r\n"), self._on_headers)
# This test is also called from wsgi_test
class HTTPConnectionTest(AsyncHTTPTestCase, LogTrapTestCase):
def get_handlers(self):
return [("/multipart", MultipartTestHandler),
("/hello", HelloWorldRequestHandler)]
def get_app(self):
return Application(self.get_handlers())
def raw_fetch(self, headers, body):
client = SimpleAsyncHTTPClient(self.io_loop)
conn = RawRequestHTTPConnection(self.io_loop, client,
httpclient.HTTPRequest(self.get_url("/")),
None, self.stop,
1024 * 1024)
conn.set_request(
b("\r\n").join(headers +
[utf8("Content-Length: %d\r\n" % len(body))]) +
b("\r\n") + body)
response = self.wait()
client.close()
response.rethrow()
return response
def test_multipart_form(self):
# Encodings here are tricky: Headers are latin1, bodies can be
# anything (we use utf8 by default).
response = self.raw_fetch([
b("POST /multipart HTTP/1.0"),
b("Content-Type: multipart/form-data; boundary=1234567890"),
b("X-Header-encoding-test: \xe9"),
],
b("\r\n").join([
b("Content-Disposition: form-data; name=argument"),
b(""),
u"\u00e1".encode("utf-8"),
b("--1234567890"),
u'Content-Disposition: form-data; name="files"; filename="\u00f3"'.encode("utf8"),
b(""),
u"\u00fa".encode("utf-8"),
b("--1234567890--"),
b(""),
]))
data = json_decode(response.body)
self.assertEqual(u"\u00e9", data["header"])
self.assertEqual(u"\u00e1", data["argument"])
self.assertEqual(u"\u00f3", data["filename"])
self.assertEqual(u"\u00fa", data["filebody"])
def test_100_continue(self):
# Run through a 100-continue interaction by hand:
# When given Expect: 100-continue, we get a 100 response after the
# headers, and then the real response after the body.
stream = IOStream(socket.socket(), io_loop=self.io_loop)
stream.connect(("localhost", self.get_http_port()), callback=self.stop)
self.wait()
stream.write(b("\r\n").join([b("POST /hello HTTP/1.1"),
b("Content-Length: 1024"),
b("Expect: 100-continue"),
b("Connection: close"),
b("\r\n")]), callback=self.stop)
self.wait()
stream.read_until(b("\r\n\r\n"), self.stop)
data = self.wait()
self.assertTrue(data.startswith(b("HTTP/1.1 100 ")), data)
stream.write(b("a") * 1024)
stream.read_until(b("\r\n"), self.stop)
first_line = self.wait()
self.assertTrue(first_line.startswith(b("HTTP/1.1 200")), first_line)
stream.read_until(b("\r\n\r\n"), self.stop)
header_data = self.wait()
headers = HTTPHeaders.parse(native_str(header_data.decode('latin1')))
stream.read_bytes(int(headers["Content-Length"]), self.stop)
body = self.wait()
self.assertEqual(body, b("Got 1024 bytes in POST"))
stream.close()
class EchoHandler(RequestHandler):
def get(self):
self.write(recursive_unicode(self.request.arguments))
class TypeCheckHandler(RequestHandler):
def prepare(self):
self.errors = {}
fields = [
('method', str),
('uri', str),
('version', str),
('remote_ip', str),
('protocol', str),
('host', str),
('path', str),
('query', str),
]
for field, expected_type in fields:
self.check_type(field, getattr(self.request, field), expected_type)
self.check_type('header_key', self.request.headers.keys()[0], str)
self.check_type('header_value', self.request.headers.values()[0], str)
self.check_type('cookie_key', self.request.cookies.keys()[0], str)
self.check_type('cookie_value', self.request.cookies.values()[0].value, str)
# secure cookies
self.check_type('arg_key', self.request.arguments.keys()[0], str)
self.check_type('arg_value', self.request.arguments.values()[0][0], bytes_type)
def post(self):
self.check_type('body', self.request.body, bytes_type)
self.write(self.errors)
def get(self):
self.write(self.errors)
def check_type(self, name, obj, expected_type):
actual_type = type(obj)
if expected_type != actual_type:
self.errors[name] = "expected %s, got %s" % (expected_type,
actual_type)
class HTTPServerTest(AsyncHTTPTestCase, LogTrapTestCase):
def get_app(self):
return Application([("/echo", EchoHandler),
("/typecheck", TypeCheckHandler),
("//doubleslash", EchoHandler),
])
def test_query_string_encoding(self):
response = self.fetch("/echo?foo=%C3%A9")
data = json_decode(response.body)
self.assertEqual(data, {u"foo": [u"\u00e9"]})
def test_types(self):
headers = {"Cookie": "foo=bar"}
response = self.fetch("/typecheck?foo=bar", headers=headers)
data = json_decode(response.body)
self.assertEqual(data, {})
response = self.fetch("/typecheck", method="POST", body="foo=bar", headers=headers)
data = json_decode(response.body)
self.assertEqual(data, {})
def test_double_slash(self):
# urlparse.urlsplit (which tornado.httpserver used to use
# incorrectly) would parse paths beginning with "//" as
# protocol-relative urls.
response = self.fetch("//doubleslash")
self.assertEqual(200, response.code)
self.assertEqual(json_decode(response.body), {})
class XHeaderTest(HandlerBaseTestCase):
class Handler(RequestHandler):
def get(self):
self.write(dict(remote_ip=self.request.remote_ip))
def get_httpserver_options(self):
return dict(xheaders=True)
def test_ip_headers(self):
self.assertEqual(self.fetch_json("/")["remote_ip"],
"127.0.0.1")
valid_ipv4 = {"X-Real-IP": "4.4.4.4"}
self.assertEqual(
self.fetch_json("/", headers=valid_ipv4)["remote_ip"],
"4.4.4.4")
valid_ipv6 = {"X-Real-IP": "2620:0:1cfe:face:b00c::3"}
self.assertEqual(
self.fetch_json("/", headers=valid_ipv6)["remote_ip"],
"2620:0:1cfe:face:b00c::3")
invalid_chars = {"X-Real-IP": "4.4.4.4<script>"}
self.assertEqual(
self.fetch_json("/", headers=invalid_chars)["remote_ip"],
"127.0.0.1")
invalid_host = {"X-Real-IP": "www.google.com"}
self.assertEqual(
self.fetch_json("/", headers=invalid_host)["remote_ip"],
"127.0.0.1")
class UnixSocketTest(AsyncTestCase, LogTrapTestCase):
"""HTTPServers can listen on Unix sockets too.
Why would you want to do this? Nginx can proxy to backends listening
on unix sockets, for one thing (and managing a namespace for unix
sockets can be easier than managing a bunch of TCP port numbers).
Unfortunately, there's no way to specify a unix socket in a url for
an HTTP client, so we have to test this by hand.
"""
def setUp(self):
super(UnixSocketTest, self).setUp()
self.tmpdir = tempfile.mkdtemp()
def tearDown(self):
shutil.rmtree(self.tmpdir)
super(UnixSocketTest, self).tearDown()
def test_unix_socket(self):
sockfile = os.path.join(self.tmpdir, "test.sock")
sock = netutil.bind_unix_socket(sockfile)
app = Application([("/hello", HelloWorldRequestHandler)])
server = HTTPServer(app, io_loop=self.io_loop)
server.add_socket(sock)
stream = IOStream(socket.socket(socket.AF_UNIX), io_loop=self.io_loop)
stream.connect(sockfile, self.stop)
self.wait()
stream.write(b("GET /hello HTTP/1.0\r\n\r\n"))
stream.read_until(b("\r\n"), self.stop)
response = self.wait()
self.assertEqual(response, b("HTTP/1.0 200 OK\r\n"))
stream.read_until(b("\r\n\r\n"), self.stop)
headers = HTTPHeaders.parse(self.wait().decode('latin1'))
stream.read_bytes(int(headers["Content-Length"]), self.stop)
body = self.wait()
self.assertEqual(body, b("Hello world"))
stream.close()
server.stop()
if not hasattr(socket, 'AF_UNIX') or sys.platform == 'cygwin':
del UnixSocketTest
|