1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
|
# mypy: allow-untyped-defs
r"""Weight Normalization from https://arxiv.org/abs/1602.07868."""
from typing import Any, TypeVar
from typing_extensions import deprecated
from torch import _weight_norm, norm_except_dim
from torch.nn.modules import Module
from torch.nn.parameter import Parameter, UninitializedParameter
__all__ = ["WeightNorm", "weight_norm", "remove_weight_norm"]
class WeightNorm:
name: str
dim: int
def __init__(self, name: str, dim: int) -> None:
if dim is None:
dim = -1
self.name = name
self.dim = dim
# TODO Make return type more specific
def compute_weight(self, module: Module) -> Any:
g = getattr(module, self.name + "_g")
v = getattr(module, self.name + "_v")
return _weight_norm(v, g, self.dim)
@staticmethod
@deprecated(
"`torch.nn.utils.weight_norm` is deprecated "
"in favor of `torch.nn.utils.parametrizations.weight_norm`.",
category=FutureWarning,
)
def apply(module, name: str, dim: int) -> "WeightNorm":
for hook in module._forward_pre_hooks.values():
if isinstance(hook, WeightNorm) and hook.name == name:
raise RuntimeError(
f"Cannot register two weight_norm hooks on the same parameter {name}"
)
if dim is None:
dim = -1
fn = WeightNorm(name, dim)
weight = getattr(module, name)
if isinstance(weight, UninitializedParameter):
raise ValueError(
"The module passed to `WeightNorm` can't have uninitialized parameters. "
"Make sure to run the dummy forward before applying weight normalization"
)
# remove w from parameter list
del module._parameters[name]
# add g and v as new parameters and express w as g/||v|| * v
module.register_parameter(
name + "_g", Parameter(norm_except_dim(weight, 2, dim).data)
)
module.register_parameter(name + "_v", Parameter(weight.data))
setattr(module, name, fn.compute_weight(module))
# recompute weight before every forward()
module.register_forward_pre_hook(fn)
return fn
def remove(self, module: Module) -> None:
weight = self.compute_weight(module)
delattr(module, self.name)
del module._parameters[self.name + "_g"]
del module._parameters[self.name + "_v"]
setattr(module, self.name, Parameter(weight.data))
def __call__(self, module: Module, inputs: Any) -> None:
setattr(module, self.name, self.compute_weight(module))
T_module = TypeVar("T_module", bound=Module)
def weight_norm(module: T_module, name: str = "weight", dim: int = 0) -> T_module:
r"""Apply weight normalization to a parameter in the given module.
.. math::
\mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}
Weight normalization is a reparameterization that decouples the magnitude
of a weight tensor from its direction. This replaces the parameter specified
by :attr:`name` (e.g. ``'weight'``) with two parameters: one specifying the magnitude
(e.g. ``'weight_g'``) and one specifying the direction (e.g. ``'weight_v'``).
Weight normalization is implemented via a hook that recomputes the weight
tensor from the magnitude and direction before every :meth:`~Module.forward`
call.
By default, with ``dim=0``, the norm is computed independently per output
channel/plane. To compute a norm over the entire weight tensor, use
``dim=None``.
See https://arxiv.org/abs/1602.07868
.. warning::
This function is deprecated. Use :func:`torch.nn.utils.parametrizations.weight_norm`
which uses the modern parametrization API. The new ``weight_norm`` is compatible
with ``state_dict`` generated from old ``weight_norm``.
Migration guide:
* The magnitude (``weight_g``) and direction (``weight_v``) are now expressed
as ``parametrizations.weight.original0`` and ``parametrizations.weight.original1``
respectively. If this is bothering you, please comment on
https://github.com/pytorch/pytorch/issues/102999
* To remove the weight normalization reparametrization, use
:func:`torch.nn.utils.parametrize.remove_parametrizations`.
* The weight is no longer recomputed once at module forward; instead, it will
be recomputed on every access. To restore the old behavior, use
:func:`torch.nn.utils.parametrize.cached` before invoking the module
in question.
Args:
module (Module): containing module
name (str, optional): name of weight parameter
dim (int, optional): dimension over which to compute the norm
Returns:
The original module with the weight norm hook
Example::
>>> m = weight_norm(nn.Linear(20, 40), name='weight')
>>> m
Linear(in_features=20, out_features=40, bias=True)
>>> m.weight_g.size()
torch.Size([40, 1])
>>> m.weight_v.size()
torch.Size([40, 20])
"""
WeightNorm.apply(module, name, dim)
return module
def remove_weight_norm(module: T_module, name: str = "weight") -> T_module:
r"""Remove the weight normalization reparameterization from a module.
Args:
module (Module): containing module
name (str, optional): name of weight parameter
Example:
>>> m = weight_norm(nn.Linear(20, 40))
>>> remove_weight_norm(m)
"""
for k, hook in module._forward_pre_hooks.items():
if isinstance(hook, WeightNorm) and hook.name == name:
hook.remove(module)
del module._forward_pre_hooks[k]
return module
raise ValueError(f"weight_norm of '{name}' not found in {module}")
|