File: utils.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (474 lines) | stat: -rw-r--r-- 16,059 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474


import copy
import logging
from collections import defaultdict

import numpy as np
from caffe2.python import core, utils
from caffe2.python.fb import hardcode_scale_zp  # type: ignore[import]


logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


def pairwise(iterable):
    "s -> (s0,s1), (s1,s2), (s2, s3), ..."
    from itertools import tee

    a, b = tee(iterable)
    next(b, None)
    return zip(a, b)


def blob_uses(net, blob):
    u = []
    for i, op in enumerate(net.op):
        if blob in op.input or blob in op.control_input:
            u.append(i)
    return u


def fuse_first_bn(net, params, removed_tensors, begin_op_index):
    net = copy.deepcopy(net)
    params = copy.deepcopy(params)

    for i, conv in enumerate(net.op[begin_op_index:], begin_op_index):
        if conv.type not in ["Conv", "ConvTranspose"]:
            continue

        uses = blob_uses(net, conv.output[0])
        if len(uses) == 0:
            continue

        j = uses[0]
        bn = net.op[j]
        if bn.type != "SpatialBN" or (len(uses) > 1 and conv.output[0] != bn.output[0]):
            if bn.type == "SpatialBN":
                logger.debug("Can't fuse if more than one user {}".format(uses))
            # Can't fuse if more than one user unless SpatialBN is inplace
            # An example of inplace SpatialBN where we want to allow multiple uses:
            # x = Conv(...)
            # ... // no interferring use or def of x (will be checked below)
            # x = SpatialBN(x, ...)
            # ...
            # z = Foo(..., x, ...)
            # ...
            # w = Boo(..., x, ...)
            # Here, we still want to fuse Conv and SpatialBN
            continue

        # There shouldn't be any def of conv.output[0] and any use or def of bn.output[0] between conv and bn
        if any(
            blob in net.op[k].input or blob in net.op[k].output
            for blob in [conv.output[0], bn.output[0]]
            for k in range(i + 1, j)
        ):
            logger.debug(
                "Can't fuse because of the following interferring uses or defs:"
            )
            for k in range(i, j + 1):
                logger.debug(net.op[k])
            continue

        # else, can fuse
        fused_conv = copy.deepcopy(conv)
        fused_conv.output[0] = bn.output[0]
        conv_weight = params[conv.input[1]]
        if len(conv.input) > 2:
            conv_bias = params[conv.input[2]]
        else:
            conv_bias = np.zeros(len(params[bn.input[2]])).astype(np.float32)

        bn_scale = params[bn.input[1]]
        bn_bias = params[bn.input[2]]
        bn_running_mean = params[bn.input[3]]
        bn_running_var = params[bn.input[4]]

        # First, BN computation can be phrased as follows:
        # (X - running_mean) * (1.0 / sqrt(running_var + eps)) *
        # bn_scale + bias
        # Thus, we can rewrite bn_scale as:
        # X * bn_scale * 1.0 / (sqrt(running_var + eps)) + (bias -
        # running_mean * (1.0 / sqrt(running_var + eps)) * bn_scale)
        # Thus, can just have the affine transform
        # X * A + B
        # where
        # A = bn_scale * 1.0 / (sqrt(running_var + eps))
        # B =  (bias - running_mean * (1.0 / sqrt(running_var + eps))
        # * bn_scale)
        eps = 1.0e-5
        for arg in bn.arg:
            if arg.name == "epsilon":
                eps = arg.f
        A = bn_scale * 1.0 / (np.sqrt(bn_running_var + eps))
        B = bn_bias - bn_running_mean * A

        # This identity should hold if we have correctly fused
        # np.testing.assert_array_equal(
        #     params[conv.output[0]] * A + B,
        #     params[bn.output[0]])

        # Now, we have that the computation made is the following:
        # ((X `conv` W) + b) * A + B
        # Then, we can simply fuse this as follows:
        # (X `conv` (W * A)) + b * A + B
        # which is simply
        # (X `conv` Q) + C
        # where

        # Q = W * A
        # C = b * A + B

        # For ConvTranspose, from the view of convolutions as a
        # Toepeliz multiplication, we have W_ = W^T, so the weights
        # are laid out as (R, S, K, K) (vs (S, R, K, K) for a Conv),
        # so the weights broadcast slightly differently. Remember, our
        # BN scale 'B' is of size (S,)

        A_ = (
            A.reshape((-1,) + tuple([1] * (conv_weight.ndim - 1)))
            if conv.type == "Conv"
            else A.reshape((1, -1) + tuple([1] * (conv_weight.ndim - 2)))
        )

        C = conv_bias * A + B
        Q = conv_weight * A_

        assert params[conv.input[1]].shape == Q.shape
        if len(conv.input) > 2:
            assert params[conv.input[2]].shape == C.shape
        else:
            assert bn_bias.shape == C.shape

        params[conv.input[1]] = Q
        if len(conv.input) > 2:
            params[conv.input[2]] = C
        else:
            params[bn.input[2]] = C
            fused_conv.input.append(bn.input[2])

        new_ops = net.op[:i] + [fused_conv] + net.op[i + 1 : j] + net.op[j + 1 :]
        del net.op[:]
        removed_tensors.append(bn.input[1])
        if len(conv.input) > 2:
            removed_tensors.append(bn.input[2])
        removed_tensors.append(bn.input[3])
        removed_tensors.append(bn.input[4])
        del params[bn.input[1]]
        if len(conv.input) > 2:
            del params[bn.input[2]]
        del params[bn.input[3]]
        del params[bn.input[4]]
        net.op.extend(new_ops)
        return net, params, removed_tensors, i + 1

    return net, params, removed_tensors, None


