1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
|
import torch
import torch.nn as nn
from torch.ao.sparsity.sparsifier.utils import module_to_fqn, fqn_to_module
from typing import Dict, List
SUPPORTED_MODULES = {
nn.Embedding,
nn.EmbeddingBag
}
def _fetch_all_embeddings(model):
"""Fetches Embedding and EmbeddingBag modules from the model
"""
embedding_modules = []
stack = [model]
while stack:
module = stack.pop()
for _, child in module.named_children():
fqn_name = module_to_fqn(model, child)
if type(child) in SUPPORTED_MODULES:
embedding_modules.append((fqn_name, child))
else:
stack.append(child)
return embedding_modules
def post_training_sparse_quantize(model,
data_sparsifier_class,
sparsify_first=True,
select_embeddings: List[nn.Module] = None,
**sparse_config):
"""Takes in a model and applies sparsification and quantization to only embeddings & embeddingbags.
The quantization step can happen before or after sparsification depending on the `sparsify_first` argument.
Args:
- model (nn.Module)
model whose embeddings needs to be sparsified
- data_sparsifier_class (type of data sparsifier)
Type of sparsification that needs to be applied to model
- sparsify_first (bool)
if true, sparsifies first and then quantizes
otherwise, quantizes first and then sparsifies.
- select_embeddings (List of Embedding modules)
List of embedding modules to in the model to be sparsified & quantized.
If None, all embedding modules with be sparsified
- sparse_config (Dict)
config that will be passed to the constructor of data sparsifier object.
Note:
1. When `sparsify_first=False`, quantization occurs first followed by sparsification.
- before sparsifying, the embedding layers are dequantized.
- scales and zero-points are saved
- embedding layers are sparsified and `squash_mask` is applied
- embedding weights are requantized using the saved scales and zero-points
2. When `sparsify_first=True`, sparsification occurs first followed by quantization.
- embeddings are sparsified first
- quantization is applied on the sparsified embeddings
"""
data_sparsifier = data_sparsifier_class(**sparse_config)
# if select_embeddings is None, perform it on all embeddings
if select_embeddings is None:
embedding_modules = _fetch_all_embeddings(model)
else:
embedding_modules = []
assert isinstance(select_embeddings, List), "the embedding_modules must be a list of embedding modules"
for emb in select_embeddings:
assert type(emb) in SUPPORTED_MODULES, "the embedding_modules list must be an embedding or embedding bags"
fqn_name = module_to_fqn(model, emb)
assert fqn_name is not None, "the embedding modules must be part of input model"
embedding_modules.append((fqn_name, emb))
if sparsify_first:
# sparsify
for name, emb_module in embedding_modules:
valid_name = name.replace('.', '_')
data_sparsifier.add_data(name=valid_name, data=emb_module)
data_sparsifier.step()
data_sparsifier.squash_mask()
# quantize
for _, emb_module in embedding_modules:
emb_module.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig
torch.quantization.prepare(model, inplace=True)
torch.quantization.convert(model, inplace=True)
else:
# quantize
for _, emb_module in embedding_modules:
emb_module.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig
torch.quantization.prepare(model, inplace=True)
torch.quantization.convert(model, inplace=True)
# retrieve scale & zero_points
quantize_params: Dict[str, Dict] = {'scales': {}, 'zero_points': {},
'dequant_weights': {}, 'axis': {},
'dtype': {}}
for name, _ in embedding_modules:
quantized_emb = fqn_to_module(model, name)
assert quantized_emb is not None # satisfy mypy
quantized_weight = quantized_emb.weight() # type: ignore[operator]
quantize_params['scales'][name] = quantized_weight.q_per_channel_scales()
quantize_params['zero_points'][name] = quantized_weight.q_per_channel_zero_points()
quantize_params['dequant_weights'][name] = torch.dequantize(quantized_weight)
quantize_params['axis'][name] = quantized_weight.q_per_channel_axis()
quantize_params['dtype'][name] = quantized_weight.dtype
# attach data to sparsifier
data_sparsifier.add_data(name=name.replace('.', '_'), data=quantize_params['dequant_weights'][name])
data_sparsifier.step()
data_sparsifier.squash_mask()
for name, _ in embedding_modules:
quantized_emb = fqn_to_module(model, name)
assert quantized_emb is not None # satisfy mypy
requantized_vector = torch.quantize_per_channel(quantize_params['dequant_weights'][name],
scales=quantize_params['scales'][name],
zero_points=quantize_params['zero_points'][name],
dtype=quantize_params['dtype'][name],
axis=quantize_params['axis'][name])
quantized_emb.set_weight(requantized_vector) # type: ignore[operator]
|