1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
|
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
__all__ = ['Embedding', 'EmbeddingBag']
class Embedding(nn.Embedding):
r"""
An embedding bag module attached with FakeQuantize modules for weight,
used for quantization aware training.
We adopt the same interface as `torch.nn.Embedding`, please see
https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html#torch.nn.Embedding
for documentation.
Similar to `torch.nn.Embedding`, with FakeQuantize modules initialized to
default.
Attributes:
weight: fake quant module for weight
"""
_FLOAT_MODULE = nn.Embedding
def __init__(self, num_embeddings, embedding_dim, padding_idx=None,
max_norm=None, norm_type=2.0, scale_grad_by_freq=False,
sparse=False, _weight=None, device=None, dtype=None, qconfig=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm,
norm_type, scale_grad_by_freq, sparse, _weight,
**factory_kwargs)
assert qconfig, 'qconfig must be provided for QAT module'
assert qconfig.weight().qscheme == torch.per_channel_affine_float_qparams, \
'Embedding weights requires a qscheme of torch.per_channel_affine_float_qparams Got ' + \
str(qconfig.weight().qscheme)
self.qconfig = qconfig
self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
def forward(self, input) -> Tensor:
return F.embedding(input, self.weight_fake_quant(self.weight), self.padding_idx,
self.max_norm, self.norm_type, self.scale_grad_by_freq,
self.sparse)
@classmethod
def from_float(cls, mod):
r"""Create a qat module from a float module
Args: `mod` a float module, either produced by torch.ao.quantization utilities
or directly from user
"""
assert type(mod) == cls._FLOAT_MODULE, ' qat.' + cls.__name__ + '.from_float only works for ' + \
cls._FLOAT_MODULE.__name__
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
assert mod.qconfig, 'Input float module must have a valid qconfig'
weight_qscheme = mod.qconfig.weight().qscheme # type: ignore[union-attr, operator]
assert weight_qscheme == torch.per_channel_affine_float_qparams, \
'Embedding weights requires a qscheme of torch.per_channel_affine_float_qparams Got ' + \
str(weight_qscheme)
qconfig = mod.qconfig
qat_embedding_bag = cls(mod.num_embeddings, mod.embedding_dim, mod.padding_idx,
mod.max_norm, mod.norm_type, mod.scale_grad_by_freq,
mod.sparse, mod.weight, qconfig=qconfig)
return qat_embedding_bag
def to_float(self):
embedding_bag = torch.nn.Embedding(self.num_embeddings, self.embedding_dim, self.padding_idx,
self.max_norm, self.norm_type, self.scale_grad_by_freq,
self.sparse, None)
embedding_bag.weight = torch.nn.Parameter(self.weight.detach())
embedding_bag.train(self.training)
return embedding_bag
class EmbeddingBag(nn.EmbeddingBag):
r"""
An embedding bag module attached with FakeQuantize modules for weight,
used for quantization aware training.
We adopt the same interface as `torch.nn.EmbeddingBag`, please see
https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html#torch.nn.EmbeddingBag
for documentation.
Similar to `torch.nn.EmbeddingBag`, with FakeQuantize modules initialized to
default.
Attributes:
weight: fake quant module for weight
"""
_FLOAT_MODULE = nn.EmbeddingBag
def __init__(self, num_embeddings, embedding_dim, max_norm=None,
norm_type=2.0, scale_grad_by_freq=False, mode='mean',
sparse=False, _weight=None, include_last_offset=False,
padding_idx=None, qconfig=None, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__(num_embeddings, embedding_dim, max_norm, norm_type,
scale_grad_by_freq, mode, sparse, _weight,
include_last_offset, padding_idx, **factory_kwargs)
assert qconfig, 'qconfig must be provided for QAT module'
assert qconfig.weight().qscheme == torch.per_channel_affine_float_qparams, \
'Embedding Bag weights requires a qscheme of torch.per_channel_affine_float_qparams Got ' + \
str(qconfig.weight().qscheme)
self.qconfig = qconfig
self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
def forward(self, input, offsets=None, per_sample_weights=None) -> Tensor:
return F.embedding_bag(input, self.weight_fake_quant(self.weight), offsets,
self.max_norm, self.norm_type,
self.scale_grad_by_freq, self.mode, self.sparse,
per_sample_weights, self.include_last_offset,
self.padding_idx)
@classmethod
def from_float(cls, mod):
r"""Create a qat module from a float module
Args: `mod` a float module, either produced by torch.ao.quantization utilities
or directly from user
"""
assert type(mod) == cls._FLOAT_MODULE, ' qat.' + cls.__name__ + '.from_float only works for ' + \
cls._FLOAT_MODULE.__name__
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
assert mod.qconfig, 'Input float module must have a valid qconfig'
weight_qscheme = mod.qconfig.weight().qscheme # type: ignore[union-attr, operator]
assert weight_qscheme == torch.per_channel_affine_float_qparams, \
'Embedding Bag weights requires a qscheme of torch.per_channel_affine_float_qparams Got ' + \
str(weight_qscheme)
qconfig = mod.qconfig
qat_embedding_bag = cls(mod.num_embeddings, mod.embedding_dim, mod.max_norm, mod.norm_type,
mod.scale_grad_by_freq, mod.mode, mod.sparse, mod.weight,
mod.include_last_offset, mod.padding_idx, qconfig=qconfig)
return qat_embedding_bag
def to_float(self):
embedding_bag = torch.nn.EmbeddingBag(self.num_embeddings, self.embedding_dim, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
None, self.include_last_offset, self.padding_idx)
embedding_bag.weight = torch.nn.Parameter(self.weight.detach())
embedding_bag.train(self.training)
return embedding_bag
|