1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
|
from __future__ import annotations
from functools import partial
from typing import Awaitable, Callable, Dict, Optional, Tuple
from aioquic.buffer import Buffer
from aioquic.h3.connection import H3_ALPN
from aioquic.quic.configuration import QuicConfiguration
from aioquic.quic.connection import QuicConnection
from aioquic.quic.events import (
ConnectionIdIssued,
ConnectionIdRetired,
ConnectionTerminated,
ProtocolNegotiated,
)
from aioquic.quic.packet import (
encode_quic_version_negotiation,
PACKET_TYPE_INITIAL,
pull_quic_header,
)
from .h3 import H3Protocol
from ..config import Config
from ..events import Closed, Event, RawData
from ..typing import AppWrapper, TaskGroup, WorkerContext
class QuicProtocol:
def __init__(
self,
app: AppWrapper,
config: Config,
context: WorkerContext,
task_group: TaskGroup,
server: Optional[Tuple[str, int]],
send: Callable[[Event], Awaitable[None]],
) -> None:
self.app = app
self.config = config
self.context = context
self.connections: Dict[bytes, QuicConnection] = {}
self.http_connections: Dict[QuicConnection, H3Protocol] = {}
self.send = send
self.server = server
self.task_group = task_group
self.quic_config = QuicConfiguration(alpn_protocols=H3_ALPN, is_client=False)
self.quic_config.load_cert_chain(certfile=config.certfile, keyfile=config.keyfile)
@property
def idle(self) -> bool:
return len(self.connections) == 0 and len(self.http_connections) == 0
async def handle(self, event: Event) -> None:
if isinstance(event, RawData):
try:
header = pull_quic_header(Buffer(data=event.data), host_cid_length=8)
except ValueError:
return
if (
header.version is not None
and header.version not in self.quic_config.supported_versions
):
data = encode_quic_version_negotiation(
source_cid=header.destination_cid,
destination_cid=header.source_cid,
supported_versions=self.quic_config.supported_versions,
)
await self.send(RawData(data=data, address=event.address))
return
connection = self.connections.get(header.destination_cid)
if (
connection is None
and len(event.data) >= 1200
and header.packet_type == PACKET_TYPE_INITIAL
and not self.context.terminated.is_set()
):
connection = QuicConnection(
configuration=self.quic_config,
original_destination_connection_id=header.destination_cid,
)
self.connections[header.destination_cid] = connection
self.connections[connection.host_cid] = connection
if connection is not None:
connection.receive_datagram(event.data, event.address, now=self.context.time())
await self._handle_events(connection, event.address)
elif isinstance(event, Closed):
pass
async def send_all(self, connection: QuicConnection) -> None:
for data, address in connection.datagrams_to_send(now=self.context.time()):
await self.send(RawData(data=data, address=address))
async def _handle_events(
self, connection: QuicConnection, client: Optional[Tuple[str, int]] = None
) -> None:
event = connection.next_event()
while event is not None:
if isinstance(event, ConnectionTerminated):
pass
elif isinstance(event, ProtocolNegotiated):
self.http_connections[connection] = H3Protocol(
self.app,
self.config,
self.context,
self.task_group,
client,
self.server,
connection,
partial(self.send_all, connection),
)
elif isinstance(event, ConnectionIdIssued):
self.connections[event.connection_id] = connection
elif isinstance(event, ConnectionIdRetired):
del self.connections[event.connection_id]
if connection in self.http_connections:
await self.http_connections[connection].handle(event)
event = connection.next_event()
await self.send_all(connection)
timer = connection.get_timer()
if timer is not None:
self.task_group.spawn(self._handle_timer, timer, connection)
async def _handle_timer(self, timer: float, connection: QuicConnection) -> None:
wait = max(0, timer - self.context.time())
await self.context.sleep(wait)
if connection._close_at is not None:
connection.handle_timer(now=self.context.time())
await self._handle_events(connection, None)
|