def fuse_bn(net, params, ignore_failure):
    # Run until we hit a fixed point
    removed_tensors = []
    begin_op_index = 0
    while True:
        (next_net, next_params, removed_tensors, begin_op_index) = fuse_first_bn(
            net, params, removed_tensors, begin_op_index
        )
        if begin_op_index is None:
            if any(op.type == "SpatialBN" for op in next_net.op) and not ignore_failure:
                raise Exception(
                    "Model contains SpatialBN op after fusion: %s", next_net
                )
            return (next_net, next_params, removed_tensors)
        net, params, removed_tensors = (next_net, next_params, removed_tensors)


def fuse_first_scale(net, params, removed_tensors):
    net = copy.deepcopy(net)
    params = copy.deepcopy(params)

    for ((i, current), (j, next_)) in pairwise(enumerate(net.op)):
        if next_.input[0] != current.output[0]:
            continue

        if (
            current.type != "SpatialBN"
            or next_.type != "Mul"
            or len(net.op) <= j + 1
            or net.op[j + 1].type != "Add"
        ):
            continue

        # else, can fuse
        bn = current
        mul = next_
        add = net.op[j + 1]

        fused_bn = copy.deepcopy(bn)
        fused_bn.output[0] = add.output[0]
        bn_scale = params[bn.input[1]]
        mul_scale = params[mul.input[1]]
        bn_bias = params[bn.input[2]]
        add_bias = params[add.input[1]]

        params[bn.input[1]] = bn_scale * mul_scale
        params[bn.input[2]] = mul_scale * bn_bias + add_bias

        new_ops = net.op[:i] + [fused_bn] + net.op[j + 2 :]
        del net.op[:]
        removed_tensors.append(mul.input[1])
        removed_tensors.append(add.input[1])
        del params[mul.input[1]]
        del params[add.input[1]]
        net.op.extend(new_ops)
        break
    return net, params, removed_tensors


def fuse_scale(net, params, ignore_failure):
    # Run until we hit a fixed point
    removed_tensors = []
    while True:
        (next_net, next_params, removed_tensors) = fuse_first_scale(
            net, params, removed_tensors
        )
        if len(next_net.op) == len(net.op):
            return (next_net, next_params, removed_tensors)
        net, params, removed_tensors = (next_net, next_params, removed_tensors)


