File: ops_test.py

package info (click to toggle)
onnxscript 0.2.0%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 12,384 kB
  • sloc: python: 75,957; sh: 41; makefile: 6
file content (337 lines) | stat: -rw-r--r-- 13,113 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Test op correctness by comparing with PyTorch results.

Usage:

    pytest onnxscript/tests/function_libs/torch_lib/ops_test.py

    To run tests on a specific operator (e.g. torch.ceil):

    pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k ceil

    To run tests on a nn operator (e.g. nn.functional.scaled_dot_product_attention):

    pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k nn_functional_scaled_dot_product_attention

## Environment variables

1. Set environment variable `CATCH_ORT_SEGFAULT=1` to catch segmentation faults
in onnxruntime by running the inference sessions in a separate process.

2. Set `CREATE_REPRODUCTION_REPORT=1` to create markdown files for reproduction of
errors.
"""

from __future__ import annotations

import os
import unittest
from typing import Callable, Optional, Sequence, Tuple

import numpy as np
import onnx
import onnxruntime as ort
import parameterized
import torch
from torch.testing._internal import common_device_type
from torch.testing._internal.opinfo import core as opinfo_core
from torch.utils import _pytree as pytree

import onnxscript
from tests.function_libs.torch_lib import (
    error_reproduction,
    ops_test_common,
    ops_test_data,
)

# All dtypes will be tested on the generated symbolic functions.
# complex64 will be flattened to float32.
TESTED_DTYPES = (
    torch.float16,
    torch.float32,
    # Uncomment below item when we really need testing it
    # torch.bfloat16,
    # torch.float64,
    torch.bool,
    # torch.int8,
    # torch.int16,
    torch.int32,
    torch.int64,
    # torch.uint8,
)
# NOTE: torch.complex32 is experimental in torch
COMPLEX_TYPES = (torch.complex64,)


def dtypes_except(*dtypes: torch.dtype) -> Sequence[torch.dtype]:
    """Returns all dtypes except the ones specified."""
    return tuple(dtype for dtype in TESTED_DTYPES if dtype not in dtypes)


def _should_skip_xfail_test_sample(
    op_name: str, sample, dtype: torch.dtype, device_type: str
) -> Tuple[Optional[str], Optional[str]]:
    """Returns a reason if a test sample should be skipped."""
    if op_name not in ops_test_data.OP_WITH_SKIPPED_XFAIL_SUBTESTS:
        return None, None
    for decorator_meta in ops_test_data.SKIP_XFAIL_SUBTESTS:
        # Linear search on ops_test_data.SKIP_XFAIL_SUBTESTS. That's fine because the list is small.
        if decorator_meta.op_name == op_name:
            assert decorator_meta.matcher is not None, "Matcher must be defined"
            if not decorator_meta.enabled_if:
                # Do not skip the test if the decorator meta is not enabled
                continue
            if decorator_meta.dtypes is not None and dtype not in decorator_meta.dtypes:
                # Not applicable for this dtype
                continue
            if (
                decorator_meta.device_type is not None
                and decorator_meta.device_type != device_type
            ):
                # Not applicable for this device_type
                continue
            if decorator_meta.matcher(sample):
                return decorator_meta.test_behavior, decorator_meta.reason
    return None, None


class TestFunctionValidity(unittest.TestCase):
    @parameterized.parameterized.expand(
        [(info.op.name, info) for info in ops_test_data.TESTED_TORCHLIB_OPS]
    )
    def test_script_function_passes_checker(
        self, _, torchlib_op_info: ops_test_data.TorchLibOpInfo
    ):
        if not isinstance(torchlib_op_info.op, onnxscript.OnnxFunction):
            self.skipTest("Traced functions does not have a function proto")
        function_proto = torchlib_op_info.op.to_function_proto()
        onnx.checker.check_function(function_proto)  # type: ignore[attr-defined]

    @parameterized.parameterized.expand(
        [(info.op.name, info) for info in ops_test_data.TESTED_TORCHLIB_OPS]
    )
    def test_function_has_op_schema(self, _, torchlib_op_info: ops_test_data.TorchLibOpInfo):
        func = torchlib_op_info.op
        schema = func.op_schema
        self.assertIsNotNone(schema)
        self.assertEqual(schema.name, func.name)


def run_test_output_match(
    test_suite: unittest.TestCase,
    device: str,
    dtype: torch.dtype,
    op: opinfo_core.OpInfo,
    function_executor: Callable,
    tested_op_mapping: dict[
        str,
        ops_test_data.TorchLibOpInfo,
    ],
):
    """Base test method for testing each opset, used by instantiate_device_type_tests.

    Args:
        test_suite: The test class instance.
        device: The PyTorch device. instantiate_device_type_tests provides this.
        dtype: The PyTorch dtype. instantiate_device_type_tests provides this.
        op: The OpInfo instance. instantiate_device_type_tests provides this.
        function_executor: The function executor. This is a function that takes
            a function and its arguments and returns the output of the function.
        tested_op_mapping: The mapping of op name to the tested op.
    """
    samples = op.sample_inputs(
        device,
        dtype,
        requires_grad=False,
    )

    torchlib_op_info = tested_op_mapping[op.name]
    # Obtain the input_wrangler that manipulates the OpInfo inputs
    # to match the aten operator signature
    # An example is nn.functional.upsample_nearest2d, which has a different signature
    # than the aten operator upsample_nearest2d
    onnx_function = torchlib_op_info.op
    input_wrangler = torchlib_op_info.input_wrangler
    if (
        not ops_test_common.dtype_op_schema_compatible(dtype, onnx_function.op_schema)
        and dtype not in COMPLEX_TYPES
    ):
        test_suite.skipTest(
            f"dtype '{dtype}' is not supported by the op '{op.name}'. "
            f"Type constraints: {onnx_function.op_schema.type_constraints}"
        )

    # Obtain the tolerance for the op
    rtol, atol = torchlib_op_info.get_tolerance(dtype)
    for i, cpu_sample in enumerate(samples):
        inputs = (cpu_sample.input, *cpu_sample.args)
        # Provide the repr to subtest because tensors are not serializable in parallel test runs
        with test_suite.subTest(
            sample_num=i,
            inputs=repr(
                [
                    f"Tensor<{inp.shape}, dtype={inp.dtype}>"
                    if isinstance(inp, torch.Tensor)
                    else inp
                    for inp in inputs
                ]
            ),
            kwargs=repr(cpu_sample.kwargs),
        ):
            try:
                device_type = cpu_sample.args[0].device.type
            except (AttributeError, IndexError):
                device_type = "cpu"
            test_behavior, reason = _should_skip_xfail_test_sample(
                op.name, cpu_sample, dtype, device_type
            )

            with ops_test_common.normal_xfail_skip_test_behaviors(test_behavior, reason):
                input_onnx = [ops_test_common.convert_tensor_to_numpy(x) for x in inputs]
                kwargs_onnx = ops_test_common.convert_kwargs_for_onnx(cpu_sample.kwargs)
                if input_wrangler:
                    input_onnx, kwargs_onnx = input_wrangler(input_onnx, kwargs_onnx)
                torch_output = op(*inputs, **cpu_sample.kwargs)

                if isinstance(torch_output, torch.Tensor) and torch.is_complex(torch_output):
                    torch_output = torch.view_as_real(torch_output.resolve_conj())

                reference_torch_outputs, _ = pytree.tree_flatten(torch_output)
                if (
                    op.name.startswith("split")
                    or op.name.startswith("chunk")
                    or op.name.startswith("unbind")
                    or op.name
                    in {"atleast_1d_Sequence", "atleast_2d_Sequence", "atleast_3d_Sequence"}
                ):
                    # Hack for handling split, chunk and unbind which relies on SplitToSequence op.
                    # Split returns a Sequence that should be treats as a single
                    # value. So we wrap it into a tuple.
                    # TODO(justinchuby): Find a more general solution
                    reference_torch_outputs = [reference_torch_outputs]

                test_name = test_suite.id()
                function_output = function_executor(test_name, reference_torch_outputs)(
                    onnx_function, input_onnx, kwargs_onnx
                )
                # Finally we re-flatten everything
                # TODO: add pytree structure comparison.
                flattened_torch_outputs, _ = pytree.tree_flatten(torch_output)
                flattened_function_outputs, _ = pytree.tree_flatten(function_output)

                assert flattened_torch_outputs
                assert len(flattened_torch_outputs) == len(flattened_function_outputs)

                for j, (torch_output, function_output) in enumerate(
                    zip(flattened_torch_outputs, flattened_function_outputs)
                ):
                    if not isinstance(function_output, np.ndarray):
                        # An onnxscript tensor
                        function_output = function_output.value

                    actual = torch.tensor(function_output)
                    expected = (
                        torch_output
                        if isinstance(torch_output, torch.Tensor)
                        else torch.tensor(torch_output)
                    )

                    if (
                        op.name in ops_test_data.NONDETERMINISTIC_OPS
                        or j in ops_test_data.COMPARE_SHAPE_ONLY_OPS[op.name]
                    ):
                        # Check shape and dtype only for ops that are known to be
                        # nondeterministic
                        test_suite.assertEqual(actual.shape, expected.shape)
                        test_suite.assertEqual(actual.dtype, expected.dtype)
                        continue

                    # Use torch.testing as opposed to np.testing to ensure dtypes and shapes match
                    try:
                        torch.testing.assert_close(
                            actual,
                            expected,
                            rtol=rtol,
                            atol=atol,
                            equal_nan=True,
                            check_device=False,
                        )
                    except AssertionError as e:
                        if os.environ.get("CREATE_REPRODUCTION_REPORT") == "1":
                            error_reproduction.create_mismatch_report(
                                test_name, i, inputs, cpu_sample.kwargs, actual, expected, e
                            )
                        if len(flattened_torch_outputs) > 1:
                            raise AssertionError(f"Output {j} mismatch") from e
                        raise


class TestOutputConsistencyFullGraph(unittest.TestCase):
    """Test output consistency between exported ONNX op run as a graph and PyTorch eager mode.

    This is a parameterized test suite.
    """

    def setUp(self) -> None:
        torch.manual_seed(42)
        np.random.seed(42)
        ort.set_seed(42)

    @ops_test_common.add_decorate_info(
        ops_test_data.OPS_DB,
        "TestOutputConsistencyFullGraph",
        "test_output_match_opinfo_",
        skip_or_xfails=ops_test_data.EXPECTED_SKIPS_OR_FAILS,
    )
    @common_device_type.ops(  # type: ignore[misc]
        [info for info in ops_test_data.OPS_DB if info.name in ops_test_data.TESTED_OPS],
        allowed_dtypes=TESTED_DTYPES,
    )
    def test_output_match_opinfo_(
        self, device: str, dtype: torch.dtype, op: opinfo_core.OpInfo
    ):
        # Base test method for testing each op by running the full ONNX graph.
        run_test_output_match(
            self,
            device,
            dtype,
            op,
            ops_test_common.graph_executor,
            ops_test_data.TORCHLIB_OPINFO_MAPPING,
        )

    @ops_test_common.add_decorate_info(
        ops_test_data.OPS_DB,
        "TestOutputConsistencyFullGraph",
        "test_complex_output_match_opinfo_",
        skip_or_xfails=ops_test_data.EXPECTED_SKIPS_OR_FAILS,
    )
    @common_device_type.ops(  # type: ignore[misc]
        [
            info
            for info in ops_test_data.OPS_DB
            if info.name in ops_test_data.COMPLEX_FUNCTION_MAPPING
        ],
        allowed_dtypes=COMPLEX_TYPES,
    )
    def test_complex_output_match_opinfo_(
        self, device: str, dtype: torch.dtype, op: opinfo_core.OpInfo
    ):
        """Base test method for testing each op by running the full ONNX graph."""
        run_test_output_match(
            self,
            device,
            dtype,
            op,
            ops_test_common.graph_executor,
            ops_test_data.COMPLEX_FUNCTION_MAPPING,
        )


common_device_type.instantiate_device_type_tests(
    TestOutputConsistencyFullGraph, globals(), only_for=["cpu", "cuda"]
)

if __name__ == "__main__":
    unittest.main()