1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238
|
# mypy: allow-untyped-defs
import functools
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple
import torch
import torch.nn as nn
@dataclass
class TracingConfig:
"""
This represents a symbolic tracing configuration.
Args:
tracer (torch.fx.Tracer): An instance of :class:`torch.fx.Tracer` to
use for symbolic tracing. The default value is the native
:class:`torch.fx.Tracer` constructed with default arguments.
However, the user may want to pass a different value such as the
``HFTracer`` for models in the HuggingFace Transformers_ library.
.. _Transformers: https://huggingface.co/docs/transformers/index
concrete_args (Optional[Dict[str, Any]]): Concrete arguments that
should not be treated as ``torch.fx.Proxy`` when tracing the
module ``forward()``. Passing ``concrete_args`` allows partially
specializing the forward, e.g. to remove control flow or data
structures. This ``concrete_args`` here is the same argument used
in :meth:`~torch.fx.Tracer.trace`.
"""
tracer: torch.fx.Tracer = field(default_factory=torch.fx.Tracer)
concrete_args: Optional[Dict[str, Any]] = None
class _ParamUsageInfo(NamedTuple):
"""
This is used for ``_ExecutionInfo.module_to_param_usage_infos`` to record
execution information. The ``dict`` maps modules to a list of these
``_ParamUsageInfo`` instances, where each instance represents a group of
parameters used together.
Specifically, for each module key in the ``dict``, each instance of this
class represents either:
(1) the module and some sublist of its ``named_parameters()`` used
together in execution (see ``_patched_create_proxy()``), or
(2) a submodule and all of ``submodule.named_parameters()`` (see
``_patched_call_module()``).
Type (1) corresponds to directly using parameters in ops without calling
``forward()``, and type (2) corresponds to calling ``forward()``. The
mapped-to lists in the ``dict`` follow the execution order.
"""
module: nn.Module
named_params: List[Tuple[str, nn.Parameter]]
class _ExecutionInfo:
"""
This represents the execution order information from the forward pass.
Attributes:
curr_module (nn.Module): Current module being traced.
module_forward_order (List[nn.Module]): The modules in (pre-)forward
order, i.e. the order in which their ``forward()`` methods are
called. Each call to a module's ``forward()`` corresponds to one
element in the list.
module_to_param_usage_infos (Dict[nn.Module, List[_ParamUsageInfo]]):
Maps a module to a list of module execution infos. See
:class:`_ParamUsageInfo` for details.
param_forward_order (List[nn.Parameter]): The parameters in forward
execution order, where only a parameter's first participation is
included.
visited_params (Set[nn.Parameter]): The parameters visited so far
during the trace. This is only used during tracing for fast
membership check. Invariant: The parameters in
``param_forward_order`` are exactly those in ``visited_params``.
"""
def __init__(self, root_module: nn.Module) -> None:
self.curr_module: nn.Module = root_module
self.module_forward_order: List[nn.Module] = [root_module]
self.module_to_param_usage_infos: Dict[nn.Module, List[_ParamUsageInfo]] = {
root_module: []
}
self.param_forward_order: List[nn.Parameter] = []
self.visited_params: Set[nn.Parameter] = set()
class _ExecOrderTracer:
def __init__(self) -> None:
self.exec_info: Optional[_ExecutionInfo] = None
@contextmanager
def patch_tracer(self, tracer: torch.fx.Tracer, root_module: nn.Module):
self.exec_info = _ExecutionInfo(root_module)
orig_call_module = tracer.call_module
orig_create_proxy = tracer.create_proxy
tracer.call_module = functools.partial( # type: ignore[method-assign]
self._patched_call_module, orig_call_module, self.exec_info
)
fqn_to_param = dict(root_module.named_parameters())
tracer.create_proxy = functools.partial( # type: ignore[method-assign]
self._patched_create_proxy,
orig_create_proxy,
self.exec_info,
fqn_to_param,
)
try:
yield
finally:
tracer.call_module = orig_call_module # type: ignore[method-assign]
tracer.create_proxy = orig_create_proxy # type: ignore[method-assign]
def _patched_call_module(
self,
call_module: Callable,
exec_info: _ExecutionInfo,
# Below are the expected arguments to `call_module()`
module: nn.Module,
forward: Callable,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
) -> Any:
"""
Overrides ``call_module`` to save execution information to
``exec_info``. Note that ``call_module`` is called during symbolic
tracing for each non-root module.
Args:
call_module (Callable): Original ``call_module`` to override.
exec_info (_ExecutionInfo): Used to record execution information.
module (nn.Module): Module corresponding to this ``call_module``.
forward (Callable): ``forward()`` method of ``module`` to be called
for this ``call_module``.
args (Tuple[Any, ...]): Positional arguments for ``forward``.
kwargs (Dict[str, Any]): Keyword arguments for ``forward``.
Returns:
Same return value as ``call_module``.
"""
exec_info.module_forward_order.append(module)
named_params = list(module.named_parameters())
curr_module = exec_info.curr_module
if named_params:
assert (
curr_module in exec_info.module_to_param_usage_infos
), "The current module should have already been processed by a patched `call_module`"
exec_info.module_to_param_usage_infos[exec_info.curr_module].append(
_ParamUsageInfo(module, named_params)
)
prev_curr_module = curr_module
exec_info.curr_module = module
exec_info.module_to_param_usage_infos[module] = []
output = call_module(module, forward, args, kwargs)
exec_info.curr_module = prev_curr_module
return output
def _patched_create_proxy(
self,
create_proxy: Callable,
exec_info: _ExecutionInfo,
fqn_to_param: Dict[str, nn.Parameter],
# Below are the expected arguments to `create_proxy()`
kind: str,
target: torch.fx.node.Target,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
name: Optional[str] = None,
type_expr: Optional[Any] = None,
proxy_factory_fn: Optional[Callable[[torch.fx.Node], torch.fx.Proxy]] = None,
) -> torch.fx.Proxy:
"""
Overrides ``create_proxy`` to save execution information to
``exec_info``. Note that ``create_proxy`` is called during symbolic
tracing for each leaf function/method/module.
Args:
create_proxy (Callable): Original ``create_proxy`` to override.
exec_info (_ExecutionInfo): Used to record execution information.
fqn_to_param (Dict[str, nn.Parameter]): ``dict`` version of the
root module's ``named_parameters()`` with FQN as key and
parameter as value.
kind (str): Kind of the target method ('call_function',
'call_method', 'get_attr', 'call_module', 'placeholder', or
'output'). See :class:`torch.fx.Graph` for details. This is
passed to ``create_proxy``.
target (torch.fx.node.Target): Contains the string name of the
function/method/module. This is passed to ``create_proxy``.
args (Tuple[Any, ...]): Positional arguments for the function/
method/module. This is passed to ``create_proxy``.
kwargs (Dict[str, Any]): Keyword arguments for the function/method/
module. This is passed to ``create_proxy``
name (Optional[str]): An optional string name for the ``Node``
created in ``create_proxy``. This is passed to
``create_proxy``.
type_expr (Optional[Any]): An optional type annotation representing
the Python type that the output of the node has. This is passed
to ``create_proxy``.
proxy_factory_fn (Callable[[torch.fx.Node], torch.fx.Proxy]):
An alternative proxy constructor used in ``create_proxy``. This
is passed to ``create_proxy``.
Returns:
torch.fx.Proxy: Created ``Node`` wrapped in a ``Proxy`` object.
"""
proxy = create_proxy(
kind, target, args, kwargs, name, type_expr, proxy_factory_fn
)
curr_module = exec_info.curr_module
if kind in ("call_function", "call_method"):
if args is not None:
named_params: List[Tuple[str, nn.Parameter]] = []
for arg in args:
if (
isinstance(arg, torch.fx.Proxy)
and arg.node.target in fqn_to_param
):
param = fqn_to_param[arg.node.target] # type: ignore[index]
named_params.append((arg.node.target, param)) # type: ignore[arg-type]
if param not in exec_info.visited_params:
exec_info.visited_params.add(param)
exec_info.param_forward_order.append(param)
if named_params:
exec_info.module_to_param_usage_infos[curr_module].append(
_ParamUsageInfo(curr_module, named_params)
)
elif kind == "call_module":
named_params = list(curr_module.named_parameters())
if named_params:
exec_info.module_to_param_usage_infos[curr_module].append(
_ParamUsageInfo(curr_module, named_params)
)
for _, param in named_params:
if param not in exec_info.visited_params:
exec_info.visited_params.add(param)
exec_info.param_forward_order.append(param)
return proxy
|