File: headers.py

package info (click to toggle)
python-scrapy 2.13.3-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 5,664 kB
  • sloc: python: 52,028; xml: 199; makefile: 25; sh: 7
file content (130 lines) | stat: -rw-r--r-- 4,252 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from __future__ import annotations

from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, AnyStr, Union, cast

from w3lib.http import headers_dict_to_raw

from scrapy.utils.datatypes import CaseInsensitiveDict, CaselessDict
from scrapy.utils.python import to_unicode

if TYPE_CHECKING:
    from collections.abc import Iterable

    # typing.Self requires Python 3.11
    from typing_extensions import Self


_RawValueT = Union[bytes, str, int]


# isn't fully compatible typing-wise with either dict or CaselessDict,
# but it needs refactoring anyway, see also https://github.com/scrapy/scrapy/pull/5146
class Headers(CaselessDict):
    """Case insensitive http headers dictionary"""

    def __init__(
        self,
        seq: Mapping[AnyStr, Any] | Iterable[tuple[AnyStr, Any]] | None = None,
        encoding: str = "utf-8",
    ):
        self.encoding: str = encoding
        super().__init__(seq)

    def update(  # type: ignore[override]
        self, seq: Mapping[AnyStr, Any] | Iterable[tuple[AnyStr, Any]]
    ) -> None:
        seq = seq.items() if isinstance(seq, Mapping) else seq
        iseq: dict[bytes, list[bytes]] = {}
        for k, v in seq:
            iseq.setdefault(self.normkey(k), []).extend(self.normvalue(v))
        super().update(iseq)

    def normkey(self, key: AnyStr) -> bytes:  # type: ignore[override]
        """Normalize key to bytes"""
        return self._tobytes(key.title())

    def normvalue(self, value: _RawValueT | Iterable[_RawValueT]) -> list[bytes]:
        """Normalize values to bytes"""
        _value: Iterable[_RawValueT]
        if value is None:
            _value = []
        elif isinstance(value, (str, bytes)):
            _value = [value]
        elif hasattr(value, "__iter__"):
            _value = value
        else:
            _value = [value]

        return [self._tobytes(x) for x in _value]

    def _tobytes(self, x: _RawValueT) -> bytes:
        if isinstance(x, bytes):
            return x
        if isinstance(x, str):
            return x.encode(self.encoding)
        if isinstance(x, int):
            return str(x).encode(self.encoding)
        raise TypeError(f"Unsupported value type: {type(x)}")

    def __getitem__(self, key: AnyStr) -> bytes | None:
        try:
            return cast(list[bytes], super().__getitem__(key))[-1]
        except IndexError:
            return None

    def get(self, key: AnyStr, def_val: Any = None) -> bytes | None:
        try:
            return cast(list[bytes], super().get(key, def_val))[-1]
        except IndexError:
            return None

    def getlist(self, key: AnyStr, def_val: Any = None) -> list[bytes]:
        try:
            return cast(list[bytes], super().__getitem__(key))
        except KeyError:
            if def_val is not None:
                return self.normvalue(def_val)
            return []

    def setlist(self, key: AnyStr, list_: Iterable[_RawValueT]) -> None:
        self[key] = list_

    def setlistdefault(
        self, key: AnyStr, default_list: Iterable[_RawValueT] = ()
    ) -> Any:
        return self.setdefault(key, default_list)

    def appendlist(self, key: AnyStr, value: Iterable[_RawValueT]) -> None:
        lst = self.getlist(key)
        lst.extend(self.normvalue(value))
        self[key] = lst

    def items(self) -> Iterable[tuple[bytes, list[bytes]]]:  # type: ignore[override]
        return ((k, self.getlist(k)) for k in self.keys())

    def values(self) -> list[bytes | None]:  # type: ignore[override]
        return [
            self[k]
            for k in self.keys()  # pylint: disable=consider-using-dict-items
        ]

    def to_string(self) -> bytes:
        return headers_dict_to_raw(self)

    def to_unicode_dict(self) -> CaseInsensitiveDict:
        """Return headers as a CaseInsensitiveDict with str keys
        and str values. Multiple values are joined with ','.
        """
        return CaseInsensitiveDict(
            (
                to_unicode(key, encoding=self.encoding),
                to_unicode(b",".join(value), encoding=self.encoding),
            )
            for key, value in self.items()
        )

    def __copy__(self) -> Self:
        return self.__class__(self)

    copy = __copy__