File: fuser_method_mappings.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (274 lines) | stat: -rw-r--r-- 11,508 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
import torch.nn as nn
import torch.nn.intrinsic as nni

from typing import Union, Callable, Tuple, Dict, Optional, Type
from torch.ao.quantization.utils import Pattern

from torch.ao.quantization.utils import get_combined_dict
from torch.ao.quantization.utils import MatchAllNode
import itertools

def fuse_conv_bn(is_qat, conv, bn):
    r"""Given the conv and bn modules, fuses them and returns the fused module

    Args:
        is_qat: a flag for whether we are using quantization aware training fusion
        or post training quantization fusion
        conv: Module instance of type conv2d/conv3d
        bn: Spatial BN instance that needs to be fused with the conv

    Examples::

        >>> m1 = nn.Conv2d(10, 20, 3)
        >>> b1 = nn.BatchNorm2d(20)
        >>> # xdoctest: +SKIP
        >>> m2 = fuse_conv_bn(m1, b1)
    """
    assert(conv.training == bn.training),\
        "Conv and BN both must be in the same mode (train or eval)."

    fused_module_class_map = {
        nn.Conv1d: nni.ConvBn1d,
        nn.Conv2d: nni.ConvBn2d,
        nn.Conv3d: nni.ConvBn3d,
    }

    if is_qat:
        assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d'
        assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'
        assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True'
        fused_module_class = fused_module_class_map.get((type(conv)), None)
        if fused_module_class is not None:
            return fused_module_class(conv, bn)
        else:
            raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn)))
    else:
        return nn.utils.fuse_conv_bn_eval(conv, bn)

def fuse_conv_bn_relu(is_qat, conv, bn, relu):
    r"""Given the conv and bn modules, fuses them and returns the fused module

    Args:
        is_qat: a flag for whether we are using quantization aware training fusion
        or post training quantization fusion
        conv: Module instance of type conv2d/conv3d
        bn: Spatial BN instance that needs to be fused with the conv

    Examples::

        >>> m1 = nn.Conv2d(10, 20, 3)
        >>> b1 = nn.BatchNorm2d(20)
        >>> r1 = nn.ReLU(inplace=False)
        >>> # xdoctest: +SKIP
        >>> m2 = fuse_conv_bn_relu(m1, b1, r1)
    """
    assert(conv.training == bn.training == relu.training),\
        "Conv and BN both must be in the same mode (train or eval)."
    fused_module : Optional[Type[nn.Sequential]] = None
    if is_qat:
        map_to_fused_module_train = {
            nn.Conv1d: nni.ConvBnReLU1d,
            nn.Conv2d: nni.ConvBnReLU2d,
            nn.Conv3d: nni.ConvBnReLU3d,
        }
        assert bn.num_features == conv.out_channels, 'Output channel of Conv must match num_features of BatchNorm'
        assert bn.affine, 'Only support fusing BatchNorm with affine set to True'
        assert bn.track_running_stats, 'Only support fusing BatchNorm with tracking_running_stats set to True'
        fused_module = map_to_fused_module_train.get(type(conv), None)
        if fused_module is not None:
            return fused_module(conv, bn, relu)
        else:
            raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn, relu)))
    else:
        map_to_fused_module_eval = {
            nn.Conv1d: nni.ConvReLU1d,
            nn.Conv2d: nni.ConvReLU2d,
            nn.Conv3d: nni.ConvReLU3d,
        }
        fused_module = map_to_fused_module_eval.get(type(conv), None)
        if fused_module is not None:
            fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
            return fused_module(fused_conv, relu)
        else:
            raise NotImplementedError("Cannot fuse eval modules: {}".format((conv, bn, relu)))

def fuse_linear_bn(is_qat, linear, bn):
    r"""Given the linear and bn modules, fuses them and returns the fused module

    Args:
        is_qat: a flag for whether we are using quantization aware training fusion
        or post training quantization fusion
        linear: Module instance of type Linear
        bn: BatchNorm1d instance that needs to be fused with the linear layer

    Examples::

        >>> m1 = nn.Linear(20, 10)
        >>> b1 = nn.BatchNorm1d(10)
        >>> # xdoctest: +SKIP
        >>> m2 = fuse_linear_bn(m1, b1)
    """
    assert(linear.training == bn.training),\
        "Linear and BN both must be in the same mode (train or eval)."

    if is_qat:
        assert bn.num_features == linear.out_features,\
            "Output features of Linear must match num_features of BatchNorm1d"
        assert bn.affine, "Only support fusing BatchNorm1d with affine set to True"
        assert bn.track_running_stats,\
            "Only support fusing BatchNorm1d with tracking_running_stats set to True"
        return nni.LinearBn1d(linear, bn)
    else:
        return nn.utils.fusion.fuse_linear_bn_eval(linear, bn)

