1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284
|
import socket
import struct
import time
import unittest
from http.client import HTTPConnection
from urllib import parse
from daphne.testing import DaphneTestingInstance, TestApplication
class DaphneTestCase(unittest.TestCase):
"""
Base class for Daphne integration test cases.
Boots up a copy of Daphne on a test port and sends it a request, and
retrieves the response. Uses a custom ASGI application and temporary files
to store/retrieve the request/response messages.
"""
### Plain HTTP helpers
def run_daphne_http(
self,
method,
path,
params,
body,
responses,
headers=None,
timeout=1,
xff=False,
request_buffer_size=None,
):
"""
Runs Daphne with the given request callback (given the base URL)
and response messages.
"""
with DaphneTestingInstance(
xff=xff, request_buffer_size=request_buffer_size
) as test_app:
# Add the response messages
test_app.add_send_messages(responses)
# Send it the request. We have to do this the long way to allow
# duplicate headers.
conn = HTTPConnection(test_app.host, test_app.port, timeout=timeout)
if params:
path += "?" + parse.urlencode(params, doseq=True)
conn.putrequest(method, path, skip_accept_encoding=True, skip_host=True)
# Manually send over headers
if headers:
for header_name, header_value in headers:
conn.putheader(header_name, header_value)
# Send body if provided.
if body:
conn.putheader("Content-Length", str(len(body)))
conn.endheaders(message_body=body)
else:
conn.endheaders()
try:
response = conn.getresponse()
except socket.timeout:
# See if they left an exception for us to load
test_app.get_received()
raise RuntimeError(
"Daphne timed out handling request, no exception found."
)
# Return scope, messages, response
return test_app.get_received() + (response,)
def run_daphne_raw(self, data, *, responses=None, timeout=1):
"""
Runs Daphne and sends it the given raw bytestring over a socket.
Accepts list of response messages the application will reply with.
Returns what Daphne sends back.
"""
assert isinstance(data, bytes)
with DaphneTestingInstance() as test_app:
if responses is not None:
test_app.add_send_messages(responses)
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.settimeout(timeout)
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.connect((test_app.host, test_app.port))
s.send(data)
try:
return s.recv(1000000)
except socket.timeout:
raise RuntimeError(
"Daphne timed out handling raw request, no exception found."
)
def run_daphne_request(
self,
method,
path,
params=None,
body=None,
headers=None,
xff=False,
request_buffer_size=None,
):
"""
Convenience method for just testing request handling.
Returns (scope, messages)
"""
scope, messages, _ = self.run_daphne_http(
method=method,
path=path,
params=params,
body=body,
headers=headers,
xff=xff,
request_buffer_size=request_buffer_size,
responses=[
{"type": "http.response.start", "status": 200},
{"type": "http.response.body", "body": b"OK"},
],
)
return scope, messages
def run_daphne_response(self, response_messages):
"""
Convenience method for just testing response handling.
Returns (scope, messages)
"""
_, _, response = self.run_daphne_http(
method="GET", path="/", params={}, body=b"", responses=response_messages
)
return response
### WebSocket helpers
def websocket_handshake(
self,
test_app,
path="/",
params=None,
headers=None,
subprotocols=None,
timeout=1,
):
"""
Runs a WebSocket handshake negotiation and returns the raw socket
object & the selected subprotocol.
You'll need to inject an accept or reject message before this
to let it complete.
"""
# Send it the request. We have to do this the long way to allow
# duplicate headers.
conn = HTTPConnection(test_app.host, test_app.port, timeout=timeout)
if params:
path += "?" + parse.urlencode(params, doseq=True)
conn.putrequest("GET", path, skip_accept_encoding=True, skip_host=True)
# Do WebSocket handshake headers + any other headers
if headers is None:
headers = []
headers.extend(
[
(b"Host", b"example.com"),
(b"Upgrade", b"websocket"),
(b"Connection", b"Upgrade"),
(b"Sec-WebSocket-Key", b"x3JJHMbDL1EzLkh9GBhXDw=="),
(b"Sec-WebSocket-Version", b"13"),
(b"Origin", b"http://example.com"),
]
)
if subprotocols:
headers.append((b"Sec-WebSocket-Protocol", ", ".join(subprotocols)))
if headers:
for header_name, header_value in headers:
conn.putheader(header_name, header_value)
conn.endheaders()
# Read out the response
try:
response = conn.getresponse()
except socket.timeout:
# See if they left an exception for us to load
test_app.get_received()
raise RuntimeError("Daphne timed out handling request, no exception found.")
# Check we got a good response code
if response.status != 101:
raise RuntimeError("WebSocket upgrade did not result in status code 101")
# Prepare headers for subprotocol searching
response_headers = {n.lower(): v for n, v in response.getheaders()}
response.read()
assert not response.closed
# Return the raw socket and any subprotocol
return conn.sock, response_headers.get("sec-websocket-protocol", None)
def websocket_send_frame(self, sock, value):
"""
Sends a WebSocket text or binary frame. Cannot handle long frames.
"""
# Header and text opcode
if isinstance(value, str):
frame = b"\x81"
value = value.encode("utf8")
else:
frame = b"\x82"
# Length plus masking signal bit
frame += struct.pack("!B", len(value) | 0b10000000)
# Mask badly
frame += b"\0\0\0\0"
# Payload
frame += value
sock.sendall(frame)
def receive_from_socket(self, sock, length, timeout=1):
"""
Receives the given amount of bytes from the socket, or times out.
"""
buf = b""
started = time.time()
while len(buf) < length:
buf += sock.recv(length - len(buf))
time.sleep(0.001)
if time.time() - started > timeout:
raise ValueError("Timed out reading from socket")
return buf
def websocket_receive_frame(self, sock):
"""
Receives a WebSocket frame. Cannot handle long frames.
"""
# Read header byte
# TODO: Proper receive buffer handling
opcode = self.receive_from_socket(sock, 1)
if opcode in [b"\x81", b"\x82"]:
# Read length
length = struct.unpack("!B", self.receive_from_socket(sock, 1))[0]
# Read payload
payload = self.receive_from_socket(sock, length)
if opcode == b"\x81":
payload = payload.decode("utf8")
return payload
else:
raise ValueError("Unknown websocket opcode: %r" % opcode)
### Assertions and test management
def tearDown(self):
"""
Ensures any storage files are cleared.
"""
TestApplication.delete_setup()
TestApplication.delete_result()
def assert_is_ip_address(self, address):
"""
Tests whether a given address string is a valid IPv4 or IPv6 address.
"""
try:
socket.inet_aton(address)
except OSError:
self.fail("'%s' is not a valid IP address." % address)
def assert_key_sets(self, required_keys, optional_keys, actual_keys):
"""
Asserts that all required_keys are in actual_keys, and that there
are no keys in actual_keys that aren't required or optional.
"""
present_keys = set(actual_keys)
# Make sure all required keys are present
self.assertTrue(required_keys <= present_keys)
# Assert that no other keys are present
self.assertEqual(set(), present_keys - required_keys - optional_keys)
def assert_valid_path(self, path):
"""
Checks the path is valid and already url-decoded.
"""
self.assertIsInstance(path, str)
# Assert that it's already url decoded
self.assertEqual(path, parse.unquote(path))
def assert_valid_address_and_port(self, host):
"""
Asserts the value is a valid (host, port) tuple.
"""
address, port = host
self.assertIsInstance(address, str)
self.assert_is_ip_address(address)
self.assertIsInstance(port, int)
|