File: test_fusion_passes.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (121 lines) | stat: -rw-r--r-- 4,915 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# -*- coding: utf-8 -*-
# Owner(s): ["oncall: quantization"]

# torch
import torch
from torch.testing import FileCheck
from torch.testing._internal.common_quantization import QuantizationTestCase

class TestFusionPasses(QuantizationTestCase):
    def test_quantized_add_relu_fusion(self):
        class MAdd(torch.nn.Module):
            def __init__(self):
                super(MAdd, self).__init__()

            def forward(self, x, y):
                a = torch.ops.quantized.add(x, y, 1., 0)
                relu_out = torch.relu(a)
                return relu_out

        A = torch.arange(-128, 130, dtype=torch.float)
        B = torch.arange(-128, 130, dtype=torch.float)
        scale = 2.0
        zero_point = 127
        qA = torch.quantize_per_tensor(A, scale=scale, zero_point=zero_point,
                                       dtype=torch.quint8)
        qB = torch.quantize_per_tensor(B, scale=scale, zero_point=zero_point,
                                       dtype=torch.quint8)

        # Check quantized add + relu fusion
        m = MAdd()
        scripted_m = torch.jit.script(m)
        ref_output = scripted_m(qA, qB)

        # Must inline the graph.
        # In this test case since we are directly calling ops
        # it does not matter, however if we are calling nn
        # modules we have to inline graph.
        torch._C._jit_pass_inline(scripted_m.graph)
        torch._C._jit_pass_fuse_quantized_add_relu(scripted_m.graph)
        FileCheck().check_not("aten::relu") \
                   .check("quantized::add_relu") \
                   .run(scripted_m.graph)
        output = scripted_m(qA, qB)
        self.assertEqual(ref_output, output)

        class MAddOut(torch.nn.Module):
            def __init__(self):
                super(MAddOut, self).__init__()

            def forward(self, x, y, z):
                a = torch.ops.quantized.add_out(x, y, z)
                relu_out = torch.relu(a)
                return relu_out

        qC = torch._empty_affine_quantized(qA.shape,
                                           scale=scale,
                                           zero_point=zero_point,
                                           dtype=torch.quint8)
        # Check quantized add + relu fusion
        m = MAddOut()
        scripted_m = torch.jit.script(m)
        ref_output = scripted_m(qA, qB, qC)
        # Must inline the graph.
        # In this test case since we are directly calling ops
        # it does not matter, however if we are calling nn
        # modules we have to inline graph.
        torch._C._jit_pass_inline(scripted_m.graph)
        torch._C._jit_pass_fuse_quantized_add_relu(scripted_m.graph)
        FileCheck().check_not("aten::relu") \
                   .check_not("quantized::add_out") \
                   .check("quantized::add_relu_out") \
                   .run(scripted_m.graph)
        output = scripted_m(qA, qB, qC)
        self.assertEqual(ref_output, output)

        class MAddScalar(torch.nn.Module):
            def __init__(self):
                super(MAddScalar, self).__init__()

            def forward(self, x, y : float):
                a = torch.ops.quantized.add_scalar(x, y)
                relu_out = torch.relu(a)
                return relu_out

        # Check quantized add + relu fusion
        m = MAddScalar()
        scripted_m = torch.jit.script(m)
        ref_output = scripted_m(qA, 3.)
        torch._C._jit_pass_inline(scripted_m.graph)
        torch._C._jit_pass_fuse_quantized_add_relu(scripted_m.graph)
        FileCheck().check_not("aten::relu") \
                   .check_not("quantized::add_scalar(") \
                   .check("quantized::add_scalar_relu") \
                   .run(scripted_m.graph)
        output = scripted_m(qA, 3.)
        self.assertEqual(ref_output, output)

        class MAddScalarOut(torch.nn.Module):
            def __init__(self):
                super(MAddScalarOut, self).__init__()

            def forward(self, x, y : float, z):
                a = torch.ops.quantized.add_scalar_out(x, y, z)
                relu_out = torch.relu(a)
                return relu_out

        qC = torch._empty_affine_quantized(qA.shape,
                                           scale=scale,
                                           zero_point=zero_point,
                                           dtype=torch.quint8)
        m = MAddScalarOut()
        scripted_m = torch.jit.script(m)
        ref_output = scripted_m(qA, 3., qC)
        torch._C._jit_pass_inline(scripted_m.graph)
        torch._C._jit_pass_fuse_quantized_add_relu(scripted_m.graph)
        FileCheck().check_not("aten::relu") \
                   .check_not("quantized::add_scalar_out") \
                   .check("quantized::add_scalar_relu_out") \
                   .run(scripted_m.graph)
        output = scripted_m(qA, 3., qC)
        self.assertEqual(ref_output, output)