1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220
|
# mypy: allow-untyped-defs
from typing import Callable, List, Optional, Tuple, Union
import torch
from torch import Tensor
from .fake_quantize import * # noqa: F403
from .fuse_modules import fuse_modules, fuse_modules_qat # noqa: F403
from .fuser_method_mappings import * # noqa: F403
from .observer import * # noqa: F403
from .pt2e._numeric_debugger import ( # noqa: F401
compare_results,
CUSTOM_KEY,
extract_results_from_loggers,
generate_numeric_debug_handle,
NUMERIC_DEBUG_HANDLE_KEY,
prepare_for_propagation_comparison,
)
from .pt2e.export_utils import (
_allow_exported_model_train_eval as allow_exported_model_train_eval,
_move_exported_model_to_eval as move_exported_model_to_eval,
_move_exported_model_to_train as move_exported_model_to_train,
)
from .qconfig import * # noqa: F403
from .qconfig_mapping import * # noqa: F403
from .quant_type import * # noqa: F403
from .quantization_mappings import * # noqa: F403 # type: ignore[no-redef]
from .quantize import * # noqa: F403
from .quantize_jit import * # noqa: F403
from .stubs import * # noqa: F403
# ensure __module__ is set correctly for public APIs
ObserverOrFakeQuantize = Union[ObserverBase, FakeQuantizeBase]
ObserverOrFakeQuantize.__module__ = "torch.ao.quantization"
for _f in [
compare_results,
extract_results_from_loggers,
generate_numeric_debug_handle,
prepare_for_propagation_comparison,
]:
_f.__module__ = "torch.ao.quantization"
__all__ = [
"DeQuantStub",
"FakeQuantize",
"FakeQuantizeBase",
"FixedQParamsFakeQuantize",
"FixedQParamsObserver",
"FusedMovingAvgObsFakeQuantize",
"HistogramObserver",
"MatchAllNode",
"MinMaxObserver",
"MovingAverageMinMaxObserver",
"MovingAveragePerChannelMinMaxObserver",
"NoopObserver",
"ObserverBase",
"ObserverOrFakeQuantize",
"Pattern",
"PerChannelMinMaxObserver",
"PlaceholderObserver",
"QConfig",
"QConfigAny",
"QConfigDynamic",
"QConfigMapping",
"QuantStub",
"QuantType",
"QuantWrapper",
"RecordingObserver",
"ReuseInputObserver",
"UniformQuantizationObserverBase",
"add_quant_dequant",
"convert",
"convert_dynamic_jit",
"convert_jit",
"default_affine_fixed_qparams_fake_quant",
"default_affine_fixed_qparams_observer",
"default_debug_observer",
"default_dynamic_fake_quant",
"default_dynamic_quant_observer",
"default_embedding_fake_quant",
"default_embedding_fake_quant_4bit",
"default_eval_fn",
"default_fake_quant",
"default_fixed_qparams_range_0to1_fake_quant",
"default_fixed_qparams_range_0to1_observer",
"default_fixed_qparams_range_neg1to1_fake_quant",
"default_fixed_qparams_range_neg1to1_observer",
"default_float_qparams_observer",
"default_float_qparams_observer_4bit",
"default_fused_act_fake_quant",
"default_fused_per_channel_wt_fake_quant",
"default_fused_wt_fake_quant",
"default_histogram_fake_quant",
"default_histogram_observer",
"default_observer",
"default_per_channel_weight_fake_quant",
"default_per_channel_weight_observer",
"default_placeholder_observer",
"default_reuse_input_observer",
"default_symmetric_fixed_qparams_fake_quant",
"default_symmetric_fixed_qparams_observer",
"default_weight_fake_quant",
"default_weight_observer",
"disable_fake_quant",
"disable_observer",
"enable_fake_quant",
"enable_observer",
"fuse_conv_bn",
"fuse_conv_bn_jit",
"fuse_conv_bn_relu",
"fuse_convtranspose_bn",
"fuse_linear_bn",
"fuse_modules",
"fuse_modules_qat",
"fused_per_channel_wt_fake_quant_range_neg_127_to_127",
"fused_wt_fake_quant_range_neg_127_to_127",
"get_combined_dict",
"get_default_compare_output_module_list",
"get_default_custom_config_dict",
"get_default_dynamic_quant_module_mappings",
"get_default_dynamic_sparse_quant_module_mappings",
"get_default_float_to_quantized_operator_mappings",
"get_default_qat_module_mappings",
"get_default_qat_qconfig",
"get_default_qat_qconfig_dict",
"get_default_qat_qconfig_mapping",
"get_default_qconfig",
"get_default_qconfig_dict",
"get_default_qconfig_mapping",
"get_default_qconfig_propagation_list",
"get_default_static_quant_module_mappings",
"get_default_static_quant_reference_module_mappings",
"get_default_static_sparse_quant_module_mappings",
"get_dynamic_quant_module_class",
"get_embedding_qat_module_mappings",
"get_embedding_static_quant_module_mappings",
"get_fuser_method",
"get_fuser_method_new",
"get_observer_state_dict",
"get_quantized_operator",
"get_static_quant_module_class",
"load_observer_state_dict",
"move_exported_model_to_eval",
"move_exported_model_to_train",
"allow_exported_model_train_eval",
"no_observer_set",
"per_channel_weight_observer_range_neg_127_to_127",
"prepare",
"prepare_dynamic_jit",
"prepare_jit",
"prepare_qat",
"propagate_qconfig_",
"qconfig_equals",
"quantize",
"quantize_dynamic",
"quantize_dynamic_jit",
"quantize_jit",
"quantize_qat",
"script_qconfig",
"script_qconfig_dict",
"swap_module",
"weight_observer_range_neg_127_to_127",
"generate_numeric_debug_handle",
"CUSTOM_KEY",
"NUMERIC_DEBUG_HANDLE_KEY",
"prepare_for_propagation_comparison",
"extract_results_from_loggers",
"compare_results",
]
def default_eval_fn(model, calib_data):
r"""Define the default evaluation function.
Default evaluation function takes a torch.utils.data.Dataset or a list of
input Tensors and run the model on the dataset
"""
for data, target in calib_data:
model(data)
class _DerivedObserverOrFakeQuantize(ObserverBase):
r"""This observer is used to describe an observer whose quantization parameters
are derived from other observers
"""
def __init__(
self,
dtype: torch.dtype,
obs_or_fqs: List[ObserverOrFakeQuantize],
derive_qparams_fn: Callable[
[List[ObserverOrFakeQuantize]], Tuple[Tensor, Tensor]
],
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
qscheme: Optional[torch.qscheme] = None,
ch_axis: Optional[int] = None,
):
super().__init__(dtype)
self.obs_or_fqs = obs_or_fqs
self.derive_qparams_fn = derive_qparams_fn
self.quant_min = quant_min
self.quant_max = quant_max
self.qscheme = qscheme
self.ch_axis = ch_axis
from .utils import is_per_channel
if is_per_channel(self.qscheme):
assert (
self.ch_axis is not None
), "Must provide a valid ch_axis if qscheme is per channel"
def forward(self, x: Tensor) -> Tensor:
return x
def calculate_qparams(self): # type:ignore[override]
return self.derive_qparams_fn(self.obs_or_fqs)
|