File: proxy3.py

package info (click to toggle)
python-moto 5.1.18-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 116,520 kB
  • sloc: python: 636,725; javascript: 181; makefile: 39; sh: 3
file content (309 lines) | stat: -rw-r--r-- 11,235 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
import re
import select
import socket
import ssl
from http.server import BaseHTTPRequestHandler
from subprocess import CalledProcessError, check_output
from threading import Lock
from typing import Any
from urllib.parse import urlparse

from botocore.awsrequest import AWSPreparedRequest

from moto.backend_index import backend_url_patterns
from moto.backends import get_backend
from moto.core import DEFAULT_ACCOUNT_ID
from moto.core.base_backend import BackendDict
from moto.core.exceptions import RESTError
from moto.core.utils import get_equivalent_url_in_aws_domain
from moto.moto_api._internal.models import moto_api_backend

from . import debug, error, info, with_color
from .certificate_creator import CertificateCreator
from .utils import get_body_from_form_data

# Adapted from https://github.com/xxlv/proxy3


class MotoRequestHandler:
    def __init__(self, port: int):
        self.lock = Lock()
        self.port = port

    def get_backend_for_host(self, host: str) -> Any:
        if host == f"http://localhost:{self.port}":
            return "moto_api"

        # Handle non-standard AWS endpoint hostnames from ISO regions or custom S3 endpoints.
        parsed_url, _ = get_equivalent_url_in_aws_domain(host)
        # Remove the querystring from the URL, as we'll never match on that
        clean_url = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path}"

        for backend, pattern in backend_url_patterns:
            if pattern.match(clean_url):
                return backend

    def get_handler_for_host(self, host: str, path: str) -> Any:
        # We do not match against URL parameters
        path = path.split("?")[0]
        backend_name = self.get_backend_for_host(host)
        if not backend_name:
            return None
        backend_dict = get_backend(backend_name)

        # Get an instance of this backend.
        # We'll only use this backend to resolve the URL's, so the exact region/account_id is irrelevant
        if isinstance(backend_dict, BackendDict):
            if "us-east-1" in backend_dict[DEFAULT_ACCOUNT_ID]:
                backend = backend_dict[DEFAULT_ACCOUNT_ID]["us-east-1"]
            else:
                backend = backend_dict[DEFAULT_ACCOUNT_ID]["aws"]
        else:
            backend = backend_dict["global"]

        for url_path, handler in backend.url_paths.items():
            if re.match(url_path, path):
                return handler

        return None

    def parse_request(
        self,
        method: str,
        host: str,
        path: str,
        headers: Any,
        body: bytes,
        form_data: dict[str, Any],
    ) -> Any:
        handler = self.get_handler_for_host(host=host, path=path)
        if handler is None:
            return 404, {}, b"AWS Service not recognized or supported"
        full_url = host + path
        request = AWSPreparedRequest(
            method, full_url, headers, body, stream_output=False
        )
        request.form_data = form_data  # type: ignore[attr-defined]
        return handler(request, full_url, headers)


