1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252
|
# mypy: allow-untyped-defs
import operator
from collections import defaultdict
from typing import Any, Callable, DefaultDict, Dict, Optional, Tuple, Type
import sympy
import torch
import torch.fx
from torch._dispatch.python import enable_python_dispatcher
from torch.fx.experimental.symbolic_shapes import (
compute_unbacked_bindings,
rebind_unbacked,
statically_known_true,
sym_eq,
)
from torch.utils import _pytree as pytree
from torch.utils._pytree import tree_map
from .virtualized import V
# Check the pattern: (nn.module, F.function/torch.Tensor.method) matched.
# Works for length 2 patterns with 1 module and 1 function/method.
def matches_module_function_pattern(
pattern: Tuple[Type[torch.nn.modules.Module], Callable[..., Any]],
node: torch.fx.node.Node,
modules: Dict[str, torch.nn.modules.Module],
) -> bool:
if len(node.args) == 0:
return False
if not isinstance(node.args[0], torch.fx.Node) or not isinstance(
node, torch.fx.Node
):
return False
# the first node is call_module
if node.args[0].op != "call_module":
return False
if not isinstance(node.args[0].target, str):
return False
if node.args[0].target not in modules:
return False
if type(modules[node.args[0].target]) is not pattern[0]:
return False
# the second node is call_function or call_method
if node.op != "call_function" and node.op != "call_method":
return False
if node.target != pattern[1]:
return False
# make sure node.args[0] output is only used by current node.
if len(node.args[0].users) > 1:
return False
return True
class FakeTensorUpdater:
"""
The main idea here is that it's difficult to maintain accurate fake
tensors (our primary form of metadata) for each node in our graph as we
transform it.
The most reliable way to obtain this information is by rerunning
faketensor propagation. However, in general, faketensor propagation is
fairly expensive. So, instead we'd like to only rerun faketensor
propagation on nodes that have changed.
In order to detect which nodes have changed, we first hash its node,
target, and argument lists (which are immutable in FX).
Then, whenever we call incremental_update, we check which FX nodes have a
new hash, and recompute the faketensor metadata for that node. Then, we
continue to recursively compute the faketensors for all users until the
fake tensors stop changing.
"""
def __init__(self, graph: torch.fx.Graph) -> None:
self.processed_hashes = set()
self.graph = graph
for node in self.graph.nodes:
self.processed_hashes.add(self.hash_node(node))
def hash_node(self, node: torch.fx.Node):
# todo(chilli): Not a great hash function
return (node, node.target, id(node.args), id(node.kwargs))
def incremental_update(self):
processed = set()
existing_storages: DefaultDict[Optional[int], int] = defaultdict(int)
for node in self.graph.nodes:
existing_storages[get_node_storage(node)] += 1
def is_intlist_same(new, old):
return statically_known_true(sym_eq(new, old))
def is_fake_tensor_same(new, old):
if type(new) != type(old):
return False
if isinstance(new, (list, tuple)):
if len(new) != len(old):
return False
return all(
is_fake_tensor_same(new_i, old_i) for new_i, old_i in zip(new, old)
)
if new is None:
return old is None
if not isinstance(new, torch.Tensor):
assert isinstance(
new, (torch.SymInt, torch.SymBool, torch.SymFloat)
), f"Unknown type {type(new)} in {self.graph}"
return (
new.node.shape_env._maybe_evaluate_static(
sympy.Eq(new.node.expr, old.node.expr)
)
== sympy.true
)
if not is_intlist_same(new.shape, old.shape) or new.layout != old.layout:
return False
if new.layout == torch.strided and (
not is_intlist_same(new.stride(), old.stride())
or not statically_known_true(
new.storage_offset() == old.storage_offset()
)
):
return False
if new.device != old.device:
return False
if get_storage(new) == get_storage(old):
return True
# This is the case where it returns a completely fresh storage that's used nowhere else.
if (
existing_storages[get_storage(old)] == 1
and get_storage(new) not in existing_storages
):
return True
return False
def should_process_node(node):
# node.target for nodes returning true from this function
# are called under fake mode and does not work for inductor
# lowerings. We check if the node.target is an aten operator
# or operator.getitem which is used when returning multiple
# tensors from an op.
return node.op == "call_function" and (
isinstance(node.target, torch._ops.OpOverload)
or node.target == operator.getitem
)
to_process = set()
for node in self.graph.nodes:
if (
self.hash_node(node) in self.processed_hashes
and id(node) not in to_process
):
continue
if not should_process_node(node):
continue
is_valid, args, kwargs = get_fake_args_kwargs(node)
if not is_valid:
continue
with V.fake_mode, enable_python_dispatcher():
new_fake_tensor = node.target(*args, **kwargs)
if "val" in node.meta and is_fake_tensor_same(
new_fake_tensor, node.meta["val"]
):
continue
rebind_unbacked(V.fake_mode.shape_env, node, new_fake_tensor)
node.meta["val"] = new_fake_tensor
if (shape_env := V.fake_mode.shape_env) and (
symbol_to_path := compute_unbacked_bindings(shape_env, new_fake_tensor)
):
# Refresh the bindings to the new symbols
node.meta["unbacked_bindings"] = symbol_to_path
existing_storages[get_node_storage(node)] += 1
to_process.update([id(user) for user in node.users])
self.processed_hashes.add(self.hash_node(node))
def get_storage(t: torch.Tensor) -> int:
return t.untyped_storage()._cdata
def get_node_storage(node: torch.fx.Node) -> Optional[int]:
if "val" not in node.meta:
return None
if not isinstance(node.meta["val"], torch.Tensor):
return None
if not torch._C._has_storage(node.meta["val"]):
return None
return get_storage(node.meta["val"])
def get_fake(x):
if isinstance(x, torch.fx.Node):
if "val" not in x.meta:
return x
return x.meta["val"]
return x
def get_fake_args_kwargs(x: torch.fx.Node) -> Tuple[bool, Tuple[Any], Dict[str, Any]]:
"""
First value returns a boolean if any of the input nodes don't have a faketensor.
"""
args, kwargs = tree_map(get_fake, (x.args, x.kwargs))
if any(
isinstance(a, torch.fx.Node) for a in pytree.arg_tree_leaves(*args, **kwargs)
):
return False, args, kwargs
return True, args, kwargs
def is_node_realized(node: torch.fx.Node) -> bool:
"""Returns true if a node is always realized when lowered to inductor IR.
NOTE: This may return some false negatives. e.g. it doesn't
handle buffers realized heuristically during lowering, or
buffers realized indirectly through view ops.
"""
from torch._inductor.lowering import fallbacks, needs_realized_inputs
def is_buffer(node: torch.fx.Node) -> bool:
if node.op == "call_function" and node.target is operator.getitem:
# For nodes with multiple outputs, we get the fx graph:
# foo = torch.ops.aten.foo(...)
# getitem = foo[0]
# getitem_1 = foo[1]
# where we need to check if foo is a fallback kernel
return is_buffer(node.args[0]) # type: ignore[arg-type]
return node.op in ("placeholder", "output") or node.target in fallbacks
if is_buffer(node):
return True
def realizes_inputs(node: torch.fx.Node) -> bool:
return node.op == "output" or node.target in needs_realized_inputs
if any(realizes_inputs(user) for user in node.users):
return True
# Otherwise, assume node isn't realized
return False
|