File: _trace_utils.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (238 lines) | stat: -rw-r--r-- 10,776 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
# mypy: allow-untyped-defs
import functools
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple

import torch
import torch.nn as nn


@dataclass
class TracingConfig:
    """
    This represents a symbolic tracing configuration.

    Args:
        tracer (torch.fx.Tracer): An instance of :class:`torch.fx.Tracer` to
            use for symbolic tracing. The default value is the native
            :class:`torch.fx.Tracer` constructed with default arguments.
            However, the user may want to pass a different value such as the
            ``HFTracer`` for models in the HuggingFace Transformers_ library.
            .. _Transformers: https://huggingface.co/docs/transformers/index
        concrete_args (Optional[Dict[str, Any]]): Concrete arguments that
            should not be treated as ``torch.fx.Proxy`` when tracing the
            module ``forward()``. Passing ``concrete_args`` allows partially
            specializing the forward, e.g. to remove control flow or data
            structures. This ``concrete_args`` here is the same argument used
            in :meth:`~torch.fx.Tracer.trace`.
    """

    tracer: torch.fx.Tracer = field(default_factory=torch.fx.Tracer)
    concrete_args: Optional[Dict[str, Any]] = None


class _ParamUsageInfo(NamedTuple):
    """
    This is used for ``_ExecutionInfo.module_to_param_usage_infos`` to record
    execution information. The ``dict`` maps modules to a list of these
    ``_ParamUsageInfo`` instances, where each instance represents a group of
    parameters used together.

    Specifically, for each module key in the ``dict``, each instance of this
    class represents either:
    (1) the module and some sublist of its ``named_parameters()`` used
    together in execution (see ``_patched_create_proxy()``), or
    (2) a submodule and all of ``submodule.named_parameters()`` (see
    ``_patched_call_module()``).

    Type (1) corresponds to directly using parameters in ops without calling
    ``forward()``, and type (2) corresponds to calling ``forward()``. The
    mapped-to lists in the ``dict`` follow the execution order.
    """

    module: nn.Module
    named_params: List[Tuple[str, nn.Parameter]]


class _ExecutionInfo:
    """
    This represents the execution order information from the forward pass.

    Attributes:
        curr_module (nn.Module): Current module being traced.
        module_forward_order (List[nn.Module]): The modules in (pre-)forward
            order, i.e. the order in which their ``forward()`` methods are
            called. Each call to a module's ``forward()`` corresponds to one
            element in the list.
        module_to_param_usage_infos (Dict[nn.Module, List[_ParamUsageInfo]]):
            Maps a module to a list of module execution infos. See
            :class:`_ParamUsageInfo` for details.
        param_forward_order (List[nn.Parameter]): The parameters in forward
            execution order, where only a parameter's first participation is
            included.
        visited_params (Set[nn.Parameter]): The parameters visited so far
            during the trace. This is only used during tracing for fast
            membership check. Invariant: The parameters in
            ``param_forward_order`` are exactly those in ``visited_params``.
    """

    def __init__(self, root_module: nn.Module) -> None:
        self.curr_module: nn.Module = root_module
        self.module_forward_order: List[nn.Module] = [root_module]
        self.module_to_param_usage_infos: Dict[nn.Module, List[_ParamUsageInfo]] = {
            root_module: []
        }
        self.param_forward_order: List[nn.Parameter] = []
        self.visited_params: Set[nn.Parameter] = set()


