1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
|
import torch
from torch import nn
import torch.nn.functional as F
import torch.nn.intrinsic as nni
import torch.nn.intrinsic.quantized as nniq
import torch.nn.intrinsic.qat as nniqat
import torch.nn.quantized as nnq
import torch.nn.quantized.dynamic as nnqd
import torch.nn.qat as nnqat
from .stubs import QuantStub, DeQuantStub
# Map for swapping float module to quantized ones
STATIC_QUANT_MODULE_MAPPINGS = {
nn.Linear: nnq.Linear,
nn.ReLU: nnq.ReLU,
nn.ReLU6: nnq.ReLU6,
nn.Hardswish: nnq.Hardswish,
nn.ELU: nnq.ELU,
nn.Conv1d: nnq.Conv1d,
nn.Conv2d: nnq.Conv2d,
nn.Conv3d: nnq.Conv3d,
nn.ConvTranspose1d: nnq.ConvTranspose1d,
nn.ConvTranspose2d: nnq.ConvTranspose2d,
nn.BatchNorm2d: nnq.BatchNorm2d,
nn.BatchNorm3d: nnq.BatchNorm3d,
nn.LayerNorm: nnq.LayerNorm,
nn.GroupNorm: nnq.GroupNorm,
nn.InstanceNorm1d: nnq.InstanceNorm1d,
nn.InstanceNorm2d: nnq.InstanceNorm2d,
nn.InstanceNorm3d: nnq.InstanceNorm3d,
nn.Embedding: nnq.Embedding,
nn.EmbeddingBag: nnq.EmbeddingBag,
QuantStub: nnq.Quantize,
DeQuantStub: nnq.DeQuantize,
# Wrapper Modules:
nnq.FloatFunctional: nnq.QFunctional,
# Intrinsic modules:
nni.ConvReLU1d: nniq.ConvReLU1d,
nni.ConvReLU2d: nniq.ConvReLU2d,
nni.ConvReLU3d: nniq.ConvReLU3d,
nni.LinearReLU: nniq.LinearReLU,
nni.BNReLU2d: nniq.BNReLU2d,
nni.BNReLU3d: nniq.BNReLU3d,
nniqat.ConvReLU2d: nniq.ConvReLU2d,
nniqat.LinearReLU: nniq.LinearReLU,
nniqat.ConvBn2d: nnq.Conv2d,
nniqat.ConvBnReLU2d: nniq.ConvReLU2d,
# QAT modules:
nnqat.Linear: nnq.Linear,
nnqat.Conv2d: nnq.Conv2d,
}
# Map for swapping float module to qat modules
QAT_MODULE_MAPPINGS = {
nn.Linear: nnqat.Linear,
nn.Conv2d: nnqat.Conv2d,
# Intrinsic modules:
nni.ConvBn2d: nniqat.ConvBn2d,
nni.ConvBnReLU2d: nniqat.ConvBnReLU2d,
nni.ConvReLU2d: nniqat.ConvReLU2d,
nni.LinearReLU: nniqat.LinearReLU
}
# Map for swapping dynamic modules
DYNAMIC_QUANT_MODULE_MAPPINGS = {
nn.Linear: nnqd.Linear,
nn.LSTM: nnqd.LSTM,
nn.LSTMCell: nnqd.LSTMCell,
nn.RNNCell: nnqd.RNNCell,
nn.GRUCell: nnqd.GRUCell,
}
# Whitelist for propagating the qconfig
_EXCLUDE_QCONFIG_PROPAGATE_LIST = {
DeQuantStub,
}
_INCLUDE_QCONFIG_PROPAGATE_LIST = {
nn.Sequential,
}
# mapping from floating point function or torch ops to quantized ops
FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS = {
F.elu: torch._ops.ops.quantized.elu,
F.hardswish: torch._ops.ops.quantized.hardswish,
F.instance_norm: torch._ops.ops.quantized.instance_norm,
F.layer_norm: torch._ops.ops.quantized.layer_norm,
}
def register_static_quant_module_mapping(
float_source_module_class, static_quant_target_module_class):
''' Register a mapping from `float_source__module_class` to `static_quant_target_module_class`
`static_quant_target_module_class` must have from_float defined as a class method
The mapping is used in the convert step of post training static quantization to
convert a float module to a statically quantized module.
'''
assert hasattr(static_quant_target_module_class, 'from_float'), 'from_float must be defined' + \
' in quantized module class'
STATIC_QUANT_MODULE_MAPPINGS[float_source_module_class] = static_quant_target_module_class
def get_static_quant_module_mappings():
''' Get module mapping for post training static quantization
'''
return STATIC_QUANT_MODULE_MAPPINGS
def get_static_quant_module_class(float_module_class):
''' Get the statically quantized module class corresponding to
the floating point module class
'''
static_quant_module_class = STATIC_QUANT_MODULE_MAPPINGS.get(float_module_class, None)
assert static_quant_module_class is not None, \
'Floating point module class {}'.format(float_module_class) + \
' does not have a corresponding quantized module class'
return static_quant_module_class
def register_qat_module_mapping(float_source_module_class, qat_target_module_class):
'''Register a mapping from `float_source_module_class` to `qat_target_module_class`,
`qat_target_module_class` must have from_float defined as a class method
This mapping is used in prepare step of quantization aware training to swap
a float module to a qat module.
'''
assert hasattr(qat_target_module_class, 'from_float'), 'from_float must be defined' + \
' in qat module class'
QAT_MODULE_MAPPINGS[float_source_module_class] = qat_target_module_class
def get_qat_module_mappings():
''' Get module mapping for quantization aware training
'''
return QAT_MODULE_MAPPINGS
def register_dynamic_quant_module_class(float_source_module_class, dynamic_quant_target_module_class):
''' Register a mapping from `float_source_module_class` to `dynamic_quant_target_module_class`,
`dynamic_quant_target_module_class` must have from_float defined as a class method
This mapping is used in convert step of post training dynamic
quantization to swap a float module to a dynamically quantized
module.
'''
assert hasattr(dynamic_quant_target_module_class, 'from_float'), 'from_float must be defined' + \
' in dynamically quantized module type'
DYNAMIC_QUANT_MODULE_MAPPINGS[float_source_module_class] = dynamic_quant_target_module_class
def get_dynamic_quant_module_mappings():
''' Get module mapping for post training dynamic quantization
'''
return DYNAMIC_QUANT_MODULE_MAPPINGS
def get_qconfig_propagation_list():
''' Get the list of module types that we'll attach qconfig
attribute to in prepare
'''
QCONFIG_PROPAGATE_MODULE_CLASS_LIST = (
(set(STATIC_QUANT_MODULE_MAPPINGS.keys()) |
set(QAT_MODULE_MAPPINGS.keys()) |
set(DYNAMIC_QUANT_MODULE_MAPPINGS.keys()) |
_INCLUDE_QCONFIG_PROPAGATE_LIST) -
_EXCLUDE_QCONFIG_PROPAGATE_LIST
)
return QCONFIG_PROPAGATE_MODULE_CLASS_LIST
def get_compare_output_module_list():
''' Get list of module class types that we will record output
in numeric suite
'''
NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST = (
set(STATIC_QUANT_MODULE_MAPPINGS.values())
| set(QAT_MODULE_MAPPINGS.values())
| set(DYNAMIC_QUANT_MODULE_MAPPINGS.values())
| set(STATIC_QUANT_MODULE_MAPPINGS.keys())
| set(QAT_MODULE_MAPPINGS.keys())
| set(DYNAMIC_QUANT_MODULE_MAPPINGS.keys())
| _INCLUDE_QCONFIG_PROPAGATE_LIST
) - _EXCLUDE_QCONFIG_PROPAGATE_LIST
return NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST
def register_quantized_operator_mapping(float_op, quantized_op):
''' Register a mapping from `floating_point_op` (torch or functional) to `quantized_op`
This is used in convert step of fx based graph mode quantization
to convert a float op to quantized op.
'''
FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS[float_op] = quantized_op
def get_quantized_operator(float_op):
''' Get the quantized operator corresponding to the float operator
'''
quantized_op = FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op, None)
assert quantized_op is not None, \
'Operator {} does not have corresponding quantized op'.format(float_op)
return quantized_op
|