File: README.md

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (449 lines) | stat: -rw-r--r-- 24,043 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
# FX Graph Mode Quantization Design Doc
<!---
```
float_model            QConfigMapping           BackendConfig
    \                          |                        /
     \                         |                      /
      \                        |                    /
(prepare_fx/prepare_qat_fx)                        /
—-------------------------------------------------------
|                         Fuse                         |
|                  QAT Module Swap                     |
|                 Insert Observers                     |
—-------------------------------------------------------
                              |
                      Calibrate/Train
                              |
(convert_fx)                  |
—--------------------------------------------------------
|                         Convert                       |
|                        Lowering                       |
—--------------------------------------------------------
                              |
                       Quantized Model
```
-->

```mermaid
---
title: High Level FX Graph Mode Quantization Flow
---
flowchart TD
    classDef nofs fill:none,stroke:none
    classDef sub fill:#D6EAF8,stroke:none
    float_model:::nofs --> prepare_fx:::sub
    QConfigMapping:::nofs --> prepare_fx
    BackendConfig:::nofs --> prepare_fx
    subgraph prepare_fx["`_(prepare_fx/prepare_qat_fx)_`"]
    Fuse:::nofs --> swap[QAT Module Swap]:::nofs --> obs[Insert Observers]:::nofs
    end
    prepare_fx --> Calibrate/Train:::nofs --> convert_fx:::sub
    subgraph convert_fx["`_(convert_fx)_`"]
    Convert:::nofs --> Lowering:::nofs
    end
    convert_fx --> qm[Quantized Model]:::nofs
```

Please refer to [TODO: link] for definitions of terminologies.

## Overview
The FX graph representation is pretty close to python/eager mode, it preserves many python/eager mode constructs like modules, functionals, torch ops, so overall the implementation reuses some of building blocks and utilities from eager mode quantization, this includes the QConfig, QConfig propagation (might be removed), fused modules, QAT module, quantized modules, QAT module swapping utility. Also the overall flow exactly matches eager mode quantization, the only difference is that the transformations like fusion, inserting stubs are fully automated and controlled by QConfigMapping and BackendConfig.

## High Level Flow with Simple Example

`prepare_fx`:
```
Floating Point Model --> (1.1 `_fuse_fx`) --> Fused Model
                     --> (1.2 QAT Module Swap) --> Model with QAT modules
                     --> (1.3 Insert Observers) --> Prepared Model
```

`convert_fx`:
```
Prepared Model --> (2.1 `convert_to_reference`) --> Reference Quantized Model
               --> (2.2 Lower to Native Backend) --> Quantized Model
```

In the following, I’ll first have a detailed description for each step, and then talk about the corresponding settings in BackendConfig. We’ll follow the terminologies defined in (draft) README.md of quantization syntax transforms in this doc.

### 0. Original Model

```
class LinearReLUModule(torch.nn.Module):
   def __init__(self) -> None:
       super().__init__()
       self.linear = torch.nn.Linear(5, 10).float()
       self.relu = torch.nn.ReLU()

   def forward(self, x):
       return self.relu(self.linear(x))
```

### 1.1 Fusion
```
fused: GraphModule(
  (linear): LinearReLU(
    (0): Linear(in_features=5, out_features=10, bias=True)
    (1): ReLU()
  )
)

def forward(self, x):
    linear = self.linear(x);  x = None
    return linear
```

What we did in this example are:

* Identify (Linear - ReLU) subgraph by searching through the model graph
* For each of the identified subgraph, we replace the `root_node` (typically the weighted module in the pattern, like Linear), with a fused module by calling the fuser_method for this pattern, a fused module is a sequential of a few modules, e.g. nni.LinearReLU is a sequential of linear and relu module

`backend_config` configurations relevant to this step are:

```
def fuse_linear_relu(is_qat, linear, relu):
    return nni.LinearReLU(linear, relu)

BackendPatternConfig((torch.nn.Linear, torch.nn.ReLU))
    .set_fuser_method(fuse_linear_relu)
    ._set_root_node_getter(my_root_node_getter)
    ._set_extra_inputs_getter(my_extra_inputs_getter)
```


`BackendPatternConfig` takes in a pattern that specifies the fusion pattern that we want to search for, pattern format can be found in https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/backend_config/README.md