class _ExecOrderTracer:
    def __init__(self) -> None:
        self.exec_info: Optional[_ExecutionInfo] = None

    @contextmanager
    def patch_tracer(self, tracer: torch.fx.Tracer, root_module: nn.Module):
        self.exec_info = _ExecutionInfo(root_module)
        orig_call_module = tracer.call_module
        orig_create_proxy = tracer.create_proxy
        tracer.call_module = functools.partial(  # type: ignore[method-assign]
            self._patched_call_module, orig_call_module, self.exec_info
        )
        fqn_to_param = dict(root_module.named_parameters())
        tracer.create_proxy = functools.partial(  # type: ignore[method-assign]
            self._patched_create_proxy,
            orig_create_proxy,
            self.exec_info,
            fqn_to_param,
        )
        try:
            yield
        finally:
            tracer.call_module = orig_call_module  # type: ignore[method-assign]
            tracer.create_proxy = orig_create_proxy  # type: ignore[method-assign]

    def _patched_call_module(
        self,
        call_module: Callable,
        exec_info: _ExecutionInfo,
        # Below are the expected arguments to `call_module()`
        module: nn.Module,
        forward: Callable,
        args: Tuple[Any, ...],
        kwargs: Dict[str, Any],
    ) -> Any:
        """
        Overrides ``call_module`` to save execution information to
        ``exec_info``. Note that ``call_module`` is called during symbolic
        tracing for each non-root module.

        Args:
            call_module (Callable): Original ``call_module`` to override.
            exec_info (_ExecutionInfo): Used to record execution information.
            module (nn.Module): Module corresponding to this ``call_module``.
            forward (Callable): ``forward()`` method of ``module`` to be called
                for this ``call_module``.
            args (Tuple[Any, ...]): Positional arguments for ``forward``.
            kwargs (Dict[str, Any]): Keyword arguments for ``forward``.

        Returns:
            Same return value as ``call_module``.
        """
        exec_info.module_forward_order.append(module)
        named_params = list(module.named_parameters())
        curr_module = exec_info.curr_module
        if named_params:
            assert (
                curr_module in exec_info.module_to_param_usage_infos
            ), "The current module should have already been processed by a patched `call_module`"
            exec_info.module_to_param_usage_infos[exec_info.curr_module].append(
                _ParamUsageInfo(module, named_params)
            )
        prev_curr_module = curr_module
        exec_info.curr_module = module
        exec_info.module_to_param_usage_infos[module] = []
        output = call_module(module, forward, args, kwargs)
        exec_info.curr_module = prev_curr_module
        return output

    def _patched_create_proxy(
        self,
        create_proxy: Callable,
        exec_info: _ExecutionInfo,
        fqn_to_param: Dict[str, nn.Parameter],
        # Below are the expected arguments to `create_proxy()`
        kind: str,
        target: torch.fx.node.Target,
        args: Tuple[Any, ...],
        kwargs: Dict[str, Any],
        name: Optional[str] = None,
        type_expr: Optional[Any] = None,
        proxy_factory_fn: Optional[Callable[[torch.fx.Node], torch.fx.Proxy]] = None,
    ) -> torch.fx.Proxy:
        """
        Overrides ``create_proxy`` to save execution information to
        ``exec_info``. Note that ``create_proxy`` is called during symbolic
        tracing for each leaf function/method/module.

        Args:
            create_proxy (Callable): Original ``create_proxy`` to override.
            exec_info (_ExecutionInfo): Used to record execution information.
            fqn_to_param (Dict[str, nn.Parameter]): ``dict`` version of the
                root module's ``named_parameters()`` with FQN as key and
                parameter as value.
            kind (str): Kind of the target method ('call_function',
                'call_method', 'get_attr', 'call_module', 'placeholder', or
                'output'). See :class:`torch.fx.Graph` for details. This is
                passed to ``create_proxy``.
            target (torch.fx.node.Target): Contains the string name of the
                function/method/module. This is passed to ``create_proxy``.
            args (Tuple[Any, ...]): Positional arguments for the function/
                method/module. This is passed to ``create_proxy``.
            kwargs (Dict[str, Any]): Keyword arguments for the function/method/
                module. This is passed to ``create_proxy``
            name (Optional[str]): An optional string name for the ``Node``
                created in ``create_proxy``. This is passed to
                ``create_proxy``.
            type_expr (Optional[Any]): An optional type annotation representing
                the Python type that the output of the node has. This is passed
                to ``create_proxy``.
            proxy_factory_fn (Callable[[torch.fx.Node], torch.fx.Proxy]):
                An alternative proxy constructor used in ``create_proxy``. This
                is passed to ``create_proxy``.

        Returns:
            torch.fx.Proxy: Created ``Node`` wrapped in a ``Proxy`` object.
        """
        proxy = create_proxy(
            kind, target, args, kwargs, name, type_expr, proxy_factory_fn
        )
        curr_module = exec_info.curr_module
        if kind in ("call_function", "call_method"):
            if args is not None:
                named_params: List[Tuple[str, nn.Parameter]] = []
                for arg in args:
                    if (
                        isinstance(arg, torch.fx.Proxy)
                        and arg.node.target in fqn_to_param
                    ):
                        param = fqn_to_param[arg.node.target]  # type: ignore[index]
                        named_params.append((arg.node.target, param))  # type: ignore[arg-type]
                        if param not in exec_info.visited_params:
                            exec_info.visited_params.add(param)
                            exec_info.param_forward_order.append(param)
                if named_params:
                    exec_info.module_to_param_usage_infos[curr_module].append(
                        _ParamUsageInfo(curr_module, named_params)
                    )
        elif kind == "call_module":
            named_params = list(curr_module.named_parameters())
            if named_params:
                exec_info.module_to_param_usage_infos[curr_module].append(
                    _ParamUsageInfo(curr_module, named_params)
                )
            for _, param in named_params:
                if param not in exec_info.visited_params:
                    exec_info.visited_params.add(param)
                    exec_info.param_forward_order.append(param)
        return proxy