File: onnx_types_test.py

package info (click to toggle)
onnxscript 0.2.0%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 12,384 kB
  • sloc: python: 75,957; sh: 41; makefile: 6
file content (77 lines) | stat: -rw-r--r-- 2,736 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
#
# mypy: disable-error-code=misc

"""Unit tests for the onnx_types module."""

from __future__ import annotations

import unittest

from parameterized import parameterized

from onnxscript.onnx_types import DOUBLE, FLOAT, TensorType, tensor_type_registry


class TestOnnxTypes(unittest.TestCase):
    def test_instantiation(self):
        with self.assertRaises(NotImplementedError):
            TensorType()
        with self.assertRaises(NotImplementedError):
            FLOAT()
        with self.assertRaises(NotImplementedError):
            FLOAT[...]()

    @parameterized.expand(tensor_type_registry.items())
    def test_type_properties(self, dtype: int, tensor_type: type[TensorType]):
        self.assertEqual(tensor_type.dtype, dtype)
        self.assertIsNone(tensor_type.shape)
        self.assertEqual(tensor_type[...].shape, ...)  # type: ignore[index]
        self.assertEqual(tensor_type[...].dtype, dtype)  # type: ignore[index]
        self.assertEqual(tensor_type[1, 2, 3].shape, (1, 2, 3))  # type: ignore[index]
        self.assertEqual(tensor_type[1, 2, 3].dtype, dtype)  # type: ignore[index]

    @parameterized.expand([(dtype,) for dtype in tensor_type_registry])
    def test_dtype_bound_to_subclass(self, dtype: int):
        with self.assertRaises(ValueError):
            type(f"InvalidTensorTypeSubclass_{dtype}", (TensorType,), {}, dtype=dtype)

    def test_shaped_doesnt_reshape(self):
        with self.assertRaises(ValueError):
            FLOAT[1][...]  # pylint: disable=pointless-statement

    @parameterized.expand(
        [
            (FLOAT, FLOAT),
            (FLOAT[None], FLOAT[None]),
            (FLOAT[1, 2, 3], FLOAT[1, 2, 3]),
            (FLOAT[1], FLOAT[1]),
            (FLOAT[...], FLOAT[Ellipsis]),
            (FLOAT["M"], FLOAT["M"]),
            (FLOAT["M", "N"], FLOAT["M", "N"]),
            (FLOAT["M", 3, 4], FLOAT["M", 3, 4]),
        ]
    )
    def test_shapes_are_same_type(self, a: TensorType, b: TensorType):
        self.assertIs(a, b)

    @parameterized.expand(
        [
            (FLOAT[0], FLOAT[None]),
            (FLOAT[1, 2], FLOAT[3, 4]),
            (FLOAT[2, 1], FLOAT[1, 2]),
            (FLOAT["M", "N"], FLOAT["N", "M"]),
            (FLOAT, DOUBLE),
            (FLOAT[1], DOUBLE[1]),
            (FLOAT["X"], DOUBLE["X"]),
        ]
    )
    def test_shapes_are_not_same_type(self, a: TensorType, b: TensorType):
        self.assertIsNot(a, b)


if __name__ == "__main__":
    unittest.main()