1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302
|
# mypy: allow-untyped-defs
import collections
from collections import defaultdict
from typing import Any, Callable, Dict, Optional
import torch
import torch.utils._pytree as pytree
aten = torch.ops.aten
# We would like to split modules into two subgraphs for runtime weight updates to work correctly.
# The use case and more information could be found at:
# https://docs.google.com/document/d/1inZC-8KarJ6gKB7G9egmYLx1V_dKX_apxon0w4zPC0Q/edit?usp=sharing
META_TAG = "MODULE_TYPE"
MODULE_TAG = "_MAIN_MODULE"
CONST_MODULE_TAG = "_CONST_MODULE"
def replace_node_with_constant(gm, node, constant, name=None):
g = gm.graph
if name:
qualname = name
else:
if not hasattr(gm, "_frozen_param_count"):
gm._frozen_param_count = 0
i = gm._frozen_param_count
while True:
qualname = f"_frozen_param{i}"
if not hasattr(gm, qualname):
break
i += 1
gm._frozen_param_count = i + 1
with g.inserting_before(node):
new_input_node = g.create_node("get_attr", qualname, (), {})
node.replace_all_uses_with(new_input_node)
new_input_node.meta.update(node.meta)
g.erase_node(node)
# needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning
gm.register_buffer(qualname, constant)
setattr(gm, qualname, constant)
class ConstantFolder(torch.fx.Interpreter):
def __init__(
self,
gm: torch.fx.GraphModule,
skip_constructors: bool = False,
):
super().__init__(gm)
self.node_replacements: Dict[torch.fx.Node, Any] = {}
self.replaced_uses: Dict[torch.fx.Node, int] = collections.Counter()
self.unknown_value = object()
self.skip_constructors: bool = skip_constructors
# overwrite this to deallocate env values if their only remaining use
# is the output
self.user_to_last_uses = self.node_to_last_non_output_use()
def is_impure(self, node: torch.fx.Node) -> bool:
if (
node.target == torch.ops.prims.convert_element_type.default
and node.args[0].op == "get_attr" # type: ignore[union-attr]
and node.args[0].meta["val"].dtype == torch.int8 # type: ignore[union-attr]
and node.args[1] == torch.bfloat16
):
# For int8_weight -> dq -> bf16_weight
return True
if node.target in [
torch.ops.quantized_decomposed.dequantize_per_channel.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
]:
# For the pattern fp32_weight -> q -> dq
# We only folding fp32_weight -> q
# int8_weight and leave dq in graph to be fused
return True
return False
def node_to_last_non_output_use(self):
last_non_output_use = collections.defaultdict(list)
seen_uses = set()
output_node = next(iter(reversed(self.module.graph.nodes))) # type: ignore[arg-type, union-attr]
for node in reversed(self.module.graph.nodes): # type: ignore[arg-type, union-attr]
if node.target == "output":
continue
def add_use(inp):
if inp in seen_uses:
return
seen_uses.add(inp)
last_non_output_use[node].append(inp)
# In-place is fine since we don't mutate
pytree.tree_map_only_(torch.fx.Node, add_use, (node.args, node.kwargs))
# if this node is only used in output, we want to gc it right away
if len(node.users) == 1 and output_node in node.users:
last_non_output_use[node].append(node)
return last_non_output_use
def run_node(self, node):
if node.target == "output":
# because we remove nodes from env on last non output use,
# re-define them now or we'll get error in interpreter
def set_env(arg):
self.env[arg] = self.unknown_value
# In-place is fine since we don't mutate
pytree.tree_map_only_(torch.fx.Node, set_env, node.args)
return super().run_node(node)
args, kwargs = self.fetch_args_kwargs_from_env(node)
flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs)
# We need to do this weird thing because in cases where flattened_inputs
# contains a ScriptObject, equality checking results in a type error if
# the types are different.
if any(
type(self.unknown_value) == type(input_) and self.unknown_value == input_
for input_ in flattened_inputs
):
return self.unknown_value
# TODO - fix errors with this
if (
node.op == "call_function"
and node.target == aten._efficientzerotensor.default
):
return self.unknown_value
# TODO - constant folding triton kernel returns the inputs -- fix this
if (
node.op == "call_function"
and node.name == "triton_kernel_wrapper_functional_proxy"
):
return self.unknown_value
# skip constructors, since inductor generates optimal code for them already
# and turning into tensor would result in an additional global memory read
# TODO - more complicated strategy
if (
self.skip_constructors
and node.op != "get_attr"
and not any(isinstance(e, torch.Tensor) for e in flattened_inputs)
):
return self.unknown_value
# All mutations should either be removed or on inputs which we did not make constant
if (
isinstance(node.target, torch._ops.OpOverload)
and torch.Tag.nondeterministic_seeded in node.target.tags
):
return self.unknown_value
out = super().run_node(node)
if node.op != "get_attr" and isinstance(out, torch.Tensor):
if out.device.type == "meta":
return out
if not self.insertable_tensor_check(out):
return out
if self.is_impure(node):
return self.unknown_value
self.add_node_replacement(node, out)
flattened_node_inps = pytree.arg_tree_leaves(*node.args, **node.kwargs)
for n in flattened_node_inps:
if not isinstance(n, torch.fx.Node):
continue
self.replaced_uses[n] += 1
for to_delete in self.user_to_last_uses.get(node, []):
if self.replaced_uses[to_delete] == len(to_delete.users):
self.node_replacements.pop(to_delete, None)
return out
def insertable_tensor_check(self, tensor: torch.Tensor) -> bool:
return True
def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None:
self.node_replacements[node] = tensor
def run(self): # type: ignore[override]
env = {}
for n in self.module.graph.find_nodes(op="placeholder"): # type: ignore[operator, union-attr]
env[n] = self.unknown_value
return super().run(initial_env=env)
def constant_fold(
gm: torch.fx.GraphModule,
constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
):
with torch.utils._python_dispatch._disable_current_modes():
cf = ConstantFolder(gm, skip_constructors=True)
cf.run()
for node, constant in cf.node_replacements.items():
if constraint_fn is not None and not constraint_fn(node):
continue
replace_node_with_constant(gm, node, constant)
erased_params = []
# Get all attr users by looking up the graph instead from node.users, because in this case
# _tensor_constant0 and _tensor_constant0_1 are actually refereing to the same tensor.
# opcode name target args kwargs
# ------------- ------------------- ---------------- --------------------------- --------
# placeholder arg0_1 arg0 () {}
# get_attr _tensor_constant0 state () {}
# call_function add aten.add.Tensor (arg0_1, _tensor_constant0) {}
# get_attr _tensor_constant0_1 state () {}
# call_function add_ aten.add_.Tensor (_tensor_constant0_1, 1) {}
# output output output ([add],) {}
get_attr_node_users = defaultdict(list)
for node in gm.graph.nodes:
if node.op == "get_attr":
get_attr_node_users[node.target].extend(node.users.keys())
for node in gm.graph.find_nodes(op="get_attr"):
if node.op == "get_attr" and len(get_attr_node_users[node.target]) == 0:
if hasattr(gm, node.target):
delattr(gm, node.target)
erased_params.append(node)
for node in erased_params:
gm.graph.erase_node(node)
gm.graph.eliminate_dead_code()
gm.graph.lint()
gm.recompile()
def constant_graph_tag(gm: torch.fx.GraphModule) -> None:
with torch.utils._python_dispatch._disable_current_modes():
cf = ConstantFolder(gm, skip_constructors=True)
cf.run()
for node in gm.graph.nodes:
if (
node.op == "get_attr"
or node in cf.node_replacements
or node in cf.replaced_uses
):
node.meta[META_TAG] = CONST_MODULE_TAG
else:
node.meta[META_TAG] = MODULE_TAG
def run_and_get_constant_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""
Construct a GraphModule which corresponds to the part which could be
constant folded in provided gm.
"""
constant_graph_tag(gm)
# We rewrite the tags, if it's a constant being directly consumed, without
# any folding opportunity, we keep it in main gm.
for node in gm.graph.find_nodes(op="get_attr"):
used_to_fold = False
for u in node.users:
if u.meta[META_TAG] == CONST_MODULE_TAG:
used_to_fold = True
break
if not used_to_fold:
node.meta[META_TAG] = MODULE_TAG
new_graph = torch.fx.Graph()
node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {}
output_nodes = []
for node in gm.graph.nodes:
if node.meta[META_TAG] == MODULE_TAG:
continue
new_node = new_graph.node_copy(node, lambda x: node_remapping[x])
node_remapping[node] = new_node
for user in node.users:
if user.meta[META_TAG] == MODULE_TAG:
output_nodes.append(new_node)
break
new_graph.output(tuple(output_nodes))
new_graph.lint()
new_gm = torch.fx.GraphModule(gm, new_graph)
return new_gm
|