1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
|
import torch
from torch.ao.quantization.fx.pattern_utils import get_default_quant_patterns, sorted_patterns_dict
from torch.ao.quantization.backend_config import (
get_native_backend_config,
ObservationType,
)
from torch.ao.quantization.quantization_types import (
Pattern,
NodePattern,
QuantizerCls,
)
from torch.ao.quantization.utils import (
activation_dtype,
get_combined_dict,
)
from ..backend_config import BackendConfig
from .quantization_patterns import QuantizeHandler
from .fusion_patterns import DefaultFuseHandler
from typing import Dict, Any, Callable, Optional
def get_quantize_handler_cls(
observation_type,
dtype_configs,
num_tensor_args_to_observation_type,
overwrite_output_fake_quantizer,
overwrite_output_observer,
input_output_observed):
class ConfigurableQuantizeHandler(QuantizeHandler):
def __init__(
self,
node_pattern: NodePattern,
modules: Dict[str, torch.nn.Module],
root_node_getter: Callable = None):
super().__init__(node_pattern, modules, root_node_getter)
if num_tensor_args_to_observation_type:
assert self.num_tensor_args in num_tensor_args_to_observation_type, \
f"Must provide observation_type config for tensor number {self.num_tensor_args}" \
f" in num_tensor_args_to_observation_type for {node_pattern}"
self.observation_type = num_tensor_args_to_observation_type[self.num_tensor_args]
else:
self.observation_type = observation_type
self.dtype_configs = dtype_configs
self.overwrite_output_fake_quantizer = overwrite_output_fake_quantizer
self.overwrite_output_observer = overwrite_output_observer
self.input_output_observed_ = input_output_observed
def is_general_tensor_value_op(self) -> bool:
return self.observation_type == ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT
# TODO: change this to output activation
def get_activation_ctr(
self,
qconfig: Any,
pattern: Pattern,
is_training: bool,
) -> Optional[Callable]:
"""
Returns the constructor for the activation observer which should be
used for the pattern matched to this handler. Some handlers override
this to a different value than what is specified in the qconfig.
"""
act_dtype = activation_dtype(qconfig)
# TODO: change to is_qat
if is_training:
if act_dtype == torch.quint8 and self.overwrite_output_fake_quantizer is not None:
return self.overwrite_output_fake_quantizer
else:
if act_dtype == torch.quint8 and self.overwrite_output_observer is not None:
return self.overwrite_output_observer
return qconfig.activation
# This is temporary, and will be removed soon
def input_output_observed(self):
return self.input_output_observed_
return ConfigurableQuantizeHandler
def get_pattern_to_quantize_handlers(backend_config: BackendConfig) -> Dict[Pattern, QuantizerCls]:
"""
Note: Quantize handler is just a holder for some check methods like
(should_insert_observer_for_output), maybe this can be a enum as well,
we can refactor this after we convert the path for fbgemm/qnnpack fully to the
new path, this is not exposed to backend developers
"""
pattern_to_quantize_handlers = {}
for pattern, config in backend_config.configs.items():
observation_type = config.observation_type
dtype_configs = config.dtype_configs
num_tensor_args_to_observation_type = config._num_tensor_args_to_observation_type
overwrite_fake_quantizer = config._overwrite_output_fake_quantize
overwrite_observer = config._overwrite_output_observer
input_output_observed = config._input_output_observed
if input_output_observed is None:
input_output_observed = True
pattern_to_quantize_handlers[pattern] = \
get_quantize_handler_cls(
observation_type,
dtype_configs,
num_tensor_args_to_observation_type,
overwrite_fake_quantizer,
overwrite_observer,
input_output_observed)
return pattern_to_quantize_handlers
# TODO: move this to torch/ao/quantization/backend_config/utils.py
def get_fusion_pattern_to_fuse_handler_cls(
backend_config: BackendConfig) -> Dict[Pattern, Callable]:
fusion_pattern_to_fuse_handlers: Dict[Pattern, Callable] = {}
for pattern, config in backend_config.configs.items():
if config.fuser_method is not None:
# TODO: is this logic right?
fusion_pattern_to_fuse_handlers[pattern] = DefaultFuseHandler
return fusion_pattern_to_fuse_handlers
# TODO: remove when all uses are changed to backend_config
def get_native_quant_patterns(additional_quant_patterns: Dict[Pattern, QuantizerCls] = None) -> Dict[Pattern, QuantizerCls]:
"""
Return a map from pattern to quantize handlers based on the default patterns and the native backend_config.
The returned map is sorted such that longer patterns will be encountered first when iterating through it.
"""
patterns = get_default_quant_patterns()
if additional_quant_patterns is not None:
patterns = get_combined_dict(patterns, additional_quant_patterns)
# TODO: currently we just extend the quantize handlers generated from
# `get_native_backend_config`
# in the future we can just assign backend_config when everything is defined
for pattern, quantize_handler in get_pattern_to_quantize_handlers(get_native_backend_config()).items():
patterns[pattern] = quantize_handler
return sorted_patterns_dict(patterns)
get_fusion_pattern_to_fuse_handler_cls.__module__ = "torch.ao.quantization.fx.backend_config_utils"
get_native_quant_patterns.__module__ = "torch.ao.quantization.fx.backend_config_utils"
get_pattern_to_quantize_handlers.__module__ = "torch.ao.quantization.fx.backend_config_utils"
__all__ = [
"get_fusion_pattern_to_fuse_handler_cls",
"get_native_quant_patterns",
"get_pattern_to_quantize_handlers",
]
|