File: sse.py

package info (click to toggle)
python-sse-starlette 3.2.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,572 kB
  • sloc: python: 3,856; makefile: 134; sh: 57
file content (383 lines) | stat: -rw-r--r-- 13,511 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
import asyncio
import logging
import signal
import threading
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import (
    Any,
    AsyncIterable,
    Awaitable,
    Callable,
    Coroutine,
    Iterator,
    Mapping,
    Optional,
    Set,
    Union,
)

import anyio
from starlette.background import BackgroundTask
from starlette.concurrency import iterate_in_threadpool
from starlette.datastructures import MutableHeaders
from starlette.responses import Response
from starlette.types import Receive, Scope, Send, Message

from sse_starlette.event import ServerSentEvent, ensure_bytes


logger = logging.getLogger(__name__)


@dataclass
class _ShutdownState:
    """Per-thread state for shutdown coordination.

    Issue #152 fix: Uses threading.local() instead of ContextVar to ensure
    one watcher per thread rather than one per async context.
    """

    events: Set[anyio.Event] = field(default_factory=set)
    watcher_started: bool = False


# Each thread gets its own shutdown state (one event loop per thread typically)
_thread_state = threading.local()


def _get_shutdown_state() -> _ShutdownState:
    """Get or create shutdown state for the current thread."""
    state = getattr(_thread_state, "shutdown_state", None)
    if state is None:
        state = _ShutdownState()
        _thread_state.shutdown_state = state
    return state


def _get_uvicorn_server():
    """
    Try to get uvicorn Server instance via signal handler introspection.

    When uvicorn registers signal handlers, they're bound methods on the Server instance.
    We can retrieve the Server from the handler's __self__ attribute.

    Returns None if:
    - Not running under uvicorn
    - Signal handler isn't a bound method
    - Any introspection fails
    """
    try:
        handler = signal.getsignal(signal.SIGTERM)
        if hasattr(handler, "__self__"):
            server = handler.__self__
            if hasattr(server, "should_exit"):
                return server
    except Exception:
        pass
    return None


async def _shutdown_watcher() -> None:
    """
    Poll for shutdown and broadcast to all events in this context.

    One watcher runs per thread (event loop). Checks two shutdown sources:
    1. AppStatus.should_exit - set when our monkey-patch works
    2. uvicorn Server.should_exit - via signal handler introspection (Issue #132 fix)

    When either becomes True, signals all registered events.
    """
    state = _get_shutdown_state()
    uvicorn_server = _get_uvicorn_server()

    try:
        while True:
            # Check our flag (monkey-patch worked or manually set)
            if AppStatus.should_exit:
                break
            # Check uvicorn's flag directly (monkey-patch failed - Issue #132)
            if (
                AppStatus.enable_automatic_graceful_drain
                and uvicorn_server is not None
                and uvicorn_server.should_exit
            ):
                AppStatus.should_exit = True  # Sync state for consistency
                break
            await anyio.sleep(0.5)

        # Shutdown detected - broadcast to all waiting events
        for event in list(state.events):
            event.set()
    finally:
        # Allow watcher to be restarted if loop is reused
        state.watcher_started = False


def _ensure_watcher_started_on_this_loop() -> None:
    """Ensure the shutdown watcher is running for this thread (event loop)."""
    state = _get_shutdown_state()
    if not state.watcher_started:
        state.watcher_started = True
        try:
            loop = asyncio.get_running_loop()
            loop.create_task(_shutdown_watcher())
        except RuntimeError:
            # No running loop - shouldn't happen in normal use
            state.watcher_started = False


class SendTimeoutError(TimeoutError):
    pass


