File: channel_shuffle_dnnlowp_op_test.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (120 lines) | stat: -rw-r--r-- 3,890 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120


import caffe2.python.hypothesis_test_util as hu
import hypothesis.strategies as st
import numpy as np
from caffe2.python import core, dyndep, utils, workspace
from hypothesis import given, settings


dyndep.InitOpsLibrary("//caffe2/caffe2/quantization/server:dnnlowp_ops")
workspace.GlobalInit(["caffe2", "--caffe2_omp_num_threads=11"])


class DNNLowPChannelShuffleOpsTest(hu.HypothesisTestCase):
    @given(
        channels_per_group=st.integers(min_value=1, max_value=5),
        groups=st.sampled_from([1, 4, 8, 9]),
        n=st.integers(0, 2),
        order=st.sampled_from(["NCHW", "NHWC"]),
        **hu.gcs_cpu_only
    )
    @settings(max_examples=10, deadline=None)
    def test_channel_shuffle(self, channels_per_group, groups, n, order, gc, dc):
        X = np.round(np.random.rand(n, channels_per_group * groups, 5, 6) * 255).astype(
            np.float32
        )
        if n != 0:
            X[0, 0, 0, 0] = 0
            X[0, 0, 0, 1] = 255
        if order == "NHWC":
            X = utils.NCHW2NHWC(X)

        net = core.Net("test_net")

        quantize = core.CreateOperator("Quantize", ["X"], ["X_q"], engine="DNNLOWP")

        channel_shuffle = core.CreateOperator(
            "ChannelShuffle",
            ["X_q"],
            ["Y_q"],
            group=groups,
            kernel=1,
            order=order,
            engine="DNNLOWP",
        )

        dequantize = core.CreateOperator("Dequantize", ["Y_q"], ["Y"], engine="DNNLOWP")

        net.Proto().op.extend([quantize, channel_shuffle, dequantize])
        workspace.FeedBlob("X", X)
        workspace.RunNetOnce(net)
        Y = workspace.FetchBlob("Y")

        def channel_shuffle_ref(X):
            if order == "NHWC":
                X = utils.NHWC2NCHW(X)
            Y_r = X.reshape(
                X.shape[0], groups, X.shape[1] // groups, X.shape[2], X.shape[3]
            )
            Y_trns = Y_r.transpose((0, 2, 1, 3, 4))
            Y_reshaped = Y_trns.reshape(X.shape)
            if order == "NHWC":
                Y_reshaped = utils.NCHW2NHWC(Y_reshaped)
            return Y_reshaped

        Y_ref = channel_shuffle_ref(X)
        np.testing.assert_allclose(Y, Y_ref)

    @given(
        channels_per_group=st.integers(min_value=32, max_value=128),
        n=st.integers(0, 2),
        **hu.gcs_cpu_only
    )
    @settings(max_examples=10, deadline=None)
    def test_channel_shuffle_fast_path(self, channels_per_group, n, gc, dc):
        order = "NHWC"
        groups = 4
        X = np.round(np.random.rand(n, channels_per_group * groups, 5, 6) * 255).astype(
            np.float32
        )
        if n != 0:
            X[0, 0, 0, 0] = 0
            X[0, 0, 0, 1] = 255
        X = utils.NCHW2NHWC(X)

        net = core.Net("test_net")

        quantize = core.CreateOperator("Quantize", ["X"], ["X_q"], engine="DNNLOWP")

        channel_shuffle = core.CreateOperator(
            "ChannelShuffle",
            ["X_q"],
            ["Y_q"],
            group=groups,
            kernel=1,
            order=order,
            engine="DNNLOWP",
        )

        dequantize = core.CreateOperator("Dequantize", ["Y_q"], ["Y"], engine="DNNLOWP")

        net.Proto().op.extend([quantize, channel_shuffle, dequantize])
        workspace.FeedBlob("X", X)
        workspace.RunNetOnce(net)
        Y = workspace.FetchBlob("Y")

        def channel_shuffle_ref(X):
            if order == "NHWC":
                X = utils.NHWC2NCHW(X)
            Y_r = X.reshape(
                X.shape[0], groups, X.shape[1] // groups, X.shape[2], X.shape[3]
            )
            Y_trns = Y_r.transpose((0, 2, 1, 3, 4))
            Y_reshaped = Y_trns.reshape(X.shape)
            if order == "NHWC":
                Y_reshaped = utils.NCHW2NHWC(Y_reshaped)
            return Y_reshaped

        Y_ref = channel_shuffle_ref(X)
        np.testing.assert_allclose(Y, Y_ref)