1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
|
import sys
from collections import defaultdict
from tempfile import SpooledTemporaryFile
from asgiref.sync import AsyncToSync, sync_to_async
class WsgiToAsgi:
"""
Wraps a WSGI application to make it into an ASGI application.
"""
def __init__(self, wsgi_application, duplicate_header_limit=100):
self.wsgi_application = wsgi_application
self.duplicate_header_limit = duplicate_header_limit
async def __call__(self, scope, receive, send):
"""
ASGI application instantiation point.
We return a new WsgiToAsgiInstance here with the WSGI app
and the scope, ready to respond when it is __call__ed.
"""
await WsgiToAsgiInstance(self.wsgi_application, self.duplicate_header_limit)(
scope, receive, send
)
class WsgiToAsgiInstance:
"""
Per-socket instance of a wrapped WSGI application
"""
def __init__(self, wsgi_application, duplicate_header_limit=100):
self.wsgi_application = wsgi_application
self.duplicate_header_limit = duplicate_header_limit
self.response_started = False
self.response_content_length = None
async def __call__(self, scope, receive, send):
if scope["type"] != "http":
raise ValueError("WSGI wrapper received a non-HTTP scope")
self.scope = scope
with SpooledTemporaryFile(max_size=65536) as body:
# Alright, wait for the http.request messages
while True:
message = await receive()
if message["type"] != "http.request":
raise ValueError("WSGI wrapper received a non-HTTP-request message")
body.write(message.get("body", b""))
if not message.get("more_body"):
break
body.seek(0)
# Wrap send so it can be called from the subthread
self.sync_send = AsyncToSync(send)
# Call the WSGI app
await self.run_wsgi_app(body)
def build_environ(self, scope, body):
"""
Builds a scope and request body into a WSGI environ object.
"""
script_name = scope.get("root_path", "").encode("utf8").decode("latin1")
path_info = scope["path"].encode("utf8").decode("latin1")
if path_info.startswith(script_name):
path_info = path_info[len(script_name) :]
environ = {
"REQUEST_METHOD": scope["method"],
"SCRIPT_NAME": script_name,
"PATH_INFO": path_info,
"QUERY_STRING": scope["query_string"].decode("ascii"),
"SERVER_PROTOCOL": "HTTP/%s" % scope["http_version"],
"wsgi.version": (1, 0),
"wsgi.url_scheme": scope.get("scheme", "http"),
"wsgi.input": body,
"wsgi.errors": sys.stderr,
"wsgi.multithread": True,
"wsgi.multiprocess": True,
"wsgi.run_once": False,
}
# Get server name and port - required in WSGI, not in ASGI
if "server" in scope:
environ["SERVER_NAME"] = scope["server"][0]
environ["SERVER_PORT"] = str(scope["server"][1])
else:
environ["SERVER_NAME"] = "localhost"
environ["SERVER_PORT"] = "80"
if scope.get("client") is not None:
environ["REMOTE_ADDR"] = scope["client"][0]
# Go through headers and make them into environ entries
_headers = defaultdict(list)
for name, value in self.scope.get("headers", []):
name = name.decode("latin1")
if name == "content-length":
corrected_name = "CONTENT_LENGTH"
elif name == "content-type":
corrected_name = "CONTENT_TYPE"
else:
corrected_name = "HTTP_%s" % name.upper().replace("-", "_")
# HTTPbis say only ASCII chars are allowed in headers, but we latin1 just in case
value = value.decode("latin1")
if (
self.duplicate_header_limit
and len(_headers[corrected_name]) >= self.duplicate_header_limit
):
raise ValueError(
f"Too many duplicate headers: {corrected_name} exceeds limit of"
f"{self.duplicate_header_limit}"
)
_headers[corrected_name].append(value)
for name, values in _headers.items():
environ[name] = ",".join(values)
return environ
def start_response(self, status, response_headers, exc_info=None):
"""
WSGI start_response callable.
"""
# Don't allow re-calling once response has begun
if self.response_started:
raise exc_info[1].with_traceback(exc_info[2])
# Don't allow re-calling without exc_info
if hasattr(self, "response_start") and exc_info is None:
raise ValueError(
"You cannot call start_response a second time without exc_info"
)
# Extract status code
status_code, _ = status.split(" ", 1)
status_code = int(status_code)
# Extract headers
headers = [
(name.lower().encode("ascii"), value.encode("ascii"))
for name, value in response_headers
]
# Extract content-length
self.response_content_length = None
for name, value in response_headers:
if name.lower() == "content-length":
self.response_content_length = int(value)
# Build and send response start message.
self.response_start = {
"type": "http.response.start",
"status": status_code,
"headers": headers,
}
@sync_to_async
def run_wsgi_app(self, body):
"""
Called in a subthread to run the WSGI app. We encapsulate like
this so that the start_response callable is called in the same thread.
"""
# Translate the scope and incoming request body into a WSGI environ
try:
environ = self.build_environ(self.scope, body)
except ValueError:
# Return 400 Bad Request if header limit exceeded
self.sync_send(
{
"type": "http.response.start",
"status": 400,
"headers": [(b"content-type", b"text/plain")],
}
)
self.sync_send(
{
"type": "http.response.body",
"body": b"Bad Request: Too many duplicate headers",
}
)
return
# Run the WSGI app
bytes_sent = 0
for output in self.wsgi_application(environ, self.start_response):
# If this is the first response, include the response headers
if not self.response_started:
self.response_started = True
self.sync_send(self.response_start)
# If the application supplies a Content-Length header
if self.response_content_length is not None:
# The server should not transmit more bytes to the client than the header allows
bytes_allowed = self.response_content_length - bytes_sent
if len(output) > bytes_allowed:
output = output[:bytes_allowed]
self.sync_send(
{"type": "http.response.body", "body": output, "more_body": True}
)
bytes_sent += len(output)
# The server should stop iterating over the response when enough data has been sent
if bytes_sent == self.response_content_length:
break
# Close connection
if not self.response_started:
self.response_started = True
self.sync_send(self.response_start)
self.sync_send({"type": "http.response.body"})
|