`set_dtype_configs`: dtype_configs are used to check against the qconfig for the pattern, to see if the qconfig is supported in the target backend or not. Currently it’s not used in fusion, but we can add this check in the future, or remove this and always fuse these patterns.
`set_fuser_method`: specifies the fuser method to use for the pattern, a fuser method will take the matched object and fuse them into a fused module.
`_set_root_node_getter`: sets a function that takes a node pattern and returns the root node in the pattern.
`_set_extra_inputs_getter`: all input args of root node will be copied over to fused module, if there are extra inputs, this function will return a list of extra inputs given the pattern.

Example usage of `root_node_getter` and `extra_input_getter`: https://gist.github.com/jerryzh168/8bea7180a8ba3c279f2c9b050f2a69a6

### 1.2 QAT Module Swap
```
GraphModule(
  (linear): LinearReLU(
    in_features=5, out_features=10, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
  )
)

def forward(self, x):
    linear = self.linear(x);  x = None
    return linear
```

In this step we swap the fused module to qat module, for example, swap nn.intrinsic.LinearReLU instances to nn.intrinsic.qat.LinearReLU module where we fake quantize the weight of linear.
For modules that has corresponding QAT modules we’ll call eager mode `convert` function with a mapping from float module to QAT module which will swap all float module (and fused module) with QAT module, this step is exactly the same as eager mode quantization, just called inside the `prepare_fx/prepare_qat_fx` function.

`backend_config` configurations relevant in this step are:
```
BackendPatternConfig(nni.LinearReLU)
    .set_qat_module(nniqat.LinearReLU)
```

The pattern used to initialize BackendPatternConfig is the class type for original or fused floating point module class.
`set_qat_module` sets the qat module class corresponding to the module class specified in the pattern.

### 1.3 QuantDeQuantStub and Observer/FakeQuantize Insertion
```
GraphModule(
  (activation_post_process_0): MinMaxObserver(min_val=inf, max_val=-inf)
  (linear): LinearReLU(
    (0): Linear(in_features=5, out_features=10, bias=True)
    (1): ReLU()
  )
  (activation_post_process_1): MinMaxObserver(min_val=inf, max_val=-inf)
)

def forward(self, x):
    activation_post_process_0 = self.activation_post_process_0(x);  x = None
    linear = self.linear(activation_post_process_0);  activation_post_process_0 = None
    activation_post_process_1 = self.activation_post_process_1(linear);  linear = None
    return activation_post_process_1
```

Note: activation_post_process_0 and activation_post_process_1 will be updated with QuantDeQuantStub

QuantDeQuantStubs are inserted based on the `qconfig_mapping` provided by users. Also we have a backend_config that specifies the configs that are supported by the backend. In this step, we will
* Check if `qconfig_mapping` is compatible with `backend_config` or not, if user requested a qconfig that is not compatible with `backend_config`, we’ll not insert observers for the operator, the config would just be ignored.
* Insert observer for the input and output of the subgraph, based on the `qconfig_mapping` (what user requested) and the `backend_config` (how the operator should be observed in a backend).

Detailed walkthrough for this step in `prepare_qat_fx` (inserting QDQStub and FakeQuantize modules):
Note: We could also insert QStub and DQStub in this step when users request to change the interface dtype for the model, standalone module or custom modules.
```
# fused and qat swapped model
# graph 1:
input - qat_linear_relu - output
              |
          FakeQuantize
(need to be updated with QDQStub + FakeQuantize)
              |
           weight

# qconfig_mapping (simplified, shown as dict)
{'qat_linear_relu': QConfig(
  weight=MinMaxObserver.with_args(dtype=torch.qint8),
  activation=HistogramObserver.with_args(dtype=torch.quint8),
)}

# backend_config (simplified)
{
  'pattern': nnqat.LinearReLU,
  'dtype_configs': [{input: torch.quint8, output: torch.quint8, weight: torch.qint8}],
}
```

step 1: assign qconfig to each op (please see [TODO: link] for details)

step 2: determine which qconfigs are valid according to the backend configuration (please see [TODO: link] for details)
(we should add a warning here)

step 3: for subgraphs with validated qconfigs, insert qstub/dqstub/qdqstub needed

To talk about what happens in this step, let’s first define some terms. Let’s view the computation graph we showed above as a Graph consists of nodes and edges, each node here will be an FX Node that represents some computation, for example linear, and each edge will be a connection between two nodes, and each edge can both be viewed as the output of the previous Node or the input of the next Node.

