File: test_backend_config.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (363 lines) | stat: -rw-r--r-- 16,475 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
# Owner(s): ["oncall: quantization"]

import torch
import torch.nn.intrinsic as nni
import torch.nn.qat as nnqat
import torch.nn.quantized._reference as nnqr
from torch.testing._internal.common_quantization import QuantizationTestCase

from torch.ao.quantization.backend_config import (
    BackendConfig,
    BackendPatternConfig,
    DTypeConfig,
    DTypeWithConstraints,
    ObservationType,
)
from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize
from torch.ao.quantization.fuser_method_mappings import reverse_sequential_wrapper2
from torch.ao.quantization.fx.quantization_patterns import _default_root_node_getter
from torch.ao.quantization.observer import default_fixed_qparams_range_0to1_observer


class TestBackendConfig(QuantizationTestCase):

    # =============
    #  DTypeConfig
    # =============

    dtype_config1 = DTypeConfig(
        input_dtype=torch.quint8,
        output_dtype=torch.quint8,
        weight_dtype=torch.qint8,
        bias_dtype=torch.float
    )

    dtype_config2 = DTypeConfig(
        input_dtype=torch.float16,
        output_dtype=torch.float,
        is_dynamic=True
    )

    activation_dtype_with_constraints = DTypeWithConstraints(
        dtype=torch.quint8,
        quant_min_lower_bound=0,
        quant_max_upper_bound=127,
        scale_min_lower_bound=2 ** -12,
    )

    weight_dtype_with_constraints = DTypeWithConstraints(
        dtype=torch.qint8,
        quant_min_lower_bound=-128,
        quant_max_upper_bound=127,
        scale_min_lower_bound=2 ** -12,
    )

    dtype_config3 = DTypeConfig(
        input_dtype=activation_dtype_with_constraints,
        output_dtype=activation_dtype_with_constraints,
        weight_dtype=weight_dtype_with_constraints,
    )

    dtype_config_dict1_legacy = {
        "input_dtype": torch.quint8,
        "output_dtype": torch.quint8,
        "weight_dtype": torch.qint8,
        "bias_dtype": torch.float,
    }

    dtype_config_dict2_legacy = {
        "input_dtype": torch.float16,
        "output_dtype": torch.float,
        "is_dynamic": True,
    }

    dtype_config_dict1 = {
        "input_dtype": DTypeWithConstraints(dtype=torch.quint8),
        "output_dtype": DTypeWithConstraints(torch.quint8),
        "weight_dtype": DTypeWithConstraints(torch.qint8),
        "bias_dtype": torch.float,
    }

    dtype_config_dict2 = {
        "input_dtype": DTypeWithConstraints(dtype=torch.float16),
        "output_dtype": DTypeWithConstraints(dtype=torch.float),
        "is_dynamic": True,
    }

    dtype_config_dict3 = {
        "input_dtype": activation_dtype_with_constraints,
        "output_dtype": activation_dtype_with_constraints,
        "weight_dtype": weight_dtype_with_constraints,
    }

    def test_dtype_config_from_dict(self):
        self.assertEqual(DTypeConfig.from_dict(self.dtype_config_dict1_legacy), self.dtype_config1)
        self.assertEqual(DTypeConfig.from_dict(self.dtype_config_dict2_legacy), self.dtype_config2)
        self.assertEqual(DTypeConfig.from_dict(self.dtype_config_dict1), self.dtype_config1)
        self.assertEqual(DTypeConfig.from_dict(self.dtype_config_dict2), self.dtype_config2)
        self.assertEqual(DTypeConfig.from_dict(self.dtype_config_dict3), self.dtype_config3)

    def test_dtype_config_to_dict(self):
        self.assertEqual(self.dtype_config1.to_dict(), self.dtype_config_dict1)
        self.assertEqual(self.dtype_config2.to_dict(), self.dtype_config_dict2)
        self.assertEqual(self.dtype_config3.to_dict(), self.dtype_config_dict3)

    # ======================
    #  BackendPatternConfig
    # ======================

    _fuser_method = reverse_sequential_wrapper2(nni.LinearReLU)

    _num_tensor_args_to_observation_type = {
        0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
        1: ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT,
        2: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
    }
    _input_type_to_index = {
        "bias": 0,
        "input": 1,
        "weight": 2,
    }
    _fake_quantize = FixedQParamsFakeQuantize.with_args(observer=default_fixed_qparams_range_0to1_observer)

    def _extra_inputs_getter(self, p):
        return (torch.rand(3, 3),)

    def _get_backend_op_config1(self):
        return BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear)) \
            .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
            .add_dtype_config(self.dtype_config1) \
            .add_dtype_config(self.dtype_config2) \
            .set_root_module(torch.nn.Linear) \
            .set_qat_module(nnqat.Linear) \
            .set_reference_quantized_module(nnqr.Linear) \
            .set_fused_module(nni.LinearReLU) \
            .set_fuser_method(self._fuser_method)

    def _get_backend_op_config2(self):
        return BackendPatternConfig(torch.add) \
            .add_dtype_config(self.dtype_config2) \
            ._set_root_node_getter(_default_root_node_getter) \
            ._set_extra_inputs_getter(self._extra_inputs_getter) \
            ._set_num_tensor_args_to_observation_type(self._num_tensor_args_to_observation_type) \
            ._set_input_type_to_index(self._input_type_to_index) \
            ._set_input_output_observed(False) \
            ._set_overwrite_output_fake_quantize(self._fake_quantize) \
            ._set_overwrite_output_observer(default_fixed_qparams_range_0to1_observer)

    def _get_backend_pattern_config_dict1(self):
        return {
            "pattern": (torch.nn.ReLU, torch.nn.Linear),
            "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
            "dtype_configs": [self.dtype_config_dict1, self.dtype_config_dict2],
            "root_module": torch.nn.Linear,
            "qat_module": nnqat.Linear,
            "reference_quantized_module_for_root": nnqr.Linear,
            "fused_module": nni.LinearReLU,
            "fuser_method": self._fuser_method,
        }

    def _get_backend_pattern_config_dict2(self):
        return {
            "pattern": torch.add,
            "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
            "dtype_configs": [self.dtype_config_dict2],
            "root_node_getter": _default_root_node_getter,
            "extra_inputs_getter": self._extra_inputs_getter,
            "num_tensor_args_to_observation_type": self._num_tensor_args_to_observation_type,
            "input_type_to_index": self._input_type_to_index,
            "input_output_observed": False,
            "overwrite_output_fake_quantize": self._fake_quantize,
            "overwrite_output_observer": default_fixed_qparams_range_0to1_observer
        }

    def test_backend_op_config_set_observation_type(self):
        conf = BackendPatternConfig(torch.nn.Linear)
        self.assertEqual(conf.observation_type, ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)
        conf.set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT)
        self.assertEqual(conf.observation_type, ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT)

    def test_backend_op_config_add_dtype_config(self):
        conf = BackendPatternConfig(torch.nn.Linear)
        self.assertEqual(len(conf.dtype_configs), 0)
        conf.add_dtype_config(self.dtype_config1)
        conf.add_dtype_config(self.dtype_config2)
        self.assertEqual(len(conf.dtype_configs), 2)
        self.assertEqual(conf.dtype_configs[0], self.dtype_config1)
        self.assertEqual(conf.dtype_configs[1], self.dtype_config2)

    def test_backend_op_config_set_root_module(self):
        conf = BackendPatternConfig(nni.LinearReLU)
        self.assertTrue(conf.root_module is None)
        conf.set_root_module(torch.nn.Linear)
        self.assertEqual(conf.root_module, torch.nn.Linear)

    def test_backend_op_config_set_qat_module(self):
        conf = BackendPatternConfig(torch.nn.Linear)
        self.assertTrue(conf.qat_module is None)
        conf.set_qat_module(nnqat.Linear)
        self.assertEqual(conf.qat_module, nnqat.Linear)

    def test_backend_op_config_set_reference_quantized_module(self):
        conf = BackendPatternConfig(torch.nn.Linear)
        self.assertTrue(conf.reference_quantized_module is None)
        conf.set_reference_quantized_module(nnqr.Linear)
        self.assertEqual(conf.reference_quantized_module, nnqr.Linear)

    def test_backend_op_config_set_fused_module(self):
        conf = BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear))
        self.assertTrue(conf.fused_module is None)
        conf.set_fused_module(nni.LinearReLU)
        self.assertEqual(conf.fused_module, nni.LinearReLU)

    def test_backend_op_config_set_fuser_method(self):
        conf = BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear))
        self.assertTrue(conf.fuser_method is None)
        conf.set_fuser_method(self._fuser_method)
        self.assertEqual(conf.fuser_method, self._fuser_method)

    def test_backend_op_config_set_root_node_getter(self):
        conf = BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear))
        self.assertTrue(conf._root_node_getter is None)
        conf._set_root_node_getter(_default_root_node_getter)
        self.assertEqual(conf._root_node_getter, _default_root_node_getter)

    def test_backend_op_config_set_extra_inputs_getter(self):
        conf = BackendPatternConfig(torch.nn.Linear)
        self.assertTrue(conf._extra_inputs_getter is None)
        conf._set_extra_inputs_getter(self._extra_inputs_getter)
        self.assertEqual(conf._extra_inputs_getter, self._extra_inputs_getter)

    def test_backend_op_config_set_num_tensor_args_to_observation_type(self):
        conf = BackendPatternConfig(torch.add)
        self.assertEqual(len(conf._num_tensor_args_to_observation_type), 0)
        conf._set_num_tensor_args_to_observation_type(self._num_tensor_args_to_observation_type)
        self.assertEqual(conf._num_tensor_args_to_observation_type, self._num_tensor_args_to_observation_type)

    def test_backend_op_config_set_input_type_to_index(self):
        conf = BackendPatternConfig(torch.addmm)
        self.assertEqual(len(conf._input_type_to_index), 0)
        conf._set_input_type_to_index(self._input_type_to_index)
        self.assertEqual(conf._input_type_to_index, self._input_type_to_index)

    def test_backend_op_config_set_input_output_observed(self):
        conf = BackendPatternConfig(torch.nn.Embedding)
        self.assertTrue(conf._input_output_observed is None)
        conf._set_input_output_observed(False)
        self.assertEqual(conf._input_output_observed, False)

    def test_backend_op_config_set_overwrite_output_fake_quantize(self):
        conf = BackendPatternConfig(torch.sigmoid)
        self.assertTrue(conf._overwrite_output_fake_quantize is None)
        conf._set_overwrite_output_fake_quantize(self._fake_quantize)
        self.assertEqual(conf._overwrite_output_fake_quantize, self._fake_quantize)

    def test_backend_op_config_set_overwrite_output_observer(self):
        conf = BackendPatternConfig(torch.sigmoid)
        self.assertTrue(conf._overwrite_output_observer is None)
        conf._set_overwrite_output_observer(default_fixed_qparams_range_0to1_observer)
        self.assertEqual(conf._overwrite_output_observer, default_fixed_qparams_range_0to1_observer)

    def test_backend_op_config_from_dict(self):
        conf_dict1 = self._get_backend_pattern_config_dict1()
        conf1 = BackendPatternConfig.from_dict(conf_dict1)
        self.assertEqual(conf1.pattern, (torch.nn.ReLU, torch.nn.Linear))
        self.assertEqual(conf1.observation_type, ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)
        self.assertEqual(conf1.root_module, torch.nn.Linear)
        self.assertEqual(conf1.qat_module, nnqat.Linear)
        self.assertEqual(conf1.reference_quantized_module, nnqr.Linear)
        self.assertEqual(conf1.fused_module, nni.LinearReLU)
        self.assertEqual(conf1.fuser_method, self._fuser_method)
        self.assertTrue(conf1._root_node_getter is None)
        self.assertTrue(conf1._extra_inputs_getter is None)
        self.assertEqual(len(conf1._num_tensor_args_to_observation_type), 0)
        self.assertEqual(len(conf1._input_type_to_index), 0)
        self.assertTrue(conf1._input_output_observed is None)
        self.assertTrue(conf1._overwrite_output_fake_quantize is None)
        self.assertTrue(conf1._overwrite_output_observer is None)
        # Test temporary/internal keys
        conf_dict2 = self._get_backend_pattern_config_dict2()
        conf2 = BackendPatternConfig.from_dict(conf_dict2)
        self.assertEqual(conf2.pattern, torch.add)
        self.assertEqual(conf2.observation_type, ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)
        self.assertTrue(conf2.root_module is None)
        self.assertTrue(conf2.qat_module is None)
        self.assertTrue(conf2.reference_quantized_module is None)
        self.assertTrue(conf2.fused_module is None)
        self.assertTrue(conf2.fuser_method is None)
        self.assertEqual(conf2._root_node_getter, _default_root_node_getter)
        self.assertEqual(conf2._extra_inputs_getter, self._extra_inputs_getter)
        self.assertEqual(conf2._num_tensor_args_to_observation_type, self._num_tensor_args_to_observation_type)
        self.assertEqual(conf2._input_type_to_index, self._input_type_to_index)
        self.assertEqual(conf2._input_output_observed, False)
        self.assertEqual(conf2._overwrite_output_fake_quantize, self._fake_quantize)
        self.assertEqual(conf2._overwrite_output_observer, default_fixed_qparams_range_0to1_observer)

    def test_backend_op_config_to_dict(self):
        conf1 = self._get_backend_op_config1()
        conf2 = self._get_backend_op_config2()
        conf_dict1 = self._get_backend_pattern_config_dict1()
        conf_dict2 = self._get_backend_pattern_config_dict2()
        self.assertEqual(conf1.to_dict(), conf_dict1)
        self.assertEqual(conf2.to_dict(), conf_dict2)

    # ===============
    #  BackendConfig
    # ===============

    def test_backend_config_set_name(self):
        conf = BackendConfig("name1")
        self.assertEqual(conf.name, "name1")
        conf.set_name("name2")
        self.assertEqual(conf.name, "name2")

    def test_backend_config_set_backend_pattern_config(self):
        conf = BackendConfig("name1")
        self.assertEqual(len(conf.configs), 0)
        backend_op_config1 = self._get_backend_op_config1()
        backend_op_config2 = self._get_backend_op_config2()
        conf.set_backend_pattern_config(backend_op_config1)
        self.assertEqual(conf.configs, {
            (torch.nn.ReLU, torch.nn.Linear): backend_op_config1,
        })
        conf.set_backend_pattern_config(backend_op_config2)
        self.assertEqual(conf.configs, {
            (torch.nn.ReLU, torch.nn.Linear): backend_op_config1,
            torch.add: backend_op_config2
        })

    def test_backend_config_from_dict(self):
        op1 = self._get_backend_op_config1()
        op2 = self._get_backend_op_config2()
        op_dict1 = self._get_backend_pattern_config_dict1()
        op_dict2 = self._get_backend_pattern_config_dict2()
        conf_dict = {
            "name": "name1",
            "configs": [op_dict1, op_dict2],
        }
        conf = BackendConfig.from_dict(conf_dict)
        self.assertEqual(conf.name, "name1")
        self.assertEqual(len(conf.configs), 2)
        key1 = (torch.nn.ReLU, torch.nn.Linear)
        key2 = torch.add
        self.assertTrue(key1 in conf.configs)
        self.assertTrue(key2 in conf.configs)
        self.assertEqual(conf.configs[key1].to_dict(), op_dict1)
        self.assertEqual(conf.configs[key2].to_dict(), op_dict2)

    def test_backend_config_to_dict(self):
        op1 = self._get_backend_op_config1()
        op2 = self._get_backend_op_config2()
        op_dict1 = self._get_backend_pattern_config_dict1()
        op_dict2 = self._get_backend_pattern_config_dict2()
        conf = BackendConfig("name1").set_backend_pattern_config(op1).set_backend_pattern_config(op2)
        conf_dict = {
            "name": "name1",
            "configs": [op_dict1, op_dict2],
        }
        self.assertEqual(conf.to_dict(), conf_dict)

if __name__ == '__main__':
    raise RuntimeError("This _test file is not meant to be run directly, use:\n\n"
                       "\tpython _test/_test_quantization.py TESTNAME\n\n"
                       "instead.")