1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218
|
from .fx import Fuser # noqa: F401
from .fx import Quantizer # noqa: F401
from torch._fx import GraphModule # type: ignore
from .fx.utils import graph_pretty_str # noqa: F401
def _check_is_graph_module(model):
if not isinstance(model, GraphModule):
raise ValueError(
'input model must be a GraphModule, ' +
'please run torch._fx.symbolic_trace on your model before using ' +
'quantize_fx. Got type:' + str(type(model)))
def fuse_fx(graph_module, inplace=False):
r""" Fuse modules in preparation for quantization
Args:
graph_module: GraphModule object from symbolic tracing (torch._fx.symbolic_trace)
"""
_check_is_graph_module(graph_module)
fuser = Fuser()
return fuser.fuse(graph_module, inplace)
def _prepare_fx(graph_module, qconfig_dict, inplace, is_dynamic_quant):
_check_is_graph_module(graph_module)
graph_module = fuse_fx(graph_module, inplace)
quantizer = Quantizer()
prepare = quantizer.prepare_dynamic if is_dynamic_quant else quantizer.prepare
prepared = prepare(graph_module, qconfig_dict, inplace=True)
return prepared
def prepare_fx(graph_module, qconfig_dict, inplace=False):
r""" Prepare a model for post training static quantization or
qantization aware training, not for public use.
Args:
graph_module: model from symbolic_tracing (torch._fx.symbolic_trace), must be
an eval model
qconfig_dict: see :func:`~torch.quantization.quantize_fx`
Return:
A GraphModule with observer or fake quant modules, ready for
calibration or quantization aware training
"""
return _prepare_fx(graph_module, qconfig_dict, inplace, is_dynamic_quant=False)
def prepare_static_fx(graph_module, qconfig_dict, inplace=False):
assert not graph_module.training, 'prepare_static_fx only works for models in ' + \
'eval mode'
return prepare_fx(graph_module, qconfig_dict, inplace)
def prepare_qat_fx(graph_module, qconfig_dict, inplace=False):
r""" Prepare a model for quantization aware training
Args:
graph_module: model from symbolic_tracing (torch._fx.symbolic_trace), must be
a train model
qconfig_dict: see :func:`~torch.quantization.quantize_fx`
Return:
A GraphModule with observer or fake quant modules, ready for
calibration or quantization aware training
"""
assert graph_module.training, 'prepare_qat_fx only works for models in ' + \
'train mode'
return prepare_fx(graph_module, qconfig_dict, inplace)
def prepare_dynamic_fx(graph_module, qconfig_dict, inplace=False):
r""" Prepare a model for post training dynamic quantization
"""
return _prepare_fx(graph_module, qconfig_dict, inplace, True)
def _convert_fx(graph_module, inplace, debug, is_dynamic_quant):
_check_is_graph_module(graph_module)
quantizer = Quantizer()
return quantizer.convert(graph_module, inplace, debug, is_dynamic_quant)
def convert_fx(graph_module, inplace=False, debug=False):
r""" Convert a calibrated or trained model to a quantized model
"""
return _convert_fx(graph_module, inplace, debug, is_dynamic_quant=False)
convert_static_fx = convert_fx
convert_qat_fx = convert_fx
def convert_dynamic_fx(graph_module, inplace=False, debug=False):
return _convert_fx(graph_module, inplace, debug, is_dynamic_quant=True)
def _quantize_fx(model, qconfig_dict, run_fn=None, run_args=None, inplace=False,
debug=False, is_dynamic_quant=False):
assert not model.training, 'quantize_fx is only used for post training ' + \
'quantization(eval mode), for quantization aware training please use ' + \
'prepare_qat_fx and convert_qat_fx.'
if is_dynamic_quant:
model = prepare_dynamic_fx(model, qconfig_dict, inplace)
# inplace is True since the inplace option is already applied in previous step
model = convert_dynamic_fx(model, inplace=True, debug=debug)
else:
assert run_fn, "Must provide calibration function for post training static quantization"
assert run_args, "Must provide calibration dataset for post training static quantization"
model = prepare_fx(model, qconfig_dict, inplace)
run_fn(model, *run_args)
# inplace is True since the inplace option is already applied in previous step
model = convert_fx(model, inplace=True, debug=debug)
return model
def quantize_static_fx(model, qconfig_dict, run_fn, run_args, inplace=False, debug=False):
r"""Quantize the input float symbolically traced GraphModule model with
post training static quantization
First it will prepare the model for calibration, then it calls
`run_fn` which will run the calibration step, after that we will
convert the model to a quantized model.
Args:
`model`: input float TorchScript model
`qconfig_dict`: qconfig_dict is a dictionary with the following configurations:
qconfig_dict = {
# optional, global config
"": qconfig?,
# optional, used for module and function types
# could also be split into module_types and function_types if we prefer
"object_type": [
(torch.nn.Conv2d, qconfig?),
(torch.nn.functional.add, qconfig?),
...,
],
# optional, used for module names
"module_name": [
("foo.bar", qconfig?)
...,
],
# optional, matched in order, first match takes precedence
"module_name_regex": [
("foo.*bar.*conv[0-9]+", qconfig?)
...,
]
# priority (in increasing order): global, object_type, module_name_regex, module_name
# qconfig == None means fusion and quantization should be skipped for anything
# matching the rule
}
`run_fn`: a calibration function for calibrating the prepared model
`run_args`: positional arguments for `run_fn`
`inplace`: carry out model transformations in-place, the original module is
mutated
`debug`: flag for producing a debug friendly model (preserve weight attribute)
Return:
Quantized TorchSciprt model.
Example:
```python
import torch
from torch.quantization import get_default_qconfig
from torch.quantization import quantize_fx
graph_module = torch._fx.symbolic_trace(float_model.eval())
qconfig = get_default_qconfig('fbgemm')
def calibrate(model, data_loader):
model.eval()
with torch.no_grad():
for image, target in data_loader:
model(image)
quantized_model = quantize_fx(
graph_module,
{'': qconfig},
calibrate,
[data_loader_test])
```
"""
return _quantize_fx(
model, qconfig_dict, run_fn, run_args, inplace, debug, is_dynamic_quant=False)
def quantize_dynamic_fx(model, qconfig_dict, inplace=False, debug=False):
r"""Quantize the input float symbolically traced GraphModule model with
post training dynamic quantization.
Currently only qint8 quantization of torch.nn.Linear is supported.
Args:
`model`: input float TorchScript model
`qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and
qconfig for that module as value, please see detailed
descriptions in :func:`~torch.quantization.quantize_fx`
`inplace`: carry out model transformations in-place, the original module is
mutated
`debug`: flag for producing a debug friendly model (preserve weight attribute)
Return:
Quantized TorchSciprt model.
Example:
```python
import torch
from torch.quantization import per_channel_dynamic_qconfig
from torch.quantization import quantize_dynmiac_fx
graph_module = torch._fx.symbolic_trace(float_model.eval())
qconfig = get_default_qconfig('fbgemm')
def calibrate(model, data_loader):
model.eval()
with torch.no_grad():
for image, target in data_loader:
model(image)
quantized_model = quantize_dynamic_fx(
graph_module,
{'': qconfig},
calibrate,
[data_loader_test])
```
"""
return _quantize_fx(
model, qconfig_dict, inplace=inplace, debug=debug, is_dynamic_quant=True)
|