1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
|
import unittest
from model import Model
from onnx import helper
from onnx import TensorProto
import numpy as np
class ModelTest(unittest.TestCase):
def setUp(self):
pass
def test_empty_model(self):
model = Model()
with self.assertRaises(Exception):
model.GenerateSchedule()
with self.assertRaises(Exception):
model.PrintLoopNest()
with self.assertRaises(Exception):
model.PrintLoweredStatement()
def test_small_model(self):
# Create one input
X = helper.make_tensor_value_info('IN', TensorProto.FLOAT, [2, 3])
# Create one output
Y = helper.make_tensor_value_info('OUT', TensorProto.FLOAT, [2, 3])
# Create a node
node_def = helper.make_node('Abs', ['IN'], ['OUT'])
# Create the model
graph_def = helper.make_graph([node_def], "test-model", [X], [Y])
onnx_model = helper.make_model(graph_def,
producer_name='onnx-example')
model = Model()
model.BuildFromOnnxModel(onnx_model)
schedule = model.OptimizeSchedule()
schedule = schedule.replace('\n', ' ')
expected_schedule = r'.*Func OUT = pipeline.get_func\(1\);.+'
self.assertRegex(schedule, expected_schedule)
input = (np.random.rand(2, 3) - 0.5).astype('float32')
outputs = model.run([input])
self.assertEqual(1, len(outputs))
output = outputs[0]
expected = np.abs(input)
np.testing.assert_allclose(expected, output)
def test_scalars(self):
# Create 2 inputs
X = helper.make_tensor_value_info('A', TensorProto.INT32, [])
Y = helper.make_tensor_value_info('B', TensorProto.INT32, [])
# Create one output
Z = helper.make_tensor_value_info('C', TensorProto.INT32, [])
# Create a node
node_def = helper.make_node('Add', ['A', 'B'], ['C'])
# Create the model
graph_def = helper.make_graph([node_def], "scalar-model", [X, Y], [Z])
onnx_model = helper.make_model(graph_def,
producer_name='onnx-example')
model = Model()
model.BuildFromOnnxModel(onnx_model)
schedule = model.OptimizeSchedule()
schedule = schedule.replace('\n', ' ')
expected_schedule = r'.*Func C = pipeline.get_func\(2\);.+'
self.assertRegex(schedule, expected_schedule)
input1 = np.random.randint(-10, 10, size=()).astype('int32')
input2 = np.random.randint(-10, 10, size=()).astype('int32')
outputs = model.run([input1, input2])
self.assertEqual(1, len(outputs))
output = outputs[0]
expected = input1 + input2
np.testing.assert_allclose(expected, output)
def test_model_with_initializer(self):
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [3, 1])
Z2 = helper.make_tensor_value_info('Z2', TensorProto.FLOAT, [2, 3, 6])
expand_node_def = helper.make_node('Expand', ['X', 'Y'], ['Z1'])
cast_node_def = helper.make_node('Scale', ['Z1'], ['Z2'])
graph_def = helper.make_graph([expand_node_def, cast_node_def],
"test-node",
[X],
[Z2],
initializer=[
helper.make_tensor('Y', TensorProto.INT64, (3,), (2, 1, 6))])
onnx_model = helper.make_model(graph_def,
producer_name='onnx-example')
model = Model()
model.BuildFromOnnxModel(onnx_model)
input_data = np.random.rand(3, 1).astype(np.float32)
outputs = model.run([input_data])
expected = input_data * np.ones([2, 1, 6], dtype=np.float32)
np.testing.assert_allclose(expected, outputs[0])
def test_tensors_rank_zero(self):
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [3, 2])
S1 = helper.make_tensor_value_info('S1', TensorProto.INT64, [])
S2 = helper.make_tensor_value_info('S2', TensorProto.FLOAT, [])
size_node = helper.make_node('Size', ['X'], ['S1'])
graph_def = helper.make_graph([size_node],
"rank_zero_test",
[X],
[S1, S2],
initializer=[
helper.make_tensor('S2', TensorProto.FLOAT, (), (3.14,))])
onnx_model = helper.make_model(graph_def,
producer_name='onnx-example')
model = Model()
model.BuildFromOnnxModel(onnx_model)
input_data = np.random.rand(3, 2).astype(np.float32)
outputs = model.run([input_data])
self.assertEqual(6, outputs[0])
self.assertAlmostEqual(3.14, outputs[1])
|