The end goal for this step is to insert QDQStubs at edges so that we produce a graph of quantized reference model when each QDQStub represents a quantize operator followed by a dequantize operator.

```
# graph 2:
input - QDQStub1 (FakeQuantize) - qat_linear_relu - QDQStub2 (FakeQuantize) - output
                                      |
                                FakeQuantize
                  (need to be updated with QDQStub + FakeQuantize)
                                      |
                                    weight
```
Note: weight + FakeQuantize is a part of qat_linear_relu

The overall logic to insert QDQStub1 and QDQStub2 inplace is the following:
0. For each node in the original graph, we compute the target_dtype for input and output for it based on qconfig, for graph1, configured with qconfig_mapping, we have:
```
# node_name_to_target_dtype_info =
# {
#     # this is placeholder node in FX Graph
#     "input" : {"input_activation": torch.float32, "output_activation": torch.float32},
#     "qat_linear_relu": {"input_activation": torch.quint8, "output_activation": torch.quint8, "weight": ...}
#     # this is the return node in FX Graph
#     "output": {"input_activation": torch.float32, "output_activation": torch.float32}
# }
```
Note: this map is generated before we insert qdqstub to graph1, and will not change in the process.

1. Inserting QDQStub1 (for input of qat_linear_relu)
   We need to look at the edge between `input` Node and `qat_linear_relu` Node here, we need to decide if we need to insert a
   QDQStub at this edge, which could serve as an input argument for `qat_linear_relu` Node (and also output for `input` Node)
   The way we decide if we want to insert QDQStub here is to figure out

   (1). The target dtype for output of `input` Node, which is torch.float32

   (2). The target dtype for input of `qat_linear_relu` Node, which is torch.quint8
   There is a mismatch here and (2) is a quantized dtype, so we need to insert QDQStub at the edge.

   We also need to attach observer/fakequant module to the QDQStub we inserted here.
2. Insert QDQStub2 (for output of qat_linear_relu)
   The logic for inserting QDQStub for output is much easier, since we assume all modules/functions in the graph produce fp32 output
   by default (we can have additional checks and extend this to work for other dtypes after we have type inference ready),
   we just need to look at the target output dtype for qat_linear_relu Node, and if it is a quantized dtype (quint8, qint8, float16),
   we would insert a QDQStub here.

Questions: How to avoid inserting duplicate QDQStubs?
e.g. when we have a single input being used by multiple ops:
```
input — linear1 —-
     \--- linear2 —
```
how do we make sure we only insert one QDQStub for input of both linear1 and linear2?
```
input - QDQStub — linear1 -
             \ —- linear2 -
```

The way we do it right now is before we insert QDQStub, we look at all users of `input` Node here and make sure there is no QDQStubs
with the same target_dtype, that is, if we already inserted a QDQStub with dtype quint8 for linear1, and linear2 is also connected to it, if we request another QDQStub with dtype quint8 when processing linear2 Node, we’ll detect that the desired QDQStub already exists and do nothing

Question: What is the logic for keeping output to be float32?
Let’s say the output of `qat_linear_relu` Node is configured as float32, both in qconfig_mapping and backend_config:
```
# qconfig_mapping (simplified, shown as dict)
{'qat_linear_relu': QConfig(
  weight=MinMaxObserver.with_args(dtype=torch.qint8),
  input_activation=HistogramObserver.with_args(dtype=torch.quint8),
  output_activation=PlaceholderObserver.with_args(dtype=torch.float32),
)}

# backend_config (simplified)
{
  'pattern': nnqat.LinearReLU,
  'dtype_configs': [{input: torch.quint8, output: torch.float32, weight: torch.qint8}],
}
```

