File: client_utils.py

package info (click to toggle)
python-elastic-transport 8.17.1-2
  • links: PTS, VCS
  • area: main
  • in suites: sid, trixie
  • size: 624 kB
  • sloc: python: 6,549; makefile: 31
file content (249 lines) | stat: -rw-r--r-- 8,227 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
#  Licensed to Elasticsearch B.V. under one or more contributor
#  license agreements. See the NOTICE file distributed with
#  this work for additional information regarding copyright
#  ownership. Elasticsearch B.V. licenses this file to you under
#  the Apache License, Version 2.0 (the "License"); you may
#  not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
# 	http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing,
#  software distributed under the License is distributed on an
#  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#  KIND, either express or implied.  See the License for the
#  specific language governing permissions and limitations
#  under the License.

import base64
import binascii
import dataclasses
import re
import urllib.parse
from platform import python_version
from typing import Optional, Tuple, TypeVar, Union
from urllib.parse import quote as _quote

from urllib3.exceptions import LocationParseError
from urllib3.util import parse_url

from ._models import DEFAULT, DefaultType, NodeConfig
from ._utils import fixup_module_metadata
from ._version import __version__

__all__ = [
    "CloudId",
    "DEFAULT",
    "DefaultType",
    "basic_auth_to_header",
    "client_meta_version",
    "create_user_agent",
    "dataclasses",
    "parse_cloud_id",
    "percent_encode",
    "resolve_default",
    "to_bytes",
    "to_str",
    "url_to_node_config",
]

T = TypeVar("T")


def resolve_default(val: Union[DefaultType, T], default: T) -> T:
    """Resolves a value that could be the ``DEFAULT`` sentinel
    into either the given value or the default value.
    """
    return val if val is not DEFAULT else default


def create_user_agent(name: str, version: str) -> str:
    """Creates the 'User-Agent' header given the library name and version"""
    return (
        f"{name}/{version} (Python/{python_version()}; elastic-transport/{__version__})"
    )


def client_meta_version(version: str) -> str:
    """Converts a Python version into a version string
    compatible with the ``X-Elastic-Client-Meta`` HTTP header.
    """
    match = re.match(r"^([0-9][0-9.]*[0-9]|[0-9])(.*)$", version)
    if match is None:
        raise ValueError(
            "Version {version!r} not formatted like a Python version string"
        )
    version, version_suffix = match.groups()

    # Don't treat post-releases as pre-releases.
    if re.search(r"^\.post[0-9]*$", version_suffix):
        return version
    if version_suffix:
        version += "p"
    return version


@dataclasses.dataclass(frozen=True, repr=True)
class CloudId:
    #: Name of the cluster in Elastic Cloud
    cluster_name: str
    #: Host and port of the Elasticsearch instance
    es_address: Optional[Tuple[str, int]]
    #: Host and port of the Kibana instance
    kibana_address: Optional[Tuple[str, int]]


def parse_cloud_id(cloud_id: str) -> CloudId:
    """Parses an Elastic Cloud ID into its components"""
    try:
        cloud_id = to_str(cloud_id)
        cluster_name, _, cloud_id = cloud_id.partition(":")
        parts = to_str(binascii.a2b_base64(to_bytes(cloud_id, "ascii")), "ascii").split(
            "$"
        )
        parent_dn = parts[0]
        if not parent_dn:
            raise ValueError()  # Caught and re-raised properly below

        es_uuid: Optional[str]
        kibana_uuid: Optional[str]
        try:
            es_uuid = parts[1]
        except IndexError:
            es_uuid = None
        try:
            kibana_uuid = parts[2] or None
        except IndexError:
            kibana_uuid = None

        if ":" in parent_dn:
            parent_dn, _, parent_port = parent_dn.rpartition(":")
            port = int(parent_port)
        else:
            port = 443
    except (ValueError, IndexError, UnicodeError):
        raise ValueError("Cloud ID is not properly formatted") from None

    es_host = f"{es_uuid}.{parent_dn}" if es_uuid else None
    kibana_host = f"{kibana_uuid}.{parent_dn}" if kibana_uuid else None

    return CloudId(
        cluster_name=cluster_name,
        es_address=(es_host, port) if es_host else None,
        kibana_address=(kibana_host, port) if kibana_host else None,
    )


