File: http_base.py

package info (click to toggle)
python-scrapy 2.14.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 6,308 kB
  • sloc: python: 55,321; xml: 199; makefile: 25; sh: 7
file content (132 lines) | stat: -rw-r--r-- 4,361 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
"""Base classes and functions for HTTP mockservers."""

from __future__ import annotations

import argparse
import sys
from abc import ABC, abstractmethod
from subprocess import PIPE, Popen
from typing import TYPE_CHECKING
from urllib.parse import urlparse

from twisted.web.server import Site

from tests.utils import get_script_run_env

from .utils import ssl_context_factory

if TYPE_CHECKING:
    from collections.abc import Callable

    from twisted.web import resource


class BaseMockServer(ABC):
    listen_http: bool = True
    listen_https: bool = True

    @property
    @abstractmethod
    def module_name(self) -> str:
        raise NotImplementedError

    def __init__(self) -> None:
        if not self.listen_http and not self.listen_https:
            raise ValueError("At least one of listen_http and listen_https must be set")

        self.proc: Popen | None = None
        self.host: str = "127.0.0.1"
        self.http_port: int | None = None
        self.https_port: int | None = None

    def __enter__(self):
        self.proc = Popen(
            [sys.executable, "-u", "-m", self.module_name, *self.get_additional_args()],
            stdout=PIPE,
            env=get_script_run_env(),
        )
        if self.listen_http:
            http_address = self.proc.stdout.readline().strip().decode("ascii")
            http_parsed = urlparse(http_address)
            self.http_port = http_parsed.port
        if self.listen_https:
            https_address = self.proc.stdout.readline().strip().decode("ascii")
            https_parsed = urlparse(https_address)
            self.https_port = https_parsed.port
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        if self.proc:
            self.proc.kill()
            self.proc.communicate()

    def get_additional_args(self) -> list[str]:
        return []

    def port(self, is_secure: bool = False) -> int:
        if not is_secure and not self.listen_http:
            raise ValueError("This server doesn't provide HTTP")
        if is_secure and not self.listen_https:
            raise ValueError("This server doesn't provide HTTPS")
        port = self.https_port if is_secure else self.http_port
        assert port is not None
        return port

    def url(self, path: str, is_secure: bool = False) -> str:
        port = self.port(is_secure)
        scheme = "https" if is_secure else "http"
        return f"{scheme}://{self.host}:{port}{path}"


def main_factory(
    resource_class: type[resource.Resource],
    *,
    listen_http: bool = True,
    listen_https: bool = True,
) -> Callable[[], None]:
    if not listen_http and not listen_https:
        raise ValueError("At least one of listen_http and listen_https must be set")

    def main() -> None:
        from twisted.internet import reactor

        root = resource_class()
        factory = Site(root)

        if listen_http:
            http_port = reactor.listenTCP(0, factory)

        if listen_https:
            parser = argparse.ArgumentParser()
            parser.add_argument("--keyfile", help="SSL key file")
            parser.add_argument("--certfile", help="SSL certificate file")
            parser.add_argument(
                "--cipher-string",
                default=None,
                help="SSL cipher string (optional)",
            )
            args = parser.parse_args()
            context_factory_kw = {}
            if args.keyfile:
                context_factory_kw["keyfile"] = args.keyfile
            if args.certfile:
                context_factory_kw["certfile"] = args.certfile
            if args.cipher_string:
                context_factory_kw["cipher_string"] = args.cipher_string
            context_factory = ssl_context_factory(**context_factory_kw)
            https_port = reactor.listenSSL(0, factory, context_factory)

        def print_listening():
            if listen_http:
                http_host = http_port.getHost()
                http_address = f"http://{http_host.host}:{http_host.port}"
                print(http_address)
            if listen_https:
                https_host = https_port.getHost()
                https_address = f"https://{https_host.host}:{https_host.port}"
                print(https_address)

        reactor.callWhenRunning(print_listening)
        reactor.run()

    return main