File: _fake_tensor_utils.py

package info (click to toggle)
pytorch 2.6.0%2Bdfsg-8
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 161,672 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (258 lines) | stat: -rw-r--r-- 8,512 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, List, Optional, Type, TYPE_CHECKING, Union

import torch
from torch import SymInt
from torch.fx.experimental.sym_node import SymNode
from torch.types import py_sym_types, PySymType
from torch.utils._backport_slots import dataclass_slots


if TYPE_CHECKING:
    import sympy

    from torch.fx.experimental.symbolic_shapes import ShapeEnv

    from .fake_tensor import _DispatchCacheKey, _MetadataIntLike


@dataclass_slots
@dataclass(frozen=True)
class _DeconstructedSymNode:
    """
    Represents a SymNode without the associated ShapeEnv
    """

    # n.b. keep the same protocol as SymNode
    _expr: sympy.Expr
    pytype: type
    _hint: Optional[Union[int, float, bool]]
    constant: Optional[Union[int, float, bool]]
    fx_node: torch.fx.Node

    @staticmethod
    def from_node(node: SymNode) -> _DeconstructedSymNode:
        return _DeconstructedSymNode(
            node._expr, node.pytype, node._hint, node.constant, node.fx_node
        )

    def extract(self, shape_env: ShapeEnv) -> SymNode:
        return SymNode(
            self._expr, shape_env, self.pytype, self._hint, self.constant, self.fx_node
        )

    def __str__(self) -> str:
        return str(self._expr)

    def __repr__(self) -> str:
        return f"_DeconstructedSymNode{{{self._expr!r}, {self.pytype!r}, {self._hint!r}, {self.constant!r}, {self.fx_node!r}}}"

    def __eq__(self, other: object) -> bool:
        raise NotImplementedError

    def __hash__(self) -> int:
        raise NotImplementedError

    # _value_eq to match SymNode
    def _value_eq(self, other: object) -> bool:
        if isinstance(other, (SymNode, _DeconstructedSymNode)):
            return (
                self._expr == other._expr
                and self.pytype == other.pytype
                and self._hint == other._hint
                and self.constant == other.constant
                and self.fx_node == other.fx_node
            )
        else:
            return False

    # _value_hash to match SymNode
    def _value_hash(self) -> int:
        return hash((self._expr, self.pytype, self._hint, self.constant, self.fx_node))


@dataclass_slots
@dataclass(frozen=True)
class _DeconstructedSymType:
    """
    Represents a SymInt, SymFloat, SymBool without the associated ShapeEnv
    """

    ty: Type[PySymType]
    node: _DeconstructedSymNode

    @staticmethod
    def from_sym_type(value: PySymType) -> _DeconstructedSymType:
        return _DeconstructedSymType(type(value), value.node)

    def extract(self, shape_env: ShapeEnv) -> PySymType:
        return self.ty(self.node.extract(shape_env))

    def __str__(self) -> str:
        return f"{self.ty}({self.node})"

    def __repr__(self) -> str:
        return f"_DeconstructedSymType({self.ty}, {self.node!r})"

    def __eq__(self, other: object) -> bool:
        return NotImplemented

    def __hash__(self) -> int:
        return NotImplemented


@dataclass_slots
@dataclass(frozen=True)
class _InputBackref:
    value: int


