File: dnnlowp_test_utils.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (451 lines) | stat: -rw-r--r-- 15,030 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451


import collections

import numpy as np
from caffe2.python import utils, workspace
from caffe2.quantization.server import dnnlowp_pybind11
from hypothesis import assume


# This function asserts quantized results (output[1:]) are close enough to
# floating point results (output[0]).
# The error bound is derived based on assumption that there's no input
# quantization error.
def check_quantized_results_close(outputs, ref=None, symmetric=False, atol_scale=0.53):
    if ref is None:
        ref = outputs[0][0]
    if ref.size == 0:
        return
    ref_min = min(np.min(ref), 0)
    ref_max = max(np.max(ref), 0)
    if symmetric:
        ref_scale = 2 * max(abs(ref_max), abs(ref_min)) / 255
    else:
        ref_scale = (ref_max - ref_min) / 255
    # should be divided by 2 in an exact math, but divide by 1.9 here
    # considering finite precision in floating-point numbers
    atol = ref_scale * atol_scale
    for o in outputs[1:]:
        np.testing.assert_allclose(o[0], outputs[0][0], atol=atol, rtol=0)


def pairwise(iterable):
    "s -> (s0,s1), (s1,s2), (s2, s3), ..."
    from itertools import tee

    a, b = tee(iterable)
    next(b, None)
    return zip(a, b)


