File: test_ops_jit.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (301 lines) | stat: -rw-r--r-- 14,004 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
# Owner(s): ["module: unknown"]

from functools import partial
from textwrap import dedent

import torch

from torch.testing import FileCheck
from torch.testing._internal.common_utils import \
    (run_tests, IS_SANDCASTLE, clone_input_helper, first_sample, skipIfSlowGradcheckEnv)
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_device_type import instantiate_device_type_tests, ops, OpDTypes
from torch.testing._internal.common_jit import JitCommonTestCase, check_against_reference
from torch.testing._internal.jit_metaprogramming_utils import create_script_fn, create_traced_fn, check_alias_annotation
from torch.testing._internal.jit_utils import disable_autodiff_subgraph_inlining, is_lambda


# TODO: fixme https://github.com/pytorch/pytorch/issues/68972
torch.set_default_dtype(torch.float32)

# variant testing is only done with torch.float and torch.cfloat to avoid
#   excessive test times and maximize signal to noise ratio
_variant_ops = partial(ops, dtypes=OpDTypes.supported,
                       allowed_dtypes=(torch.float, torch.cfloat))



# Tests operators for consistency between JIT and eager, also checks
#   correctness of JIT specific alias schemas and intended
#   autodifferentiation behavior.
# Inherits from JitCommonTestCase instead of TestCase directly to share
#   functionality with original test_jit.py method operator tests
@skipIfSlowGradcheckEnv
class TestJit(JitCommonTestCase):
    exact_dtype = True

    # Tests that the forward and backward passes of operations produce the
    #   same values for the cross-product of op variants (function, method, inplace)
    #   and runtimes (eager, traced, scripted).
    # TODO WARNING: inplace x {traced, scripted} not currently tested
    @_variant_ops(op_db)
    def test_variant_consistency_jit(self, device, dtype, op):
        _requires_grad = (dtype in op.supported_backward_dtypes(torch.device(device).type))

        include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
        samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad, include_conjugated_inputs=include_conjugated_inputs)

        # Acquires variants to test
        func = op.get_op()
        method = op.get_method()
        variants = {
            # TODO: inplace tests currently fail, fix and add inplace variant
            'function': func, 'method': method,
        }

        # scripting strips the torch.ops prefix from these operators
        # incorrectly; don't bother testing this case.  Count this
        # as "testing"
        if isinstance(func, torch._ops.OpOverload):
            self.skipTest("variant consistency doesn't work on torch.ops")

        # TODO: find better way to standardize on op registration itself..
        has_fake_function = op.name in ["resize_", 'resize_as_']

        if has_fake_function:
            variants = {'method': getattr(torch.Tensor, op.name)}
            samples = op.sample_inputs(device, dtype, requires_grad=False)


        tested = False
        for sample in samples:
            # Test traced and scripted consistency
            for func_type, variant in variants.items():
                if variant is None:
                    continue

                # scripting and check_alias_analysis do not work with lambdas
                # lambdas are typically used as a way to simulate methods without
                # functional variants, so rely on the other variant for testing
                # for now
                if is_lambda(variant):
                    continue

                tested = True
                try:
                    self.indiv_variant_test_jit(device, dtype, op, sample, func_type, variant, has_fake_function)
                except Exception as e:
                    variant_error_info = dedent(f"""
                        Error testing {op.name} {func_type} variant
                        with dtype: {dtype}
                        with inputs {sample}:
                    """)
                    raise Exception(variant_error_info) from e

        assert tested, "JIT Test does not execute any logic"

    def indiv_variant_test_jit(self, device, dtype, op, sample, func_type, variant, has_fake_function):
        _requires_grad = (dtype in op.supported_backward_dtypes(torch.device(device).type))
        support_script = op.supports_scripting
        # Create accessor for script function variant
        name = op.name + '_' if func_type == 'inplace' else op.name

        # run with disable_autodiff_subgraph_inlining(True) to test
        #   autodiff support. Context manager forces the graph to contain
        #   DifferentiableGraph nodes if they are present
        with disable_autodiff_subgraph_inlining():
            # Check scripted forward, grad, and grad grad
            if support_script:
                script_fn = create_script_fn(self, name, func_type)

            def out_fn(output):
                # Processes the output for autograd
                if sample.output_process_fn_grad is not None:
                    return sample.output_process_fn_grad(output)
                return output

            def get_sample():
                return clone_input_helper(sample.input) if op.name[-1] == '_' else sample.input

            if support_script:
                check_against_reference(self,
                                        script_fn,
                                        op.get_op(),
                                        out_fn,
                                        (get_sample(),) + sample.args,
                                        sample.kwargs,
                                        no_grad=not _requires_grad, no_gradgrad=not op.supports_gradgrad)

            # Check traced forward, grad, and grad grad
            # TODO: fix tracing here
            supports_tracing = op.supports_tracing and not has_fake_function
            if op.assert_jit_shape_analysis:
                self.assertTrue(supports_tracing)

            if supports_tracing:
                traced_fn = create_traced_fn(self, variant)
                check_against_reference(self,
                                        traced_fn,
                                        op.get_op(),
                                        out_fn,
                                        (get_sample(),) + sample.args,
                                        sample.kwargs,
                                        no_grad=not _requires_grad, no_gradgrad=not op.supports_gradgrad)

            # Check alias annotation schema for correctness (make
            #   sure inputs that aren't supposed to be modified aren't)
            # Note: only runs in float32 because schema isn't affected by dtype,
            #   so running it on all dtypes is would be excessive
            if dtype == torch.float32:
                # TODO: no reason why we cant run this with tracing graph
                if support_script and op.name != "rsub":
                    check_alias_annotation(name, (get_sample(),) + sample.args, sample.kwargs,
                                           func_type=func_type, aten_name=op.aten_name)

                # TODO: use script graph as well
                checked_shape_analysis = False
                if supports_tracing:
                    out = variant(get_sample(), *sample.args, **sample.kwargs)

                    # right now, tuple of outputs and tensor output supported
                    # TODO: list of tensor outputs
                    tuple_of_tensors = isinstance(out, tuple) and all([isinstance(elem, torch.Tensor) for elem in out])

                    if isinstance(out, torch.Tensor) or tuple_of_tensors:
                        if tuple_of_tensors:
                            sizes = [elem.size() for elem in out]
                        else:
                            sizes = out.size()
                        self.checkShapeAnalysis(sizes, traced_fn.graph, op.assert_jit_shape_analysis)
                        checked_shape_analysis = True
                if op.assert_jit_shape_analysis:
                    self.assertTrue(checked_shape_analysis)

            # Check autodifferentiation of nodes for traced and scripted graphs, only need to check once per sample
            if dtype is torch.float32:
                # Sandcastle doesn't fuse nodes
                if IS_SANDCASTLE:
                    # fusible nodes are expected to be found in FusionGroups in the DifferentiableGraphs
                    nonfusible_nodes = op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes
                    fusible_nodes = []
                else:
                    nonfusible_nodes = op.autodiff_nonfusible_nodes
                    fusible_nodes = op.autodiff_fusible_nodes

                if supports_tracing:
                    self.assertAutodiffNode(traced_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes)
                if support_script:
                    self.assertAutodiffNode(script_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes)

    # alias testing is only done with torch.float for the same reason
    _alias_ops = partial(ops, dtypes=OpDTypes.supported,
                         allowed_dtypes=(torch.float,))

    @_alias_ops((op for op in op_db if op.aliases))
    def test_jit_alias_remapping(self, device, dtype, op):
        # NOTE: only tests on first sample
        samples = op.sample_inputs(device, dtype, requires_grad=True)
        sample = first_sample(self, samples)

        # [Scripting Data Preparation]
        # Prepare data for test scripting
        # Below we prepare strings of args/kwargs with and without type annotations.
        # These strings are inserted into function template strings which is then torch scripted.
        # - args string is ["t0"] corresponding to the "input" tensor required by the op
        # - args_kw is the value of args and strings of kwargs used to call the op (without type annotations), for example,
        # ["to", "1.0", "(1,)", "True", "tensor(1.0)"] -> def fn(t0): return variant(t0, 1.0, (1,), True, tensor(1.0))
        args = ["t0"]

        def quote_strs(v):
            if isinstance(v, str):
                return f"'{v}'"

            return str(v)

        args_kw = args + \
            [f"{v}" for v in sample.args] + \
            [f"{k}={quote_strs(v)}" for k, v in sample.kwargs.items()]

        # Prepare data for test tracing
        sample_args_kwargs = ()
        if len(sample.args) > 0:
            sample_args_kwargs += (sample.args, )
        if len(sample.kwargs) > 0:
            sample_args_kwargs += (sample.kwargs, )

        original_name = op.aten_name
        original_name_inplace = original_name + "_"
        expected_dtype = op(sample.input, *sample.args, **sample.kwargs).dtype

        for a_op in op.aliases:
            inplace = a_op.inplace_variant
            method_or_inplace = [a_op.inplace_variant, a_op.method_variant]
            variants = (v for v in (a_op.op, a_op.method_variant, a_op.inplace_variant) if v is not None)

            # Test scripting:
            for variant in variants:
                variant_name = variant.__name__
                op_name = original_name_inplace if variant is inplace else original_name

                if variant in method_or_inplace:
                    fn_template = '''
                        def _fn(t0{c}):
                            return t0.{alias_name}({args_kw})
                    '''
                    # remove the first input tensor
                    script = fn_template.format(
                        c=", " if len(args_kw[1:]) > 1 else "",
                        args_kw=", ".join(args_kw[1:]),
                        alias_name=variant_name,
                    )
                else:
                    fn_template = '''
                        def _fn({args}):
                            return variant({args_kw})
                    '''
                    script = fn_template.format(
                        args=", ".join(args),
                        args_kw=", ".join(args_kw),
                    )

                # Required to avoid undefined value: tensor error in JIT
                # compilation of the function template
                script = script.replace("tensor(", "torch.tensor(")

                scripted = torch.jit.CompilationUnit(script)._fn

                if (variant is inplace and not torch.can_cast(expected_dtype, dtype)):
                    try:
                        inp = clone_input_helper(sample.input)
                        scripted(inp)
                    except Exception as e:
                        continue
                    self.fail("Inplace operation on integer tensor that should be promoted to float didn't fail!")

                inp = clone_input_helper(sample.input)
                scripted(inp)
                inp = clone_input_helper(sample.input)
                graph = scripted.graph_for(inp)
                FileCheck().check(op.aten_name).check_not(variant_name).run(graph)

            # Test tracing:
            for variant in variants:
                variant_name = variant.__name__
                op_name = original_name_inplace if variant is inplace else original_name

                def _fn(*sample_args, **sample_kwargs):
                    return variant(*sample_args, **sample_kwargs)

                inp = (clone_input_helper(sample.input),) + sample_args_kwargs
                traced = torch.jit.trace(_fn, *inp)
                inp = (clone_input_helper(sample.input),) + sample_args_kwargs
                traced(*inp)
                inp = (clone_input_helper(sample.input),) + sample_args_kwargs
                graph = traced.graph_for(*inp)
                FileCheck().check(op_name).check_not(variant_name).run(graph)


instantiate_device_type_tests(TestJit, globals())

if __name__ == '__main__':
    run_tests()