File: sharing.py

package info (click to toggle)
python-opt-einsum 3.4.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,772 kB
  • sloc: python: 4,124; makefile: 31; javascript: 15
file content (216 lines) | stat: -rw-r--r-- 6,914 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
"""A module for sharing intermediates between contractions.

Copyright (c) 2018 Uber Technologies
"""

import contextlib
import functools
import numbers
import threading
from collections import Counter, defaultdict
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
from typing import Counter as CounterType

from opt_einsum.parser import alpha_canonicalize, parse_einsum_input
from opt_einsum.typing import ArrayType

CacheKeyType = Union[Tuple[str, str, int, Tuple[int, ...]], Tuple[str, int]]
CacheType = Dict[CacheKeyType, ArrayType]

__all__ = [
    "currently_sharing",
    "get_sharing_cache",
    "shared_intermediates",
    "count_cached_ops",
    "transpose_cache_wrap",
    "einsum_cache_wrap",
    "to_backend_cache_wrap",
]

_SHARING_STACK: Dict[int, List[CacheType]] = defaultdict(list)


def currently_sharing() -> bool:
    """Check if we are currently sharing a cache -- thread specific."""
    return threading.get_ident() in _SHARING_STACK


def get_sharing_cache() -> CacheType:
    """Return the most recent sharing cache -- thread specific."""
    return _SHARING_STACK[threading.get_ident()][-1]


def _add_sharing_cache(cache: CacheType) -> Any:
    _SHARING_STACK[threading.get_ident()].append(cache)


def _remove_sharing_cache() -> None:
    tid = threading.get_ident()
    _SHARING_STACK[tid].pop()
    if not _SHARING_STACK[tid]:
        del _SHARING_STACK[tid]


@contextlib.contextmanager
def shared_intermediates(
    cache: Optional[CacheType] = None,
) -> Generator[CacheType, None, None]:
    """Context in which contract intermediate results are shared.

    Note that intermediate computations will not be garbage collected until
    1. this context exits, and
    2. the yielded cache is garbage collected (if it was captured).

    **Parameters:**

    - **cache** - *(dict)* If specified, a user-stored dict in which intermediate results will be stored. This can be used to interleave sharing contexts.

    **Returns:**

    - **cache** - *(dict)* A dictionary in which sharing results are stored. If ignored,
        sharing results will be garbage collected when this context is
        exited. This dict can be passed to another context to resume
        sharing.
    """
    if cache is None:
        cache = {}
    _add_sharing_cache(cache)
    try:
        yield cache
    finally:
        _remove_sharing_cache()


def count_cached_ops(cache: CacheType) -> CounterType[str]:
    """Returns a counter of the types of each op in the cache.
    This is useful for profiling to increase sharing.
    """
    return Counter(key[0] for key in cache.keys())


def _save_tensors(*tensors: ArrayType) -> None:
    """Save tensors in the cache to prevent their ids from being recycled.
    This is needed to prevent false cache lookups.
    """
    cache = get_sharing_cache()
    for tensor in tensors:
        cache["tensor", id(tensor)] = tensor


def _memoize(key: CacheKeyType, fn: Any, *args: Any, **kwargs: Any) -> ArrayType:
    """Memoize ``fn(*args, **kwargs)`` using the given ``key``.
    Results will be stored in the innermost ``cache`` yielded by
    :func:`shared_intermediates`.
    """
    cache = get_sharing_cache()
    if key in cache:
        return cache[key]
    result = fn(*args, **kwargs)
    cache[key] = result
    return result


def transpose_cache_wrap(transpose: Any) -> Any:
    """Decorates a ``transpose()`` implementation to be memoized inside a
    :func:`shared_intermediates` context.
    """

    @functools.wraps(transpose)
    def cached_transpose(a, axes, backend="numpy"):
        if not currently_sharing():
            return transpose(a, axes, backend=backend)

        # hash by axes
        _save_tensors(a)
        axes = tuple(axes)
        key = "transpose", backend, id(a), axes
        return _memoize(key, transpose, a, axes, backend=backend)

    return cached_transpose


def tensordot_cache_wrap(tensordot: Any) -> Any:
    """Decorates a ``tensordot()`` implementation to be memoized inside a
    :func:`shared_intermediates` context.
    """

    @functools.wraps(tensordot)
    def cached_tensordot(x, y, axes=2, backend="numpy"):
        if not currently_sharing():
            return tensordot(x, y, axes, backend=backend)

        # hash based on the (axes_x,axes_y) form of axes
        _save_tensors(x, y)
        if isinstance(axes, numbers.Number):
            axes = (
                list(range(len(x.shape)))[len(x.shape) - axes :],
                list(range(len(y.shape)))[:axes],
            )
        axes = tuple(axes[0]), tuple(axes[1])
        key = "tensordot", backend, id(x), id(y), axes
        return _memoize(key, tensordot, x, y, axes, backend=backend)

    return cached_tensordot


def einsum_cache_wrap(einsum: Any) -> Any:
    """Decorates an ``einsum()`` implementation to be memoized inside a
    :func:`shared_intermediates` context.
    """

    @functools.wraps(einsum)
    def cached_einsum(*args, **kwargs):
        if not currently_sharing():
            return einsum(*args, **kwargs)

        # hash modulo commutativity by computing a canonical ordering and names
        backend = kwargs.pop("backend", "numpy")
        equation = args[0]
        inputs, output, operands = parse_einsum_input(args)
        inputs = inputs.split(",")

        _save_tensors(*operands)

        # Build canonical key
        canonical = sorted(zip(inputs, map(id, operands)), key=lambda x: x[1])
        canonical_ids = tuple(id_ for _, id_ in canonical)
        canonical_inputs = ",".join(input_ for input_, _ in canonical)
        canonical_equation = alpha_canonicalize(canonical_inputs + "->" + output)

        key = "einsum", backend, canonical_equation, canonical_ids
        return _memoize(key, einsum, equation, *operands, backend=backend)

    return cached_einsum


def to_backend_cache_wrap(to_backend: Any = None, constants: Any = False) -> Any:
    """Decorates an ``to_backend()`` implementation to be memoized inside a
    :func:`shared_intermediates` context (e.g. ``to_cupy``, ``to_torch``).
    """
    # manage the case that decorator is called with args
    if to_backend is None:
        return functools.partial(to_backend_cache_wrap, constants=constants)

    if constants:

        @functools.wraps(to_backend)
        def cached_to_backend(array, constant=False):
            if not currently_sharing():
                return to_backend(array, constant=constant)

            # hash by id
            key = to_backend.__name__, id(array), constant
            return _memoize(key, to_backend, array, constant=constant)

    else:

        @functools.wraps(to_backend)
        def cached_to_backend(array):
            if not currently_sharing():
                return to_backend(array)

            # hash by id
            key = to_backend.__name__, id(array)
            return _memoize(key, to_backend, array)

    return cached_to_backend