1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
|
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
#
# mypy: disable-error-code=misc
"""Unit tests for the onnx_types module."""
from __future__ import annotations
import unittest
from parameterized import parameterized
from onnxscript.onnx_types import DOUBLE, FLOAT, TensorType, tensor_type_registry
class TestOnnxTypes(unittest.TestCase):
def test_instantiation(self):
with self.assertRaises(NotImplementedError):
TensorType()
with self.assertRaises(NotImplementedError):
FLOAT()
with self.assertRaises(NotImplementedError):
FLOAT[...]()
@parameterized.expand(tensor_type_registry.items())
def test_type_properties(self, dtype: int, tensor_type: type[TensorType]):
self.assertEqual(tensor_type.dtype, dtype)
self.assertIsNone(tensor_type.shape)
self.assertEqual(tensor_type[...].shape, ...) # type: ignore[index]
self.assertEqual(tensor_type[...].dtype, dtype) # type: ignore[index]
self.assertEqual(tensor_type[1, 2, 3].shape, (1, 2, 3)) # type: ignore[index]
self.assertEqual(tensor_type[1, 2, 3].dtype, dtype) # type: ignore[index]
@parameterized.expand([(dtype,) for dtype in tensor_type_registry])
def test_dtype_bound_to_subclass(self, dtype: int):
with self.assertRaises(ValueError):
type(f"InvalidTensorTypeSubclass_{dtype}", (TensorType,), {}, dtype=dtype)
def test_shaped_doesnt_reshape(self):
with self.assertRaises(ValueError):
FLOAT[1][...] # pylint: disable=pointless-statement
@parameterized.expand(
[
(FLOAT, FLOAT),
(FLOAT[None], FLOAT[None]),
(FLOAT[1, 2, 3], FLOAT[1, 2, 3]),
(FLOAT[1], FLOAT[1]),
(FLOAT[...], FLOAT[Ellipsis]),
(FLOAT["M"], FLOAT["M"]),
(FLOAT["M", "N"], FLOAT["M", "N"]),
(FLOAT["M", 3, 4], FLOAT["M", 3, 4]),
]
)
def test_shapes_are_same_type(self, a: TensorType, b: TensorType):
self.assertIs(a, b)
@parameterized.expand(
[
(FLOAT[0], FLOAT[None]),
(FLOAT[1, 2], FLOAT[3, 4]),
(FLOAT[2, 1], FLOAT[1, 2]),
(FLOAT["M", "N"], FLOAT["N", "M"]),
(FLOAT, DOUBLE),
(FLOAT[1], DOUBLE[1]),
(FLOAT["X"], DOUBLE["X"]),
]
)
def test_shapes_are_not_same_type(self, a: TensorType, b: TensorType):
self.assertIsNot(a, b)
if __name__ == "__main__":
unittest.main()
|