File: README.md

package info (click to toggle)
pytorch 2.6.0%2Bdfsg-8
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 161,672 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (189 lines) | stat: -rw-r--r-- 12,117 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
## BackendConfig Overview

BackendConfig allows PyTorch quantization to work with different backend or kernel libraries. These backends may have different sets of supported quantized operator patterns, and the same operator patterns may require different handling across different backends. To make quantization work with different backends and allow maximum flexibility, we strived to make all the parts of the quantization flow configurable with BackendConfig. Currently, it is only used by FX graph mode quantization. For more details on how it integrates with the FX graph mode quantization flow, refer to this [README](/torch/ao/quantization/fx/README.md).

BackendConfig configures quantization behavior in terms of operator patterns. For each operator pattern, we need to specify what the supported data types are for the input and output activations, weights, and biases, and also specify the QAT modules, the reference quantized modules etc., which will be used in module swapping during the quantization passes.

Quantized backends can have different support in terms of the following aspects:
* Quantization scheme (symmetric vs asymmetric, per-channel vs per-tensor)
* Data type (float32, float16, int8, uint8, bfloat16, etc.) for input/output/weight/bias
* Quantized (and fused) mapping: Some quantized operators may have different numerics compared to a naive (dequant - float_op - quant) reference implementation. For weighted operators, such as conv and linear, we need to be able to specify custom reference modules and a mapping from the float modules
* QAT mapping: For weighted operators, we need to swap them with the Quantization Aware Training (QAT) versions that add fake quantization to the weights

As an example, here is what fbgemm looks like:
|                                           | fbgemm                                                                |
|-------------------------------------------|-----------------------------------------------------------------------|
| Quantization Scheme                       | activation: per tensor, weight: per tensor or per channel             |
| Data Type                                 | activation: quint8 (with qmin/qmax range restrictions), weight: qint8 |
| Quantized and Fused Operators and Mapping | e.g. torch.nn.Conv2d -> torch.ao.nn.quantized.reference.Conv2d        |
| QAT Module Mapping                        | e.g. torch.nn.Conv2d -> torch.ao.nn.qat.Conv2d                        |

Instead of hardcoding the fusion mappings, float to reference quantized module mappings, fusion patterns etc., we will derive everything from the BackendConfig throughout the code base. This allows PyTorch Quantization to work with all first-party (fbgemm and qnnpack) and third-party backends (TensorRT, executorch etc.) that may differ from native backends in different aspects. With the recent addition of xnnpack, integrated as part of the qnnpack backend in PyTorch, the BackendConfig is needed to define the new constraints required for xnnpack quantized operators.

## Pattern Specification

The operator patterns used in BackendConfig are float modules, functional operators, pytorch operators, or a tuple combination of the above. For example:
* torch.nn.Linear
* torch.nn.functional.linear
* torch.add
* operator.add
* (torch.nn.functional.linear, torch.nn.functional.relu)
* (torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU)

Tuple patterns are treated as sequential patterns, and currently only tuples of 2 or 3 elements are supported.

### Advanced Pattern Specification

The above format should satisfy the vast majority of use cases. However, it does not handle more complex scenarios such as graph patterns. For these use cases, the BackendConfig API offers an alternative "reverse nested tuple" pattern format, enabled through `BackendPatternConfig()._set_pattern_complex_format(...)`. Note that this format is deprecated and will be replaced in a future version of PyTorch.
```
operator = module_type | functional | torch op | native op | MatchAllNode
Pattern = (operator, Pattern, Pattern, ...) | operator
```
where the first item for each Pattern is the operator, and the rest are the patterns for the arguments of the operator.
For example, the pattern (nn.ReLU, (operator.add, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d))) would match the following graph:
```
tensor_1            tensor_2
 |                    |
 *(MatchAllNode)  nn.Conv2d
 |                    |
 |             nn.BatchNorm2d
 \                  /
  -- operator.add --
         |
      nn.ReLU
```

During prepare and convert, we’ll match the last node, which will be the anchor point of the match, and we can retrieve the whole graph by tracing back from the node. E.g. in the example above, we matched the `nn.ReLU` node, and `node.args[0]` is the `operator.add` node.

## BackendConfig Implementation

The BackendConfig is comprised of a list of BackendPatternConfigs, each of which define the specifications and the requirements for an operator pattern. Here is an example usage:

```
import torch
from torch.ao.quantization.backend_config import (
    BackendConfig,
    BackendPatternConfig,
    DTypeConfig,
    ObservationType,
)

weighted_int8_dtype_config = DTypeConfig(
    input_dtype=torch.quint8,
    output_dtype=torch.quint8,
    weight_dtype=torch.qint8,
    bias_dtype=torch.float)

def fuse_conv2d_relu(is_qat, conv, relu):
    """Return a fused ConvReLU2d from individual conv and relu modules."""
    return torch.ao.nn.intrinsic.ConvReLU2d(conv, relu)

# For quantizing Linear
linear_config = BackendPatternConfig(torch.nn.Linear) \
    .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
    .add_dtype_config(weighted_int8_dtype_config) \
    .set_root_module(torch.nn.Linear) \
    .set_qat_module(torch.ao.nn.qat.Linear) \
    .set_reference_quantized_module(torch.ao.nn.quantized.reference.Linear)

# For fusing Conv2d + ReLU into ConvReLU2d
conv_relu_config = BackendPatternConfig((torch.nn.Conv2d, torch.nn.ReLU)) \
    .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
    .add_dtype_config(weighted_int8_dtype_config) \
    .set_fused_module(torch.ao.nn.intrinsic.ConvReLU2d) \
    .set_fuser_method(fuse_conv2d_relu)

# For quantizing ConvReLU2d
fused_conv_relu_config = BackendPatternConfig(torch.ao.nn.intrinsic.ConvReLU2d) \
    .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
    .add_dtype_config(weighted_int8_dtype_config) \
    .set_root_module(torch.nn.Conv2d) \
    .set_qat_module(torch.ao.nn.intrinsic.qat.ConvReLU2d) \
    .set_reference_quantized_module(torch.ao.nn.quantized.reference.Conv2d)

backend_config = BackendConfig("my_backend") \
    .set_backend_pattern_config(linear_config) \
    .set_backend_pattern_config(conv_relu_config) \
    .set_backend_pattern_config(fused_conv_relu_config)
```

