File: composable_quantizer.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (79 lines) | stat: -rw-r--r-- 2,993 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from __future__ import annotations

from typing import Dict, List, TYPE_CHECKING

from .quantizer import QuantizationAnnotation, Quantizer


if TYPE_CHECKING:
    import torch
    from torch.fx import Node

__all__ = [
    "ComposableQuantizer",
]


class ComposableQuantizer(Quantizer):
    """
    ComposableQuantizer allows users to combine more than one quantizer into a single quantizer.
    This allows users to quantize a model with multiple quantizers. E.g., embedding quantization
    maybe supported by one quantizer while linear layers and other ops might be supported by another
    quantizer.

    ComposableQuantizer is initialized with a list of `Quantizer` instances.
    The order of the composition matters since that is the order in which the quantizers will be
    applies.
    Example:
    ```
    embedding_quantizer = EmbeddingQuantizer()
    linear_quantizer = MyLinearQuantizer()
    xnnpack_quantizer = XNNPackQuantizer() # to handle ops not quantized by previous two quantizers
    composed_quantizer = ComposableQuantizer([embedding_quantizer, linear_quantizer, xnnpack_quantizer])
    prepared_m = prepare_pt2e(model, composed_quantizer)
    ```
    """

    def __init__(self, quantizers: List[Quantizer]):
        super().__init__()
        self.quantizers = quantizers
        self._graph_annotations: Dict[Node, QuantizationAnnotation] = {}

    def _record_and_validate_annotations(
        self, gm: torch.fx.GraphModule, quantizer: Quantizer
    ) -> None:
        for n in gm.graph.nodes:
            if "quantization_annotation" in n.meta:
                # check if the annotation has been changed by
                # comparing QuantizationAnnotation object id
                if n in self._graph_annotations and (
                    id(self._graph_annotations[n])
                    != id(n.meta["quantization_annotation"])
                ):
                    raise RuntimeError(
                        f"Quantizer {quantizer.__class__.__name__} has changed annotations on node {n}"
                    )
                else:
                    self._graph_annotations[n] = n.meta["quantization_annotation"]
            else:
                if n in self._graph_annotations:
                    raise RuntimeError(
                        f"Quantizer {quantizer.__class__.__name__} has removed annotations on node {n}"
                    )

    def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
        """just handling global spec for now"""
        for quantizer in self.quantizers:
            quantizer.annotate(model)
            self._record_and_validate_annotations(model, quantizer)
        return model

    def transform_for_annotation(
        self, model: torch.fx.GraphModule
    ) -> torch.fx.GraphModule:
        for quantizer in self.quantizers:
            model = quantizer.transform_for_annotation(model)
        return model

    def validate(self, model: torch.fx.GraphModule) -> None:
        pass