@dataclass_slots
@dataclass
class _PySymInputStub:
    """
    Represents a SymInt in the cached key. Needed because SymInt doesn't
    support __eq__ or __hash__ directly.
    """

    # value can be:
    #   PySymType: This is the 'normal' SymInt value, wrapped so we can use
    #              hash/eq as value hash/eq (normally SymInt does object
    #              hash/eq).
    #   _DeconstructedSymType: This is used when storing the _PySymInputStub in
    #                          the cache to avoid cyclic ShapeEnv references.
    #   _InputBackref: This is a back-reference to a previous _PySymInputStub in
    #                  the key.
    value: Union[PySymType, _DeconstructedSymType, _InputBackref]

    def __init__(
        self, value: Union[PySymType, _DeconstructedSymType, _InputBackref]
    ) -> None:
        # For inputs (values in the `key`) we need to keep the PySymType intact
        # - this way if we need to reuse it as an output we can properly copy
        # the original value.
        self.value = value

    def strip_shape_env(self) -> None:
        if isinstance(self.value, py_sym_types):
            self.value = _DeconstructedSymType.from_sym_type(self.value)

    def extract(self, shape_env: ShapeEnv) -> PySymType:
        if isinstance(self.value, _DeconstructedSymType):
            return self.value.extract(shape_env)
        else:
            # We should never see an _InputBackref here - anyone extracting a
            # value should be pulling from the original entry (the one this
            # backref points at).
            assert not isinstance(self.value, _InputBackref)
            return self.value

    def __str__(self) -> str:
        return str(self.value)

    def __repr__(self) -> str:
        return f"_PySymInputStub({self.value!r})"

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, _PySymInputStub):
            return False
        elif isinstance(self.value, _InputBackref) or isinstance(
            other.value, _InputBackref
        ):
            return self.value == other.value
        else:
            return self.value.node._value_eq(other.value.node)

    def __hash__(self) -> int:
        if isinstance(self.value, _InputBackref):
            return hash(self.value)
        else:
            return self.value.node._value_hash()


@dataclass_slots
@dataclass
class _SymIntOutputStub:
    """
    Represents a SymInt in the cached output.
    """

    # This is either an `int` which represents the index in the key to copy the
    # SymNode from or it's the deconstructed SymNode itself.
    value: Union[int, _DeconstructedSymNode]

    def __init__(self, value: SymInt, key_path: Optional[int]) -> None:
        if key_path is None:
            self.value = _DeconstructedSymNode.from_node(value.node)
        else:
            self.value = key_path

    def extract(self, key: _DispatchCacheKey, shape_env: ShapeEnv) -> SymInt:
        if isinstance(self.value, _DeconstructedSymNode):
            return SymInt(self.value.extract(shape_env))
        else:
            src = key.key[self.value]
            assert isinstance(src, _PySymInputStub) and isinstance(src.value, SymInt)
            return src.value

    def __repr__(self) -> str:
        return f"_SymIntOutputStub({self.value!r})"

    def __eq__(self, other: object) -> bool:
        raise NotImplementedError

    def __hash__(self) -> int:
        raise NotImplementedError


@dataclass_slots
@dataclass
class _CacheKeyState:
    """
    State used while building our cache key.
    """

    # We track the SymNodes so when we get the output we can see if it exactly
    # matches one of the inputs so we can uncache it properly.
    sym_node_lookup: Dict[int, int]  # id(SymNode) -> index

    # There are cases where we're asked to perform an op when we have no
    # ShapeEnv on the FakeTensorMode - but for SymNodes we MUST have a
    # ShapeEnv. So as we scan if we see a SymNode (with a ShapeEnv) we record it
    # here.
    shape_env: Optional[ShapeEnv]

    def __init__(self, shape_env: Optional[ShapeEnv] = None) -> None:
        self.sym_node_lookup = {}
        self.shape_env = shape_env

    def cache_on_shape_env(self) -> bool:
        """
        Returns true if the CacheKey needs to be cached on the ShapeEnv
        rather than the global cache.

        If our inputs contain a SymNode then we can't cache this operation on
        the global cache because the cached output will implicitly depend on
        guard values which might not be true on some other ShapeEnv. So unless
        we're also going to cache the guards we need to cache this operation on
        the ShapeEnv instead of globally.
        """
        return bool(self.sym_node_lookup)

    def convert_sym_int(self, result: List[object], arg: SymInt) -> None:
        node_id = id(arg.node)
        if node_id in self.sym_node_lookup:
            result.append(_InputBackref(self.sym_node_lookup[node_id]))
        else:
            self.sym_node_lookup[node_id] = len(result)
            if self.shape_env is None:
                self.shape_env = arg.node.shape_env
            result.append(_PySymInputStub(arg))

    def convert_output(self, arg: _MetadataIntLike) -> _MetadataIntLike:
        if isinstance(arg, SymInt):
            return _SymIntOutputStub(arg, self.sym_node_lookup.get(id(arg.node), None))
        else:
            return arg