File: compat.py

package info (click to toggle)
python-elasticsearch 9.2.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 23,568 kB
  • sloc: python: 108,748; makefile: 151; javascript: 97
file content (113 lines) | stat: -rw-r--r-- 3,615 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#  Licensed to Elasticsearch B.V. under one or more contributor
#  license agreements. See the NOTICE file distributed with
#  this work for additional information regarding copyright
#  ownership. Elasticsearch B.V. licenses this file to you under
#  the Apache License, Version 2.0 (the "License"); you may
#  not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
# 	http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing,
#  software distributed under the License is distributed on an
#  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#  KIND, either express or implied.  See the License for the
#  specific language governing permissions and limitations
#  under the License.

import inspect
import os
import sys
from contextlib import contextmanager
from pathlib import Path
from threading import Thread
from typing import Any, Callable, Iterator, Tuple, Type, Union

string_types: Tuple[Type[str], Type[bytes]] = (str, bytes)

DISABLE_WARN_STACKLEVEL_ENV_VAR = "DISABLE_WARN_STACKLEVEL"


def to_str(x: Union[str, bytes], encoding: str = "ascii") -> str:
    if not isinstance(x, str):
        return x.decode(encoding)
    return x


def to_bytes(x: Union[str, bytes], encoding: str = "ascii") -> bytes:
    if not isinstance(x, bytes):
        return x.encode(encoding)
    return x


def warn_stacklevel() -> int:
    """Dynamically determine warning stacklevel for warnings based on the call stack"""
    if os.environ.get(DISABLE_WARN_STACKLEVEL_ENV_VAR) in ["1", "true", "True"]:
        return 0
    try:
        # Grab the root module from the current module '__name__'
        module_name = __name__.partition(".")[0]
        module_path = Path(sys.modules[module_name].__file__)  # type: ignore[arg-type]

        # If the module is a folder we're looking at
        # subdirectories, otherwise we're looking for
        # an exact match.
        module_is_folder = module_path.name == "__init__.py"
        if module_is_folder:
            module_path = module_path.parent

        # Look through frames until we find a file that
        # isn't a part of our module, then return that stacklevel.
        for level, frame in enumerate(inspect.stack()):
            # Garbage collecting frames
            frame_filename = Path(frame.filename)
            del frame

            if (
                # If the module is a folder we look at subdirectory
                module_is_folder
                and module_path not in frame_filename.parents
            ) or (
                # Otherwise we're looking for an exact match.
                not module_is_folder
                and module_path != frame_filename
            ):
                return level
    except KeyError:
        pass
    return 0


@contextmanager
def safe_thread(
    target: Callable[..., Any], *args: Any, **kwargs: Any
) -> Iterator[Thread]:
    """Run a thread within a context manager block.

    The thread is automatically joined when the block ends. If the thread raised
    an exception, it is raised in the caller's context.
    """
    captured_exception = None

    def run() -> None:
        try:
            target(*args, **kwargs)
        except BaseException as exc:
            nonlocal captured_exception
            captured_exception = exc

    thread = Thread(target=run)
    thread.start()
    yield thread
    thread.join()
    if captured_exception:
        raise captured_exception


__all__ = [
    "string_types",
    "to_str",
    "to_bytes",
    "warn_stacklevel",
    "safe_thread",
]