File: test_model_zoo.py

package info (click to toggle)
onnx 1.17.0-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 52,856 kB
  • sloc: python: 73,992; cpp: 53,539; makefile: 50; sh: 48; javascript: 1
file content (127 lines) | stat: -rw-r--r-- 4,799 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# Copyright (c) ONNX Project Contributors

# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import argparse
import gc
import os
import shutil
import sys
import time

import config

import onnx
from onnx import hub, version_converter

MIN_SHAPE_INFERENCE_OPSET = 4


def skip_model(error_message: str, skip_list: list[str], model_name: str):
    print(error_message)
    skip_list.append(model_name)


def main():
    parser = argparse.ArgumentParser(description="Test settings")
    # default: test all models in the repo
    # if test_dir is specified, only test files under that specified path
    parser.add_argument(
        "--test_dir",
        required=False,
        default="",
        type=str,
        help="Directory path for testing. e.g., text, vision",
    )
    model_list = hub.list_models()
    print(f"=== Running ONNX Checker on {len(model_list)} models ===")

    # run checker on each model
    failed_models = []
    failed_messages = []
    skip_models: list[str] = []
    for m in model_list:
        start = time.time()
        model_name = m.model
        model_path = m.model_path
        print(f"-----------------Testing: {model_name}-----------------")
        try:
            model = hub.load(model_name)
            # 1) Test onnx checker and shape inference
            if model.opset_import[0].version < MIN_SHAPE_INFERENCE_OPSET:
                # Ancient opset version does not have defined shape inference function
                onnx.checker.check_model(model)
                print(f"[PASS]: {model_name} is checked by onnx checker. ")
            else:
                # stricter onnx.checker with onnx.shape_inference
                onnx.checker.check_model(model, True)
                print(
                    f"[PASS]: {model_name} is checked by onnx checker with shape_inference. "
                )

                # 2) Test onnx version converter with upgrade functionality
                original_version = model.opset_import[0].version
                latest_opset_version = onnx.helper.VERSION_TABLE[-1][2]
                if original_version < latest_opset_version:
                    if model_path in config.SKIP_VERSION_CONVERTER_MODELS:
                        skip_model(
                            f"[SKIP]: model {model_name} is in the skip list for version converter. ",
                            skip_models,
                            model_name,
                        )
                    elif model_path.endswith("-int8.onnx"):
                        skip_model(
                            f"[SKIP]: model {model_name} is a quantized model using non-official ONNX domain. ",
                            skip_models,
                            model_name,
                        )
                    else:
                        converted = version_converter.convert_version(
                            model, original_version + 1
                        )
                        onnx.checker.check_model(converted, True)
                        print(
                            f"[PASS]: {model_name} can be version converted by original_version+1. "
                        )
                elif original_version == latest_opset_version:
                    skip_model(
                        f"[SKIP]: {model_name} is already the latest opset version. ",
                        skip_models,
                        model_name,
                    )
                else:
                    raise RuntimeError(  # noqa: TRY301
                        f"{model_name} has unsupported opset_version {original_version}. "
                    )

            # remove the model to save space in CIs
            full_model_path = os.path.join(hub._ONNX_HUB_DIR, model_path)
            parent_dir = os.path.dirname(full_model_path)
            if os.path.exists(parent_dir):
                shutil.rmtree(parent_dir)

        except Exception as e:  # noqa: BLE001
            print(f"[FAIL]: {e}")
            failed_models.append(model_name)
            failed_messages.append((model_name, e))
        end = time.time()
        print(f"--------------Time used: {end - start} secs-------------")
        # enable gc collection to prevent MemoryError by loading too many large models
        gc.collect()

    if len(failed_models) == 0:
        print(
            f"{len(model_list)} models have been checked. {len(skip_models)} models were skipped."
        )
    else:
        print(
            f"In all {len(model_list)} models, {len(failed_models)} models failed, {len(skip_models)} models were skipped"
        )
        for model_name, error in failed_messages:
            print(f"{model_name} failed because: {error}")
        sys.exit(1)


if __name__ == "__main__":
    main()