File: embedding_ops.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (295 lines) | stat: -rw-r--r-- 13,677 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
import torch
import torch.nn as nn
from torch import Tensor  # noqa: F401
from torch._jit_internal import Optional, List  # noqa: F401

from .utils import hide_packed_params_repr
from .utils import _quantize_weight

__all__ = ['EmbeddingPackedParams', 'Embedding', 'EmbeddingBag']

class EmbeddingPackedParams(torch.nn.Module):
    _version = 1

    def __init__(self, num_embeddings, embedding_dim, dtype=torch.quint8):
        super(EmbeddingPackedParams, self).__init__()
        self.dtype = dtype
        if self.dtype in [torch.quint8, torch.quint4x2]:
            scales = torch.ones(num_embeddings, dtype=torch.float)
            zero_points = torch.zeros(num_embeddings, dtype=torch.float)
            wq = torch._empty_per_channel_affine_quantized([num_embeddings, embedding_dim], scales=scales,
                                                           zero_points=zero_points,
                                                           axis=0, dtype=self.dtype)
            self.set_weight(wq)
        else:
            raise NotImplementedError(f'Unsupported dtype on quantized embedding! Supports quint8 and quint4x2. Got dtype: {dtype}')

    @torch.jit.export
    def set_weight(self, weight: torch.Tensor) -> None:
        if self.dtype in [torch.quint8, torch.quint4x2]:
            self._packed_weight = torch.ops.quantized.embedding_bag_prepack(weight)
        else:
            raise NotImplementedError('Unsupported dtype for quantized embedding prepack! Supports quint8 and quint4x2.')


    @torch.jit.export
    def _weight(self):
        if self.dtype in [torch.quint8, torch.quint4x2]:
            return torch.ops.quantized.embedding_bag_unpack(self._packed_weight)
        else:
            raise NotImplementedError('Unsupported dtype for quantized embedding unpack! Supports quint8 and quint4x2.')

    def forward(self, x):
        return x

    # Version 1
    #   self
    #   |--- _packed_weight : Tensor representing weight of EmbeddingPackedParamsBase
    #   |--- dtype : torch.dtype

    def _save_to_state_dict(self, destination, prefix, keep_vars):
        super(EmbeddingPackedParams, self)._save_to_state_dict(destination, prefix, keep_vars)
        destination[prefix + 'dtype'] = self.dtype
        destination[prefix + '_packed_weight'] = self._weight()

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        self.dtype = state_dict[prefix + 'dtype']
        state_dict.pop(prefix + 'dtype')

        weight = state_dict[prefix + '_packed_weight']
        state_dict.pop(prefix + '_packed_weight')
        self.set_weight(weight)

        super(EmbeddingPackedParams, self)._load_from_state_dict(state_dict, prefix, local_metadata, False,
                                                                 missing_keys, unexpected_keys, error_msgs)

    def __repr__(self):
        return self._weight().__repr__()

