File: tile_op_test.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (117 lines) | stat: -rw-r--r-- 3,887 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117





import numpy as np

from hypothesis import given, settings
import hypothesis.strategies as st
import unittest

from caffe2.python import core, workspace
import caffe2.python.hypothesis_test_util as hu
import caffe2.python.serialized_test.serialized_test_util as serial


class TestTile(serial.SerializedTestCase):
    @given(M=st.integers(min_value=1, max_value=10),
           K=st.integers(min_value=1, max_value=10),
           N=st.integers(min_value=1, max_value=10),
           tiles=st.integers(min_value=1, max_value=3),
           axis=st.integers(min_value=0, max_value=2),
           **hu.gcs)
    @settings(deadline=10000)
    def test_tile(self, M, K, N, tiles, axis, gc, dc):
        X = np.random.rand(M, K, N).astype(np.float32)

        op = core.CreateOperator(
            'Tile', ['X'], 'out',
            tiles=tiles,
            axis=axis,
        )

        def tile_ref(X, tiles, axis):
            dims = np.asarray([1, 1, 1], dtype=np.int)
            dims[axis] = tiles
            tiled_data = np.tile(X, dims)
            return (tiled_data,)

        # Check against numpy reference
        self.assertReferenceChecks(gc, op, [X, tiles, axis],
                                   tile_ref)
        # Check over multiple devices
        self.assertDeviceChecks(dc, op, [X], [0])
        # Gradient check wrt X
        self.assertGradientChecks(gc, op, [X], 0, [0])

    @unittest.skipIf(not workspace.has_gpu_support, "No gpu support")
    @given(M=st.integers(min_value=1, max_value=200),
           N=st.integers(min_value=1, max_value=200),
           tiles=st.integers(min_value=50, max_value=100),
           **hu.gcs)
    def test_tile_grad(self, M, N, tiles, gc, dc):
        X = np.random.rand(M, N).astype(np.float32)
        axis = 1

        op = core.CreateOperator(
            'Tile', ['X'], 'out',
            tiles=tiles,
            axis=axis,
        )

        def tile_ref(X, tiles, axis):
            dims = np.asarray([1, 1], dtype=np.int)
            dims[axis] = tiles
            tiled_data = np.tile(X, dims)
            return (tiled_data,)

        # Check against numpy reference
        self.assertReferenceChecks(gc, op, [X, tiles, axis],
                                   tile_ref)
        # Check over multiple devices
        self.assertDeviceChecks(dc, op, [X], [0])

        # Gradient check wrt X
        grad_op = core.CreateOperator(
            'TileGradient', ['dOut'], 'dX',
            tiles=tiles,
            axis=axis,
        )
        dX = np.random.rand(M, N * tiles).astype(np.float32)
        self.assertDeviceChecks(dc, grad_op, [dX], [0])

    @given(M=st.integers(min_value=1, max_value=10),
           K=st.integers(min_value=1, max_value=10),
           N=st.integers(min_value=1, max_value=10),
           tiles=st.integers(min_value=1, max_value=3),
           axis=st.integers(min_value=0, max_value=2),
           **hu.gcs)
    @settings(deadline=10000)
    def test_tilewinput(self, M, K, N, tiles, axis, gc, dc):
        X = np.random.rand(M, K, N).astype(np.float32)

        tiles_arg = np.array([tiles], dtype=np.int32)
        axis_arg = np.array([axis], dtype=np.int32)

        op = core.CreateOperator(
            'Tile', ['X', 'tiles', 'axis'], 'out',
        )

        def tile_ref(X, tiles, axis):
            dims = np.asarray([1, 1, 1], dtype=np.int)
            dims[axis] = tiles
            tiled_data = np.tile(X, dims)
            return (tiled_data,)

        # Check against numpy reference
        self.assertReferenceChecks(gc, op, [X, tiles_arg, axis_arg],
                                   tile_ref)
        # Check over multiple devices
        self.assertDeviceChecks(dc, op, [X, tiles_arg, axis_arg], [0])
        # Gradient check wrt X
        self.assertGradientChecks(gc, op, [X, tiles_arg, axis_arg], 0, [0])


if __name__ == "__main__":
    unittest.main()