File: _signals.py

package info (click to toggle)
dask.distributed 2022.12.1%2Bds.1-3
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 10,164 kB
  • sloc: python: 81,938; javascript: 1,549; makefile: 228; sh: 100
file content (41 lines) | stat: -rw-r--r-- 1,297 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from __future__ import annotations

import asyncio
import logging
import signal
from typing import Any

logger = logging.getLogger(__name__)


async def wait_for_signals() -> None:
    """Wait for sigint or sigterm by setting global signal handlers"""
    signals = (signal.SIGINT, signal.SIGTERM)
    loop = asyncio.get_running_loop()
    event = asyncio.Event()

    old_handlers: dict[int, Any] = {}
    caught_signal: int | None = None

    def handle_signal(signum, frame):
        # *** Do not log or print anything in here
        # https://stackoverflow.com/questions/45680378/how-to-explain-the-reentrant-runtimeerror-caused-by-printing-in-signal-handlers
        nonlocal caught_signal
        caught_signal = signum
        # Restore old signal handler to allow for quicker exit
        # if the user sends the signal again.
        signal.signal(signum, old_handlers[signum])
        loop.call_soon_threadsafe(event.set)

    for sig in signals:
        old_handlers[sig] = signal.signal(sig, handle_signal)

    try:
        await event.wait()
        assert caught_signal
        logger.info(
            "Received signal %s (%d)", signal.Signals(caught_signal).name, caught_signal
        )
    finally:
        for sig in signals:
            signal.signal(sig, old_handlers[sig])