File: test_websocket.py

package info (click to toggle)
python-daphne 4.1.2-2
  • links: PTS, VCS
  • area: main
  • in suites: sid, trixie
  • size: 396 kB
  • sloc: python: 2,565; makefile: 25
file content (338 lines) | stat: -rw-r--r-- 13,974 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
import collections
import time
from urllib import parse

import http_strategies
from http_base import DaphneTestCase, DaphneTestingInstance
from hypothesis import given, settings

from daphne.testing import BaseDaphneTestingInstance


class TestWebsocket(DaphneTestCase):
    """
    Tests WebSocket handshake, send and receive.
    """

    def assert_valid_websocket_scope(
        self, scope, path="/", params=None, headers=None, scheme=None, subprotocols=None
    ):
        """
        Checks that the passed scope is a valid ASGI HTTP scope regarding types
        and some urlencoding things.
        """
        # Check overall keys
        self.assert_key_sets(
            required_keys={
                "asgi",
                "type",
                "path",
                "raw_path",
                "query_string",
                "headers",
            },
            optional_keys={"scheme", "root_path", "client", "server", "subprotocols"},
            actual_keys=scope.keys(),
        )
        self.assertEqual(scope["asgi"]["version"], "3.0")
        # Check that it is the right type
        self.assertEqual(scope["type"], "websocket")
        # Path
        self.assert_valid_path(scope["path"])
        # Scheme
        self.assertIn(scope.get("scheme", "ws"), ["ws", "wss"])
        if scheme:
            self.assertEqual(scheme, scope["scheme"])
        # Query string (byte string and still url encoded)
        query_string = scope["query_string"]
        self.assertIsInstance(query_string, bytes)
        if params:
            self.assertEqual(
                query_string, parse.urlencode(params or []).encode("ascii")
            )
        # Ordering of header names is not important, but the order of values for a header
        # name is. To assert whether that order is kept, we transform both the request
        # headers and the channel message headers into a dictionary
        # {name: [value1, value2, ...]} and check if they're equal.
        transformed_scope_headers = collections.defaultdict(list)
        for name, value in scope["headers"]:
            transformed_scope_headers.setdefault(name, [])
            # Make sure to split out any headers collapsed with commas
            for bit in value.split(b","):
                if bit.strip():
                    transformed_scope_headers[name].append(bit.strip())
        transformed_request_headers = collections.defaultdict(list)
        for name, value in headers or []:
            expected_name = name.lower().strip()
            expected_value = value.strip()
            # Make sure to split out any headers collapsed with commas
            transformed_request_headers.setdefault(expected_name, [])
            for bit in expected_value.split(b","):
                if bit.strip():
                    transformed_request_headers[expected_name].append(bit.strip())
        for name, value in transformed_request_headers.items():
            self.assertIn(name, transformed_scope_headers)
            self.assertEqual(value, transformed_scope_headers[name])
        # Root path
        self.assertIsInstance(scope.get("root_path", ""), str)
        # Client and server addresses
        client = scope.get("client")
        if client is not None:
            self.assert_valid_address_and_port(client)
        server = scope.get("server")
        if server is not None:
            self.assert_valid_address_and_port(server)
        # Subprotocols
        scope_subprotocols = scope.get("subprotocols", [])
        if scope_subprotocols:
            assert all(isinstance(x, str) for x in scope_subprotocols)
        if subprotocols:
            assert sorted(scope_subprotocols) == sorted(subprotocols)

    def assert_valid_websocket_connect_message(self, message):
        """
        Asserts that a message is a valid http.request message
        """
        # Check overall keys
        self.assert_key_sets(
            required_keys={"type"}, optional_keys=set(), actual_keys=message.keys()
        )
        # Check that it is the right type
        self.assertEqual(message["type"], "websocket.connect")

    def test_accept(self):
        """
        Tests we can open and accept a socket.
        """
        with DaphneTestingInstance() as test_app:
            test_app.add_send_messages([{"type": "websocket.accept"}])
            self.websocket_handshake(test_app)
            # Validate the scope and messages we got
            scope, messages = test_app.get_received()
            self.assert_valid_websocket_scope(scope)
            self.assert_valid_websocket_connect_message(messages[0])

    def test_reject(self):
        """
        Tests we can reject a socket and it won't complete the handshake.
        """
        with DaphneTestingInstance() as test_app:
            test_app.add_send_messages([{"type": "websocket.close"}])
            with self.assertRaises(RuntimeError):
                self.websocket_handshake(test_app)

    def test_subprotocols(self):
        """
        Tests that we can ask for subprotocols and then select one.
        """
        subprotocols = ["proto1", "proto2"]
        with DaphneTestingInstance() as test_app:
            test_app.add_send_messages(
                [{"type": "websocket.accept", "subprotocol": "proto2"}]
            )
            _, subprotocol = self.websocket_handshake(
                test_app, subprotocols=subprotocols
            )
            # Validate the scope and messages we got
            assert subprotocol == "proto2"
            scope, messages = test_app.get_received()
            self.assert_valid_websocket_scope(scope, subprotocols=subprotocols)
            self.assert_valid_websocket_connect_message(messages[0])

    def test_xff(self):
        """
        Tests that X-Forwarded-For headers get parsed right
        """
        headers = [["X-Forwarded-For", "10.1.2.3"], ["X-Forwarded-Port", "80"]]
        with DaphneTestingInstance(xff=True) as test_app:
            test_app.add_send_messages([{"type": "websocket.accept"}])
            self.websocket_handshake(test_app, headers=headers)
            # Validate the scope and messages we got
            scope, messages = test_app.get_received()
            self.assert_valid_websocket_scope(scope)
            self.assert_valid_websocket_connect_message(messages[0])
            assert scope["client"] == ["10.1.2.3", 80]

    @given(
        request_path=http_strategies.http_path(),
        request_params=http_strategies.query_params(),
        request_headers=http_strategies.headers(),
    )
    @settings(max_examples=5, deadline=2000)
    def test_http_bits(self, request_path, request_params, request_headers):
        """
        Tests that various HTTP-level bits (query string params, path, headers)
        carry over into the scope.
        """
        with DaphneTestingInstance() as test_app:
            test_app.add_send_messages([{"type": "websocket.accept"}])
            self.websocket_handshake(
                test_app,
                path=parse.quote(request_path),
                params=request_params,
                headers=request_headers,
            )
            # Validate the scope and messages we got
            scope, messages = test_app.get_received()
            self.assert_valid_websocket_scope(
                scope, path=request_path, params=request_params, headers=request_headers
            )
            self.assert_valid_websocket_connect_message(messages[0])

    def test_raw_path(self):
        """
        Tests that /foo%2Fbar produces raw_path and a decoded path
        """
        with DaphneTestingInstance() as test_app:
            test_app.add_send_messages([{"type": "websocket.accept"}])
            self.websocket_handshake(test_app, path="/foo%2Fbar")
            # Validate the scope and messages we got
            scope, _ = test_app.get_received()

        self.assertEqual(scope["path"], "/foo/bar")
        self.assertEqual(scope["raw_path"], b"/foo%2Fbar")

    @given(daphne_path=http_strategies.http_path())
    @settings(max_examples=5, deadline=2000)
    def test_root_path(self, *, daphne_path):
        """
        Tests root_path handling.
        """
        headers = [("Daphne-Root-Path", parse.quote(daphne_path))]
        with DaphneTestingInstance() as test_app:
            test_app.add_send_messages([{"type": "websocket.accept"}])
            self.websocket_handshake(
                test_app,
                path="/",
                headers=headers,
            )
            # Validate the scope and messages we got
            scope, _ = test_app.get_received()

        # Daphne-Root-Path is not included in the returned 'headers' section.
        self.assertNotIn(
            "daphne-root-path", (header[0].lower() for header in scope["headers"])
        )
        # And what we're looking for, root_path being set.
        self.assertEqual(scope["root_path"], daphne_path)

    def test_text_frames(self):
        """
        Tests we can send and receive text frames.
        """
        with DaphneTestingInstance() as test_app:
            # Connect
            test_app.add_send_messages([{"type": "websocket.accept"}])
            sock, _ = self.websocket_handshake(test_app)
            _, messages = test_app.get_received()
            self.assert_valid_websocket_connect_message(messages[0])
            # Prep frame for it to send
            test_app.add_send_messages(
                [{"type": "websocket.send", "text": "here be dragons 🐉"}]
            )
            # Send it a frame
            self.websocket_send_frame(sock, "what is here? 🌍")
            # Receive a frame and make sure it's correct
            assert self.websocket_receive_frame(sock) == "here be dragons 🐉"
            # Make sure it got our frame
            _, messages = test_app.get_received()
            assert messages[1] == {
                "type": "websocket.receive",
                "text": "what is here? 🌍",
            }

    def test_binary_frames(self):
        """
        Tests we can send and receive binary frames with things that are very
        much not valid UTF-8.
        """
        with DaphneTestingInstance() as test_app:
            # Connect
            test_app.add_send_messages([{"type": "websocket.accept"}])
            sock, _ = self.websocket_handshake(test_app)
            _, messages = test_app.get_received()
            self.assert_valid_websocket_connect_message(messages[0])
            # Prep frame for it to send
            test_app.add_send_messages(
                [{"type": "websocket.send", "bytes": b"here be \xe2 bytes"}]
            )
            # Send it a frame
            self.websocket_send_frame(sock, b"what is here? \xe2")
            # Receive a frame and make sure it's correct
            assert self.websocket_receive_frame(sock) == b"here be \xe2 bytes"
            # Make sure it got our frame
            _, messages = test_app.get_received()
            assert messages[1] == {
                "type": "websocket.receive",
                "bytes": b"what is here? \xe2",
            }

    def test_http_timeout(self):
        """
        Tests that the HTTP timeout doesn't kick in for WebSockets
        """
        with DaphneTestingInstance(http_timeout=1) as test_app:
            # Connect
            test_app.add_send_messages([{"type": "websocket.accept"}])
            sock, _ = self.websocket_handshake(test_app)
            _, messages = test_app.get_received()
            self.assert_valid_websocket_connect_message(messages[0])
            # Wait 2 seconds
            time.sleep(2)
            # Prep frame for it to send
            test_app.add_send_messages([{"type": "websocket.send", "text": "cake"}])
            # Send it a frame
            self.websocket_send_frame(sock, "still alive?")
            # Receive a frame and make sure it's correct
            assert self.websocket_receive_frame(sock) == "cake"

    def test_application_checker_handles_asyncio_cancellederror(self):
        with CancellingTestingInstance() as app:
            # Connect to the websocket app, it will immediately raise
            # asyncio.CancelledError
            sock, _ = self.websocket_handshake(app)
            # Disconnect from the socket
            sock.close()
            # Wait for application_checker to clean up the applications for
            # disconnected clients, and for the server to be stopped.
            time.sleep(3)
            # Make sure we received either no error, or a ConnectionsNotEmpty
            while not app.process.errors.empty():
                err, _tb = app.process.errors.get()
                if not isinstance(err, ConnectionsNotEmpty):
                    raise err
                self.fail(
                    "Server connections were not cleaned up after an asyncio.CancelledError was raised"
                )


async def cancelling_application(scope, receive, send):
    import asyncio

    from twisted.internet import reactor

    # Stop the server after a short delay so that the teardown is run.
    reactor.callLater(2, reactor.stop)
    await send({"type": "websocket.accept"})
    raise asyncio.CancelledError()


class ConnectionsNotEmpty(Exception):
    pass


class CancellingTestingInstance(BaseDaphneTestingInstance):
    def __init__(self):
        super().__init__(application=cancelling_application)

    def process_teardown(self):
        import multiprocessing

        # Get a hold of the enclosing DaphneProcess (we're currently running in
        # the same process as the application).
        proc = multiprocessing.current_process()
        # By now the (only) socket should have disconnected, and the
        # application_checker should have run. If there are any connections
        # still, it means that the application_checker did not clean them up.
        if proc.server.connections:
            raise ConnectionsNotEmpty()