1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253
|
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import torch
import torch.nn as nn
from torch import Tensor
from torch._functorch.utils import exposed_in
@exposed_in("torch.func")
def functional_call(
module: "torch.nn.Module",
parameter_and_buffer_dicts: Union[Dict[str, Tensor], Sequence[Dict[str, Tensor]]],
args: Optional[Union[Any, Tuple]] = None,
kwargs: Optional[Dict[str, Any]] = None,
*,
tie_weights: bool = True,
strict: bool = False,
):
r"""Performs a functional call on the module by replacing the module parameters
and buffers with the provided ones.
.. note:: If the module has active parametrizations, passing a value in the
:attr:`parameter_and_buffer_dicts` argument with the name set to the regular parameter
name will completely disable the parametrization.
If you want to apply the parametrization function to the value passed
please set the key as ``{submodule_name}.parametrizations.{parameter_name}.original``.
.. note:: If the module performs in-place operations on parameters/buffers, these will be reflected
in the ``parameter_and_buffer_dicts`` input.
Example::
>>> a = {'foo': torch.zeros(())}
>>> # xdoctest: +SKIP
>>> mod = Foo() # does self.foo = self.foo + 1
>>> print(mod.foo) # tensor(0.)
>>> functional_call(mod, a, torch.ones(()))
>>> print(mod.foo) # tensor(0.)
>>> print(a['foo']) # tensor(1.)
.. note:: If the module has tied weights, whether or not functional_call respects the tying is determined by the
tie_weights flag.
Example::
>>> a = {'foo': torch.zeros(())}
>>> # xdoctest: +SKIP
>>> mod = Foo() # has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied
>>> print(mod.foo) # tensor(1.)
>>> mod(torch.zeros(())) # tensor(2.)
>>> functional_call(mod, a, torch.zeros(())) # tensor(0.) since it will change self.foo_tied too
>>> functional_call(mod, a, torch.zeros(()), tie_weights=False) # tensor(1.)--self.foo_tied is not updated
>>> new_a = {'foo': torch.zeros(()), 'foo_tied': torch.zeros(())}
>>> functional_call(mod, new_a, torch.zeros()) # tensor(0.)
An example of passing multiple dictionaries
.. code-block:: python
a = ({'weight': torch.ones(1, 1)}, {'buffer': torch.zeros(1)}) # two separate dictionaries
mod = nn.Bar(1, 1) # return self.weight @ x + self.buffer
print(mod.weight) # tensor(...)
print(mod.buffer) # tensor(...)
x = torch.randn((1, 1))
print(x)
functional_call(mod, a, x) # same as x
print(mod.weight) # same as before functional_call
And here is an example of applying the grad transform over the parameters
of a model.
.. code-block:: python
import torch
import torch.nn as nn
from torch.func import functional_call, grad
x = torch.randn(4, 3)
t = torch.randn(4, 3)
model = nn.Linear(3, 3)
def compute_loss(params, x, t):
y = functional_call(model, params, x)
return nn.functional.mse_loss(y, t)
grad_weights = grad(compute_loss)(dict(model.named_parameters()), x, t)
.. note:: If the user does not need grad tracking outside of grad transforms, they can detach all of the
parameters for better performance and memory usage
Example::
>>> detached_params = {k: v.detach() for k, v in model.named_parameters()}
>>> grad_weights = grad(compute_loss)(detached_params, x, t)
>>> grad_weights.grad_fn # None--it's not tracking gradients outside of grad
This means that the user cannot call ``grad_weight.backward()``. However, if they don't need autograd tracking
outside of the transforms, this will result in less memory usage and faster speeds.
Args:
module (torch.nn.Module): the module to call
parameters_and_buffer_dicts (Dict[str, Tensor] or tuple of Dict[str, Tensor]): the parameters that will be used in
the module call. If given a tuple of dictionaries, they must have distinct keys so that all dictionaries can
be used together
args (Any or tuple): arguments to be passed to the module call. If not a tuple, considered a single argument.
kwargs (dict): keyword arguments to be passed to the module call
tie_weights (bool, optional): If True, then parameters and buffers tied in the original model will be treated as
tied in the reparameterized version. Therefore, if True and different values are passed for the tied
parameters and buffers, it will error. If False, it will not respect the originally tied parameters and
buffers unless the values passed for both weights are the same. Default: True.
strict (bool, optional): If True, then the parameters and buffers passed in must match the parameters and
buffers in the original module. Therefore, if True and there are any missing or unexpected keys, it will
error. Default: False.
Returns:
Any: the result of calling ``module``.
"""
if isinstance(parameter_and_buffer_dicts, dict):
parameters_and_buffers = parameter_and_buffer_dicts
elif isinstance(parameter_and_buffer_dicts, Sequence):
if not all(isinstance(d, dict) for d in parameter_and_buffer_dicts):
raise ValueError(
"Expected all elements of parameter_and_buffer_dicts to be dictionaries"
)
all_keys = [k for d in parameter_and_buffer_dicts for k in d.keys()]
all_keys_counter: Dict[str, int] = {}
for k in all_keys:
v = all_keys_counter.get(k, 0)
all_keys_counter[k] = v + 1
repeated_keys = [key for key, n in all_keys_counter.items() if n > 1]
if len(repeated_keys) > 0:
raise ValueError(
f"{repeated_keys} appeared in multiple dictionaries; behavior of functional call is ambiguous"
)
parameters_and_buffers = {
k: v for d in parameter_and_buffer_dicts for k, v in d.items()
}
else:
raise ValueError(
f"Expected parameter_and_buffer_dicts to be a dict, or a list/tuple of dicts, "
f"but got {type(parameter_and_buffer_dicts)}"
)
return nn.utils.stateless._functional_call(
module,
parameters_and_buffers,
args,
kwargs,
tie_weights=tie_weights,
strict=strict,
)
@exposed_in("torch.func")
def stack_module_state(
models: Union[Sequence[nn.Module], nn.ModuleList],
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""stack_module_state(models) -> params, buffers
Prepares a list of torch.nn.Modules for ensembling with :func:`vmap`.
Given a list of ``M`` ``nn.Modules`` of the same class, returns two dictionaries
that stack all of their parameters and buffers together, indexed by name.
The stacked parameters are optimizable (i.e. they are new leaf nodes in the
autograd history that are unrelated to the original parameters and can be
passed directly to an optimizer).
Here's an example of how to ensemble over a very simple model:
.. code-block:: python
num_models = 5
batch_size = 64
in_features, out_features = 3, 3
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
data = torch.randn(batch_size, 3)
def wrapper(params, buffers, data):
return torch.func.functional_call(models[0], (params, buffers), data)
params, buffers = stack_module_state(models)
output = vmap(wrapper, (0, 0, None))(params, buffers, data)
assert output.shape == (num_models, batch_size, out_features)
When there's submodules, this follows state dict naming conventions
.. code-block:: python
import torch.nn as nn
class Foo(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
hidden = 4
self.l1 = nn.Linear(in_features, hidden)
self.l2 = nn.Linear(hidden, out_features)
def forward(self, x):
return self.l2(self.l1(x))
num_models = 5
in_features, out_features = 3, 3
models = [Foo(in_features, out_features) for i in range(num_models)]
params, buffers = stack_module_state(models)
print(list(params.keys())) # "l1.weight", "l1.bias", "l2.weight", "l2.bias"
.. warning::
All of the modules being stacked together must be the same (except for
the values of their parameters/buffers). For example, they should be in the
same mode (training vs eval).
"""
if len(models) == 0:
raise RuntimeError("stack_module_state: Expected at least one model, got 0.")
if not (all(m.training for m in models) or all(not m.training for m in models)):
raise RuntimeError(
"stack_module_state: Expected all models to have the same training/eval mode."
)
model0_typ = type(models[0])
if not all(type(m) == model0_typ for m in models):
raise RuntimeError(
"stack_module_state: Expected all models to be of the same class."
)
all_params = [dict(model.named_parameters()) for model in models]
params = {
k: construct_stacked_leaf(tuple(params[k] for params in all_params), k)
for k in all_params[0]
}
all_buffers = [dict(model.named_buffers()) for model in models]
buffers = {
k: construct_stacked_leaf(tuple(buffers[k] for buffers in all_buffers), k)
for k in all_buffers[0]
}
return params, buffers
def construct_stacked_leaf(
tensors: Union[Tuple[Tensor, ...], List[Tensor]], name: str
) -> Tensor:
all_requires_grad = all(t.requires_grad for t in tensors)
none_requires_grad = all(not t.requires_grad for t in tensors)
if not all_requires_grad and not none_requires_grad:
raise RuntimeError(
f"Expected {name} from each model to have the same .requires_grad"
)
result = torch.stack(tensors)
if all_requires_grad:
result = result.detach().requires_grad_()
return result
|