File: __init__.py

package info (click to toggle)
python-thinc 8.1.7-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 5,804 kB
  • sloc: python: 15,818; javascript: 1,554; ansic: 342; makefile: 20; sh: 13
file content (180 lines) | stat: -rw-r--r-- 5,404 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import contextlib
from typing import Type, Dict, Any, Callable, Optional, cast

from contextvars import ContextVar
import threading

from .ops import Ops
from .cupy_ops import CupyOps
from .numpy_ops import NumpyOps
from .mps_ops import MPSOps
from ._cupy_allocators import cupy_tensorflow_allocator, cupy_pytorch_allocator
from ._param_server import ParamServer
from ..util import assert_tensorflow_installed, assert_pytorch_installed
from ..util import get_torch_default_device, is_cupy_array, require_cpu
from .. import registry
from ..compat import cupy, has_cupy


context_ops: ContextVar[Optional[Ops]] = ContextVar("context_ops", default=None)
context_pools: ContextVar[dict] = ContextVar("context_pools", default={})

# Internal use of thread-local storage only for detecting cases where a Jupyter
# notebook might not have preserved contextvars across cells.
_GLOBAL_STATE = {"ops": None}


def set_gpu_allocator(allocator: str) -> None:  # pragma: no cover
    """Route GPU memory allocation via PyTorch or tensorflow.
    Raise an error if the given argument does not match either of the two.
    """
    if allocator == "pytorch":
        use_pytorch_for_gpu_memory()
    elif allocator == "tensorflow":
        use_tensorflow_for_gpu_memory()
    else:
        raise ValueError(
            f"Invalid 'gpu_allocator' argument: '{allocator}'. Available allocators are: 'pytorch', 'tensorflow'"
        )


def use_pytorch_for_gpu_memory() -> None:  # pragma: no cover
    """Route GPU memory allocation via PyTorch.

    This is recommended for using PyTorch and cupy together, as otherwise
    OOM errors can occur when there's available memory sitting in the other
    library's pool.

    We'd like to support routing Tensorflow memory allocation via PyTorch as well
    (or vice versa), but do not currently have an implementation for it.
    """
    assert_pytorch_installed()

    if get_torch_default_device().type != "cuda":
        return

    pools = context_pools.get()
    if "pytorch" not in pools:
        pools["pytorch"] = cupy.cuda.MemoryPool(allocator=cupy_pytorch_allocator)
    cupy.cuda.set_allocator(pools["pytorch"].malloc)


def use_tensorflow_for_gpu_memory() -> None:  # pragma: no cover
    """Route GPU memory allocation via TensorFlow.

    This is recommended for using TensorFlow and cupy together, as otherwise
    OOM errors can occur when there's available memory sitting in the other
    library's pool.

    We'd like to support routing PyTorch memory allocation via Tensorflow as
    well (or vice versa), but do not currently have an implementation for it.
    """
    assert_tensorflow_installed()
    pools = context_pools.get()
    if "tensorflow" not in pools:
        pools["tensorflow"] = cupy.cuda.MemoryPool(allocator=cupy_tensorflow_allocator)
    cupy.cuda.set_allocator(pools["tensorflow"].malloc)


def _import_extra_cpu_backends():
    try:
        from thinc_apple_ops import AppleOps
    except ImportError:
        pass
    try:
        from thinc_bigendian_ops import BigEndianOps
    except ImportError:
        pass


def get_ops(name: str, **kwargs) -> Ops:
    """Get a backend object.

    The special name "cpu" returns the best available CPU backend."""

    ops_by_name = {ops_cls.name: ops_cls for ops_cls in registry.ops.get_all().values()}  # type: ignore

    cls: Optional[Callable[..., Ops]] = None
    if name == "cpu":
        _import_extra_cpu_backends()
        cls = ops_by_name.get("numpy")
        cls = ops_by_name.get("apple", cls)
        cls = ops_by_name.get("bigendian", cls)
    else:
        cls = ops_by_name.get(name)

    if cls is None:
        raise ValueError(f"Invalid backend: {name}")

    return cls(**kwargs)


def get_array_ops(arr):
    """Return CupyOps for a cupy array, NumpyOps otherwise."""
    if is_cupy_array(arr):
        return CupyOps()
    else:
        return NumpyOps()


@contextlib.contextmanager
def use_ops(name: str, **kwargs):
    """Change the backend to execute on for the scope of the block."""
    current_ops = get_current_ops()
    set_current_ops(get_ops(name, **kwargs))
    try:
        yield
    finally:
        set_current_ops(current_ops)


def get_current_ops() -> Ops:
    """Get the current backend object."""
    if context_ops.get() is None:
        require_cpu()
    return cast(Ops, context_ops.get())


def set_current_ops(ops: Ops) -> None:
    """Change the current backend object."""
    context_ops.set(ops)
    _get_thread_state().ops = ops


def contextvars_eq_thread_ops() -> bool:
    current_ops = context_ops.get()
    thread_ops = _get_thread_state().ops
    if type(current_ops) == type(thread_ops):
        return True
    return False


def _get_thread_state():
    """Get a thread-specific state variable that inherits from a global
    state when it's created."""
    thread: threading.Thread = threading.current_thread()
    if not hasattr(thread, "__local"):
        thread.__local = _create_thread_local(_GLOBAL_STATE)
    return thread.__local


def _create_thread_local(
    attrs: Dict[str, Any], local_class: Type[threading.local] = threading.local
):
    obj = local_class()
    for name, value in attrs.items():
        setattr(obj, name, value)
    return obj


__all__ = [
    "set_current_ops",
    "get_current_ops",
    "use_ops",
    "ParamServer",
    "Ops",
    "CupyOps",
    "MPSOps",
    "NumpyOps",
    "has_cupy",
]