1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
|
###################################################################################################
# ATTENTION! This test will most probably fail if you install TensorRT 6.0.1 only.
# That's because it's shipped with older version of ONNX parser not supporting some
# required features. To make it work please use new version: https://github.com/onnx/onnx-tensorrt
# Just clone it and do something like this:
#
# ~/pt/third_party/onnx-tensorrt$ mkdir build/
# ~/pt/third_party/onnx-tensorrt$ cd build/
# ~/pt/third_party/onnx-tensorrt/build$ cmake ..
# ~/pt/third_party/onnx-tensorrt/build$ make
# ~/pt/third_party/onnx-tensorrt/build$ sudo cp libnvonnxparser.so.6.0.1 /usr/lib/x86_64-linux-gnu
#
# This note is valid for 6.0.1 release only. September 18th, 2019.
###################################################################################################
import os
import unittest
from PIL import Image
import numpy as np
import torch
import torchvision.models as models
import pycuda.driver as cuda
# This import causes pycuda to automatically manage CUDA context creation and cleanup.
import pycuda.autoinit
import tensorrt as trt
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
def allocate_buffers(engine):
h_input = cuda.pagelocked_empty(trt.volume(engine.get_binding_shape(0)),
dtype=trt.nptype(trt.float32))
h_output = cuda.pagelocked_empty(trt.volume(engine.get_binding_shape(1)),
dtype=trt.nptype(trt.float32))
d_input = cuda.mem_alloc(h_input.nbytes)
d_output = cuda.mem_alloc(h_output.nbytes)
stream = cuda.Stream()
return h_input, d_input, h_output, d_output, stream
def load_normalized_test_case(input_shape, test_image, pagelocked_buffer, normalization_hint):
def normalize_image(image):
c, h, w = input_shape
image_arr = np.asarray(image.resize((w, h), Image.ANTIALIAS)).transpose([2, 0, 1])\
.astype(trt.nptype(trt.float32)).ravel()
if (normalization_hint == 0):
return (image_arr / 255.0 - 0.45) / 0.225
elif (normalization_hint == 1):
return (image_arr / 256.0 - 0.5)
np.copyto(pagelocked_buffer, normalize_image(Image.open(test_image)))
return test_image
class Test_PT_ONNX_TRT(unittest.TestCase):
def __enter__(self):
return self
def setUp(self):
data_path = os.path.join(os.path.dirname(__file__), 'data')
self.image_files=["binoculars.jpeg", "reflex_camera.jpeg", "tabby_tiger_cat.jpg"]
for index, f in enumerate(self.image_files):
self.image_files[index] = os.path.abspath(os.path.join(data_path, f))
if not os.path.exists(self.image_files[index]):
raise FileNotFoundError(self.image_files[index] + " does not exist.")
with open(os.path.abspath(os.path.join(data_path, "class_labels.txt")), 'r') as f:
self.labels = f.read().split('\n')
def build_engine_onnx(self, model_file):
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(flags = 1) as network, trt.OnnxParser(network, TRT_LOGGER) as parser:
builder_config = builder.create_builder_config()
builder_config.max_workspace_size = 1 << 33
with open(model_file, 'rb') as model:
if not parser.parse(model.read()):
for error in range(parser.num_errors):
self.fail("ERROR: {}".format(parser.get_error(error)))
return builder.build_engine(network, builder_config)
def _test_model(self, model_name, input_shape = (3, 224, 224), normalization_hint = 0):
model = getattr(models, model_name)(pretrained=True)
shape = (1,) + input_shape
dummy_input = (torch.randn(shape),)
onnx_name = model_name + ".onnx"
torch.onnx.export(model,
dummy_input,
onnx_name,
input_names = [],
output_names = [],
verbose=False,
export_params=True,
opset_version=9)
with self.build_engine_onnx(onnx_name) as engine:
h_input, d_input, h_output, d_output, stream = allocate_buffers(engine)
with engine.create_execution_context() as context:
err_count = 0
for index, f in enumerate(self.image_files):
test_case = load_normalized_test_case(input_shape, f,\
h_input, normalization_hint)
cuda.memcpy_htod_async(d_input, h_input, stream)
context.execute_async_v2(bindings=[d_input, d_output],
stream_handle=stream.handle)
cuda.memcpy_dtoh_async(h_output, d_output, stream)
stream.synchronize()
amax = np.argmax(h_output)
pred = self.labels[amax]
if "_".join(pred.split()) not in\
os.path.splitext(os.path.basename(test_case))[0]:
err_count = err_count + 1
self.assertLessEqual(err_count, 1, "Too many recognition errors")
def test_alexnet(self):
self._test_model("alexnet", (3, 227, 227))
def test_resnet18(self):
self._test_model("resnet18")
def test_resnet34(self):
self._test_model("resnet34")
def test_resnet50(self):
self._test_model("resnet50")
def test_resnet101(self):
self._test_model("resnet101")
@unittest.skip("Takes 2m")
def test_resnet152(self):
self._test_model("resnet152")
def test_resnet50_2(self):
self._test_model("wide_resnet50_2")
@unittest.skip("Takes 2m")
def test_resnet101_2(self):
self._test_model("wide_resnet101_2")
def test_squeezenet1_0(self):
self._test_model("squeezenet1_0")
def test_squeezenet1_1(self):
self._test_model("squeezenet1_1")
def test_googlenet(self):
self._test_model("googlenet")
def test_inception_v3(self):
self._test_model("inception_v3")
def test_mnasnet0_5(self):
self._test_model("mnasnet0_5", normalization_hint = 1)
def test_mnasnet1_0(self):
self._test_model("mnasnet1_0", normalization_hint = 1)
def test_mobilenet_v2(self):
self._test_model("mobilenet_v2", normalization_hint = 1)
def test_shufflenet_v2_x0_5(self):
self._test_model("shufflenet_v2_x0_5")
def test_shufflenet_v2_x1_0(self):
self._test_model("shufflenet_v2_x1_0")
def test_vgg11(self):
self._test_model("vgg11")
def test_vgg11_bn(self):
self._test_model("vgg11_bn")
def test_vgg13(self):
self._test_model("vgg13")
def test_vgg13_bn(self):
self._test_model("vgg13_bn")
def test_vgg16(self):
self._test_model("vgg16")
def test_vgg16_bn(self):
self._test_model("vgg16_bn")
def test_vgg19(self):
self._test_model("vgg19")
def test_vgg19_bn(self):
self._test_model("vgg19_bn")
@unittest.skip("Takes 13m")
def test_densenet121(self):
self._test_model("densenet121")
@unittest.skip("Takes 25m")
def test_densenet161(self):
self._test_model("densenet161")
@unittest.skip("Takes 27m")
def test_densenet169(self):
self._test_model("densenet169")
@unittest.skip("Takes 44m")
def test_densenet201(self):
self._test_model("densenet201")
if __name__ == '__main__':
unittest.main()
|