# Make sure we won't have overflows from vpmaddubsw instruction used in fbgemm)
def avoid_vpmaddubsw_overflow_fc(
    batch_size, input_channels, output_channels, X, X_min, X_max, W, W_min, W_max
):
    for i, j in np.ndindex((batch_size, output_channels)):
        for k in range(0, input_channels // 2 * 2, 2):
            x0 = X[i, k] - X_min
            x1 = X[i, k + 1] - X_min
            w0 = W[j, k] - 128 - W_min
            w1 = W[j, k + 1] - 128 - W_min
            if x0 * w0 + x1 * w1 < -(1 << 15):
                w1_adjusted = (-(1 << 15) - float(x0) * w0) / x1
                W[j, k + 1] = int(w1_adjusted) + 128 + W_min
            elif x0 * w0 + x1 * w1 > (1 << 15) - 1:
                w1_adjusted = ((1 << 15) - 1 - float(x0) * w0) / x1
                W[j, k + 1] = int(w1_adjusted) + 128 + W_min

    # Go through the same loop again to double check we don't have any overflow
    for i, j in np.ndindex((batch_size, output_channels)):
        for k in range(0, input_channels // 2 * 2, 2):
            x0 = X[i, k] - X_min
            x1 = X[i, k + 1] - X_min
            w0 = W[j, k] - 128 - W_min
            w1 = W[j, k + 1] - 128 - W_min
            assert -(1 << 15) <= x0 * w0 + x1 * w1 < (1 << 15)


# Make sure we won't have overflows from vpmaddubsw instruction used in
# fbgemm (FIXME: this assumes fbgemm is used only for NHWC and im2col
# is done in a way that input_channels is the fastest moving
# dimension).
#
# strides, pads, kernels, dilations, and sizes should be tuples with the same dimension
# (2 for 2D conv, 3 for 3D conv, and so on)
def avoid_vpmaddubsw_overflow(
    strides,
    pads,
    kernels,
    dilations,
    sizes,
    input_channels,
    output_channels,
    batch_size,
    X,
    X_min,
    X_max,
    W,
    W_min,
    W_max,
):
    ndim = len(sizes)
    dkernels = tuple((dilations[i] * (kernels[i] - 1) + 1) for i in range(ndim))
    size_cols = tuple(
        (sizes[i] + 2 * pads[i] - dkernels[i]) // strides[i] + 1 for i in range(ndim)
    )
    for out_idx in np.ndindex((batch_size,) + size_cols + (output_channels,)):
        b = out_idx[0]
        oc = out_idx[-1]
        o_spatial = out_idx[1:-1]
        for filter_idx1, filter_idx2 in pairwise(
            np.ndindex(kernels + (input_channels,))
        ):
            f0 = filter_idx1[:-1]
            ic0 = filter_idx1[-1]

            f1 = filter_idx2[:-1]
            ic1 = filter_idx2[-1]

            i0s = tuple(
                strides[i] * o_spatial[i] - pads[i] + dilations[i] * f0[i]
                for i in range(ndim)
            )
            i1s = tuple(
                strides[i] * o_spatial[i] - pads[i] + dilations[i] * f1[i]
                for i in range(ndim)
            )

            w0 = W[(oc,) + f0 + (ic0,)] - 128 - W_min
            w1 = W[(oc,) + f1 + (ic1,)] - 128 - W_min

            if all(0 <= i0s[i] < sizes[i] for i in range(ndim)):
                x0 = X[(b,) + i0s + (ic0,)] - X_min
            else:
                # padding
                x0 = -X_min

            if all(0 <= i1s[i] < sizes[i] for i in range(ndim)):
                x1 = X[(b,) + i1s + (ic1,)] - X_min
            else:
                # padding
                x1 = -X_min

            if x0 * w0 + x1 * w1 < -(1 << 15):
                w1_adjusted = (-(1 << 15) - float(x0) * w0) / x1
                W[(oc,) + f1 + (ic1,)] = int(w1_adjusted) + 128 + W_min
            elif x0 * w0 + x1 * w1 >= (1 << 15):
                w1_adjusted = ((1 << 15) - 1 - float(x0) * w0) / x1
                W[(oc,) + f1 + (ic1,)] = int(w1_adjusted) + 128 + W_min

    # Go through the same loop again to double check we don't have any overflow
    for out_idx in np.ndindex((batch_size,) + size_cols + (output_channels,)):
        b = out_idx[0]
        oc = out_idx[-1]
        o_spatial = out_idx[1:-1]
        for filter_idx1, filter_idx2 in pairwise(
            np.ndindex(kernels + (input_channels,))
        ):
            f0 = filter_idx1[:-1]
            ic0 = filter_idx1[-1]

            f1 = filter_idx2[:-1]
            ic1 = filter_idx2[-1]

            i0s = tuple(
                strides[i] * o_spatial[i] - pads[i] + dilations[i] * f0[i]
                for i in range(ndim)
            )
            i1s = tuple(
                strides[i] * o_spatial[i] - pads[i] + dilations[i] * f1[i]
                for i in range(ndim)
            )

            w0 = W[(oc,) + f0 + (ic0,)] - 128 - W_min
            w1 = W[(oc,) + f1 + (ic1,)] - 128 - W_min

            if all(0 <= i0s[i] < sizes[i] for i in range(ndim)):
                x0 = X[(b,) + i0s + (ic0,)] - X_min
            else:
                # padding
                x0 = -X_min

            if all(0 <= i1s[i] < sizes[i] for i in range(ndim)):
                x1 = X[(b,) + i1s + (ic1,)] - X_min
            else:
                # padding
                x1 = -X_min

            assert -(1 << 15) <= x0 * w0 + x1 * w1 < (1 << 15)


# strides, pads, kernels, dilations, and sizes should be tuples with the same dimension
# (2 for 2D conv, 3 for 3D conv, and so on)
def generate_convnd_inputs(
    strides,
    pads,
    kernels,
    dilations,
    sizes,
    group,
    input_channels_per_group,
    output_channels_per_group,
    batch_size,
    order,
    groupwise_quantization=False,
    preserve_activation_sparsity=False,
    preserve_weight_sparsity=False,
):
    dim = len(sizes)
    assume(all(len(a) == dim for a in [strides, pads, kernels, dilations]))
    assume(all(sizes[d] >= dilations[d] * (kernels[d] - 1) + 1 for d in range(dim)))
    input_channels = input_channels_per_group * group
    output_channels = output_channels_per_group * group
    depthwise_convolution = (
        input_channels_per_group == 1 and output_channels_per_group == 1
    )

    assert input_channels > 1
    assert output_channels > 1

    # X and W have scale 1, so exactly represented after quantization
    X_min = 0 if preserve_activation_sparsity else -77
    X_max = X_min + 255
    X_range = X_max - X_min
    if depthwise_convolution and groupwise_quantization:
        # For depthwise convolution, it's not enough to set input channel 0
        # to all X_min to avoid overflow from vpmaddubsw
        X_range /= 2
    X = np.round(
        np.random.rand(*((batch_size,) + tuple(sizes) + (input_channels,))) * X_range
        + X_min
    )
    X = X.astype(np.float32)
    if (
        batch_size != 0
        and depthwise_convolution
        and groupwise_quantization
        and not preserve_activation_sparsity
    ):
        # Put X_max in a position not to be paired with any padded value.
        # Put X_min to all positions that can be paired with the X_max value.
        #
        # This is an example of a pattern for 3x3x3
        #  .   .   .   .   .
        #  .   .   .   .   .
        #  .   .   .   .   .
        #  .   .   .   .   .
        #  .   .   .   .  min
        #
        #  .   .   .   .   .
        #  .   .   .   .  min
        #  .  min max min  .
        # min  .   .   .   .
        #  .   .   .   .   .
        #
        # min  .   .   .   .
        #  .   .   .   .   .
        #  .   .   .   .   .
        #  .   .   .   .   .
        #  .   .   .   .   .

        # Make sure we have enough dimension
        assert X.shape[1] >= 3
        assert all(X.shape[d + 1] >= kernels[d] + 2 for d in range(1, dim))

        # Take subtensor we want to manipulate
        X_sub = X[(0,) * (X.ndim - dim - 1) + (slice(None),) * dim + (0,)]

        # Put X_max in the middle of the subtensor
        X_sub[(1,) + tuple(kernels[d] // 2 + 1 for d in range(1, dim))] = X_max

        # Put X_min to the positions that can be paired with X_max across
        # the slowest moving dimension
        X_sub[[[0, 2]] + [[kernels[d] + 1, 0] for d in range(1, dim)]] = X_min

        # Put X_min to other positions that can be paired with X_max
        for d1 in range(1, dim):
            X_sub[
                [[1]]
                + [[kernels[d2] // 2 + 1] for d2 in range(1, d1)]
                + [[kernels[d1] // 2, kernels[d1] // 2 + 2]]
                + [[kernels[d2] + 1, 0] for d2 in range(d1 + 1, dim)]
            ] = X_min
    else:
        # input channel 0 is all X_min to avoid overflow from vpmaddubsw when
        # multiplied with W_min and W_max
        X[..., 0] = X_min
        if batch_size != 0:
            X[(0,) * (X.ndim - 1) + (1,)] = X_max

    if preserve_weight_sparsity:
        W_min = -128
        W_max = 100
    else:
        W_min = -100
        W_max = W_min + 255
    W = np.round(
        np.random.rand(
            *((output_channels,) + tuple(kernels) + (input_channels_per_group,))
        )
        * (W_max - W_min)
        + W_min
    )
    W = W.astype(np.float32)
    if groupwise_quantization:
        for g in range(group):
            W[(g * output_channels_per_group,) + (0,) * (W.ndim - 1)] = W_min
            if depthwise_convolution:
                W[(g * output_channels_per_group, 1) + (0,) * (W.ndim - 2)] = W_max
            else:
                assert output_channels_per_group > 1
                W[(g * output_channels_per_group + 1,) + (0,) * (W.ndim - 1)] = W_max

            # Make sure each group has different ranges to really see the effect
            # of group-wise quantization.
            if not preserve_weight_sparsity:
                W[
                    g * output_channels_per_group : (g + 1) * output_channels_per_group,
                ] += g
    else:
        W[(0,) + (0,) * (W.ndim - 1)] = W_min
        W[(1,) + (0,) * (W.ndim - 1)] = W_max

    different_range_per_group = groupwise_quantization and not preserve_weight_sparsity
    for g in range(group):
        avoid_vpmaddubsw_overflow(
            strides,
            pads,
            kernels,
            dilations,
            sizes,
            input_channels_per_group,
            output_channels_per_group,
            batch_size,
            X[..., g * input_channels_per_group : (g + 1) * input_channels_per_group],
            X_min,
            X_max,
            W[g * output_channels_per_group : (g + 1) * output_channels_per_group,],
            W_min + (g if different_range_per_group else 0),
            W_max + (g if different_range_per_group else 0),
        )

    if order == "NCHW":
        X = utils.NHWC2NCHW(X)
        W = utils.NHWC2NCHW(W)

    b = np.random.randn(output_channels).astype(np.float32)

    return X, W, b


def generate_conv_inputs(
    stride,
    pad,
    kernel,
    dilation,
    size,
    group,
    input_channels_per_group,
    output_channels_per_group,
    batch_size,
    order,
    groupwise_quantization=False,
    preserve_activation_sparsity=False,
    preserve_weight_sparsity=False,
):
    return generate_convnd_inputs(
        (stride,) * 2,
        (pad,) * 2,
        (kernel,) * 2,
        (dilation,) * 2,
        (size,) * 2,
        group,
        input_channels_per_group,
        output_channels_per_group,
        batch_size,
        order,
        groupwise_quantization,
        preserve_activation_sparsity,
        preserve_weight_sparsity,
    )


def run_conv_or_fc(
    test_case,
    init_net,
    net,
    X,
    W,
    b,
    op_type,
    engine,
    order,
    gc,
    outputs,
    scale=None,
    zero_point=None,
    x_scale=None,
    x_zero_point=None,
):
    if order:
        # Conv
        Output = collections.namedtuple("Output", ["Y", "op_type", "engine", "order"])
    else:
        # FC
        Output = collections.namedtuple("Output", ["Y", "op_type", "engine"])

    # We run DNNLOWP ops multiple times to test their first runs that
    # do caching so exercises different code paths from the subsequent
    # runs

    # self.ws.run re-creates operator every time so this test covers
    # cases when we have multiple nets sharing the same workspace
    test_case.ws.create_blob("X").feed(X, device_option=gc)
    test_case.ws.create_blob("W").feed(W, device_option=gc)
    test_case.ws.create_blob("b").feed(b, device_option=gc)
    if scale is not None and zero_point is not None:
        with workspace.WorkspaceGuard(test_case.ws):
            dnnlowp_pybind11.CreateInt8QuantParamsBlob(
                "quant_param", float(scale), int(zero_point)
            )
    if x_scale is not None and x_zero_point is not None:
        with workspace.WorkspaceGuard(test_case.ws):
            dnnlowp_pybind11.CreateInt8QuantParamsBlob(
                "X_quant_param", float(x_scale), int(x_zero_point)
            )

    if init_net:
        test_case.ws.run(init_net)
    for i in range(1 if engine == "" else 2):
        test_case.ws.run(net)
        Y = test_case.ws.blobs["Y"].fetch()
        if order:
            outputs.append(Output(Y=Y, op_type=op_type, engine=engine, order=order))
        else:
            outputs.append(Output(Y=Y, op_type=op_type, engine=engine))

    # workspace.CreateNet + workspace.RunNet reuses the same operator
    if engine != "":
        workspace.FeedBlob("X", X)
        workspace.FeedBlob("W", W)
        workspace.FeedBlob("b", b)
        if scale is not None and zero_point is not None:
            dnnlowp_pybind11.CreateInt8QuantParamsBlob(
                "quant_param", float(scale), int(zero_point)
            )
        if x_scale is not None and x_zero_point is not None:
            dnnlowp_pybind11.CreateInt8QuantParamsBlob(
                "X_quant_param", float(x_scale), int(x_zero_point)
            )

        if init_net:
            workspace.RunNetOnce(init_net)
        workspace.CreateNet(net)
        for i in range(2):
            workspace.RunNet(net)
            Y = workspace.FetchBlob("Y")
            if order:
                outputs.append(Output(Y=Y, op_type=op_type, engine=engine, order=order))
            else:
                outputs.append(Output(Y=Y, op_type=op_type, engine=engine))