1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
|
# mypy: allow-untyped-defs
import functools
from typing import Any, Dict, Optional, TYPE_CHECKING
import torch
from torch.ao.quantization.observer import HistogramObserver, PerChannelMinMaxObserver
from torch.ao.quantization.quantizer.quantizer import QuantizationSpec
from torch.ao.quantization.quantizer.x86_inductor_quantizer import (
_is_any_annotated,
FilterFn,
int8_in_int8_out_ops,
X86InductorQuantizer,
)
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import QuantizationConfig
from torch.fx import Node
if TYPE_CHECKING:
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
__all__ = [
"XPUInductorQuantizer",
"get_default_xpu_inductor_quantization_config",
]
@functools.lru_cache
def get_default_xpu_inductor_quantization_config():
extra_args: Dict[str, Any] = {"eps": 2**-12}
act_observer_or_fake_quant_ctr = HistogramObserver
act_quantization_spec = QuantizationSpec(
dtype=torch.int8,
quant_min=-128,
quant_max=127,
qscheme=torch.per_tensor_affine,
is_dynamic=False,
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
**extra_args
),
)
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
PerChannelMinMaxObserver
)
weight_quantization_spec = QuantizationSpec(
dtype=torch.int8,
quant_min=-128,
quant_max=127,
qscheme=torch.per_channel_symmetric,
ch_axis=0, # 0 corresponding to weight shape = (oc, ic, kh, kw) of conv
is_dynamic=False,
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
**extra_args
),
)
bias_quantization_spec = None # will use placeholder observer by default
quantization_config = QuantizationConfig(
act_quantization_spec,
act_quantization_spec,
weight_quantization_spec,
bias_quantization_spec,
False,
)
return quantization_config
class XPUInductorQuantizer(X86InductorQuantizer):
"""
XPUInductorQuantizer is a class designed to facilitate
quantization capability at Intel GPU backend. The class
highly reuses the existing implementation of
X86InductorQuantizer as both are intended to take advantage
of the optimized kernels in oneDNN library.
"""
def __init__(self) -> None:
super().__init__()
"""
Following annotate_xx overrides the impls in base class, as
no XPU implementation for these operators currently. We would
gradually enable the XPU implementation and remove following
overrides. We keep the annotate methods but make the function
body empty, aiming to let `_generate_qdq_quantized_model`
generate qdq around op and graph execute on fp32 dtype for
unspported operators.
"""
def _annotate_qat_conv2d_fusion_pattern(
self,
model: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[FilterFn] = None,
):
pass
def _annotate_conv2d_binary(
self,
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[FilterFn] = None,
) -> None:
pass
def _annotate_conv2d_binary_unary(
self,
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[FilterFn] = None,
) -> None:
pass
def _annotate_linear_fusion_pattern(
self,
model: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[FilterFn] = None,
):
pass
def _annotate_matmul(
self,
model: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[FilterFn] = None,
):
pass
def _annotate_maxpool2d(
self,
node: Node,
quantization_config: Optional[QuantizationConfig],
) -> None:
"""
Here we skip the annotate logic for maxpool at XPU backend
as the quantized::max_pool2d is only implemented for CPU.
"""
return
def _annotate_output_for_int8_in_int8_out_pattern(
self,
node: Node,
) -> None:
if (node.target in int8_in_int8_out_ops) and (_is_any_annotated([node])):
if node.target == torch.ops.aten.max_pool2d.default:
return
else:
input_node = node.all_input_nodes[0]
self._annotate_output_share_observer_as_input(input_node, node)
return
|