File: test_pt_onnx_trt.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (190 lines) | stat: -rw-r--r-- 7,636 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
###################################################################################################
# ATTENTION! This test will most probably fail if you install TensorRT 6.0.1 only.
# That's because it's shipped with older version of ONNX parser not supporting some
# required features. To make it work please use new version: https://github.com/onnx/onnx-tensorrt
# Just clone it and do something like this:
#
# ~/pt/third_party/onnx-tensorrt$ mkdir build/
# ~/pt/third_party/onnx-tensorrt$ cd build/
# ~/pt/third_party/onnx-tensorrt/build$ cmake ..
# ~/pt/third_party/onnx-tensorrt/build$ make
# ~/pt/third_party/onnx-tensorrt/build$ sudo cp libnvonnxparser.so.6.0.1 /usr/lib/x86_64-linux-gnu
#
# This note is valid for 6.0.1 release only. September 18th, 2019.
###################################################################################################

import os
import unittest

from PIL import Image
import numpy as np
import torch
import torchvision.models as models

import pycuda.driver as cuda
# This import causes pycuda to automatically manage CUDA context creation and cleanup.
import pycuda.autoinit

import tensorrt as trt
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

def allocate_buffers(engine):
    h_input = cuda.pagelocked_empty(trt.volume(engine.get_binding_shape(0)),
                                    dtype=trt.nptype(trt.float32))
    h_output = cuda.pagelocked_empty(trt.volume(engine.get_binding_shape(1)),
                                     dtype=trt.nptype(trt.float32))
    d_input = cuda.mem_alloc(h_input.nbytes)
    d_output = cuda.mem_alloc(h_output.nbytes)
    stream = cuda.Stream()
    return h_input, d_input, h_output, d_output, stream

def load_normalized_test_case(input_shape, test_image, pagelocked_buffer, normalization_hint):
    def normalize_image(image):
        c, h, w = input_shape
        image_arr = np.asarray(image.resize((w, h), Image.ANTIALIAS)).transpose([2, 0, 1])\
            .astype(trt.nptype(trt.float32)).ravel()
        if (normalization_hint == 0):
            return (image_arr / 255.0 - 0.45) / 0.225
        elif (normalization_hint == 1):
            return (image_arr / 256.0 - 0.5)
    np.copyto(pagelocked_buffer, normalize_image(Image.open(test_image)))
    return test_image

class Test_PT_ONNX_TRT(unittest.TestCase):
    def __enter__(self):
        return self

    def setUp(self):
        data_path = os.path.join(os.path.dirname(__file__), 'data')
        self.image_files=["binoculars.jpeg", "reflex_camera.jpeg", "tabby_tiger_cat.jpg"]
        for index, f in enumerate(self.image_files):
            self.image_files[index] = os.path.abspath(os.path.join(data_path, f))
            if not os.path.exists(self.image_files[index]):
                raise FileNotFoundError(self.image_files[index] + " does not exist.")
        with open(os.path.abspath(os.path.join(data_path, "class_labels.txt")), 'r') as f:
            self.labels = f.read().split('\n')

    def build_engine_onnx(self, model_file):
        with trt.Builder(TRT_LOGGER) as builder, builder.create_network(flags = 1) as network, trt.OnnxParser(network, TRT_LOGGER) as parser:
            builder_config = builder.create_builder_config()
            builder_config.max_workspace_size = 1 << 33
            with open(model_file, 'rb') as model:
                if not parser.parse(model.read()):
                    for error in range(parser.num_errors):
                        self.fail("ERROR: {}".format(parser.get_error(error)))
            return builder.build_engine(network, builder_config)

    def _test_model(self, model_name, input_shape = (3, 224, 224), normalization_hint = 0):

        model = getattr(models, model_name)(pretrained=True)

        shape = (1,) + input_shape
        dummy_input  = (torch.randn(shape),)
        onnx_name = model_name + ".onnx"

        torch.onnx.export(model,
                          dummy_input,
                          onnx_name,
                          input_names = [],
                          output_names = [],
                          verbose=False,
                          export_params=True,
                          opset_version=9)

        with self.build_engine_onnx(onnx_name) as engine:
            h_input, d_input, h_output, d_output, stream = allocate_buffers(engine)
            with engine.create_execution_context() as context:
                err_count = 0
                for index, f in enumerate(self.image_files):
                    test_case = load_normalized_test_case(input_shape, f,\
                        h_input, normalization_hint)
                    cuda.memcpy_htod_async(d_input, h_input, stream)

                    context.execute_async_v2(bindings=[d_input, d_output],
                                             stream_handle=stream.handle)
                    cuda.memcpy_dtoh_async(h_output, d_output, stream)
                    stream.synchronize()

                    amax = np.argmax(h_output)
                    pred = self.labels[amax]
                    if "_".join(pred.split()) not in\
                            os.path.splitext(os.path.basename(test_case))[0]:
                        err_count = err_count + 1
                self.assertLessEqual(err_count, 1, "Too many recognition errors")

    def test_alexnet(self):
        self._test_model("alexnet", (3, 227, 227))

    def test_resnet18(self):
        self._test_model("resnet18")
    def test_resnet34(self):
        self._test_model("resnet34")
    def test_resnet50(self):
        self._test_model("resnet50")
    def test_resnet101(self):
        self._test_model("resnet101")
    @unittest.skip("Takes 2m")
    def test_resnet152(self):
        self._test_model("resnet152")

    def test_resnet50_2(self):
        self._test_model("wide_resnet50_2")
    @unittest.skip("Takes 2m")
    def test_resnet101_2(self):
        self._test_model("wide_resnet101_2")

    def test_squeezenet1_0(self):
        self._test_model("squeezenet1_0")
    def test_squeezenet1_1(self):
        self._test_model("squeezenet1_1")

    def test_googlenet(self):
        self._test_model("googlenet")
    def test_inception_v3(self):
        self._test_model("inception_v3")

    def test_mnasnet0_5(self):
        self._test_model("mnasnet0_5", normalization_hint = 1)
    def test_mnasnet1_0(self):
        self._test_model("mnasnet1_0", normalization_hint = 1)

    def test_mobilenet_v2(self):
        self._test_model("mobilenet_v2", normalization_hint = 1)

    def test_shufflenet_v2_x0_5(self):
        self._test_model("shufflenet_v2_x0_5")
    def test_shufflenet_v2_x1_0(self):
        self._test_model("shufflenet_v2_x1_0")

    def test_vgg11(self):
        self._test_model("vgg11")
    def test_vgg11_bn(self):
        self._test_model("vgg11_bn")
    def test_vgg13(self):
        self._test_model("vgg13")
    def test_vgg13_bn(self):
        self._test_model("vgg13_bn")
    def test_vgg16(self):
        self._test_model("vgg16")
    def test_vgg16_bn(self):
        self._test_model("vgg16_bn")
    def test_vgg19(self):
        self._test_model("vgg19")
    def test_vgg19_bn(self):
        self._test_model("vgg19_bn")

    @unittest.skip("Takes 13m")
    def test_densenet121(self):
        self._test_model("densenet121")
    @unittest.skip("Takes 25m")
    def test_densenet161(self):
        self._test_model("densenet161")
    @unittest.skip("Takes 27m")
    def test_densenet169(self):
        self._test_model("densenet169")
    @unittest.skip("Takes 44m")
    def test_densenet201(self):
        self._test_model("densenet201")

if __name__ == '__main__':
    unittest.main()