class Embedding(torch.nn.Module):
    r"""
    A quantized Embedding module with quantized packed weights as inputs.
    We adopt the same interface as `torch.nn.Embedding`, please see
    https://pytorch.org/docs/stable/nn.html#torch.nn.Embedding for documentation.

    Similar to :class:`~torch.nn.Embedding`, attributes will be randomly
    initialized at module creation time and will be overwritten later

    Attributes:
        weight (Tensor): the non-learnable quantized weights of the module of
                         shape :math:`(\text{num\_embeddings}, \text{embedding\_dim})`.

    Examples::
        >>> m = nn.quantized.Embedding(num_embeddings=10, embedding_dim=12)
        >>> indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8])
        >>> output = m(indices)
        >>> print(output.size())
        torch.Size([9, 12])

    """
    _version = 1

    def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
                 max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
                 sparse: bool = False, _weight: Optional[Tensor] = None, dtype=torch.quint8) -> None:
        super(Embedding, self).__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.dtype = dtype

        if _weight is None:
            scales = torch.ones(num_embeddings, dtype=torch.float)
            zero_points = torch.zeros(num_embeddings, dtype=torch.float)
            qweight = torch._empty_per_channel_affine_quantized([num_embeddings, embedding_dim],
                                                                scales=scales, zero_points=zero_points,
                                                                axis=0, dtype=torch.quint8)
        else:
            assert list(_weight.shape) == [num_embeddings, embedding_dim], \
                'Shape of weight does not match num_embeddings and embedding_dim'
            qweight = _weight

        self._packed_params = EmbeddingPackedParams(num_embeddings, embedding_dim, dtype)
        self._packed_params.set_weight(qweight)

    def forward(self, indices: Tensor) -> Tensor:
        if self.dtype == torch.quint4x2:
            return torch.ops.quantized.embedding_4bit(self._packed_params._packed_weight, indices)
        else:
            return torch.ops.quantized.embedding_byte(self._packed_params._packed_weight, indices)

    def _get_name(self):
        return 'QuantizedEmbedding'

    def __repr__(self):
        return hide_packed_params_repr(self, EmbeddingPackedParams)

    def extra_repr(self):
        extra_repr_str = 'num_embeddings={}, embedding_dim={}, dtype={}, qscheme={}'.format(
            self.num_embeddings, self.embedding_dim, self._packed_params.dtype, self.weight().qscheme()
        )

        return extra_repr_str

    def set_weight(self, w: torch.Tensor) -> None:
        self._packed_params.set_weight(w)

    def weight(self):
        return self._packed_params._weight()

    @classmethod
    def from_float(cls, mod):
        r"""Create a quantized embedding module from a float module

        Args:
            mod (Module): a float module, either produced by torch.ao.quantization
                          utilities or provided by user
        """
        if hasattr(mod, 'weight_fake_quant'):
            assert type(mod) == torch.ao.nn.qat.Embedding, 'nnq.' + cls.__name__ + '.from_float ' + \
                'with fake quant only works for ' + torch.ao.nn.qat.Embedding.__name__
            weight_observer = mod.weight_fake_quant
            activation_post_process = mod.activation_post_process
        else:
            assert type(mod) == nn.Embedding, 'nnq.' + cls.__name__ + '.from_float only works for ' + \
                nn.Embedding.__name__
            assert hasattr(mod, 'qconfig'), 'Embedding input float module must have qconfig defined'
            from torch.ao.quantization import float_qparams_weight_only_qconfig
            if mod.qconfig is not None and mod.qconfig.weight is not None:  # type: ignore[union-attr]
                weight_observer = mod.qconfig.weight()  # type: ignore[union-attr, operator]
            else:
                weight_observer = float_qparams_weight_only_qconfig.weight()

        dtype = weight_observer.dtype
        is_float_qparams_qconfig = weight_observer.qscheme == torch.per_channel_affine_float_qparams
        assert is_float_qparams_qconfig, \
            'Embedding quantization is only supported with float_qparams_weight_only_qconfig.'

        assert dtype == torch.quint8 or dtype == torch.quint4x2, \
            f'The only supported dtype for nnq.Embedding is torch.quint8 and torch.quint4x2, got {dtype}'

        # Run the observer to calculate qparams.
        weight_observer(mod.weight)
        qweight = _quantize_weight(mod.weight.float(), weight_observer)

        # Create quantized Embedding module and pass in the quantized weight
        qembedding = Embedding(mod.num_embeddings, mod.embedding_dim)
        qembedding.set_weight(qweight)
        return qembedding

    @classmethod
    def from_reference(cls, ref_embedding):
        qembedding = cls(
            ref_embedding.num_embeddings,
            ref_embedding.embedding_dim,
            ref_embedding.padding_idx,
            ref_embedding.max_norm,
            ref_embedding.norm_type,
            ref_embedding.scale_grad_by_freq,
            ref_embedding.sparse,
            ref_embedding.get_quantized_weight(),
            ref_embedding.weight_dtype,
        )
        return qembedding