What we’ll do here is when we are trying to insert output QDQStub for `qat_linear_relu`, we look at the target output dtype for this node (node_name_to_target_dtype_info["qat_linear_relu"]["output_activation"], and find that it is float, which is not a quantized dtype, so
will do nothing here.
Note that this does not prevent other operators following `qat_linear_relu` to insert a QDQStub at the output of `qat_linear_relu`, since we are dealing with an `edge` of the graph here, and an `edge` is connected to two nodes, which means
the output of `qat_linear_relu` will also be the input of a node following `qat_linear_relu`.

`backend_config` configurations used in this step:
```
BackendConfig(nniqat.LinearReLU)
    .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)
    .set_dtype_configs([
        DTypeConfig(input_dtype=torch.quint8, output_dtype = torch.quint8, weight_dtype = torch.qint8, bias_dtype = torch.float32)]
    )
```

Pattern in this case is the same as before, it defines the pattern for the subgraph we are dealing with

`set_observation_type`: sets the observation type for the patter, currently only two types:

`OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT` means the output observer instance will be different from the input, which is the most common type of observer placement.

`OUTPUT_SHARE_OBSERVER_WITH_INPUT` means the output observer is shared with input, they will be the same instance. This is useful for operators like cat.

`set_dtype_configs`: sets a list of supported (activation, weight, bias, etc.) dtype combinations for qconfigs for the pattern. Note that we represent different modes of quantization (static/dynamic/`weight_only`) purely through this combination, for example, fbgemm static quantization can be represented as:
```
{
  "input_activation": torch.quint8,
  "weight": torch.qint8,
  "output_activation": torch.quint8
}
```

Note: the dtype config will be used to configure the support for dynamic quantization as well

Note: we may extend this to support more fine grained configurations of args, kwargs, attributes and outputs in the future

Note: we are referring to observer here, which is an implementation detail, we can change this to talk about quantization parameters instead, e.g. `QParamsType.OUTPUT_USE_DIFFERENT_QPARAMS_AS_INPUT` and `QParamsType.OUTPUT_USE_SAME_QPARAMS_AS_INPUT`

### 2. Calibration/Training
After we insert observers, we run the model to calibrate observers or to fine tune. This step is identical to eager mode quantization. After that the observer/fakequantize modules contain sufficient information to determine quantization parameters according to the observed data.

### 3.1 Conversion to Reference Quantized Model
```
quantized: GraphModule(
  (linear): LinearReLU(
    (0): QuantizedLinear(Reference)(in_features=5, out_features=10, bias=True)
    (1): ReLU()
  )
)

def forward(self, x):
    linear_input_scale_0 = self.linear_input_scale_0
    linear_input_zero_point_0 = self.linear_input_zero_point_0
    quantize_per_tensor = torch.quantize_per_tensor(x, linear_input_scale_0, linear_input_zero_point_0, torch.quint8);  x = linear_input_scale_0 = linear_input_zero_point_0 = None
    dequantize = quantize_per_tensor.dequantize();  quantize_per_tensor = None
    linear = self.linear(dequantize);  dequantize = None
    linear_scale_0 = self.linear_scale_0
    linear_zero_point_0 = self.linear_zero_point_0
    quantize_per_tensor_1 = torch.quantize_per_tensor(linear, linear_scale_0, linear_zero_point_0, torch.quint8);  linear = linear_scale_0 = linear_zero_point_0 = None
    dequantize_1 = quantize_per_tensor_1.dequantize();  quantize_per_tensor_1 = None
    return dequantize_1
```

After we insert observers, we’ll need to convert the model to a reference quantized model. Reference quantized model is a model that uses reference patterns to represent quantized operators, this serves as the standard interface for quantized operators between PyTorch quantization and backend lowering passes. For more details, please take a look at this [RFC](https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md). This pass is pretty straightforward, what we do is:

(1). for each QDQStub (attached with Observer for FakeQuantize modules) in the graph, we'll convert it to calls to quantize and dequantize functions based on the attributes of attached Observer and FakeQuantize modules (e.g. qscheme, dtype etc.)

(2). for weighted modules like linear/conv, we convert them to corresponding reference quantized module.

Example:
```
# graph 1
input - QDQStub1 (FakeQuantize) - qat_linear_relu - QDQStub2 (FakeQuantize) - output
                                      |
                                FakeQuantize
                  (need to be updated with QDQStub + FakeQuantize)
                                      |
                                    Weight

Note: weight + FakeQuantize is a part of qat_linear_relu module

# graph 2
input - quantize - dequantize - reference_linear_relu - quantize - dequantize - output
                                        |
                                   dequantize
                                        |
                                    quantize
                                        |
                                      weight
```
Note: weight + quantize + dequantize is a part of reference_linear_relu module

To decide which quantize node we want to use, we’ll look at:

(1). dtype of attached Observer/FakeQuantize module

(2). qscheme of attached Observer/FakeQuantize module

(3). (optionally) other attributes of attached Observer/FakeQuantize module

The quantize operator we can choose from right now are: (quantize_per_tensor, quantize_per_channel, to, quantize_per_tensor_dynamic)

```
backend_config configurations used in this step:
BackendConfig(nniqat.LinearReLU)
    .set_root_module(nn.Linear)
    .set_reference_quantized_module_for_root(nnqr.Linear)
    .set_fused_module(nni.LinearReLU)
```

Pattern in this case is the same as before, it defines the pattern for the subgraph we are dealing with

`set_root_module`: Sets a module class for the root of the pattern, e.g. nn.Linear for a nni.LinearReLU/nniqat.LinearReLU, used to identify the modules that needs to be swapped to reference quantized module

`set_reference_quantized_module_for_root`: Sets the corresponding reference quantized module class for root module class, e.g. when root_module is nn.Linear, this will be nn.quantized.reference.Linear, used to swap the root module to be a reference quantized module.

Note: we are only swapping `root_module` here, for example, in the current example, the original module is `nniqat.LinearReLU`, when we are converting weight modules(step (2)), we first convert `nniqat.LinearReLU` to a float module, in this case, the fused LinearReLU module: `nni.LinearReLU`, and then swap the root_module (`nn.Linear`) with reference quantized module (`nnqr.Linear`), so we end up with a `nni.LinearReLU` module, which is a sequential module of a `nnqr.Linear` and `nn.ReLU`.

Basically, the corresponding reference quantized module for both `nniqat.LinearReLU` and `nni.LinearReLU` would be a `nni.LinearReLU` Sequential module (originally `nn.Linear` + `nn.ReLU`) with `nn.Linear` being replaced by `nnqr.Linear`: `nni.LinearReLU(nnqr.Linear, nn.ReLU)`.

`set_fused_module`: This is the corresponding fused module class for the pattern, used to identify fused modules that needs to be converted to reference quantized module

### 3.2 Lower to PyTorch Native Backend
```
GraphModule(
  (linear): QuantizedLinearReLU(in_features=5, out_features=10, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
)

def forward(self, x):
    linear_input_scale_0 = self.linear_input_scale_0
    linear_input_zero_point_0 = self.linear_input_zero_point_0
    quantize_per_tensor = torch.quantize_per_tensor(x, linear_input_scale_0, linear_input_zero_point_0, torch.quint8);  x = linear_input_scale_0 = linear_input_zero_point_0 = None
    linear = self.linear(quantize_per_tensor);  quantize_per_tensor = None
    dequantize_1 = linear.dequantize();  linear = None
    return dequantize_1
```

Currently, PyTorch has native quantized backends: fbgemm and qnnpack, so we need a lowering pass to lower the reference quantized model to a model that is using native quantized operators in PyTorch. What this pass did is

1. Recognize the reference patterns like: "dequantize - `float_op` - quantize" in the graph and replace them with the quantized modules (under torch.nn.quantized namespace) or operators (under torch.ops.quantized namespace, or torch namespace)
In general there are three types of patterns:

* Static quantization:
```
dequantize -> float_op -> quantize_per_tensor
```

* Dynamic quantization:
```
quantize_per_tensor_dynamic -> dequantize -> float_op
```

* Weight only quantization:
```
                                       input - float_op - output
      weight - quantize_per_tensor - dequantize /
```

2. Prepack and fold the weights for quantized linear and quantized conv operator
3. The lowering pass is also going to keep some patterns for quantized operators unfused, since user may explicitly request some operators to stay in float by configuring the qconfig to be None

There are no configurations related to lowering in `backend_config` since it is backend developer’s responsibility to implement lowering pass and each of the backend developers may have their own configurations. So from end to end, `backend_config` and together with qconfig_mapping controls what Reference Quantized Model is produced by FX Graph Mode Quantization, not lowered model.

However, for some operator based backends, like the current pytorch native backends including fbgemm and qnnpack. We could interpret `backend_config` in terms of configurations for operators as well. e.g. configuring `input_dtype=quint8`, `weight_dtype=qint8`, `output_dtype=torch.quint8` for nn.Linear is saying that the quantized linear will take a `quint8` activation and `qint8` weight as input and outputs a `quint8` activation. But there is no guarantee that this interpretation will always work in the future, especially when we add new flavors of quantized operators.

## Extensibility

FX graph mode quantization can be extended to work with different backends, which may have different sets of supported quantized operator patterns and different requirements for each pattern. For more detail, please refer to the [BackendConfig README](/torch/ao/quantization/backend_config/README.md).