1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224
|
# mypy: allow-untyped-defs
import uuid
from collections import OrderedDict
from functools import wraps
from typing import Callable, Dict, List, Optional, Sequence, Type, Union
import torch
import torch.nn as nn
from torch.distributed._composable_state import _State
from torch.distributed.utils import _get_root_modules
def generate_state_key(string="__composable_api_state_key"):
return f"{string}_{str(uuid.uuid4())}"
STATE_KEY = generate_state_key()
REGISTRY_KEY = generate_state_key()
# TODO: we can add additional info to RegistryItem to share across APIs. E.g.,
# we can add args and kwargs here, and then we can detect whether fully_shard
# is combined with reentrant activation checkpointing and error out with a clear
# message.
class RegistryItem:
pass
def contract(state_cls: Type[_State] = _State):
r"""
Decorate a function as a composable distributed API, where the first
argument of the function must be an :class:`nn.Module` instance or sequence
of :class:`nn.Module` instances.
The decorator verifies that the decorated function does not modify
fully-qualified names (FQNs) for parameters, buffers, or modules. The
decorated function can return different module instances than the input
modules; the FQN invariant will be enforced following the input order.
When a function ``func`` is decorated by ``@contract()``, a
``.state(module: nn.Module)`` method will be installed to the decorated
function. Then you can retrieve and modify the state on a module by calling
``func.state(module)``.
Example::
>>> # xdoctest: +SKIP
>>> import torch.nn as nn
>>>
>>> class MyModel(nn.Module):
>>> def __init__(self) -> None:
>>> super().__init__()
>>> self.l1 = nn.Linear(10, 10)
>>> self.l2 = nn.Linear(10, 10)
>>>
>>> def forward(self, x):
>>> return self.l2(self.l1(x))
>>>
>>> @contract()
>>> def my_feature(module: nn.Module) -> nn.Module:
>>> my_feature.state(module).some_state = "any value"
>>> return module
>>>
>>> model = MyModel()
>>> my_feature(model.l1)
>>> assert my_feature.state(model.l1).some_state == "any value"
>>> my_feature(model.l2)
>>> model(torch.randn(2, 10)).sum().backward()
"""
# wraps will make functions decorated with contract() pickleable - needed for integration with torch.package
@wraps(state_cls)
def inner(func):
@wraps(func)
def wrapper(
module: Union[nn.Module, Sequence[nn.Module]], *args, **kwargs
) -> Optional[nn.Module]:
inp_module = module
if isinstance(module, nn.Module):
modules = [module]
else:
# If the user passes a sequence of modules, then we assume that
# we only need to insert the state object on the root modules
# (i.e. those without a parent) among the passed-in modules.
modules = _get_root_modules(list(module))
state = state_cls() # shared across all modules
registry_item = RegistryItem() # shared across all modules
# `func` is allowed to return different module instances than the
# input modules as long as FQNs are preserved following the input
# module order
all_orig_named_params: List[Dict[str, nn.Parameter]] = []
all_orig_named_buffers: List[Dict[str, torch.Tensor]] = []
all_orig_named_modules: List[Dict[str, nn.Module]] = []
for module in modules:
default_all_state: Dict[Callable, _State] = OrderedDict()
default_registry: Dict[str, RegistryItem] = OrderedDict()
all_state: Dict[Callable, _State] = module.__dict__.setdefault( # type: ignore[call-overload]
STATE_KEY, default_all_state
)
if not isinstance(all_state, dict):
raise AssertionError(
f"Distributed composable API states corrupted: {all_state}"
)
registry: Dict[str, RegistryItem] = module.__dict__.setdefault( # type: ignore[call-overload]
REGISTRY_KEY, default_registry
)
if not isinstance(registry, dict):
raise AssertionError(
f"Distributed composable API registry corrupted: {registry}"
)
if func in all_state or func.__name__ in registry:
raise AssertionError(
"Each distinct composable distributed API can only be applied to a "
f"module once. {func.__name__} has already been applied to the "
f"following module:\n{module}"
)
all_state.setdefault(func, state)
registry.setdefault(func.__name__, registry_item)
all_orig_named_params.append(OrderedDict(module.named_parameters()))
all_orig_named_buffers.append(OrderedDict(module.named_buffers()))
all_orig_named_modules.append(OrderedDict(module.named_modules()))
updated = func(inp_module, *args, **kwargs)
if updated is None:
updated = inp_module
if isinstance(updated, nn.Module):
updated_modules = [updated]
else:
updated_modules = _get_root_modules(list(inp_module)) # type: ignore[arg-type]
all_new_named_params: List[Dict[str, nn.Parameter]] = []
all_new_named_buffers: List[Dict[str, torch.Tensor]] = []
all_new_named_modules: List[Dict[str, nn.Module]] = []
for module in updated_modules:
all_new_named_params.append(OrderedDict(module.named_parameters()))
all_new_named_buffers.append(OrderedDict(module.named_buffers()))
all_new_named_modules.append(OrderedDict(module.named_modules()))
num_orig_modules = len(all_orig_named_modules)
num_new_modules = len(all_new_named_modules)
if num_orig_modules != num_new_modules:
raise AssertionError(
f"{func.__name__} should return the same number of modules as input modules"
f"Inputs: {num_orig_modules} modules\n"
f"Outputs: {num_new_modules} modules"
)
def check_fqn(orig_fqns: List[str], new_fqns: List[str], check_key: str):
if orig_fqns == new_fqns:
return
orig_fqn_set, new_fqn_set = set(orig_fqns), set(new_fqns)
orig_only = orig_fqn_set - new_fqn_set
new_only = new_fqn_set - orig_fqn_set
if len(orig_only) or len(new_only):
raise RuntimeError(
f"{check_key}"
"Composable distributed API implementations cannot modify FQNs.\n"
f"FQNs only in original: {orig_only}\n"
f"FQNs only in new: {new_only}"
)
else:
raise RuntimeError(
f"{check_key}"
"Composable distributed API implementations cannot modify "
"the order of FQNs.\n"
f"Original FQNs: {orig_only}\n"
f"New FQNs: {new_only}"
)
for orig_named_params, new_named_params in zip(
all_orig_named_params, all_new_named_params
):
check_fqn(
list(orig_named_params.keys()),
list(new_named_params.keys()),
"Checking parameters: ",
)
for orig_named_buffers, new_named_buffers in zip(
all_orig_named_buffers, all_new_named_buffers
):
check_fqn(
list(orig_named_buffers.keys()),
list(new_named_buffers.keys()),
"Checking buffers: ",
)
for orig_named_modules, new_named_modules in zip(
all_orig_named_modules, all_new_named_modules
):
check_fqn(
list(orig_named_modules.keys()),
list(new_named_modules.keys()),
"Checking modules: ",
)
# TODO: verify that installed distributed paradigms are compatible with
# each other.
return updated
def get_state(module: nn.Module) -> Optional[_State]:
return module.__dict__.setdefault( # type: ignore[call-overload]
STATE_KEY,
{}, # TODO(@yhcharles): this is a temporary fix, need a better way
).get(
func
) # type: ignore[call-overload]
wrapper.state = get_state # type: ignore[attr-defined]
return wrapper
return inner
def _get_registry(module: nn.Module) -> Optional[Dict[str, RegistryItem]]:
r"""
Get an ``OrderedDict`` of composable APIs that have been applied to the
``module``, indexed by the API name. If no API has been applied, then this
returns ``None``.
"""
return getattr(module, REGISTRY_KEY, None)
|