1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258
|
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, List, Optional, Type, TYPE_CHECKING, Union
import torch
from torch import SymInt
from torch.fx.experimental.sym_node import SymNode
from torch.types import py_sym_types, PySymType
from torch.utils._backport_slots import dataclass_slots
if TYPE_CHECKING:
import sympy
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from .fake_tensor import _DispatchCacheKey, _MetadataIntLike
@dataclass_slots
@dataclass(frozen=True)
class _DeconstructedSymNode:
"""
Represents a SymNode without the associated ShapeEnv
"""
# n.b. keep the same protocol as SymNode
_expr: sympy.Expr
pytype: type
_hint: Optional[Union[int, float, bool]]
constant: Optional[Union[int, float, bool]]
fx_node: torch.fx.Node
@staticmethod
def from_node(node: SymNode) -> _DeconstructedSymNode:
return _DeconstructedSymNode(
node._expr, node.pytype, node._hint, node.constant, node.fx_node
)
def extract(self, shape_env: ShapeEnv) -> SymNode:
return SymNode(
self._expr, shape_env, self.pytype, self._hint, self.constant, self.fx_node
)
def __str__(self) -> str:
return str(self._expr)
def __repr__(self) -> str:
return f"_DeconstructedSymNode{{{self._expr!r}, {self.pytype!r}, {self._hint!r}, {self.constant!r}, {self.fx_node!r}}}"
def __eq__(self, other: object) -> bool:
raise NotImplementedError
def __hash__(self) -> int:
raise NotImplementedError
# _value_eq to match SymNode
def _value_eq(self, other: object) -> bool:
if isinstance(other, (SymNode, _DeconstructedSymNode)):
return (
self._expr == other._expr
and self.pytype == other.pytype
and self._hint == other._hint
and self.constant == other.constant
and self.fx_node == other.fx_node
)
else:
return False
# _value_hash to match SymNode
def _value_hash(self) -> int:
return hash((self._expr, self.pytype, self._hint, self.constant, self.fx_node))
@dataclass_slots
@dataclass(frozen=True)
class _DeconstructedSymType:
"""
Represents a SymInt, SymFloat, SymBool without the associated ShapeEnv
"""
ty: Type[PySymType]
node: _DeconstructedSymNode
@staticmethod
def from_sym_type(value: PySymType) -> _DeconstructedSymType:
return _DeconstructedSymType(type(value), value.node)
def extract(self, shape_env: ShapeEnv) -> PySymType:
return self.ty(self.node.extract(shape_env))
def __str__(self) -> str:
return f"{self.ty}({self.node})"
def __repr__(self) -> str:
return f"_DeconstructedSymType({self.ty}, {self.node!r})"
def __eq__(self, other: object) -> bool:
return NotImplemented
def __hash__(self) -> int:
return NotImplemented
@dataclass_slots
@dataclass(frozen=True)
class _InputBackref:
value: int
@dataclass_slots
@dataclass
class _PySymInputStub:
"""
Represents a SymInt in the cached key. Needed because SymInt doesn't
support __eq__ or __hash__ directly.
"""
# value can be:
# PySymType: This is the 'normal' SymInt value, wrapped so we can use
# hash/eq as value hash/eq (normally SymInt does object
# hash/eq).
# _DeconstructedSymType: This is used when storing the _PySymInputStub in
# the cache to avoid cyclic ShapeEnv references.
# _InputBackref: This is a back-reference to a previous _PySymInputStub in
# the key.
value: Union[PySymType, _DeconstructedSymType, _InputBackref]
def __init__(
self, value: Union[PySymType, _DeconstructedSymType, _InputBackref]
) -> None:
# For inputs (values in the `key`) we need to keep the PySymType intact
# - this way if we need to reuse it as an output we can properly copy
# the original value.
self.value = value
def strip_shape_env(self) -> None:
if isinstance(self.value, py_sym_types):
self.value = _DeconstructedSymType.from_sym_type(self.value)
def extract(self, shape_env: ShapeEnv) -> PySymType:
if isinstance(self.value, _DeconstructedSymType):
return self.value.extract(shape_env)
else:
# We should never see an _InputBackref here - anyone extracting a
# value should be pulling from the original entry (the one this
# backref points at).
assert not isinstance(self.value, _InputBackref)
return self.value
def __str__(self) -> str:
return str(self.value)
def __repr__(self) -> str:
return f"_PySymInputStub({self.value!r})"
def __eq__(self, other: object) -> bool:
if not isinstance(other, _PySymInputStub):
return False
elif isinstance(self.value, _InputBackref) or isinstance(
other.value, _InputBackref
):
return self.value == other.value
else:
return self.value.node._value_eq(other.value.node)
def __hash__(self) -> int:
if isinstance(self.value, _InputBackref):
return hash(self.value)
else:
return self.value.node._value_hash()
@dataclass_slots
@dataclass
class _SymIntOutputStub:
"""
Represents a SymInt in the cached output.
"""
# This is either an `int` which represents the index in the key to copy the
# SymNode from or it's the deconstructed SymNode itself.
value: Union[int, _DeconstructedSymNode]
def __init__(self, value: SymInt, key_path: Optional[int]) -> None:
if key_path is None:
self.value = _DeconstructedSymNode.from_node(value.node)
else:
self.value = key_path
def extract(self, key: _DispatchCacheKey, shape_env: ShapeEnv) -> SymInt:
if isinstance(self.value, _DeconstructedSymNode):
return SymInt(self.value.extract(shape_env))
else:
src = key.key[self.value]
assert isinstance(src, _PySymInputStub) and isinstance(src.value, SymInt)
return src.value
def __repr__(self) -> str:
return f"_SymIntOutputStub({self.value!r})"
def __eq__(self, other: object) -> bool:
raise NotImplementedError
def __hash__(self) -> int:
raise NotImplementedError
@dataclass_slots
@dataclass
class _CacheKeyState:
"""
State used while building our cache key.
"""
# We track the SymNodes so when we get the output we can see if it exactly
# matches one of the inputs so we can uncache it properly.
sym_node_lookup: Dict[int, int] # id(SymNode) -> index
# There are cases where we're asked to perform an op when we have no
# ShapeEnv on the FakeTensorMode - but for SymNodes we MUST have a
# ShapeEnv. So as we scan if we see a SymNode (with a ShapeEnv) we record it
# here.
shape_env: Optional[ShapeEnv]
def __init__(self, shape_env: Optional[ShapeEnv] = None) -> None:
self.sym_node_lookup = {}
self.shape_env = shape_env
def cache_on_shape_env(self) -> bool:
"""
Returns true if the CacheKey needs to be cached on the ShapeEnv
rather than the global cache.
If our inputs contain a SymNode then we can't cache this operation on
the global cache because the cached output will implicitly depend on
guard values which might not be true on some other ShapeEnv. So unless
we're also going to cache the guards we need to cache this operation on
the ShapeEnv instead of globally.
"""
return bool(self.sym_node_lookup)
def convert_sym_int(self, result: List[object], arg: SymInt) -> None:
node_id = id(arg.node)
if node_id in self.sym_node_lookup:
result.append(_InputBackref(self.sym_node_lookup[node_id]))
else:
self.sym_node_lookup[node_id] = len(result)
if self.shape_env is None:
self.shape_env = arg.node.shape_env
result.append(_PySymInputStub(arg))
def convert_output(self, arg: _MetadataIntLike) -> _MetadataIntLike:
if isinstance(arg, SymInt):
return _SymIntOutputStub(arg, self.sym_node_lookup.get(id(arg.node), None))
else:
return arg
|