File: _remove_effect_tokens_pass.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (161 lines) | stat: -rw-r--r-- 5,943 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
# mypy: allow-untyped-defs
import operator
from typing import List

import torch
from torch._higher_order_ops.effects import _get_schema, with_effects

from .exported_program import ExportedProgram
from .graph_signature import (
    CustomObjArgument,
    InputKind,
    InputSpec,
    OutputKind,
    OutputSpec,
    TokenArgument,
)


def _remove_effect_tokens_from_graph_helper(
    ep, num_tokens, input_token_names, output_token_names
):
    inputs_to_lifted_custom_objs = ep.graph_signature.inputs_to_lifted_custom_objs

    output_node = None
    with_effect_nodes: List[torch.fx.Node] = []

    # Output node need to check its args agianst output_token_names (collected from output_spec)
    # Therefore, we only need to find the top-levele output node
    output_node = next(reversed(ep.graph_module.graph.find_nodes(op="output")))
    for module in ep.graph_module.modules():
        if not isinstance(module, torch.fx.GraphModule):
            continue

        for node in module.graph.nodes:
            if not (node.op == "call_function" and node.target is with_effects):
                continue

            with_effect_nodes.append(node)

    # Remove tokens from outputs
    assert output_node is not None
    output_args = output_node.args[0]
    assert len(output_args) >= num_tokens
    out_token_nodes = output_args[:num_tokens]
    output_node.args = (tuple(output_args[num_tokens:]),)
    for out_token in out_token_nodes:
        assert out_token.name in output_token_names
        out_token.users.clear()
        ep.graph.erase_node(out_token)

    # Replace with_effects(token, func, args) with just func(args)
    for node in reversed(with_effect_nodes):
        func = node.args[1]
        assert isinstance(func, (torch._ops.OpOverload, torch._ops.HigherOrderOperator))

        if func == torch.ops.higher_order.call_torchbind:
            custom_obj_meta = node.args[2].meta["val"]
            assert isinstance(custom_obj_meta, CustomObjArgument)
            if custom_obj_meta.fake_val:
                custom_obj = custom_obj_meta.fake_val
            elif node.args[2].name in inputs_to_lifted_custom_objs:
                custom_obj = ep.constants[
                    inputs_to_lifted_custom_objs[node.args[2].name]
                ]
            else:
                raise RuntimeError(f"Unable to find custom obj for node {node}")
            schema = _get_schema(func, (custom_obj,) + node.args[3:])
        else:
            schema = _get_schema(func, node.args[2:])

        with ep.graph.inserting_before(node):
            new_node = ep.graph.call_function(func, node.args[2:], node.kwargs)
        for k, v in node.meta.items():
            new_node.meta[k] = v

        node.replace_all_uses_with(new_node)

        # Update user getitem nodes
        for user in list(new_node.users.keys()):
            assert user.target == operator.getitem
            # getitem(with_effects, 0) == token
            if user.args[1] == 0:
                ep.graph.erase_node(user)

        if len(schema.returns) == 1:
            # If the function has 1 return then it will just directly return the
            # result -- we don't need a getitem. So we can replace all the
            # getitem(with_effects, 1) with just the note itself.
            for user in list(new_node.users.keys()):
                assert user.args[1] == 1
                user.replace_all_uses_with(new_node)

            new_node.meta["val"] = node.meta["val"][1]
        elif len(schema.returns) > 1:
            # If the function has more than 1 return then since we got rid of
            # the 1st return value (the token), we need to bump all the other
            # getitem calls by 1 down
            for user in list(new_node.users.keys()):
                assert user.args[1] >= 1
                user.args = (user.args[0], user.args[1] - 1)

            new_node.meta["val"] = node.meta["val"][1:]
        else:
            assert len(schema.returns) == 0
            assert len(new_node.users) == 0
            new_node.meta["val"] = None

        ep.graph.erase_node(node)

    # Remove tokens from inputs
    placeholders = [node for node in ep.graph.nodes if node.op == "placeholder"]
    assert len(placeholders) >= num_tokens
    inp_token_nodes = placeholders[:num_tokens]
    for inp_token in inp_token_nodes:
        assert inp_token.name in input_token_names
        ep.graph.erase_node(inp_token)

    ep.graph.eliminate_dead_code()


def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram:
    """
    Removes the existance of tokens from the exported program, including:
    - Removes the input and output tokens
    - Replaces with_effects(token, func, args) with just func(args)

    This function does an inplace modification on the given ExportedProgram.
    """
    num_tokens: int = 0
    input_token_names: List[str] = []
    new_input_specs: List[InputSpec] = []
    for inp in ep.graph_signature.input_specs:
        if inp.kind == InputKind.TOKEN:
            num_tokens += 1
            assert isinstance(inp.arg, TokenArgument)
            input_token_names.append(inp.arg.name)
        else:
            new_input_specs.append(inp)

    num_out_tokens: int = 0
    new_output_specs: List[OutputSpec] = []
    output_token_names: List[OutputSpec] = []
    for out in ep.graph_signature.output_specs:
        if out.kind == OutputKind.TOKEN:
            num_out_tokens += 1
            output_token_names.append(out.arg.name)
        else:
            new_output_specs.append(out)

    # Update graph signature
    ep.graph_signature.input_specs = new_input_specs
    ep.graph_signature.output_specs = new_output_specs

    assert num_tokens == num_out_tokens

    with ep.graph_module._set_replace_hook(ep.graph_signature.get_replace_hook()):
        _remove_effect_tokens_from_graph_helper(
            ep, num_tokens, input_token_names, output_token_names
        )

    return ep