class ProxyRequestHandler(BaseHTTPRequestHandler):
    timeout = 5

    def __init__(self, *args: Any, **kwargs: Any):
        sock = [a for a in args if isinstance(a, socket.socket)][0]
        _, port = sock.getsockname()
        self.protocol_version = "HTTP/1.1"
        self.moto_request_handler = MotoRequestHandler(port)
        self.cert_creator = CertificateCreator()
        BaseHTTPRequestHandler.__init__(self, *args, **kwargs)

    @staticmethod
    def validate() -> None:
        debug("Starting initial validation...")
        CertificateCreator().validate()
        # Validate the openssl command is available
        try:
            debug("Verifying SSL version...")
            svn_output = check_output(["openssl", "version"])
            debug(svn_output)
        except CalledProcessError as e:
            info(e.output)
            raise

    def do_CONNECT(self) -> None:
        address = self.path.split(":")
        port = int(address[1]) or 443
        if address[0] in moto_api_backend.proxy_hosts_to_passthrough:
            self.connect_relay((address[0], port))
            return

        certpath = self.cert_creator.create(self.path)
        self.wfile.write(
            f"{self.protocol_version} 200 Connection Established\r\n".encode()
        )
        self.send_header("k", "v")
        self.end_headers()

        ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
        ssl_context.load_cert_chain(
            keyfile=CertificateCreator.certkey,
            certfile=certpath,
        )
        ssl_context.check_hostname = False
        self.connection = ssl_context.wrap_socket(
            self.connection,
            server_side=True,
        )
        self.rfile = self.connection.makefile("rb", self.rbufsize)  # type: ignore
        self.wfile = self.connection.makefile("wb", self.wbufsize)  # type: ignore

        conntype = self.headers.get("Proxy-Connection", "")
        if self.protocol_version == "HTTP/1.1" and conntype.lower() != "close":
            self.close_connection = 0  # type: ignore
        else:
            self.close_connection = 1  # type: ignore

    def do_GET(self) -> None:
        req = self
        host, path = self._get_host_and_path(req)

        if f"{host}{path}" in moto_api_backend.proxy_urls_to_passthrough:
            parsed = urlparse(host)
            target_host, target_port = parsed.netloc, "80"
            if ":" in target_host:
                target_host, target_port = target_host.split(":")
            self.passthrough_http((target_host, int(target_port)))
            return

        req_body = b""
        if "chunked" in self.headers.get("Transfer-Encoding", ""):
            # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Transfer-Encoding
            req_body = self.read_chunked_body(self.rfile)
        elif "Content-Length" in req.headers:
            content_length = int(req.headers["Content-Length"])
            req_body = self.rfile.read(content_length)
        if self.headers.get("Content-Type", "").startswith("multipart/form-data"):
            boundary = self.headers["Content-Type"].split("boundary=")[-1]
            req_body, form_data = get_body_from_form_data(req_body, boundary)  # type: ignore
            for key, val in form_data.items():
                self.headers[key] = [val]  # type: ignore
        else:
            form_data = {}

        req_body = self.decode_request_body(req.headers, req_body)  # type: ignore

        try:
            info(f"{with_color(33, req.command.upper())} {host}{path}")  # noqa
            if req_body is not None:
                debug("\tbody\t" + with_color(31, text=req_body))
            debug(f"\theaders\t{with_color(31, text=dict(req.headers))}")
            response = self.moto_request_handler.parse_request(
                method=req.command,
                host=host,
                path=path,
                headers=req.headers,
                body=req_body,
                form_data=form_data,
            )
            debug("\t=====RESPONSE========")
            debug("\t" + with_color(color=33, text=response))
            debug("\n")

            if isinstance(response, tuple):
                res_status, res_headers, res_body = response
            else:
                res_status, res_headers, res_body = (200, {}, response)

        except RESTError as e:
            if isinstance(e.get_headers(), list):
                res_headers = dict(e.get_headers())
            else:
                res_headers = e.get_headers()
            res_status = e.code
            res_body = e.get_body()

        except Exception as e:
            error(e)
            self.send_error(502)
            return

        res_reason = "OK"
        if isinstance(res_body, str):
            res_body = res_body.encode("utf-8")

        if "content-length" not in res_headers and res_body:
            res_headers["Content-Length"] = str(len(res_body))

        self.wfile.write(
            f"{self.protocol_version} {res_status} {res_reason}\r\n".encode()
        )
        if res_headers:
            for k, v in res_headers.items():
                if isinstance(v, bytes):
                    self.send_header(k, v.decode("utf-8"))
                else:
                    self.send_header(k, v)
            self.end_headers()
        if res_body:
            self.wfile.write(res_body)
        self.close_connection = True

    def _get_host_and_path(self, req: Any) -> tuple[str, str]:
        if isinstance(self.connection, ssl.SSLSocket):
            host = "https://" + req.headers["Host"]
        else:
            host = "http://" + req.headers["Host"]
        path = req.path
        if path.startswith(host):
            path = path[len(host) :]
        return host, path

    def passthrough_http(self, address: tuple[str, int]) -> None:
        s = socket.create_connection(address, timeout=self.timeout)
        s.send(self.raw_requestline)  # type: ignore[attr-defined]
        for key, val in self.headers.items():
            s.send(f"{key}: {val}\r\n".encode())
        s.send(b"\r\n")
        while True:
            data = s.recv(1024)
            if not data:
                break
            self.wfile.write(data)

    def connect_relay(self, address: tuple[str, int]) -> None:
        try:
            s = socket.create_connection(address, timeout=self.timeout)
        except Exception:
            self.send_error(502)
            return
        self.send_response(200, "Connection Established")
        self.end_headers()

        conns = [self.connection, s]
        self.close_connection = False
        while not self.close_connection:
            rlist, wlist, xlist = select.select(conns, [], conns, self.timeout)
            if xlist or not rlist:
                break
            for r in rlist:
                other = conns[1] if r is conns[0] else conns[0]
                data = r.recv(8192)
                if not data:
                    self.close_connection = True
                    break
                other.sendall(data)

    def read_chunked_body(self, reader: Any) -> bytes:
        chunked_body = b""
        while True:
            line = reader.readline().strip()
            chunk_length = int(line, 16)
            if chunk_length != 0:
                chunked_body += reader.read(chunk_length)

            # Each chunk is followed by an additional empty newline
            reader.readline()

            # a chunk size of 0 is an end indication
            if chunk_length == 0:
                # AWS does send additional (checksum-)headers, but we can ignore them
                break
        return chunked_body

    def decode_request_body(self, headers: dict[str, str], body: Any) -> Any:
        if body is None:
            return body
        if headers.get("Content-Type", "") in [
            "application/x-amz-json-1.1",
            "application/x-www-form-urlencoded; charset=utf-8",
        ]:
            return body.decode("utf-8")
        return body

    do_HEAD = do_GET
    do_POST = do_GET
    do_PUT = do_GET
    do_PATCH = do_GET
    do_DELETE = do_GET
    do_OPTIONS = do_GET