def fuse_convtranspose_bn(is_qat, convt, bn):
    r"""Given ConvTranspose and bn modules, fuses them and returns the fused module

    Args:
        convt: Module instance of type ConvTransposeNd
        bn: BatchNormNd instance that needs to be fused with the linear layer.
            batch norm N should match the ConvTranspose N

    Examples::

        >>> m1 = nn.ConvTranspose2d(10, 20, 3)
        >>> b1 = nn.BatchNorm2d(20)
        >>> # xdoctest: +SKIP
        >>> m2 = fuse_convtranspose_bn(m1, b1)
    """
    assert(convt.training == bn.training),\
        "ConvTranspose and BN both must be in the same mode (train or eval)."

    if is_qat:
        raise Exception("Fusing ConvTranspose+BatchNorm not yet supported in QAT.")
    else:
        return nn.utils.fusion.fuse_conv_bn_eval(convt, bn, transpose=True)

def sequential_wrapper2(sequential):
    """ Given a sequential class for two modules, return a function that takes
    is_qat, and then two modules as argument, that ignores the is_qat flag
    and always returns the sequential that combines the two input modules
    """
    def fuser_method(is_qat, m1, m2):
        return sequential(m1, m2)
    return fuser_method

DEFAULT_OP_LIST_TO_FUSER_METHOD: Dict[Tuple, Union[nn.Sequential, Callable]] = {
    (nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn,
    (nn.Conv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
    (nn.Conv2d, nn.BatchNorm2d): fuse_conv_bn,
    (nn.Conv2d, nn.BatchNorm2d, nn.ReLU): fuse_conv_bn_relu,
    (nn.Conv3d, nn.BatchNorm3d): fuse_conv_bn,
    (nn.Conv3d, nn.BatchNorm3d, nn.ReLU): fuse_conv_bn_relu,
    (nn.Conv1d, nn.ReLU): sequential_wrapper2(nni.ConvReLU1d),
    (nn.Conv2d, nn.ReLU): sequential_wrapper2(nni.ConvReLU2d),
    (nn.Conv3d, nn.ReLU): sequential_wrapper2(nni.ConvReLU3d),
    (nn.Linear, nn.BatchNorm1d): fuse_linear_bn,
    (nn.Linear, nn.ReLU): sequential_wrapper2(nni.LinearReLU),
    (nn.BatchNorm2d, nn.ReLU): sequential_wrapper2(nni.BNReLU2d),
    (nn.BatchNorm3d, nn.ReLU): sequential_wrapper2(nni.BNReLU3d),
    (nn.ConvTranspose1d, nn.BatchNorm1d): fuse_convtranspose_bn,
    (nn.ConvTranspose2d, nn.BatchNorm2d): fuse_convtranspose_bn,
    (nn.ConvTranspose3d, nn.BatchNorm3d): fuse_convtranspose_bn,
}

def get_fuser_method(op_list, additional_fuser_method_mapping=None):
    ''' Get fuser method for the given list of module types,
    return None if fuser method does not exist
    '''
    if additional_fuser_method_mapping is None:
        additional_fuser_method_mapping = {}
    all_mappings = get_combined_dict(DEFAULT_OP_LIST_TO_FUSER_METHOD,
                                     additional_fuser_method_mapping)
    fuser_method = all_mappings.get(op_list, None)
    assert fuser_method is not None, "did not find fuser method for: {} ".format(op_list)
    return fuser_method

def reverse_sequential_wrapper2(sequential):
    """ Given a sequential class for two modules, return a function that takes
    is_qat, and then two modules as argument, that ignores the is_qat flag
    and always returns the sequential that combines the two input modules, with
    the order of two inputs reversed
    """
    def fuser_method(is_qat, m1, m2):
        return sequential(m2, m1)
    return fuser_method

def reverse2(f):
    def reversed(is_qat, x, y):
        return f(is_qat, y, x)
    return reversed

def reverse3(f):
    def reversed(is_qat, x, w):
        y, z = w
        return f(is_qat, z, y, x)
    return reversed

DEFAULT_PATTERN_TO_FUSER_METHOD: Dict[Pattern, Union[nn.Sequential, Callable]] = {
    (nn.BatchNorm1d, nn.Conv1d): reverse2(fuse_conv_bn),
    (nn.ReLU, (nn.BatchNorm1d, nn.Conv1d)): reverse3(fuse_conv_bn_relu),
    (nn.BatchNorm2d, nn.Conv2d): reverse2(fuse_conv_bn),
    (nn.ReLU, (nn.BatchNorm2d, nn.Conv2d)): reverse3(fuse_conv_bn_relu),
    (nn.BatchNorm3d, nn.Conv3d): reverse2(fuse_conv_bn),
    (nn.ReLU, (nn.BatchNorm3d, nn.Conv3d)): reverse3(fuse_conv_bn_relu),
    (nn.ReLU, nn.Conv1d): reverse_sequential_wrapper2(nni.ConvReLU1d),
    (nn.ReLU, nn.Conv2d): reverse_sequential_wrapper2(nni.ConvReLU2d),
    (nn.ReLU, nn.Conv3d): reverse_sequential_wrapper2(nni.ConvReLU3d),
    (nn.BatchNorm1d, nn.Linear): reverse2(fuse_linear_bn),
    (nn.ReLU, nn.Linear): reverse_sequential_wrapper2(nni.LinearReLU),
    (nn.ReLU, nn.BatchNorm2d): reverse_sequential_wrapper2(nni.BNReLU2d),
    (nn.ReLU, nn.BatchNorm3d): reverse_sequential_wrapper2(nni.BNReLU3d),
    (nn.BatchNorm1d, nn.ConvTranspose1d): reverse2(fuse_convtranspose_bn),
    (nn.BatchNorm2d, nn.ConvTranspose2d): reverse2(fuse_convtranspose_bn),
    (nn.BatchNorm3d, nn.ConvTranspose3d): reverse2(fuse_convtranspose_bn),
}

def get_valid_patterns(op_pattern):
    """
    Returns a list of valid patterns generated from the op_pattern,
    since MatchAllNode can match all types of nodes,
    e.g. pattern (torch.nn.Conv2d, torch.add) should also be able to match keys like
    (MatchAllNode, torch.add) and (torch.nn.Conv2d, MatchAllNode)

    Example Input:
    (torch.add, (torch.nn.ReLU, torch.nn.Conv2d))

    Example Output:
    [(torch.add, (torch.nn.ReLU, torch.nn.Conv2d)),
     (torch.add, (torch.nn.ReLU, MatchAllNode)),
     (torch.add, (MatchAllNode, torch.nn.Conv2d)),
     (torch.add, (MatchAllNode, MatchAllNode)),
     (MatchAllNode, (torch.nn.ReLU, torch.nn.Conv2d)),
     (MatchAllNode, (torch.nn.ReLU, MatchAllNode)),
     (MatchAllNode, (MatchAllNode, torch.nn.Conv2d)),
     (MatchAllNode, (MatchAllNode, MatchAllNode)),
    ]
    """
    result = []
    if isinstance(op_pattern, (tuple, list)):
        sub_combs = []
        for sub_pattern in op_pattern:
            sub_combs.append(get_valid_patterns(sub_pattern))
        result = list(itertools.product(*sub_combs))
    else:
        result = [op_pattern, MatchAllNode]
    return result

def get_fuser_method_new(
        op_pattern: Pattern,
        fuser_method_mapping: Optional[Dict[Pattern, Union[nn.Sequential, Callable]]] = None):
    """ This will be made defult after we deparate the get_fuser_method
    Would like to implement this first and have a separate PR for deprecation
    """
    if fuser_method_mapping is None:
        fuser_method_mapping = DEFAULT_PATTERN_TO_FUSER_METHOD

    op_patterns = get_valid_patterns(op_pattern)
    fuser_method = None
    for op_pattern in op_patterns:
        fuser_method = fuser_method_mapping.get(op_pattern, None)
        if fuser_method is not None:
            break
    assert fuser_method is not None, "did not find fuser method for: {} ".format(op_pattern)
    return fuser_method