File: httpcompression.py

package info (click to toggle)
python-scrapy 2.13.3-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 5,664 kB
  • sloc: python: 52,028; xml: 199; makefile: 25; sh: 7
file content (193 lines) | stat: -rw-r--r-- 7,138 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
from __future__ import annotations

from itertools import chain
from logging import getLogger
from typing import TYPE_CHECKING, Any

from scrapy import Request, Spider, signals
from scrapy.exceptions import IgnoreRequest, NotConfigured
from scrapy.http import Response, TextResponse
from scrapy.responsetypes import responsetypes
from scrapy.utils._compression import (
    _DecompressionMaxSizeExceeded,
    _inflate,
    _unbrotli,
    _unzstd,
)
from scrapy.utils.gz import gunzip

if TYPE_CHECKING:
    # typing.Self requires Python 3.11
    from typing_extensions import Self

    from scrapy.crawler import Crawler
    from scrapy.statscollectors import StatsCollector


logger = getLogger(__name__)

ACCEPTED_ENCODINGS: list[bytes] = [b"gzip", b"deflate"]

try:
    try:
        import brotli  # noqa: F401
    except ImportError:
        import brotlicffi  # noqa: F401
except ImportError:
    pass
else:
    ACCEPTED_ENCODINGS.append(b"br")

try:
    import zstandard  # noqa: F401
except ImportError:
    pass
else:
    ACCEPTED_ENCODINGS.append(b"zstd")


class HttpCompressionMiddleware:
    """This middleware allows compressed (gzip, deflate) traffic to be
    sent/received from websites"""

    def __init__(
        self,
        stats: StatsCollector | None = None,
        *,
        crawler: Crawler | None = None,
    ):
        if not crawler:
            self.stats = stats
            self._max_size = 1073741824
            self._warn_size = 33554432
            return
        self.stats = crawler.stats
        self._max_size = crawler.settings.getint("DOWNLOAD_MAXSIZE")
        self._warn_size = crawler.settings.getint("DOWNLOAD_WARNSIZE")
        crawler.signals.connect(self.open_spider, signals.spider_opened)

    @classmethod
    def from_crawler(cls, crawler: Crawler) -> Self:
        if not crawler.settings.getbool("COMPRESSION_ENABLED"):
            raise NotConfigured
        return cls(crawler=crawler)

    def open_spider(self, spider: Spider) -> None:
        if hasattr(spider, "download_maxsize"):
            self._max_size = spider.download_maxsize
        if hasattr(spider, "download_warnsize"):
            self._warn_size = spider.download_warnsize

    def process_request(
        self, request: Request, spider: Spider
    ) -> Request | Response | None:
        request.headers.setdefault("Accept-Encoding", b", ".join(ACCEPTED_ENCODINGS))
        return None

    def process_response(
        self, request: Request, response: Response, spider: Spider
    ) -> Request | Response:
        if request.method == "HEAD":
            return response
        if isinstance(response, Response):
            content_encoding = response.headers.getlist("Content-Encoding")
            if content_encoding:
                max_size = request.meta.get("download_maxsize", self._max_size)
                warn_size = request.meta.get("download_warnsize", self._warn_size)
                try:
                    decoded_body, content_encoding = self._handle_encoding(
                        response.body, content_encoding, max_size
                    )
                except _DecompressionMaxSizeExceeded:
                    raise IgnoreRequest(
                        f"Ignored response {response} because its body "
                        f"({len(response.body)} B compressed) exceeded "
                        f"DOWNLOAD_MAXSIZE ({max_size} B) during "
                        f"decompression."
                    )
                if len(response.body) < warn_size <= len(decoded_body):
                    logger.warning(
                        f"{response} body size after decompression "
                        f"({len(decoded_body)} B) is larger than the "
                        f"download warning size ({warn_size} B)."
                    )
                if content_encoding:
                    self._warn_unknown_encoding(response, content_encoding)
                response.headers["Content-Encoding"] = content_encoding
                if self.stats:
                    self.stats.inc_value(
                        "httpcompression/response_bytes",
                        len(decoded_body),
                        spider=spider,
                    )
                    self.stats.inc_value(
                        "httpcompression/response_count", spider=spider
                    )
                respcls = responsetypes.from_args(
                    headers=response.headers, url=response.url, body=decoded_body
                )
                kwargs: dict[str, Any] = {"body": decoded_body}
                if issubclass(respcls, TextResponse):
                    # force recalculating the encoding until we make sure the
                    # responsetypes guessing is reliable
                    kwargs["encoding"] = None
                response = response.replace(cls=respcls, **kwargs)
                if not content_encoding:
                    del response.headers["Content-Encoding"]

        return response

    def _handle_encoding(
        self, body: bytes, content_encoding: list[bytes], max_size: int
    ) -> tuple[bytes, list[bytes]]:
        to_decode, to_keep = self._split_encodings(content_encoding)
        for encoding in to_decode:
            body = self._decode(body, encoding, max_size)
        return body, to_keep

    @staticmethod
    def _split_encodings(
        content_encoding: list[bytes],
    ) -> tuple[list[bytes], list[bytes]]:
        supported_encodings = {*ACCEPTED_ENCODINGS, b"x-gzip"}
        to_keep: list[bytes] = [
            encoding.strip().lower()
            for encoding in chain.from_iterable(
                encodings.split(b",") for encodings in content_encoding
            )
        ]
        to_decode: list[bytes] = []
        while to_keep:
            encoding = to_keep.pop()
            if encoding not in supported_encodings:
                to_keep.append(encoding)
                return to_decode, to_keep
            to_decode.append(encoding)
        return to_decode, to_keep

    @staticmethod
    def _decode(body: bytes, encoding: bytes, max_size: int) -> bytes:
        if encoding in {b"gzip", b"x-gzip"}:
            return gunzip(body, max_size=max_size)
        if encoding == b"deflate":
            return _inflate(body, max_size=max_size)
        if encoding == b"br":
            return _unbrotli(body, max_size=max_size)
        if encoding == b"zstd":
            return _unzstd(body, max_size=max_size)
        # shouldn't be reached
        return body  # pragma: no cover

    def _warn_unknown_encoding(
        self, response: Response, encodings: list[bytes]
    ) -> None:
        encodings_str = b",".join(encodings).decode()
        msg = (
            f"{self.__class__.__name__} cannot decode the response for {response.url} "
            f"from unsupported encoding(s) '{encodings_str}'."
        )
        if b"br" in encodings:
            msg += " You need to install brotli or brotlicffi to decode 'br'."
        if b"zstd" in encodings:
            msg += " You need to install zstandard to decode 'zstd'."
        logger.warning(msg)