1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286
|
from .batchnorm import _NormBase
from .. import functional as F
from torch import Tensor
class _InstanceNorm(_NormBase):
def __init__(
self,
num_features: int,
eps: float = 1e-5,
momentum: float = 0.1,
affine: bool = False,
track_running_stats: bool = False
) -> None:
super(_InstanceNorm, self).__init__(
num_features, eps, momentum, affine, track_running_stats)
def _check_input_dim(self, input):
raise NotImplementedError
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
version = local_metadata.get('version', None)
# at version 1: removed running_mean and running_var when
# track_running_stats=False (default)
if version is None and not self.track_running_stats:
running_stats_keys = []
for name in ('running_mean', 'running_var'):
key = prefix + name
if key in state_dict:
running_stats_keys.append(key)
if len(running_stats_keys) > 0:
error_msgs.append(
'Unexpected running stats buffer(s) {names} for {klass} '
'with track_running_stats=False. If state_dict is a '
'checkpoint saved before 0.4.0, this may be expected '
'because {klass} does not track running stats by default '
'since 0.4.0. Please remove these keys from state_dict. If '
'the running stats are actually needed, instead set '
'track_running_stats=True in {klass} to enable them. See '
'the documentation of {klass} for details.'
.format(names=" and ".join('"{}"'.format(k) for k in running_stats_keys),
klass=self.__class__.__name__))
for key in running_stats_keys:
state_dict.pop(key)
super(_InstanceNorm, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
def forward(self, input: Tensor) -> Tensor:
self._check_input_dim(input)
return F.instance_norm(
input, self.running_mean, self.running_var, self.weight, self.bias,
self.training or not self.track_running_stats, self.momentum, self.eps)
class InstanceNorm1d(_InstanceNorm):
r"""Applies Instance Normalization over a 3D input (a mini-batch of 1D
inputs with optional additional channel dimension) as described in the paper
`Instance Normalization: The Missing Ingredient for Fast Stylization
<https://arxiv.org/abs/1607.08022>`__.
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The mean and standard-deviation are calculated per-dimension separately
for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors
of size `C` (where `C` is the input size) if :attr:`affine` is ``True``.
The standard-deviation is calculated via the biased estimator, equivalent to
`torch.var(input, unbiased=False)`.
By default, this layer uses instance statistics computed from input data in
both training and evaluation modes.
If :attr:`track_running_stats` is set to ``True``, during training this
layer keeps running estimates of its computed mean and variance, which are
then used for normalization during evaluation. The running estimates are
kept with a default :attr:`momentum` of 0.1.
.. note::
This :attr:`momentum` argument is different from one used in optimizer
classes and the conventional notion of momentum. Mathematically, the
update rule for running statistics here is
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
new observed value.
.. note::
:class:`InstanceNorm1d` and :class:`LayerNorm` are very similar, but
have some subtle differences. :class:`InstanceNorm1d` is applied
on each channel of channeled data like multidimensional time series, but
:class:`LayerNorm` is usually applied on entire sample and often in NLP
tasks. Additionally, :class:`LayerNorm` applies elementwise affine
transform, while :class:`InstanceNorm1d` usually don't apply affine
transform.
Args:
num_features: :math:`C` from an expected input of size
:math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)`
eps: a value added to the denominator for numerical stability. Default: 1e-5
momentum: the value used for the running_mean and running_var computation. Default: 0.1
affine: a boolean value that when set to ``True``, this module has
learnable affine parameters, initialized the same way as done for batch normalization.
Default: ``False``.
track_running_stats: a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
this module does not track such statistics and always uses batch
statistics in both training and eval modes. Default: ``False``
Shape:
- Input: :math:`(N, C, L)`
- Output: :math:`(N, C, L)` (same shape as input)
Examples::
>>> # Without Learnable Parameters
>>> m = nn.InstanceNorm1d(100)
>>> # With Learnable Parameters
>>> m = nn.InstanceNorm1d(100, affine=True)
>>> input = torch.randn(20, 100, 40)
>>> output = m(input)
"""
def _check_input_dim(self, input):
if input.dim() == 2:
raise ValueError(
'InstanceNorm1d returns 0-filled tensor to 2D tensor.'
'This is because InstanceNorm1d reshapes inputs to'
'(1, N * C, ...) from (N, C,...) and this makes'
'variances 0.'
)
if input.dim() != 3:
raise ValueError('expected 3D input (got {}D input)'
.format(input.dim()))
class InstanceNorm2d(_InstanceNorm):
r"""Applies Instance Normalization over a 4D input (a mini-batch of 2D inputs
with additional channel dimension) as described in the paper
`Instance Normalization: The Missing Ingredient for Fast Stylization
<https://arxiv.org/abs/1607.08022>`__.
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The mean and standard-deviation are calculated per-dimension separately
for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors
of size `C` (where `C` is the input size) if :attr:`affine` is ``True``.
The standard-deviation is calculated via the biased estimator, equivalent to
`torch.var(input, unbiased=False)`.
By default, this layer uses instance statistics computed from input data in
both training and evaluation modes.
If :attr:`track_running_stats` is set to ``True``, during training this
layer keeps running estimates of its computed mean and variance, which are
then used for normalization during evaluation. The running estimates are
kept with a default :attr:`momentum` of 0.1.
.. note::
This :attr:`momentum` argument is different from one used in optimizer
classes and the conventional notion of momentum. Mathematically, the
update rule for running statistics here is
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
new observed value.
.. note::
:class:`InstanceNorm2d` and :class:`LayerNorm` are very similar, but
have some subtle differences. :class:`InstanceNorm2d` is applied
on each channel of channeled data like RGB images, but
:class:`LayerNorm` is usually applied on entire sample and often in NLP
tasks. Additionally, :class:`LayerNorm` applies elementwise affine
transform, while :class:`InstanceNorm2d` usually don't apply affine
transform.
Args:
num_features: :math:`C` from an expected input of size
:math:`(N, C, H, W)`
eps: a value added to the denominator for numerical stability. Default: 1e-5
momentum: the value used for the running_mean and running_var computation. Default: 0.1
affine: a boolean value that when set to ``True``, this module has
learnable affine parameters, initialized the same way as done for batch normalization.
Default: ``False``.
track_running_stats: a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
this module does not track such statistics and always uses batch
statistics in both training and eval modes. Default: ``False``
Shape:
- Input: :math:`(N, C, H, W)`
- Output: :math:`(N, C, H, W)` (same shape as input)
Examples::
>>> # Without Learnable Parameters
>>> m = nn.InstanceNorm2d(100)
>>> # With Learnable Parameters
>>> m = nn.InstanceNorm2d(100, affine=True)
>>> input = torch.randn(20, 100, 35, 45)
>>> output = m(input)
"""
def _check_input_dim(self, input):
if input.dim() != 4:
raise ValueError('expected 4D input (got {}D input)'
.format(input.dim()))
class InstanceNorm3d(_InstanceNorm):
r"""Applies Instance Normalization over a 5D input (a mini-batch of 3D inputs
with additional channel dimension) as described in the paper
`Instance Normalization: The Missing Ingredient for Fast Stylization
<https://arxiv.org/abs/1607.08022>`__.
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The mean and standard-deviation are calculated per-dimension separately
for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors
of size C (where C is the input size) if :attr:`affine` is ``True``.
The standard-deviation is calculated via the biased estimator, equivalent to
`torch.var(input, unbiased=False)`.
By default, this layer uses instance statistics computed from input data in
both training and evaluation modes.
If :attr:`track_running_stats` is set to ``True``, during training this
layer keeps running estimates of its computed mean and variance, which are
then used for normalization during evaluation. The running estimates are
kept with a default :attr:`momentum` of 0.1.
.. note::
This :attr:`momentum` argument is different from one used in optimizer
classes and the conventional notion of momentum. Mathematically, the
update rule for running statistics here is
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
new observed value.
.. note::
:class:`InstanceNorm3d` and :class:`LayerNorm` are very similar, but
have some subtle differences. :class:`InstanceNorm3d` is applied
on each channel of channeled data like 3D models with RGB color, but
:class:`LayerNorm` is usually applied on entire sample and often in NLP
tasks. Additionally, :class:`LayerNorm` applies elementwise affine
transform, while :class:`InstanceNorm3d` usually don't apply affine
transform.
Args:
num_features: :math:`C` from an expected input of size
:math:`(N, C, D, H, W)`
eps: a value added to the denominator for numerical stability. Default: 1e-5
momentum: the value used for the running_mean and running_var computation. Default: 0.1
affine: a boolean value that when set to ``True``, this module has
learnable affine parameters, initialized the same way as done for batch normalization.
Default: ``False``.
track_running_stats: a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
this module does not track such statistics and always uses batch
statistics in both training and eval modes. Default: ``False``
Shape:
- Input: :math:`(N, C, D, H, W)`
- Output: :math:`(N, C, D, H, W)` (same shape as input)
Examples::
>>> # Without Learnable Parameters
>>> m = nn.InstanceNorm3d(100)
>>> # With Learnable Parameters
>>> m = nn.InstanceNorm3d(100, affine=True)
>>> input = torch.randn(20, 100, 35, 45, 10)
>>> output = m(input)
"""
def _check_input_dim(self, input):
if input.dim() != 5:
raise ValueError('expected 5D input (got {}D input)'
.format(input.dim()))
|