File: testutils.py

package info (click to toggle)
onnxscript 0.2.0%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 12,384 kB
  • sloc: python: 75,957; sh: 41; makefile: 6
file content (116 lines) | stat: -rw-r--r-- 3,949 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import functools
import os
import pathlib
import unittest

import numpy as np
import onnx
import onnxruntime
import torch

from onnxscript import optimizer
from onnxscript._legacy_ir import visitor
from onnxscript.rewriter import onnxruntime as ort_rewriter
from onnxscript.utils import evaluation_utils


class TestBase(unittest.TestCase):
    """The base class for testing ONNX Script functions for internal use."""

    def validate(self, fn):
        """Validate script function translation."""
        return fn.to_function_proto()


def skip_if_no_cuda(reason: str):
    def skip_dec(func):
        @functools.wraps(func)
        def wrapper(self, *args, **kwargs):
            if not torch.cuda.is_available() or not onnxruntime.get_device() == "GPU":
                raise unittest.SkipTest(f"GPU is not available. {reason}")
            return func(self, *args, **kwargs)

        return wrapper

    return skip_dec


class OpTypeAnalysisVisitor(visitor.ProtoVisitorCore):
    def __init__(self):
        super().__init__()
        self.op_types = set()

    def visit_model(self, model: onnx.ModelProto):
        self.op_types = set()
        super().visit_model(model)

    def process_node(self, node: onnx.NodeProto):
        self.op_types.add((node.domain, node.op_type, getattr(node, "overload", "")))
        return super().process_node(node)


def test_onnxruntime_rewrite(
    model_basename: str,
    model_count: int,
    expected_optypes: set[tuple[str, str, str]],
    rtol: float = 1e-2,
    atol: float = 1e-2,
):
    dir_path = pathlib.Path(os.path.dirname(os.path.realpath(__file__)))
    unittest_root_dir = dir_path.parent.parent / "testdata" / "unittest_models"
    for model_index in range(model_count):
        model_name = f"{model_basename}_{model_index}"
        model_dir = unittest_root_dir / f"{model_name}"
        model_path = model_dir / f"{model_name}.onnx"
        model = onnx.load(model_path)

        # TODO: Parity issue with randomly generated data. Need investigation.
        # inputs = generate_random_input(model)
        inputs, expected_outputs = evaluation_utils.load_test_data(
            model_dir, [i.name for i in model.graph.input]
        )

        optimized = optimizer.optimize(
            model,
            onnx_shape_inference=False,
            num_iterations=2,
        )
        rewritten = ort_rewriter.rewrite(optimized)
        # NOTE: uncomment this to save the optimized model.
        # onnx.save(rewritten, model_dir / f"{model_name}_opt.onnx")

        # Check expected operator is found.
        optype_analysis = OpTypeAnalysisVisitor()
        optype_analysis.visit_model(rewritten)
        for domain, op_type, overload in expected_optypes:
            if (domain, op_type, overload) not in optype_analysis.op_types:
                raise AssertionError(
                    f"Expected op type {domain}:{op_type}:{overload} not found in rewritten model."
                )

        # Run baseline model
        providers = ["CUDAExecutionProvider"]

        # Run optimized model
        optimized_session = onnxruntime.InferenceSession(
            rewritten.SerializeToString(), providers=providers
        )
        optimized_outputs = optimized_session.run(None, inputs)

        for i, (baseline_output, optimized_output) in enumerate(
            zip(expected_outputs, optimized_outputs)
        ):
            try:
                np.testing.assert_equal(baseline_output.shape, optimized_output.shape)
                np.testing.assert_allclose(
                    baseline_output, optimized_output, rtol=rtol, atol=atol
                )
            except AssertionError as e:
                print(
                    f"Failed for model {model_name} and output {i} with rtol={rtol} and atol={atol}\n{e}"
                )
                raise