1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
|
from __future__ import annotations
from functools import partial
from io import BytesIO
from typing import Callable, List, Optional, Tuple
from .typing import (
ASGIFramework,
ASGIReceiveCallable,
ASGISendCallable,
HTTPScope,
Scope,
WSGIFramework,
)
class InvalidPathError(Exception):
pass
class ASGIWrapper:
def __init__(self, app: ASGIFramework) -> None:
self.app = app
async def __call__(
self,
scope: Scope,
receive: ASGIReceiveCallable,
send: ASGISendCallable,
sync_spawn: Callable,
call_soon: Callable,
) -> None:
await self.app(scope, receive, send)
class WSGIWrapper:
def __init__(self, app: WSGIFramework, max_body_size: int) -> None:
self.app = app
self.max_body_size = max_body_size
async def __call__(
self,
scope: Scope,
receive: ASGIReceiveCallable,
send: ASGISendCallable,
sync_spawn: Callable,
call_soon: Callable,
) -> None:
if scope["type"] == "http":
await self.handle_http(scope, receive, send, sync_spawn, call_soon)
elif scope["type"] == "websocket":
await send({"type": "websocket.close"}) # type: ignore
elif scope["type"] == "lifespan":
return
else:
raise Exception(f"Unknown scope type, {scope['type']}")
async def handle_http(
self,
scope: HTTPScope,
receive: ASGIReceiveCallable,
send: ASGISendCallable,
sync_spawn: Callable,
call_soon: Callable,
) -> None:
body = bytearray()
while True:
message = await receive()
body.extend(message.get("body", b"")) # type: ignore
if len(body) > self.max_body_size:
await send({"type": "http.response.start", "status": 400, "headers": []})
await send({"type": "http.response.body", "body": b"", "more_body": False})
return
if not message.get("more_body"):
break
try:
environ = _build_environ(scope, body)
except InvalidPathError:
await send({"type": "http.response.start", "status": 404, "headers": []})
else:
await sync_spawn(self.run_app, environ, partial(call_soon, send))
await send({"type": "http.response.body", "body": b"", "more_body": False})
def run_app(self, environ: dict, send: Callable) -> None:
headers: List[Tuple[bytes, bytes]]
status_code: Optional[int] = None
def start_response(
status: str,
response_headers: List[Tuple[str, str]],
exc_info: Optional[Exception] = None,
) -> None:
nonlocal headers, status_code
raw, _ = status.split(" ", 1)
status_code = int(raw)
headers = [
(name.lower().encode("ascii"), value.encode("ascii"))
for name, value in response_headers
]
send({"type": "http.response.start", "status": status_code, "headers": headers})
for output in self.app(environ, start_response):
send({"type": "http.response.body", "body": output, "more_body": True})
def _build_environ(scope: HTTPScope, body: bytes) -> dict:
server = scope.get("server") or ("localhost", 80)
path = scope["path"]
script_name = scope.get("root_path", "")
if path.startswith(script_name):
path = path[len(script_name) :]
path = path if path != "" else "/"
else:
raise InvalidPathError()
environ = {
"REQUEST_METHOD": scope["method"],
"SCRIPT_NAME": script_name.encode("utf8").decode("latin1"),
"PATH_INFO": path.encode("utf8").decode("latin1"),
"QUERY_STRING": scope["query_string"].decode("ascii"),
"SERVER_NAME": server[0],
"SERVER_PORT": server[1],
"SERVER_PROTOCOL": "HTTP/%s" % scope["http_version"],
"wsgi.version": (1, 0),
"wsgi.url_scheme": scope.get("scheme", "http"),
"wsgi.input": BytesIO(body),
"wsgi.errors": BytesIO(),
"wsgi.multithread": True,
"wsgi.multiprocess": True,
"wsgi.run_once": False,
}
if scope.get("client") is not None:
environ["REMOTE_ADDR"] = scope["client"][0]
for raw_name, raw_value in scope.get("headers", []):
name = raw_name.decode("latin1")
if name == "content-length":
corrected_name = "CONTENT_LENGTH"
elif name == "content-type":
corrected_name = "CONTENT_TYPE"
else:
corrected_name = "HTTP_%s" % name.upper().replace("-", "_")
# HTTPbis say only ASCII chars are allowed in headers, but we latin1 just in case
value = raw_value.decode("latin1")
if corrected_name in environ:
value = environ[corrected_name] + "," + value # type: ignore
environ[corrected_name] = value
return environ
|