1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
|
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from .utils import ReferenceQuantizedModule
from typing import Optional, Dict, Any
__all__ = ['Embedding', 'EmbeddingBag']
class Embedding(nn.Embedding, ReferenceQuantizedModule):
""" A reference quantized Embedding module that fits into the
FX Graph Mode Quantization workflow, activation will be floating point Tensor,
we will store floating point weight as well in the module, but in forward we'll
quantize and dequantize the weight before running the floating point functional
embedding operator.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
sparse: bool = False, _weight: Optional[Tensor] = None,
device=None, dtype=None,
weight_qparams: Optional[Dict[str, Any]] = None) -> None:
super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm,
norm_type, scale_grad_by_freq, sparse, _weight, device, dtype)
self._init_weight_qparams(weight_qparams, device)
def _get_name(self):
return "QuantizedEmbedding(Reference)"
def forward(self, input: Tensor) -> Tensor:
weight_quant_dequant = self.get_weight()
return F.embedding(
input, weight_quant_dequant, self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.sparse)
@classmethod
def from_float(cls, mod, weight_qparams):
return cls(
mod.num_embeddings,
mod.embedding_dim,
mod.padding_idx,
mod.max_norm,
mod.norm_type,
mod.scale_grad_by_freq,
mod.sparse,
mod.weight,
mod.weight.device,
mod.weight.dtype,
weight_qparams)
class EmbeddingBag(nn.EmbeddingBag, ReferenceQuantizedModule):
""" A reference quantized EmbeddingBag module that fits into the
FX Graph Mode Quantization workflow, activation will be floating point Tensor,
we will store floating point weight as well in the module, but in forward we'll
quantize and dequantize the weight before running the floating point functional
embedding operator.
"""
def __init__(self, num_embeddings: int, embedding_dim: int,
max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
mode: str = 'mean', sparse: bool = False, _weight: Optional[Tensor] = None,
include_last_offset: bool = False, padding_idx: Optional[int] = None,
device=None, dtype=None,
weight_qparams: Optional[Dict[str, Any]] = None) -> None:
super().__init__(num_embeddings, embedding_dim, max_norm, norm_type,
scale_grad_by_freq, mode, sparse, _weight, include_last_offset,
padding_idx, device, dtype)
self._init_weight_qparams(weight_qparams, device)
def _get_name(self):
return "QuantizedEmbedding(Reference)"
def forward(self, input: Tensor, offsets: Optional[Tensor] = None, per_sample_weights: Optional[Tensor] = None) -> Tensor:
weight_quant_dequant = self.get_weight()
return F.embedding_bag(input, weight_quant_dequant, offsets,
self.max_norm, self.norm_type,
self.scale_grad_by_freq, self.mode, self.sparse,
per_sample_weights, self.include_last_offset,
self.padding_idx)
@classmethod
def from_float(cls, mod, weight_qparams):
return cls(
mod.num_embeddings,
mod.embedding_dim,
mod.max_norm,
mod.norm_type,
mod.scale_grad_by_freq,
mod.mode,
mod.sparse,
mod.weight,
mod.include_last_offset,
mod.padding_idx,
mod.weight.device,
mod.weight.dtype,
weight_qparams
)
|