class AppStatus:
    """Helper to capture a shutdown signal from Uvicorn so we can gracefully terminate SSE streams."""

    should_exit = False
    enable_automatic_graceful_drain = True
    original_handler: Optional[Callable] = None

    @staticmethod
    def disable_automatic_graceful_drain():
        """
        Prevent automatic SSE stream termination on server shutdown.

        WARNING: When disabled, you MUST set AppStatus.should_exit = True
        at some point during shutdown, or streams will never close and the
        server will hang indefinitely (or until uvicorn's graceful shutdown
        timeout expires).
        """
        AppStatus.enable_automatic_graceful_drain = False

    @staticmethod
    def enable_automatic_graceful_drain_mode():
        """
        Re-enable automatic SSE stream termination on server shutdown.

        This restores the default behavior where SIGTERM triggers immediate
        stream draining. Call this to undo a previous call to
        disable_automatic_graceful_drain().
        """
        AppStatus.enable_automatic_graceful_drain = True

    @staticmethod
    def handle_exit(*args, **kwargs):
        if AppStatus.enable_automatic_graceful_drain:
            AppStatus.should_exit = True
        if AppStatus.original_handler is not None:
            AppStatus.original_handler(*args, **kwargs)


try:
    from uvicorn.main import Server

    AppStatus.original_handler = Server.handle_exit
    Server.handle_exit = AppStatus.handle_exit  # type: ignore
except ImportError:
    logger.debug(
        "Uvicorn not installed. Graceful shutdown on server termination disabled."
    )

Content = Union[str, bytes, dict, ServerSentEvent, Any]
SyncContentStream = Iterator[Content]
AsyncContentStream = AsyncIterable[Content]
ContentStream = Union[AsyncContentStream, SyncContentStream]


