File: _symbolic_trace.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (243 lines) | stat: -rw-r--r-- 10,161 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
import contextlib
import functools
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple

import torch


__all__ = ["TracingConfig"]


@dataclass
class TracingConfig:
    """
    Configurations used in ``ParamExecOrderWrapPolicy`` for symbolic tracing of
    a model.

    Args:
        tracer (torch.fx.Tracer): An instance of ``torch.fx.Tracer`` that will
            be used to perform symbolic tracing. ``tracer`` is default to be
            ``torch.fx.Tracer()``, but can also be instance of some child class
            of ``torch.fx.Tracer``. For example, one may want to use
            ``HFTracer`` for models in Transformers: .. _Transformers:
            https://huggingface.co/docs/transformers/index
        concrete_args (Optional[Dict[str, Any]]): Concrete arguments that should
            not be treated as ``torch.fx.Proxy`` when tracing the forward
            function. ``concrete_args`` allows one to partially specialize the
            forward function, including removing control flow or data
            structures. ``concrete_args`` is also the argument used in
            :meth:`~torch.fx.Tracer.trace`.
    """

    tracer: torch.fx.Tracer = torch.fx.Tracer()
    concrete_args: Optional[Dict[str, Any]] = None


@dataclass
class _ExecutionInfo:
    """
    Contains the execution order information in the model forward pass.

    Attributes:
        current_module: record the module that is currently being traced.

        module_forward_order: a list of modules, where the ordering is based on
            when their forward function is called. ``module_forward_order``
            includes the info of how many times a module is called + used to
            check the forward order in different iterations.

        param_exec_order: a list of parameters ordered based on their execution
        order.

        module_to_execution_infos: a dict that maps each module to a list of
            tuples each containing a module and a list of named parameters.
            ``module_execution_info_dict`` is used as the parameter execution
            order info. For a given module, each tuple: 1. either contains this
            module and part of its ``named_parameters`` that will be executed
            together, 2. or contains one of its child modules and all of the
            child module's ``named_parameters``. The list of tuples is ordered
            based on the parameter execution order.
    """

    current_module: torch.nn.Module
    module_forward_order: List[torch.nn.Module]
    module_to_execution_infos: Dict[
        torch.nn.Module,
        List[Tuple[torch.nn.Module, List[Tuple[str, torch.nn.Parameter]]]],
    ]
    param_exec_order: List[torch.nn.Parameter] = field(default_factory=list)


def _init_execution_info(root_module: torch.nn.Module) -> _ExecutionInfo:
    """
    Create an instance of _ExecutionInfo with initialization based on
    ``root_module``.

    Args:
        root_module (torch.nn.Module): the module to get the execution
        information via ``tracer.trace()`` inside ``_patch_tracer``.
    """
    return _ExecutionInfo(
        current_module=root_module,
        module_forward_order=[root_module],
        module_to_execution_infos={root_module: []},
    )


def _patched_create_proxy(
    create_proxy: Callable,
    execution_info: _ExecutionInfo,
    prefixed_param_name_to_param: Dict[str, torch.nn.Parameter],
    kind: str,
    target: torch.fx.node.Target,
    args: Tuple[Any, ...],
    kwargs: Dict[str, Any],
    name: Optional[str] = None,
    type_expr: Optional[Any] = None,
    proxy_factory_fn: Callable[[torch.fx.Node], torch.fx.Proxy] = None,
) -> torch.fx.Proxy:
    """
    Override of :meth:`~torch.fx.Tracer.create_proxy`. ``Tracer.create_proxy``
    is called in symbolic tracing for each leaf function/method/module. This
    override intercepts the recording of each of these operations to update
    ``execution_info.module_to_execution_infos``.

    Args:
        create_proxy (Callable):
            The ``create_proxy`` function to be patched.
        execution_info (_ExecutionInfo):
            Used to record the execution information.
        prefixed_param_name_to_param (Dict[str, torch.nn.Parameter]):
            A dict that maps each prefixed parameter name to the parameter.
        kind (str):
            The type of the target method. One of 'call_function',
            'call_method', 'get_attr', 'call_module', 'placeholder', or
            'output'. The semantics of these opcodes are described in the
            ``torch.fx.Graph`` docstring. This is the input to ``create_proxy``.
        target (torch.fx.node.Target):
            Contains the string name of the method. This is the input to
            ``create_proxy``.
        args (Tuple[Any, ...]):
            Arguments of the method. This is the input to ``create_proxy``.
        kwargs (Dict[str, Any]):
            Keyword arguments of the method. This is the input to
            ``create_proxy``.
        name (Optional[str]):
            An optional string name for the ``Node`` created in
            ``create_proxy``. This is the input to ``create_proxy``.
        type_expr (Optional[Any]):
            An optional type annotation representing the Python type the output
            of a node will have. This is the input to ``create_proxy``.
        proxy_factory_fn (Callable[[torch.fx.Node], torch.fx.Proxy]):
            An alternative proxy constructor used in ``create_proxy``. This is
            the input to ``create_proxy``.
    """
    proxy = create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)

    module = execution_info.current_module
    if kind in ["call_function", "call_method"]:
        if args is not None:
            named_params: List[Tuple[str, torch.nn.Parameter]] = []
            for arg in args:
                if isinstance(arg, torch.fx.Proxy) and arg.node.target in prefixed_param_name_to_param:
                    param = prefixed_param_name_to_param[arg.node.target]
                    named_params.append((arg.node.target, param))
                    if param not in set(execution_info.param_exec_order):
                        execution_info.param_exec_order.append(param)
            if named_params:
                execution_info.module_to_execution_infos[module].append((module, named_params))
    elif kind == "call_module":
        named_params = list(module.named_parameters())
        if named_params:
            execution_info.module_to_execution_infos[module].append(
                (module, named_params)
            )
        for (_, p) in named_params:
            if p not in set(execution_info.param_exec_order):
                execution_info.param_exec_order.append(p)
    return proxy


