1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
|
from __future__ import annotations
from typing import Dict, List, TYPE_CHECKING
from .quantizer import QuantizationAnnotation, Quantizer
if TYPE_CHECKING:
import torch
from torch.fx import Node
__all__ = [
"ComposableQuantizer",
]
class ComposableQuantizer(Quantizer):
"""
ComposableQuantizer allows users to combine more than one quantizer into a single quantizer.
This allows users to quantize a model with multiple quantizers. E.g., embedding quantization
maybe supported by one quantizer while linear layers and other ops might be supported by another
quantizer.
ComposableQuantizer is initialized with a list of `Quantizer` instances.
The order of the composition matters since that is the order in which the quantizers will be
applies.
Example:
```
embedding_quantizer = EmbeddingQuantizer()
linear_quantizer = MyLinearQuantizer()
xnnpack_quantizer = XNNPackQuantizer() # to handle ops not quantized by previous two quantizers
composed_quantizer = ComposableQuantizer([embedding_quantizer, linear_quantizer, xnnpack_quantizer])
prepared_m = prepare_pt2e(model, composed_quantizer)
```
"""
def __init__(self, quantizers: List[Quantizer]):
super().__init__()
self.quantizers = quantizers
self._graph_annotations: Dict[Node, QuantizationAnnotation] = {}
def _record_and_validate_annotations(
self, gm: torch.fx.GraphModule, quantizer: Quantizer
) -> None:
for n in gm.graph.nodes:
if "quantization_annotation" in n.meta:
# check if the annotation has been changed by
# comparing QuantizationAnnotation object id
if n in self._graph_annotations and (
id(self._graph_annotations[n])
!= id(n.meta["quantization_annotation"])
):
raise RuntimeError(
f"Quantizer {quantizer.__class__.__name__} has changed annotations on node {n}"
)
else:
self._graph_annotations[n] = n.meta["quantization_annotation"]
else:
if n in self._graph_annotations:
raise RuntimeError(
f"Quantizer {quantizer.__class__.__name__} has removed annotations on node {n}"
)
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""just handling global spec for now"""
for quantizer in self.quantizers:
quantizer.annotate(model)
self._record_and_validate_annotations(model, quantizer)
return model
def transform_for_annotation(
self, model: torch.fx.GraphModule
) -> torch.fx.GraphModule:
for quantizer in self.quantizers:
model = quantizer.transform_for_annotation(model)
return model
def validate(self, model: torch.fx.GraphModule) -> None:
pass
|