class EventSourceResponse(Response):
    """
    Streaming response that sends data conforming to the SSE (Server-Sent Events) specification.
    """

    DEFAULT_PING_INTERVAL = 15
    DEFAULT_SEPARATOR = "\r\n"

    def __init__(
        self,
        content: ContentStream,
        status_code: int = 200,
        headers: Optional[Mapping[str, str]] = None,
        media_type: str = "text/event-stream",
        background: Optional[BackgroundTask] = None,
        ping: Optional[int] = None,
        sep: Optional[str] = None,
        ping_message_factory: Optional[Callable[[], ServerSentEvent]] = None,
        data_sender_callable: Optional[
            Callable[[], Coroutine[None, None, None]]
        ] = None,
        send_timeout: Optional[float] = None,
        client_close_handler_callable: Optional[
            Callable[[Message], Awaitable[None]]
        ] = None,
    ) -> None:
        # Validate separator
        if sep not in (None, "\r\n", "\r", "\n"):
            raise ValueError(f"sep must be one of: \\r\\n, \\r, \\n, got: {sep}")
        self.sep = sep or self.DEFAULT_SEPARATOR

        # If content is sync, wrap it for async iteration
        if isinstance(content, AsyncIterable):
            self.body_iterator = content
        else:
            self.body_iterator = iterate_in_threadpool(content)

        self.status_code = status_code
        self.media_type = self.media_type if media_type is None else media_type
        self.background = background
        self.data_sender_callable = data_sender_callable
        self.send_timeout = send_timeout

        # Build SSE-specific headers.
        _headers = MutableHeaders()
        if headers is not None:  # pragma: no cover
            _headers.update(headers)

        # "The no-store response directive indicates that any caches of any kind (private or shared)
        # should not store this response."
        # -- https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Cache-Control
        # allow cache control header to be set by user to support fan out proxies
        # https://www.fastly.com/blog/server-sent-events-fastly

        _headers.setdefault("Cache-Control", "no-store")
        # mandatory for servers-sent events headers
        _headers["Connection"] = "keep-alive"
        _headers["X-Accel-Buffering"] = "no"
        self.init_headers(_headers)

        self.ping_interval = self.DEFAULT_PING_INTERVAL if ping is None else ping
        self.ping_message_factory = ping_message_factory

        self.client_close_handler_callable = client_close_handler_callable

        self.active = True
        # https://github.com/sysid/sse-starlette/pull/55#issuecomment-1732374113
        self._send_lock = anyio.Lock()

    @property
    def ping_interval(self) -> Union[int, float]:
        return self._ping_interval

    @ping_interval.setter
    def ping_interval(self, value: Union[int, float]) -> None:
        if not isinstance(value, (int, float)):
            raise TypeError("ping interval must be int")
        if value < 0:
            raise ValueError("ping interval must be greater than 0")
        self._ping_interval = value

    def enable_compression(self, force: bool = False) -> None:
        raise NotImplementedError("Compression is not supported for SSE streams.")

    async def _stream_response(self, send: Send) -> None:
        """Send out SSE data to the client as it becomes available in the iterator."""
        await send(
            {
                "type": "http.response.start",
                "status": self.status_code,
                "headers": self.raw_headers,
            }
        )

        async for data in self.body_iterator:
            chunk = ensure_bytes(data, self.sep)
            logger.debug("chunk: %s", chunk)
            with anyio.move_on_after(self.send_timeout) as cancel_scope:
                await send(
                    {"type": "http.response.body", "body": chunk, "more_body": True}
                )

            if cancel_scope and cancel_scope.cancel_called:
                if hasattr(self.body_iterator, "aclose"):
                    await self.body_iterator.aclose()
                raise SendTimeoutError()

        async with self._send_lock:
            self.active = False
            await send({"type": "http.response.body", "body": b"", "more_body": False})

    async def _listen_for_disconnect(self, receive: Receive) -> None:
        """Watch for a disconnect message from the client."""
        while self.active:
            message = await receive()
            if message["type"] == "http.disconnect":
                self.active = False
                logger.debug("Got event: http.disconnect. Stop streaming.")
                if self.client_close_handler_callable:
                    await self.client_close_handler_callable(message)
                break

    @staticmethod
    async def _listen_for_exit_signal() -> None:
        """Wait for shutdown signal via the shared watcher."""
        if AppStatus.should_exit:
            return

        _ensure_watcher_started_on_this_loop()

        state = _get_shutdown_state()
        event = anyio.Event()
        state.events.add(event)

        try:
            # Double-check after registration
            if AppStatus.should_exit:
                return
            await event.wait()
        finally:
            state.events.discard(event)

    async def _ping(self, send: Send) -> None:
        """Periodically send ping messages to keep the connection alive on proxies.
        - frequenccy ca every 15 seconds.
        - Alternatively one can send periodically a comment line (one starting with a ':' character)
        """
        while self.active:
            await anyio.sleep(self._ping_interval)
            sse_ping = (
                self.ping_message_factory()
                if self.ping_message_factory
                else ServerSentEvent(
                    comment=f"ping - {datetime.now(timezone.utc)}", sep=self.sep
                )
            )
            ping_bytes = ensure_bytes(sse_ping, self.sep)
            logger.debug("ping: %s", ping_bytes)

            async with self._send_lock:
                if self.active:
                    await send(
                        {
                            "type": "http.response.body",
                            "body": ping_bytes,
                            "more_body": True,
                        }
                    )

    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
        """Entrypoint for Starlette's ASGI contract. We spin up tasks:
        - _stream_response to push events
        - _ping to keep the connection alive
        - _listen_for_exit_signal to respond to server shutdown
        - _listen_for_disconnect to respond to client disconnect
        """
        async with anyio.create_task_group() as task_group:
            # https://trio.readthedocs.io/en/latest/reference-core.html#custom-supervisors
            async def cancel_on_finish(coro: Callable[[], Awaitable[None]]):
                await coro()
                task_group.cancel_scope.cancel()

            task_group.start_soon(cancel_on_finish, lambda: self._stream_response(send))
            task_group.start_soon(cancel_on_finish, lambda: self._ping(send))
            task_group.start_soon(cancel_on_finish, self._listen_for_exit_signal)

            if self.data_sender_callable:
                task_group.start_soon(self.data_sender_callable)

            # Wait for the client to disconnect last
            task_group.start_soon(
                cancel_on_finish, lambda: self._listen_for_disconnect(receive)
            )

        if self.background is not None:
            await self.background()