def to_str(
    value: Union[str, bytes], encoding: str = "utf-8", errors: str = "strict"
) -> str:
    if isinstance(value, bytes):
        return value.decode(encoding, errors)
    return value


def to_bytes(
    value: Union[str, bytes], encoding: str = "utf-8", errors: str = "strict"
) -> bytes:
    if isinstance(value, str):
        return value.encode(encoding, errors)
    return value


# Python 3.7 added '~' to the safe list for urllib.parse.quote()
_QUOTE_ALWAYS_SAFE = frozenset(
    "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_.-~"
)


def percent_encode(
    string: Union[bytes, str],
    safe: str = "/",
    encoding: Optional[str] = None,
    errors: Optional[str] = None,
) -> str:
    """Percent-encodes a string so it can be used in an HTTP request target"""
    # Redefines 'urllib.parse.quote()' to always have the '~' character
    # within the 'ALWAYS_SAFE' list. The character was added in Python 3.7
    safe = "".join(_QUOTE_ALWAYS_SAFE.union(set(safe)))
    return _quote(string, safe, encoding=encoding, errors=errors)  # type: ignore[arg-type]


def basic_auth_to_header(basic_auth: Tuple[str, str]) -> str:
    """Converts a 2-tuple into a 'Basic' HTTP Authorization header"""
    if (
        not isinstance(basic_auth, tuple)
        or len(basic_auth) != 2
        or any(not isinstance(item, (str, bytes)) for item in basic_auth)
    ):
        raise ValueError(
            "'basic_auth' must be a 2-tuple of str/bytes (username, password)"
        )
    return (
        f"Basic {base64.b64encode(b':'.join(to_bytes(x) for x in basic_auth)).decode()}"
    )


def url_to_node_config(
    url: str, use_default_ports_for_scheme: bool = False
) -> NodeConfig:
    """Constructs a :class:`elastic_transport.NodeConfig` instance from a URL.
    If a username/password are specified in the URL they are converted to an
    'Authorization' header. Always fills in a default port for HTTPS.

    :param url: URL to transform into a NodeConfig.
    :param use_default_ports_for_scheme: If 'True' will resolve default ports for HTTP.
    """
    try:
        parsed_url = parse_url(url)
    except LocationParseError:
        raise ValueError(f"Could not parse URL {url!r}") from None

    parsed_port: Optional[int] = parsed_url.port
    if parsed_url.port is None and parsed_url.scheme is not None:
        # Always fill in a default port for HTTPS
        if parsed_url.scheme == "https":
            parsed_port = 443
        # Only fill HTTP default port when asked to explicitly
        elif parsed_url.scheme == "http" and use_default_ports_for_scheme:
            parsed_port = 80

    if any(
        component in (None, "")
        for component in (parsed_url.scheme, parsed_url.host, parsed_port)
    ):
        raise ValueError(
            "URL must include a 'scheme', 'host', and 'port' component (ie 'https://localhost:9200')"
        )
    assert parsed_url.scheme is not None
    assert parsed_url.host is not None
    assert parsed_port is not None

    headers = {}
    if parsed_url.auth:
        # `urllib3.util.url_parse` ensures `parsed_url` is correctly
        # percent-encoded but does not percent-decode userinfo, so we have to
        # do it ourselves to build the basic auth header correctly.
        encoded_username, _, encoded_password = parsed_url.auth.partition(":")
        username = urllib.parse.unquote(encoded_username)
        password = urllib.parse.unquote(encoded_password)

        headers["authorization"] = basic_auth_to_header((username, password))

    host = parsed_url.host.strip("[]")
    if not parsed_url.path or parsed_url.path == "/":
        path_prefix = ""
    else:
        path_prefix = parsed_url.path

    return NodeConfig(
        scheme=parsed_url.scheme,
        host=host,
        port=parsed_port,
        path_prefix=path_prefix,
        headers=headers,
    )


fixup_module_metadata(__name__, globals())
del fixup_module_metadata