File: test_sse.py

package info (click to toggle)
python-sse-starlette 2.3.4-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 944 kB
  • sloc: python: 1,487; makefile: 134; sh: 57
file content (360 lines) | stat: -rw-r--r-- 12,510 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
import asyncio
import logging
import math
from functools import partial

import anyio
import anyio.lowlevel
import pytest
from starlette.background import BackgroundTask
from starlette.testclient import TestClient

from sse_starlette.event import ServerSentEvent
from sse_starlette.sse import EventSourceResponse
from sse_starlette.sse import SendTimeoutError
from tests.anyio_compat import collapse_excgroups

_log = logging.getLogger(__name__)


# Test fixtures and helpers
@pytest.fixture
def mock_generator():
    async def numbers(minimum, maximum):
        for i in range(minimum, maximum + 1):
            await asyncio.sleep(0.1)
            yield i

    return numbers


@pytest.fixture
def mock_memory_channels():
    async def setup():
        send_chan, recv_chan = anyio.create_memory_object_stream(math.inf)
        return send_chan, recv_chan

    return setup


class TestEventSourceResponse:
    @pytest.mark.parametrize(
        "input_type,separator,expected_output",
        [
            ("integer", "\r\n", b"data: 1\r\n\r\n"),
            ("dict_simple", "\r\n", b"data: 1\r\n\r\n"),
            ("dict_with_event", "\r\n", b"event: message\r\ndata: 1\r\n\r\n"),
            ("dict_with_event", "\r", b"event: message\rdata: 1\r\r"),
        ],
    )
    async def test_response_send_whenValidInput_thenGeneratesExpectedOutput(
        self,
        reset_appstatus_event,
        mock_generator,
        input_type,
        separator,
        expected_output,
    ):
        # Arrange
        async def app(scope, receive, send):
            async def format_output(value):
                if input_type == "integer":
                    return value
                elif input_type == "dict_simple":
                    return dict(data=value)
                else:
                    return dict(data=value, event="message")

            async def generate():
                generator = mock_generator(1, 5)
                async for value in generator:
                    yield await format_output(value)

            response = EventSourceResponse(generate(), ping=0.2, sep=separator)
            await response(scope, receive, send)

        # Act
        client = TestClient(app)
        response = client.get("/")

        # Assert
        assert expected_output in response.content
        assert response.content.decode().count("ping") == 2

    @pytest.mark.parametrize(
        "producer_output,expected_sse_response",
        [
            # Test raw integers being converted to SSE format
            ("raw_integer", b"data: 1\r\n\r\n"),
            # Test dict with just data field
            ("simple_dict", b"data: 1\r\n\r\n"),
            # Test dict with both event and data fields
            ("event_dict", b"event: message\r\ndata: 1\r\n\r\n"),
        ],
    )
    def test_eventSourceResponse_whenUsingMemoryChannel_thenHandlesAsyncQueueCorrectly(
        self, reset_appstatus_event, producer_output, expected_sse_response
    ):
        """Tests that EventSourceResponse can properly consume data from an async memory channel.

        This test verifies the producer-consumer pattern where:
        1. Producer (stream_numbers) puts data into a memory channel
        2. Consumer (EventSourceResponse) reads from that channel and formats as SSE

        This differs from direct generator tests by:
        - Using separate producer/consumer components
        - Testing async queue-based communication
        - Verifying SSE works with buffered async data sources
        """

        # Arrange
        async def app(scope, receive, send):
            # Create bounded memory channel for producer-consumer communication
            send_chan, recv_chan = anyio.create_memory_object_stream(
                max_buffer_size=math.inf
            )

            # Producer function that writes to the channel
            async def stream_numbers(producer_channel, start, end):
                async with producer_channel:
                    for i in range(start, end + 1):
                        await anyio.sleep(0.1)  # Simulate async data production

                        # Format data based on test case
                        if producer_output == "raw_integer":
                            data = i
                        elif producer_output == "simple_dict":
                            data = dict(data=i)
                        else:  # event_dict
                            data = dict(data=i, event="message")

                        # Send to channel for consumption
                        await producer_channel.send(data)

            # Create SSE response that consumes from channel
            response = EventSourceResponse(
                recv_chan,  # Consumer reads from receive channel
                data_sender_callable=partial(
                    stream_numbers, send_chan, 1, 5
                ),  # Producer writes to send channel
                ping=0.2,
            )
            await response(scope, receive, send)

        # Act
        client = TestClient(app)
        response = client.get("/")

        # Assert
        assert response.content.decode().count("ping") == 2
        assert expected_sse_response in response.content

    @pytest.mark.anyio
    async def test_disconnect_whenClientDisconnects_thenHandlesGracefully(
        self, httpx_client, caplog
    ):
        # Arrange
        caplog.set_level(logging.DEBUG)

        # Act & Assert
        with pytest.raises(TimeoutError):
            with anyio.fail_after(1) as scope:
                try:
                    async with anyio.create_task_group() as tg:
                        # https://www.python-httpx.org/async/#streaming-responses
                        tg.start_soon(httpx_client.get, "/endless")
                finally:
                    assert scope.cancel_called is True
                    assert "chunk: b'data: 4\\r\\n\\r\\n'" in caplog.text
                    assert "Disconnected from client" in caplog.text

    @pytest.mark.anyio
    async def test_send_whenTimeoutOccurs_thenRaisesSendTimeoutError(
        self, reset_appstatus_event
    ):
        # Arrange
        # Send timeout is set to 0.5s, but `send` will take 1s. Expect SendTimeoutError.
        cleanup_executed = False

        async def event_publisher():
            try:
                yield {"event": "test", "data": "data"}
                pytest.fail("Should not reach this point")
            finally:
                nonlocal cleanup_executed
                cleanup_executed = True

        async def mock_send(*args, **kwargs):
            await anyio.sleep(1.0)

        async def mock_receive():
            await anyio.lowlevel.checkpoint()
            return {"type": "message"}

        response = EventSourceResponse(event_publisher(), send_timeout=0.5)

        # Act & Assert
        with pytest.raises(SendTimeoutError):
            with collapse_excgroups():
                await response({}, mock_receive, mock_send)

        assert cleanup_executed, "Cleanup should be executed on timeout"

    def test_headers_whenCustomHeadersProvided_thenMergesCorrectly(self):
        # Arrange
        custom_headers = {
            "cache-control": "no-cache",
            "x-accel-buffering": "yes",  # Should not override
            "connection": "close",  # Should not override
            "x-custom-header": "custom-value",
        }

        # Act
        response = EventSourceResponse(range(1, 5), headers=custom_headers, ping=0.2)
        headers = dict((h.decode(), v.decode()) for h, v in response.raw_headers)

        # Assert
        assert headers["cache-control"] == "no-cache"
        assert headers["x-accel-buffering"] == "no"  # Should keep default
        assert headers["connection"] == "keep-alive"  # Should keep default
        assert headers["x-custom-header"] == "custom-value"
        assert headers["content-type"] == "text/event-stream; charset=utf-8"

    def test_headers_whenCreated_thenHasCorrectCharset(
        self, reset_appstatus_event, mock_generator
    ):
        # Arrange
        generator = mock_generator(1, 5)

        # Act
        response = EventSourceResponse(generator, ping=0.2)
        content_type_headers = [
            (h.decode(), v.decode())
            for h, v in response.raw_headers
            if h.decode() == "content-type"
        ]

        # Assert
        assert len(content_type_headers) == 1
        header_name, header_value = content_type_headers[0]
        assert header_value == "text/event-stream; charset=utf-8"

    @pytest.mark.anyio
    async def test_ping_whenConcurrentWithEvents_thenRespectsLocking(
        self, reset_appstatus_event
    ):
        # Sequencing here is as follows to reproduce race condition:
        # t=0.5s - event_publisher sends the first response item,
        #          claiming the lock and going to sleep for 1 second so until t=1.5s.
        # t=1.0s - ping task wakes up and tries to call send while we know
        #          that event_publisher is still blocked inside it and holding the lock
        # Arrange
        lock = anyio.Lock()

        async def event_publisher():
            for i in range(2):
                await anyio.sleep(0.5)
                yield i

        async def send(*args, **kwargs):
            # Raises WouldBlock if called while someone else already holds the lock
            lock.acquire_nowait()
            await anyio.sleep(1.0)
            lock.release()

        async def receive():
            await anyio.lowlevel.checkpoint()
            return {"type": "message"}

        response = EventSourceResponse(event_publisher(), ping=1)

        # Act & Assert
        with pytest.raises(anyio.WouldBlock):
            with collapse_excgroups():
                await response({}, receive, send)

    def test_pingInterval_whenCreated_thenUsesDefaultValue(self):
        # Arrange & Act
        response = EventSourceResponse(0)

        # Assert
        assert response.ping_interval == response.DEFAULT_PING_INTERVAL

    def test_pingInterval_whenValidValueSet_thenUpdatesInterval(self):
        # Arrange
        response = EventSourceResponse(0)
        new_interval = 25

        # Act
        response.ping_interval = new_interval

        # Assert
        assert response.ping_interval == new_interval

    def test_pingInterval_whenStringProvided_thenRaisesTypeError(self):
        # Arrange
        response = EventSourceResponse(0)
        invalid_interval = "ten"

        # Act & Assert
        with pytest.raises(TypeError, match="ping interval must be int"):
            response.ping_interval = invalid_interval

    def test_pingInterval_whenNegativeValue_thenRaisesValueError(self):
        # Arrange
        response = EventSourceResponse(0)
        negative_interval = -42

        # Act & Assert
        with pytest.raises(ValueError, match="ping interval must be greater than 0"):
            response.ping_interval = negative_interval

    def test_compression_whenEnabled_thenRaisesNotImplemented(self):
        # Arrange
        response = EventSourceResponse(range(1, 5))

        # Act & Assert
        with pytest.raises(NotImplementedError):
            response.enable_compression()

    @pytest.mark.parametrize("separator", ["\n", "\r", "\r\n"])
    def test_customSeparator_whenProvided_thenUsesCorrectSeparator(self, separator):
        # Arrange
        test_data = "test_data"
        test_event = "test_event"

        # Act
        response = ServerSentEvent(test_data, event=test_event, sep=separator)
        result = response.encode()

        # Assert
        expected = (
            f"event: {test_event}{separator}data: {test_data}{separator}{separator}"
        )
        assert result == expected.encode()

    @pytest.mark.anyio
    async def test_backgroundTask_whenProvided_thenExecutesAfterResponse(
        self, reset_appstatus_event
    ):
        # Arrange
        task_executed = False

        async def background_task():
            nonlocal task_executed
            task_executed = True

        async def mock_send(*args, **kwargs):
            pass

        async def mock_receive():
            await anyio.lowlevel.checkpoint()
            return {"type": "http.disconnect"}

        response = EventSourceResponse([], background=BackgroundTask(background_task))

        # Act
        await response({}, mock_receive, mock_send)

        # Assert
        assert task_executed, "Background task should be executed"