1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
|
from __future__ import annotations
from enum import auto, Enum
from time import time
from typing import Any, Awaitable, Callable, Optional, Tuple
from urllib.parse import unquote
from .events import Body, EndBody, Event, InformationalResponse, Request, Response, StreamClosed
from ..config import Config
from ..typing import (
AppWrapper,
ASGISendEvent,
HTTPResponseStartEvent,
HTTPScope,
TaskGroup,
WorkerContext,
)
from ..utils import (
build_and_validate_headers,
suppress_body,
UnexpectedMessageError,
valid_server_name,
)
PUSH_VERSIONS = {"2", "3"}
EARLY_HINTS_VERSIONS = {"2", "3"}
class ASGIHTTPState(Enum):
# The ASGI Spec is clear that a response should not start till the
# framework has sent at least one body message hence why this
# state tracking is required.
REQUEST = auto()
RESPONSE = auto()
CLOSED = auto()
class HTTPStream:
def __init__(
self,
app: AppWrapper,
config: Config,
context: WorkerContext,
task_group: TaskGroup,
tls: Optional[dict[str, Any]],
client: Optional[Tuple[str, int]],
server: Optional[Tuple[str, int]],
send: Callable[[Event], Awaitable[None]],
stream_id: int,
transport=None,
) -> None:
self.app = app
self.client = client
self.closed = False
self.config = config
self.context = context
self.response: HTTPResponseStartEvent
self.scope: HTTPScope
self.send = send
self.scheme = "https" if tls is not None else "http"
self.tls = tls
self.server = server
self.start_time: float
self.state = ASGIHTTPState.REQUEST
self.stream_id = stream_id
self.task_group = task_group
self.transport = transport
@property
def idle(self) -> bool:
return False
async def handle(self, event: Event) -> None:
if self.closed:
return
elif isinstance(event, Request):
self.start_time = time()
path, _, query_string = event.raw_path.partition(b"?")
self.scope = {
"type": "http",
"http_version": event.http_version,
"asgi": {"spec_version": "2.1", "version": "3.0"},
"method": event.method,
"scheme": self.scheme,
"path": unquote(path.decode("ascii")),
"raw_path": path,
"query_string": query_string,
"root_path": self.config.root_path,
"headers": event.headers,
"client": self.client,
"server": self.server,
"extensions": {},
}
if event.http_version in PUSH_VERSIONS:
self.scope["extensions"]["http.response.push"] = {}
if event.http_version in EARLY_HINTS_VERSIONS:
self.scope["extensions"]["http.response.early_hint"] = {}
if self.tls is not None:
self.scope["extensions"]["tls"] = self.tls
if self.transport is not None:
self.scope["extensions"]["_transport"] = self.transport
if valid_server_name(self.config, event):
self.app_put = await self.task_group.spawn_app(
self.app, self.config, self.scope, self.app_send
)
else:
await self._send_error_response(404)
self.closed = True
elif isinstance(event, Body):
await self.app_put(
{"type": "http.request", "body": bytes(event.data), "more_body": True}
)
elif isinstance(event, EndBody):
await self.app_put({"type": "http.request", "body": b"", "more_body": False})
elif isinstance(event, StreamClosed):
self.closed = True
await self.config.log.access(self.scope, None, time() - self.start_time)
if self.app_put is not None:
await self.app_put({"type": "http.disconnect"})
async def app_send(self, message: Optional[ASGISendEvent]) -> None:
if message is None: # ASGI App has finished sending messages
if not self.closed:
# Cleanup if required
if self.state == ASGIHTTPState.REQUEST:
await self._send_error_response(500)
await self.send(StreamClosed(stream_id=self.stream_id))
else:
if message["type"] == "http.response.start" and self.state == ASGIHTTPState.REQUEST:
self.response = message
elif (
message["type"] == "http.response.push"
and self.scope["http_version"] in PUSH_VERSIONS
):
if not isinstance(message["path"], str):
raise TypeError(f"{message['path']} should be a str")
headers = [(b":scheme", self.scope["scheme"].encode())]
for name, value in self.scope["headers"]:
if name == b"host":
headers.append((b":authority", value))
headers.extend(build_and_validate_headers(message["headers"]))
await self.send(
Request(
stream_id=self.stream_id,
headers=headers,
http_version=self.scope["http_version"],
method="GET",
raw_path=message["path"].encode(),
)
)
elif (
message["type"] == "http.response.early_hint"
and self.scope["http_version"] in EARLY_HINTS_VERSIONS
and self.state == ASGIHTTPState.REQUEST
):
headers = [(b"link", bytes(link).strip()) for link in message["links"]]
await self.send(
InformationalResponse(
stream_id=self.stream_id,
headers=headers,
status_code=103,
)
)
elif message["type"] == "http.response.body" and self.state in {
ASGIHTTPState.REQUEST,
ASGIHTTPState.RESPONSE,
}:
if self.state == ASGIHTTPState.REQUEST:
headers = build_and_validate_headers(self.response.get("headers", []))
await self.send(
Response(
stream_id=self.stream_id,
headers=headers,
status_code=int(self.response["status"]),
)
)
self.state = ASGIHTTPState.RESPONSE
if (
not suppress_body(self.scope["method"], int(self.response["status"]))
and message.get("body", b"") != b""
):
await self.send(
Body(stream_id=self.stream_id, data=bytes(message.get("body", b"")))
)
if not message.get("more_body", False):
if self.state != ASGIHTTPState.CLOSED:
self.state = ASGIHTTPState.CLOSED
await self.config.log.access(
self.scope, self.response, time() - self.start_time
)
await self.send(EndBody(stream_id=self.stream_id))
await self.send(StreamClosed(stream_id=self.stream_id))
else:
raise UnexpectedMessageError(self.state, message["type"])
async def _send_error_response(self, status_code: int) -> None:
await self.send(
Response(
stream_id=self.stream_id,
headers=[(b"content-length", b"0"), (b"connection", b"close")],
status_code=status_code,
)
)
await self.send(EndBody(stream_id=self.stream_id))
self.state = ASGIHTTPState.CLOSED
await self.config.log.access(
self.scope, {"status": status_code, "headers": []}, time() - self.start_time
)
|