def _patched_call_module(
    call_module: Callable,
    execution_info: _ExecutionInfo,
    module: torch.nn.Module,
    forward: Callable[..., Any],
    args: Tuple[Any, ...],
    kwargs: Dict[str, Any],
) -> Any:
    """
    Override of :meth:`~torch.fx.Tracer.call_module`. ``Tracer.call_module`` is
    called in symbolic tracing for each non-root module. This override
    intercepts the recording of each operation to update
    ``execution_info.module_forward_order`` and
    ``execution_info.module_to_execution_infos``.

    Args:
        call_module (Callable):
            The ``call_module`` function to be patched.
        execution_info (_ExecutionInfo):
            Used to repord the execution information.
        module (torch.nn.Module):
            The module for which a call is being emitted.
        forward (Callable[..., Any]):
            The ``forward()`` method of the ``torch.nn.Module`` to be invoked.
        args (Tuple[Any, ...]):
            ``args`` of the module callsite.
        kwargs (Dict[str, Any]):
            ``kwargs`` of the module callsite.
    """
    execution_info.module_forward_order.append(module)
    named_params = list(module.named_parameters())
    if named_params:
        execution_info.module_to_execution_infos[execution_info.current_module].append(
            (module, list(module.named_parameters()))
        )
    # Stores away current_module for restoration later
    prev_current_module = execution_info.current_module
    execution_info.current_module = module
    # Note that if the forward of module is called multiple times, this will record
    # the execution info of the last forward pass.
    execution_info.module_to_execution_infos[module] = []
    output = call_module(module, forward, args, kwargs)
    execution_info.current_module = prev_current_module
    return output


@contextlib.contextmanager
def _patch_tracer(
    tracer: torch.fx.Tracer,
    root_module: torch.nn.Module,
    execution_info: _ExecutionInfo,
) -> Generator:
    """
    Within the context manager, patches the input tracer so that during
    ``tracer.trace()``, the forward order of all modules and the parameter
    execution information are recorded. The patches of the input tracer will be
    removed after the context manager exits.

    Args:
        tracer (torch.fx.Tracer): the input ``tracer`` whose member functions
            will be patched within the context manager.
        root_module (torch.nn.Module): the top-level module to be traced
            and should not contain any FSDP modules.
        execution_info (_ExecutionInfo): used to record the execution order
            information when performing ``tracer.trace()`` within the context
            manager.
    """
    original_call_module = tracer.call_module
    original_create_proxy = tracer.create_proxy

    tracer.call_module = functools.partial(
        _patched_call_module, original_call_module, execution_info
    )
    prefixed_param_name_to_param = dict(root_module.named_parameters())
    tracer.create_proxy = functools.partial(
        _patched_create_proxy, original_create_proxy, execution_info, prefixed_param_name_to_param
    )
    try:
        yield
    finally:
        tracer.call_module = original_call_module
        tracer.create_proxy = original_create_proxy