1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
|
from __future__ import annotations
import asyncio
from ssl import SSLError
from typing import Any, Generator, Optional
from .task_group import TaskGroup
from .worker_context import WorkerContext
from ..config import Config
from ..events import Closed, Event, RawData, Updated
from ..protocol import ProtocolWrapper
from ..typing import AppWrapper
from ..utils import parse_socket_addr
MAX_RECV = 2**16
class TCPServer:
def __init__(
self,
app: AppWrapper,
loop: asyncio.AbstractEventLoop,
config: Config,
context: WorkerContext,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
) -> None:
self.app = app
self.config = config
self.context = context
self.loop = loop
self.protocol: ProtocolWrapper
self.reader = reader
self.writer = writer
self.send_lock = asyncio.Lock()
self.idle_lock = asyncio.Lock()
self._idle_handle: Optional[asyncio.Task] = None
def __await__(self) -> Generator[Any, None, None]:
return self.run().__await__()
async def run(self) -> None:
socket = self.writer.get_extra_info("socket")
try:
client = parse_socket_addr(socket.family, socket.getpeername())
server = parse_socket_addr(socket.family, socket.getsockname())
ssl_object = self.writer.get_extra_info("ssl_object")
if ssl_object is not None:
tls = {}
alpn_protocol = ssl_object.selected_alpn_protocol()
else:
tls = None
alpn_protocol = "http/1.1"
async with TaskGroup(self.loop) as task_group:
self.protocol = ProtocolWrapper(
self.app,
self.config,
self.context,
task_group,
tls,
client,
server,
self.protocol_send,
alpn_protocol,
(self.reader, self.writer),
)
await self.protocol.initiate()
await self._start_idle()
await self._read_data()
except OSError:
pass
finally:
await self._close()
async def protocol_send(self, event: Event) -> None:
if isinstance(event, RawData):
async with self.send_lock:
try:
self.writer.write(event.data)
await self.writer.drain()
except (ConnectionError, RuntimeError):
await self.protocol.handle(Closed())
elif isinstance(event, Closed):
await self._close()
elif isinstance(event, Updated):
if event.idle:
await self._start_idle()
else:
await self._stop_idle()
async def _read_data(self) -> None:
while not self.reader.at_eof():
try:
data = await asyncio.wait_for(self.reader.read(MAX_RECV), self.config.read_timeout)
except (
ConnectionError,
OSError,
asyncio.TimeoutError,
TimeoutError,
SSLError,
):
break
else:
await self.protocol.handle(RawData(data))
await self.protocol.handle(Closed())
async def _close(self) -> None:
try:
self.writer.write_eof()
except (NotImplementedError, OSError, RuntimeError):
pass # Likely SSL connection
try:
self.writer.close()
await self.writer.wait_closed()
except (BrokenPipeError, ConnectionAbortedError, ConnectionResetError, RuntimeError):
pass # Already closed
await self._stop_idle()
async def _initiate_server_close(self) -> None:
await self.protocol.handle(Closed())
self.writer.close()
async def _start_idle(self) -> None:
async with self.idle_lock:
if self._idle_handle is None:
self._idle_handle = self.loop.create_task(self._run_idle())
async def _stop_idle(self) -> None:
async with self.idle_lock:
if self._idle_handle is not None:
self._idle_handle.cancel()
try:
await self._idle_handle
except asyncio.CancelledError:
pass
self._idle_handle = None
async def _run_idle(self) -> None:
try:
await asyncio.wait_for(self.context.terminated.wait(), self.config.keep_alive_timeout)
except asyncio.TimeoutError:
pass
await asyncio.shield(self._initiate_server_close())
|