1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
|
# mypy: allow-untyped-defs
import operator
from typing import List
import torch
from torch._higher_order_ops.effects import _get_schema, with_effects
from .exported_program import ExportedProgram
from .graph_signature import (
CustomObjArgument,
InputKind,
InputSpec,
OutputKind,
OutputSpec,
TokenArgument,
)
def _remove_effect_tokens_from_graph_helper(
ep, num_tokens, input_token_names, output_token_names
):
inputs_to_lifted_custom_objs = ep.graph_signature.inputs_to_lifted_custom_objs
output_node = None
with_effect_nodes: List[torch.fx.Node] = []
# Output node need to check its args agianst output_token_names (collected from output_spec)
# Therefore, we only need to find the top-levele output node
output_node = next(reversed(ep.graph_module.graph.find_nodes(op="output")))
for module in ep.graph_module.modules():
if not isinstance(module, torch.fx.GraphModule):
continue
for node in module.graph.nodes:
if not (node.op == "call_function" and node.target is with_effects):
continue
with_effect_nodes.append(node)
# Remove tokens from outputs
assert output_node is not None
output_args = output_node.args[0]
assert len(output_args) >= num_tokens
out_token_nodes = output_args[:num_tokens]
output_node.args = (tuple(output_args[num_tokens:]),)
for out_token in out_token_nodes:
assert out_token.name in output_token_names
out_token.users.clear()
ep.graph.erase_node(out_token)
# Replace with_effects(token, func, args) with just func(args)
for node in reversed(with_effect_nodes):
func = node.args[1]
assert isinstance(func, (torch._ops.OpOverload, torch._ops.HigherOrderOperator))
if func == torch.ops.higher_order.call_torchbind:
custom_obj_meta = node.args[2].meta["val"]
assert isinstance(custom_obj_meta, CustomObjArgument)
if custom_obj_meta.fake_val:
custom_obj = custom_obj_meta.fake_val
elif node.args[2].name in inputs_to_lifted_custom_objs:
custom_obj = ep.constants[
inputs_to_lifted_custom_objs[node.args[2].name]
]
else:
raise RuntimeError(f"Unable to find custom obj for node {node}")
schema = _get_schema(func, (custom_obj,) + node.args[3:])
else:
schema = _get_schema(func, node.args[2:])
with ep.graph.inserting_before(node):
new_node = ep.graph.call_function(func, node.args[2:], node.kwargs)
for k, v in node.meta.items():
new_node.meta[k] = v
node.replace_all_uses_with(new_node)
# Update user getitem nodes
for user in list(new_node.users.keys()):
assert user.target == operator.getitem
# getitem(with_effects, 0) == token
if user.args[1] == 0:
ep.graph.erase_node(user)
if len(schema.returns) == 1:
# If the function has 1 return then it will just directly return the
# result -- we don't need a getitem. So we can replace all the
# getitem(with_effects, 1) with just the note itself.
for user in list(new_node.users.keys()):
assert user.args[1] == 1
user.replace_all_uses_with(new_node)
new_node.meta["val"] = node.meta["val"][1]
elif len(schema.returns) > 1:
# If the function has more than 1 return then since we got rid of
# the 1st return value (the token), we need to bump all the other
# getitem calls by 1 down
for user in list(new_node.users.keys()):
assert user.args[1] >= 1
user.args = (user.args[0], user.args[1] - 1)
new_node.meta["val"] = node.meta["val"][1:]
else:
assert len(schema.returns) == 0
assert len(new_node.users) == 0
new_node.meta["val"] = None
ep.graph.erase_node(node)
# Remove tokens from inputs
placeholders = [node for node in ep.graph.nodes if node.op == "placeholder"]
assert len(placeholders) >= num_tokens
inp_token_nodes = placeholders[:num_tokens]
for inp_token in inp_token_nodes:
assert inp_token.name in input_token_names
ep.graph.erase_node(inp_token)
ep.graph.eliminate_dead_code()
def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram:
"""
Removes the existance of tokens from the exported program, including:
- Removes the input and output tokens
- Replaces with_effects(token, func, args) with just func(args)
This function does an inplace modification on the given ExportedProgram.
"""
num_tokens: int = 0
input_token_names: List[str] = []
new_input_specs: List[InputSpec] = []
for inp in ep.graph_signature.input_specs:
if inp.kind == InputKind.TOKEN:
num_tokens += 1
assert isinstance(inp.arg, TokenArgument)
input_token_names.append(inp.arg.name)
else:
new_input_specs.append(inp)
num_out_tokens: int = 0
new_output_specs: List[OutputSpec] = []
output_token_names: List[OutputSpec] = []
for out in ep.graph_signature.output_specs:
if out.kind == OutputKind.TOKEN:
num_out_tokens += 1
output_token_names.append(out.arg.name)
else:
new_output_specs.append(out)
# Update graph signature
ep.graph_signature.input_specs = new_input_specs
ep.graph_signature.output_specs = new_output_specs
assert num_tokens == num_out_tokens
with ep.graph_module._set_replace_hook(ep.graph_signature.get_replace_hook()):
_remove_effect_tokens_from_graph_helper(
ep, num_tokens, input_token_names, output_token_names
)
return ep
|