File: test_quantize_fx_lite_script_module.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (95 lines) | stat: -rw-r--r-- 2,938 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# Owner(s): ["oncall: mobile"]

import torch
import torch.nn as nn
import torch.ao.nn.quantized as nnq
import torch.utils.bundled_inputs
from torch.ao.quantization import (
    default_qconfig,
    float_qparams_weight_only_qconfig,
)

# graph mode quantization based on fx
from torch.ao.quantization.quantize_fx import (
    prepare_fx,
    convert_fx,
)
from torch.testing._internal.common_quantization import NodeSpec as ns
from torch.testing._internal.common_quantization import (
    QuantizationLiteTestCase,
    LinearModelWithSubmodule,
)


class TestLiteFuseFx(QuantizationLiteTestCase):

    # Tests from:
    # ./caffe2/test/quantization/fx/test_quantize_fx.py

    def test_embedding(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12)

            def forward(self, indices):
                return self.emb(indices)

        model = M().eval()
        indices = torch.randint(low=0, high=10, size=(20,))

        quantized_node = ns.call_module(nnq.Embedding)
        configs = [
            (float_qparams_weight_only_qconfig, ns.call_module(nnq.Embedding)),
            (None, ns.call_module(nn.Embedding)),
            (default_qconfig, ns.call_module(nn.Embedding)),
        ]

        for qconfig, node in configs:
            qconfig_dict = {"": qconfig}
            m = prepare_fx(model, qconfig_dict)
            m = convert_fx(m)
            self._compare_script_and_mobile(m, input=indices)

    def test_conv2d(self):
        class M(torch.nn.Module):
            def __init__(self):
                super(M, self).__init__()
                self.conv1 = nn.Conv2d(1, 1, 1)
                self.conv2 = nn.Conv2d(1, 1, 1)

            def forward(self, x):
                x = self.conv1(x)
                x = self.conv2(x)
                return x

        m = M().eval()
        qconfig_dict = {"": default_qconfig, "module_name": [("conv1", None)]}
        m = prepare_fx(m, qconfig_dict)
        data = torch.randn(1, 1, 1, 1)
        m = convert_fx(m)
        # first conv is quantized, second conv is not quantized
        self._compare_script_and_mobile(m, input=data)

    def test_submodule(self):
        # test quantizing complete module, submodule and linear layer
        configs = [
            {},
            {"module_name": [("subm", None)]},
            {"module_name": [("fc", None)]},
        ]
        for config in configs:
            model = LinearModelWithSubmodule().eval()
            qconfig_dict = {
                "": torch.ao.quantization.get_default_qconfig("qnnpack"),
                **config,
            }
            model = prepare_fx(model, qconfig_dict)
            quant = convert_fx(model)

            x = torch.randn(5, 5)
            self._compare_script_and_mobile(quant, input=x)


if __name__ == "__main__":
    run_tests()