1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
|
from __future__ import annotations
from typing import Any, Awaitable, Callable, Optional, Tuple, Union
from .h2 import H2Protocol
from .h11 import H2CProtocolRequiredError, H2ProtocolAssumedError, H11Protocol
from ..config import Config
from ..events import Event, RawData
from ..typing import AppWrapper, TaskGroup, WorkerContext
class ProtocolWrapper:
def __init__(
self,
app: AppWrapper,
config: Config,
context: WorkerContext,
task_group: TaskGroup,
tls: Optional[dict[str, Any]],
client: Optional[Tuple[str, int]],
server: Optional[Tuple[str, int]],
send: Callable[[Event], Awaitable[None]],
alpn_protocol: Optional[str] = None,
transport=None,
) -> None:
self.app = app
self.config = config
self.context = context
self.task_group = task_group
self.tls = tls
self.client = client
self.server = server
self.send = send
self.protocol: Union[H11Protocol, H2Protocol]
self.transport = transport
if alpn_protocol == "h2":
self.protocol = H2Protocol(
self.app,
self.config,
self.context,
self.task_group,
self.tls,
self.client,
self.server,
self.send,
self.transport,
)
else:
self.protocol = H11Protocol(
self.app,
self.config,
self.context,
self.task_group,
self.tls,
self.client,
self.server,
self.send,
self.transport,
)
async def initiate(self) -> None:
return await self.protocol.initiate()
async def handle(self, event: Event) -> None:
try:
return await self.protocol.handle(event)
except H2ProtocolAssumedError as error:
self.protocol = H2Protocol(
self.app,
self.config,
self.context,
self.task_group,
self.tls,
self.client,
self.server,
self.send,
)
await self.protocol.initiate()
if error.data != b"":
return await self.protocol.handle(RawData(data=error.data))
except H2CProtocolRequiredError as error:
self.protocol = H2Protocol(
self.app,
self.config,
self.context,
self.task_group,
self.tls,
self.client,
self.server,
self.send,
)
await self.protocol.initiate(error.headers, error.settings)
if error.data != b"":
return await self.protocol.handle(RawData(data=error.data))
|