1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
|
# mypy: allow-untyped-defs
import torch
__all__ = ["Dropout"]
class Dropout(torch.nn.Dropout):
r"""This is the quantized equivalent of :class:`~torch.nn.Dropout`.
And this is a placeholder to enable models where fp32 tensors
had dropout to work with quantized tensors in train and eval mode.
Args:
p: probability of an element to be zeroed
inplace: can optionally do the operation in-place. Default: ``False``
"""
def forward(self, input):
return input
def _get_name(self):
return "QuantizedDropout"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
return cls(mod.p, mod.inplace)
@classmethod
def from_reference(cls, mod, scale, zero_point):
return cls(mod.p, mod.inplace)
|