### Observer Insertion

Relevant APIs:
* `set_observation_type`

During the prepare phase, we insert observers (or QuantDeQuantStubs in the future) into the graph for this operator pattern based on the observation type, which specifies whether to use different observers for the inputs and the outputs of the pattern. For more detail, see `torch.ao.quantization.backend_config.ObservationType`.

### Reference Quantized Patterns

Relevant APIs:
* `set_root_module`
* `set_reference_quantized_module`

During the convert phase, when we construct the reference quantized model, the root modules (e.g. `torch.nn.Linear` for `nni.LinearReLU` or `nniqat.LinearReLU`) will be swapped to the corresponding reference quantized modules (e.g. `torch.ao.nn.reference.Linear`). This allows custom backends to specify custom reference quantized module implementations to match the numerics of their lowered operators. Since this is a one-to-one mapping, both the root module and the reference quantized module must be specified in the same BackendPatternConfig in order for the conversion to take place.

### Fusion

Relevant APIs:
* `set_fuser_method`
* `set_fused_module`
* `_set_root_node_getter`
* `_set_extra_inputs_getter`

As an optimization, operator patterns such as (`torch.nn.Linear`, `torch.nn.ReLU`) may be fused into `nni.LinearReLU`. This is performed during the prepare phase according to the function specified in `set_fuser_method`, which replaces the pattern with the fused module. During the convert phase, these fused modules (identified by `set_fused_module`) will then be converted to the reference quantized versions of the modules.

In FX graph mode quantization, we replace the corresponding nodes in the graph using two helper functions set by the user: `root_node_getter`, which returns the root node (typically the weighted module in the pattern like `torch.nn.Linear`) to replace the matched pattern in the graph, and `extra_inputs_getter`, which returns a list of extra input arguments that will be appended to the existing arguments of the fused module (copied over from the root node). See [this snippet](https://gist.github.com/jerryzh168/8bea7180a8ba3c279f2c9b050f2a69a6) for an example usage.

### Data Type Restrictions

Relevant APIs:
* `add_dtype_config`
* `set_dtype_configs`

DTypeConfig specifies a set of supported data types for input/output/weight/bias along with the associated constraints, if any. There are two ways of specifying `input_dtype`, `output_dtype`, and `weight_dtype`, as simple `torch.dtype`s or as `DTypeWithConstraints`, e.g.:

```
import torch
from torch.ao.quantization.backend import DTypeConfig, DTypeWithConstraints

dtype_config = DTypeConfig(
    input_dtype=torch.quint8,
    output_dtype=torch.quint8,
    weight_dtype=torch.qint8,
    bias_dtype=torch.float)

dtype_config_with_constraints = DTypeConfig(
    input_dtype=DTypeWithConstraints(
        dtype=torch.quint8,
        quant_min_lower_bound=0,
        quant_max_upper_bound=255,
        scale_min_lower_bound=2 ** -12,
    ),
    output_dtype=DTypeWithConstraints(
        dtype=torch.quint8,
        quant_min_lower_bound=0,
        quant_max_upper_bound=255,
        scale_min_lower_bound=2 ** -12,
    ),
    weight_dtype=DTypeWithConstraints(
        dtype=torch.qint8,
        quant_min_lower_bound=-128,
        quant_max_upper_bound=127,
        scale_min_lower_bound=2 ** -12,
    ),
    bias_dtype=torch.float)
```

During the prepare phase of quantization, we will compare the data types specified in these DTypeConfigs to the ones specified in the matching QConfig for a given operator pattern. If the data types do not match (or the constraints are not satisfied) for all the DTypeConfigs specified for the operator pattern, then we will simply ignore the QConfig and skip quantizing this pattern.

#### Quantization range

The user's QConfig may specify `quant_min` and `quant_max`, which are min and max restrictions on the quantization values. Here we set the lower bound for the `quant_min` and then upper bound for the `quant_max` to represent the limits of the backend. If a QConfig exceeds these limits in either direction, it will be treated as violating this constraint.

#### Scale range

Similarly, the user's QConfig may specify a minimum value for the quantization scale (currently exposed as `eps` but will change in the future to better reflect the semantics). Here we set the lower bound for the `scale_min` to represent the limits of the backend. If a QConfig's min scale value falls below this limit, the QConfig will be treated as violating this constraint. Note that `scale_max_upper_bound` is currently not used, because there is no corresponding mechanism to enforce this on the observer yet.

#### Fixed quantization parameters

For ops with fixed quantization parameters such as `torch.nn.Sigmoid` or `torch.nn.Tanh`, the BackendConfig can specify the specific scale and zero point values as constraints on the input and output activations. The user's QConfigs for these ops must use `FixedQParamsObserver` or `FixedQParamsFakeQuantize` for their activations with matching scale and zero point values, otherwise these QConfigs will be ignored.