File: udp_server.py

package info (click to toggle)
python-urllib3 2.5.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 2,340 kB
  • sloc: python: 26,167; makefile: 122; javascript: 92; sh: 11
file content (60 lines) | stat: -rw-r--r-- 2,179 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from __future__ import annotations

import asyncio
from typing import Optional, Tuple, TYPE_CHECKING

from .task_group import TaskGroup
from .worker_context import WorkerContext
from ..config import Config
from ..events import Event, RawData
from ..typing import AppWrapper
from ..utils import parse_socket_addr

if TYPE_CHECKING:
    # h3/Quic is an optional part of Hypercorn
    from ..protocol.quic import QuicProtocol  # noqa: F401


class UDPServer(asyncio.DatagramProtocol):
    def __init__(
        self,
        app: AppWrapper,
        loop: asyncio.AbstractEventLoop,
        config: Config,
        context: WorkerContext,
    ) -> None:
        self.app = app
        self.config = config
        self.context = context
        self.loop = loop
        self.protocol: "QuicProtocol"
        self.protocol_queue: asyncio.Queue = asyncio.Queue(10)
        self.transport: Optional[asyncio.DatagramTransport] = None

    def connection_made(self, transport: asyncio.DatagramTransport) -> None:  # type: ignore
        self.transport = transport

    def datagram_received(self, data: bytes, address: Tuple[bytes, str]) -> None:  # type: ignore
        try:
            self.protocol_queue.put_nowait(RawData(data=data, address=address))  # type: ignore
        except asyncio.QueueFull:
            pass  # Just throw the data away, is UDP

    async def run(self) -> None:
        # h3/Quic is an optional part of Hypercorn
        from ..protocol.quic import QuicProtocol  # noqa: F811

        socket = self.transport.get_extra_info("socket")
        server = parse_socket_addr(socket.family, socket.getsockname())
        async with TaskGroup(self.loop) as task_group:
            self.protocol = QuicProtocol(
                self.app, self.config, self.context, task_group, server, self.protocol_send
            )

            while not self.context.terminated.is_set() or not self.protocol.idle:
                event = await self.protocol_queue.get()
                await self.protocol.handle(event)

    async def protocol_send(self, event: Event) -> None:
        if isinstance(event, RawData):
            self.transport.sendto(event.data, event.address)