1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327
|
import copy
import inspect
import warnings
import re
from typing import Optional, Callable, Tuple
from contextlib import contextmanager
from e3nn import get_optimization_defaults, set_optimization_defaults
import torch
from torch import nn
from torch import fx
from opt_einsum_fx import jitable
ModuleFactory = Callable[..., nn.Module]
TypeTuple = Tuple[type, ...]
_E3NN_COMPILE_MODE = "__e3nn_compile_mode__"
_VALID_MODES = ("trace", "script", "unsupported", None)
_MAKE_TRACING_INPUTS = "_make_tracing_inputs"
def compile_mode(mode: str):
"""Decorator to set the compile mode of a module.
Parameters
----------
mode : str
'script', 'trace', or None
"""
if mode not in _VALID_MODES:
raise ValueError("Invalid compile mode")
def decorator(obj):
if not (inspect.isclass(obj) and issubclass(obj, torch.nn.Module)):
raise TypeError("@e3nn.util.jit.compile_mode can only decorate classes derived from torch.nn.Module")
setattr(obj, _E3NN_COMPILE_MODE, mode)
return obj
return decorator
def get_compile_mode(mod: torch.nn.Module) -> str:
"""Get the compilation mode of a module.
Parameters
----------
mod : torch.nn.Module
Returns
-------
'script', 'trace', or None if the module was not decorated with @compile_mode
"""
if hasattr(mod, _E3NN_COMPILE_MODE):
mode = getattr(mod, _E3NN_COMPILE_MODE)
else:
mode = getattr(type(mod), _E3NN_COMPILE_MODE, None)
if mode is None and isinstance(mod, fx.GraphModule):
mode = "script"
assert mode in _VALID_MODES, "Invalid compile mode `%r`" % mode
return mode
def compile(
mod: torch.nn.Module,
n_trace_checks: int = 1,
script_options: dict = None,
trace_options: dict = None,
in_place: bool = True,
recurse: bool = True,
):
"""Recursively compile a module and all submodules according to their decorators.
(Sub)modules without decorators will be unaffected.
Parameters
----------
mod : torch.nn.Module
The module to compile. The module will have its submodules compiled replaced in-place.
n_trace_checks : int, default = 1
How many random example inputs to generate when tracing a module. Must be at least one in order to have a tracing
input. Extra example inputs will be pased to ``torch.jit.trace`` to confirm that the traced copmute graph doesn't
change.
script_options : dict, default = {}
Extra kwargs for ``torch.jit.script``.
trace_options : dict, default = {}
Extra kwargs for ``torch.jit.trace``.
in_place : bool, default True
Whether to insert the recursively compiled submodules in-place, or do a deepcopy first.
recurse : bool, default True
Whether to recurse through the module's children before passing the parent to TorchScript
Returns
-------
Returns the compiled module.
"""
script_options = script_options or {}
trace_options = trace_options or {}
mode = get_compile_mode(mod)
if mode == "unsupported":
raise NotImplementedError(f"{type(mod).__name__} does not support TorchScript compilation")
if not in_place:
mod = copy.deepcopy(mod)
# TODO: debug logging
assert n_trace_checks >= 1
if recurse:
# == recurse to children ==
# This allows us to trace compile submodules of modules we are going to script
for submod_name, submod in mod.named_children():
setattr(
mod,
submod_name,
compile(
submod,
n_trace_checks=n_trace_checks,
script_options=script_options,
trace_options=trace_options,
in_place=True, # since we deepcopied the module above, we can do inplace
recurse=recurse, # always true in this branch
),
)
# == Compile this module now ==
if mode == "script":
if isinstance(mod, fx.GraphModule):
mod = jitable(mod)
# In recent PyTorch versions (probably >1.12, definitely >=2.0), PyTorch's implementation of fx.GraphModule
# causes a warning to be raised when fx.GraphModules are compiled to TorchScript with `torch.jit.script`:
#
# torch/jit/_check.py:177: UserWarning: The TorchScript type system doesn't support instance-level
# annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the
# class body, or 2) wrap the type in `torch.jit.Attribute`.
#
# Using the debugger traces this back to the following line in PyTorch:
# https://github.com/pytorch/pytorch/blob/v2.3.1/torch/fx/graph_module.py#L446
# Because the metadata stored by GraphModule is not relevant to the compiled TorchScript module
# we need, it should be safe to ignore this warning. The below code suppresses this warning as
# narrowly as possible to ensure it can still be raised from user code.
# See also: https://github.com/pytorch/pytorch/issues/89064
# Note: In PyTorch 2.10.0+, this warning is raised from ast.py instead of torch/jit/_check.py,
# so we don't filter by module to catch both cases.
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
# warnings treats this argument as a regex, but we want to match a string literal exactly, so escape it:
message=re.escape(
"The TorchScript type system doesn't support instance-level annotations on empty non-base types "
"in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type "
"in `torch.jit.Attribute`."
),
category=UserWarning,
# don't filter by module since in PyTorch 2.10.0+ the warning comes from ast.py instead of torch
)
mod = torch.jit.script(mod, **script_options)
else:
mod = torch.jit.script(mod, **script_options)
elif mode == "trace":
# These are always modules, so we're always using trace_module
# We need tracing inputs:
check_inputs = get_tracing_inputs(
mod,
n_trace_checks,
)
assert len(check_inputs) >= 1, "Must have at least one tracing input."
# Do the actual trace
mod = torch.jit.trace_module(mod, inputs=check_inputs[0], check_inputs=check_inputs, **trace_options)
return mod
def get_tracing_inputs(
mod: torch.nn.Module, n: int = 1, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None
):
"""Get random tracing inputs for ``mod``.
First checks if ``mod`` has a ``_make_tracing_inputs`` method. If so, calls it with ``n`` as the single argument and
returns its results.
Otherwise, attempts to infer the input signature of the module using ``e3nn.util._argtools._get_io_irreps``.
Parameters
----------
mod : torch.nn.Module
n : int, default = 1
A hint for how many inputs are wanted. Usually n will be returned, but modules don't necessarily have to.
device : torch.device
The device to do tracing on. If `None` (default), will be guessed.
dtype : torch.dtype
The dtype to trace with. If `None` (default), will be guessed.
Returns
-------
list of dict
Tracing inputs in the format of ``torch.jit.trace_module``: dicts mapping method names like ``'forward'`` to tuples of
arguments.
"""
# Avoid circular imports
from ._argtools import _get_device, _get_floating_dtype, _get_io_irreps, _rand_args, _to_device_dtype
# - Get inputs -
if hasattr(mod, _MAKE_TRACING_INPUTS):
# This returns a trace_module style dict of method names to test inputs
trace_inputs = mod._make_tracing_inputs(n)
assert isinstance(trace_inputs, list)
for d in trace_inputs:
assert isinstance(d, dict), "_make_tracing_inputs must return a list of dict[str, tuple]"
assert all(
isinstance(k, str) and isinstance(v, tuple) for k, v in d.items()
), "_make_tracing_inputs must return a list of dict[str, tuple]"
else:
# Try to infer. This will throw if it can't.
irreps_in, _ = _get_io_irreps(mod, irreps_out=[None]) # we're only trying to infer inputs
trace_inputs = [{"forward": _rand_args(irreps_in)} for _ in range(n)]
# - Put them on the right device -
if device is None:
device = _get_device(mod)
if dtype is None:
dtype = _get_floating_dtype(mod)
# Move them
trace_inputs = _to_device_dtype(trace_inputs, device, dtype)
return trace_inputs
def trace_module(mod: torch.nn.Module, inputs: dict = None, check_inputs: list = None, in_place: bool = True):
"""Trace a module.
Identical signature to ``torch.jit.trace_module``, but first recursively compiles ``mod`` using ``compile``.
Parameters
----------
mod : torch.nn.Module
inputs : dict
check_inputs : list of dict
Returns
-------
Traced module.
"""
check_inputs = check_inputs or []
# Set the compile mode for mod, temporarily
old_mode = getattr(mod, _E3NN_COMPILE_MODE, None)
if old_mode is not None and old_mode != "trace":
warnings.warn(
f"Trying to trace a module of type {type(mod).__name__} marked with @compile_mode != 'trace', expect errors!"
)
setattr(mod, _E3NN_COMPILE_MODE, "trace")
# If inputs are provided, set make_tracing_input temporarily
old_make_tracing_input = None
if inputs is not None:
old_make_tracing_input = getattr(mod, _MAKE_TRACING_INPUTS, None)
setattr(mod, _MAKE_TRACING_INPUTS, lambda num: ([inputs] + check_inputs))
# Compile
out = compile(mod, in_place=in_place)
# Restore old values, if we had them
if old_mode is not None:
setattr(mod, _E3NN_COMPILE_MODE, old_mode)
if old_make_tracing_input is not None:
setattr(mod, _MAKE_TRACING_INPUTS, old_make_tracing_input)
return out
def trace(mod: torch.nn.Module, example_inputs: tuple = None, check_inputs: list = None, in_place: bool = True):
"""Trace a module.
Identical signature to ``torch.jit.trace``, but first recursively compiles ``mod`` using :func:``compile``.
Parameters
----------
mod : torch.nn.Module
example_inputs : tuple
check_inputs : list of tuple
Returns
-------
Traced module.
"""
check_inputs = check_inputs or []
return trace_module(
mod=mod,
inputs=({"forward": example_inputs} if example_inputs is not None else None),
check_inputs=[{"forward": c} for c in check_inputs],
in_place=in_place,
)
def script(mod: torch.nn.Module, in_place: bool = True):
"""Script a module.
Like ``torch.jit.script``, but first recursively compiles ``mod`` using :func:``compile``.
Parameters
----------
mod : torch.nn.Module
Returns
-------
Scripted module.
"""
# Set the compile mode for mod, temporarily
old_mode = getattr(mod, _E3NN_COMPILE_MODE, None)
if old_mode is not None and old_mode != "script":
warnings.warn(
f"Trying to script a module of type {type(mod).__name__} marked with @compile_mode != 'script', expect errors!"
)
setattr(mod, _E3NN_COMPILE_MODE, "script")
# Compile
out = compile(mod, in_place=in_place)
# Restore old values, if we had them
if old_mode is not None:
setattr(mod, _E3NN_COMPILE_MODE, old_mode)
return out
@contextmanager
def disable_e3nn_codegen():
"""Context manager that disables the legacy PyTorch code generation used in e3nn."""
init_val = get_optimization_defaults()["jit_script_fx"]
set_optimization_defaults(jit_script_fx=False)
yield
set_optimization_defaults(jit_script_fx=init_val)
|