class EmbeddingBag(Embedding):
    r"""
    A quantized EmbeddingBag module with quantized packed weights as inputs.
    We adopt the same interface as `torch.nn.EmbeddingBag`, please see
    https://pytorch.org/docs/stable/nn.html#torch.nn.EmbeddingBag for documentation.

    Similar to :class:`~torch.nn.EmbeddingBag`, attributes will be randomly
    initialized at module creation time and will be overwritten later

    Attributes:
        weight (Tensor): the non-learnable quantized weights of the module of
                         shape :math:`(\text{num\_embeddings}, \text{embedding\_dim})`.

    Examples::
        >>> m = nn.quantized.EmbeddingBag(num_embeddings=10, embedding_dim=12, include_last_offset=True, mode='sum')
        >>> indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3])
        >>> offsets = torch.tensor([0, 19, 20, 28, 28, 32])
        >>> output = m(indices, offsets)
        >>> print(output.size())
        torch.Size([5, 12])

    """
    _version = 1

    def __init__(self, num_embeddings: int, embedding_dim: int,
                 max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
                 mode: str = 'sum', sparse: bool = False, _weight: Optional[Tensor] = None,
                 include_last_offset: bool = False, dtype=torch.quint8) -> None:
        super(EmbeddingBag, self).__init__(num_embeddings, embedding_dim, _weight=_weight, dtype=dtype)

        self.mode = mode
        self.pruned_weights = False
        self.include_last_offset = include_last_offset
        self.dtype = dtype

    def forward(self, indices: Tensor, offsets: Optional[Tensor] = None, per_sample_weights: Optional[Tensor] = None,
                compressed_indices_mapping: Optional[Tensor] = None) -> Tensor:
        if self.dtype == torch.quint4x2:
            return torch.ops.quantized.embedding_bag_4bit(self._packed_params._packed_weight, indices, offsets, False, 0,
                                                          self.pruned_weights, per_sample_weights, compressed_indices_mapping,
                                                          self.include_last_offset)
        else:
            return torch.ops.quantized.embedding_bag_byte(self._packed_params._packed_weight, indices, offsets, False, 0,
                                                          self.pruned_weights, per_sample_weights, compressed_indices_mapping,
                                                          self.include_last_offset)

    def _get_name(self):
        return 'QuantizedEmbeddingBag'

    @classmethod
    def from_float(cls, mod):
        r"""Create a quantized embedding_bag module from a float module

        Args:
            mod (Module): a float module, either produced by torch.ao.quantization
                          utilities or provided by user
        """
        if hasattr(mod, 'weight_fake_quant'):
            weight_observer = mod.weight_fake_quant
        else:
            assert type(mod) == nn.EmbeddingBag, 'nnq.' + cls.__name__ + '.from_float only works for ' + \
                nn.EmbeddingBag.__name__
            assert hasattr(mod, 'qconfig'), 'EmbeddingBag input float module must have qconfig defined'
            from torch.ao.quantization.qconfig import float_qparams_weight_only_qconfig
            if mod.qconfig is not None and mod.qconfig.weight is not None:  # type: ignore[union-attr]
                weight_observer = mod.qconfig.weight()  # type: ignore[union-attr, operator]
            else:
                weight_observer = float_qparams_weight_only_qconfig.weight()

        dtype = weight_observer.dtype
        is_float_qparams_qconfig = weight_observer.qscheme == torch.per_channel_affine_float_qparams
        assert is_float_qparams_qconfig, \
            'EmbeddingBag quantization is only supported with float_qparams_weight_only_qconfig.'

        assert dtype == torch.quint8 or dtype == torch.quint4x2, \
            f'The only supported dtype for nnq.EmbeddingBag is torch.quint8 and torch.quint4x2, got {dtype}'

        # Run the observer to calculate qparams.
        weight_observer(mod.weight)
        qweight = _quantize_weight(mod.weight.float(), weight_observer)

        # Create quantized EmbeddingBag module and pass in the quantized weight
        qembedding_bag = EmbeddingBag(mod.num_embeddings, mod.embedding_dim, dtype=dtype)
        qembedding_bag.set_weight(qweight)
        return qembedding_bag

    @classmethod
    def from_reference(cls, ref_embedding_bag):
        qembedding_bag = cls(
            ref_embedding_bag.num_embeddings,
            ref_embedding_bag.embedding_dim,
            ref_embedding_bag.max_norm,
            ref_embedding_bag.norm_type,
            ref_embedding_bag.scale_grad_by_freq,
            ref_embedding_bag.mode,
            ref_embedding_bag.sparse,
            ref_embedding_bag.get_quantized_weight(),
            ref_embedding_bag.include_last_offset,
            ref_embedding_bag.weight_dtype,
        )
        return qembedding_bag