def fuse_first_relu(net, begin_op_index, ignore_op_with_output=None):
    net = copy.deepcopy(net)

    for i, conv in enumerate(net.op[begin_op_index:], begin_op_index):
        if conv.type not in ["Conv", "ConvTranspose", "Sum", "SpatialBN"]:
            continue

        uses = blob_uses(net, conv.output[0])
        if (
            len(uses) == 0
            or ignore_op_with_output
            and conv.output[0] in ignore_op_with_output
        ):
            continue

        j = uses[0]
        relu = net.op[j]
        if relu.type != "Relu" or len(uses) > 1 and conv.output[0] != relu.output[0]:
            # Can't fuse if more than one user unless Relu is inplace
            if relu.type == "Relu":
                logger.debug("Can't fuse if more than one user {}".format(uses))
            continue

        # There shouldn't be any def of conv.output[0] and any use or def of relu.output[0] between conv and relu
        if any(
            blob in net.op[k].input or blob in net.op[k].output
            for blob in [conv.output[0], relu.output[0]]
            for k in range(i + 1, j)
        ):
            logger.debug(
                "Can't fuse because of the following interferring uses or defs:"
            )
            for k in range(i, j + 1):
                logger.debug(net.op[k])
            continue

        # else, can fuse
        fused_conv = copy.deepcopy(conv)
        fused_conv.type = conv.type + "Relu"
        fused_conv.output[0] = relu.output[0]

        new_ops = net.op[:i] + [fused_conv] + net.op[i + 1 : j] + net.op[j + 1 :]
        del net.op[:]
        net.op.extend(new_ops)
        return net, i + 1
    return net, None


def fuse_relu(net, ignore_failure, ignore_op_with_output=None):
    # Run until we hit a fixed point
    begin_op_index = 0
    while True:
        next_net, begin_op_index = fuse_first_relu(
            net, begin_op_index, ignore_op_with_output
        )
        if begin_op_index is None:
            if any(op.type == "Relu" for op in next_net.op) and not ignore_failure:
                raise Exception("Model contains Relu op after fusion: %s", next_net)
            return next_net
        net = next_net


def last_producer(ops, blob):
    for (i, op) in reversed(list(enumerate(ops))):
        if op.output[0] == blob:
            return i
    raise ValueError("Failed to find last producer of blob, %s", blob)


def swap_first_concat_relu(net, ignore_op_with_output=None):
    net = copy.deepcopy(net)

    for ((i, current), (j, next_)) in pairwise(enumerate(net.op)):
        if next_.input[0] != current.output[0]:
            continue

        if current.type != "Concat" or next_.type != "Relu":
            continue

        if ignore_op_with_output and current.output[0] in ignore_op_with_output:
            continue

        # else, can swap
        concat = copy.deepcopy(current)
        relu = copy.deepcopy(next_)
        pre_ops = copy.deepcopy(net.op[:i])
        post_ops = copy.deepcopy(net.op[j + 1 :])

        # Delete the Relu after Concat
        concat.output[0] = relu.output[0]

        # Insert Relu after each op that produces inputs to Concat
        for blob in concat.input:
            k = last_producer(pre_ops, blob)
            producer = pre_ops[k]
            assert producer.output[0] == blob
            producer.output[0] = blob + "_pre_relu"

            new_relu = copy.deepcopy(relu)
            new_relu.input[0] = producer.output[0]
            new_relu.output[0] = blob

            pre_ops = pre_ops[: k + 1] + [new_relu] + pre_ops[k + 1 :]

        new_ops = pre_ops + [concat] + post_ops
        del net.op[:]
        net.op.extend(new_ops)
        break
    return net


def swap_concat_relu(net, ignore_op_with_output=None):
    # Run until we hit a fixed point
    while True:
        next_net = swap_first_concat_relu(net, ignore_op_with_output)
        if len(next_net.op) == len(net.op):
            return next_net
        net = next_net


