1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
|
from __future__ import annotations
import ssl
from math import inf
from typing import Any, Generator, Optional
import trio
from .task_group import TaskGroup
from .worker_context import WorkerContext
from ..config import Config
from ..events import Closed, Event, RawData, Updated
from ..protocol import ProtocolWrapper
from ..typing import AppWrapper
from ..utils import parse_socket_addr
MAX_RECV = 2**16
class TCPServer:
def __init__(
self, app: AppWrapper, config: Config, context: WorkerContext, stream: trio.abc.Stream
) -> None:
self.app = app
self.config = config
self.context = context
self.protocol: ProtocolWrapper
self.send_lock = trio.Lock()
self.idle_lock = trio.Lock()
self.stream = stream
self._idle_handle: Optional[trio.CancelScope] = None
def __await__(self) -> Generator[Any, None, None]:
return self.run().__await__()
async def run(self) -> None:
try:
try:
with trio.fail_after(self.config.ssl_handshake_timeout):
await self.stream.do_handshake()
except (trio.BrokenResourceError, trio.TooSlowError):
return # Handshake failed
alpn_protocol = self.stream.selected_alpn_protocol()
socket = self.stream.transport_stream.socket
tls = {"alpn_protocol": alpn_protocol}
client_certificate = self.stream.getpeercert(binary_form=False)
if client_certificate:
tls["client_cert_name"] = ", ".join(
[f"{part[0][0]}={part[0][1]}" for part in client_certificate["subject"]]
)
except AttributeError: # Not SSL
alpn_protocol = "http/1.1"
socket = self.stream.socket
tls = None
try:
client = parse_socket_addr(socket.family, socket.getpeername())
server = parse_socket_addr(socket.family, socket.getsockname())
async with TaskGroup() as task_group:
self._task_group = task_group
self.protocol = ProtocolWrapper(
self.app,
self.config,
self.context,
task_group,
tls,
client,
server,
self.protocol_send,
alpn_protocol,
self.stream,
)
await self.protocol.initiate()
await self._start_idle()
await self._read_data()
except OSError:
pass
finally:
await self._close()
async def protocol_send(self, event: Event) -> None:
if isinstance(event, RawData):
async with self.send_lock:
try:
with trio.CancelScope() as cancel_scope:
cancel_scope.shield = True
await self.stream.send_all(event.data)
except (trio.BrokenResourceError, trio.ClosedResourceError):
await self.protocol.handle(Closed())
elif isinstance(event, Closed):
await self._close()
await self.protocol.handle(Closed())
elif isinstance(event, Updated):
if event.idle:
await self._start_idle()
else:
await self._stop_idle()
async def _read_data(self) -> None:
while True:
try:
with trio.fail_after(self.config.read_timeout or inf):
data = await self.stream.receive_some(MAX_RECV)
except (
trio.ClosedResourceError,
trio.BrokenResourceError,
trio.TooSlowError,
):
break
else:
await self.protocol.handle(RawData(data))
if data == b"":
break
await self.protocol.handle(Closed())
async def _close(self) -> None:
try:
await self.stream.send_eof()
except (
trio.BrokenResourceError,
AttributeError,
trio.BusyResourceError,
trio.ClosedResourceError,
):
# They're already gone, nothing to do
# Or it is a SSL stream
pass
await self.stream.aclose()
async def _initiate_server_close(self) -> None:
await self.protocol.handle(Closed())
await self.stream.aclose()
async def _start_idle(self) -> None:
async with self.idle_lock:
if self._idle_handle is None:
self._idle_handle = await self._task_group._nursery.start(self._run_idle)
async def _stop_idle(self) -> None:
async with self.idle_lock:
if self._idle_handle is not None:
self._idle_handle.cancel()
self._idle_handle = None
async def _run_idle(
self,
task_status: trio._core._run._TaskStatus = trio.TASK_STATUS_IGNORED,
) -> None:
cancel_scope = trio.CancelScope()
task_status.started(cancel_scope)
with cancel_scope:
with trio.move_on_after(self.config.keep_alive_timeout):
await self.context.terminated.wait()
cancel_scope.shield = True
await self._initiate_server_close()
|