def add_version_to_conv_bias(net, init_net):
    """
    In architectures such as FPN (https://arxiv.org/abs/1612.03144), few Conv
    ops share the same weight and bias and are run at different scales of
    the input. Since 'bias_scale = input_scale * weight_scale', sharing the
    same bias blob among multiple Conv ops means that we need different bias
    scale for each of the ops. To achieve this, we just duplicate those bias
    blobs that are used by multiple Conv ops before performing int8 rewrite.
    """
    bias_count = defaultdict(int)
    for op in net._net.op:
        if "Conv" in op.type and len(op.input) >= 3:
            bias_count[op.input[2]] += 1

    bias_fill_op = {}
    for op in init_net._net.op:
        if bias_count[op.output[0]] > 1:
            bias_fill_op[op.output[0]] = op

    bias_version = defaultdict(int)
    for op in net._net.op:
        if "Conv" in op.type and len(op.input) >= 3:
            bias = op.input[2]
            if bias_count[bias] <= 1:
                continue

            version = bias_version[bias]
            bias_version[bias] += 1
            if version == 0:
                continue

            new_bias = bias + "_v" + str(version)
            fill_op = copy.deepcopy(bias_fill_op[bias])
            fill_op.output[0] = new_bias
            init_net._net.op.extend([fill_op])
            op.input[2] = new_bias
            net._net.external_input.append(new_bias)


def add_quantization_param_args_(op, q_param):
    op.arg.extend(
        [
            utils.MakeArgument("Y_scale", q_param.scale),
            utils.MakeArgument("Y_zero_point", q_param.zero_point),
        ]
    )


def choose_quantization_params(tensor_min, tensor_max, preserve_sparsity=False):
    if tensor_min < 0 and tensor_max > 0 and preserve_sparsity:
        symmetric_qmin = -(255 // 2 + 1)
        symmetric_qmax = 255 // 2
        max_scale = max(
            abs(tensor_min / symmetric_qmin), abs(tensor_max / symmetric_qmax)
        )
        tensor_min = max_scale * symmetric_qmin
        tensor_max = max_scale * symmetric_qmax

    q_param = hardcode_scale_zp.choose_quantization_params(tensor_min, tensor_max)

    if tensor_min < 0 and tensor_max > 0 and preserve_sparsity:
        q_param = hardcode_scale_zp.QuantizationParam(q_param.scale, 128)

    return q_param


def add_quantization_param_args(op, tensor, preserve_sparsity=False):
    tensor_min = 0 if tensor.size == 0 else tensor.min()
    tensor_max = 0 if tensor.size == 0 else tensor.max()

    q_param = choose_quantization_params(tensor_min, tensor_max, preserve_sparsity)

    add_quantization_param_args_(op, q_param)
    return q_param


def create_int8_given_tensor_fill(tensor, out_blob_name, preserve_sparsity=False):
    """
    Create Int8GivenTensorFill op that quantizes the given tensor and outputs
    an Int8Tensor with out_blob_name.
    """
    op = core.CreateOperator("Int8GivenTensorFill", [], out_blob_name)
    q_param = add_quantization_param_args(op, tensor, preserve_sparsity)
    quantized_tensor = (
        np.around(tensor / q_param.scale).astype(np.int32) + q_param.zero_point
    )
    quantized_tensor = np.maximum(0, np.minimum(quantized_tensor, 255))
    op.arg.extend(
        [
            utils.MakeArgument("values", quantized_tensor.astype(np.uint8).tobytes()),
            utils.MakeArgument("shape", quantized_tensor.shape),
        ]
    )
    return op, q_param


def create_int8_bias_tensor_fill(tensor, out_blob_name, x_q_param, w_q_param):
    """
    Similar to create_int8_given_tensor_fill, but for bias blobs to be stored
    as int32.
    """
    scale = x_q_param.scale * w_q_param.scale
    quantized_tensor = np.around(tensor / scale).astype(np.int32)
    quantized_tensor.reshape(-1)
    op = core.CreateOperator("Int8GivenIntTensorFill", [], out_blob_name)
    op.arg.extend(
        [
            utils.MakeArgument("values", quantized_tensor),
            utils.MakeArgument("shape", quantized_tensor.shape),
        ]
    )
    q_param = hardcode_scale_zp.QuantizationParam(scale, 0)
    add_quantization_param_args_(op, q_param)
    return op