File: ONNXTypes.md

package info (click to toggle)
onnx 1.20.0-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 62,544 kB
  • sloc: python: 77,643; cpp: 60,445; sh: 52; makefile: 51; javascript: 1
file content (105 lines) | stat: -rw-r--r-- 3,352 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
<!--
Copyright (c) ONNX Project Contributors

SPDX-License-Identifier: Apache-2.0
-->

# ONNX Types

## Optional Type

An optional type represents a reference to either an element (could be Tensor, Sequence, Map, or Sparse Tensor) or a null value. The optional type appears in model inputs, outputs, as well as intermediate values.

### Use-cases

Optional type enables users to represent more dynamic typing scenarios in ONNX. Similar to Optional[X] type hint in Python typing which is equivalent to Union[None, X], Optional types in ONNX may reference a single element, or null.

### Examples in PyTorch

Optional type only appears in TorchScript graphs generated by jit script compiler. Scripting a model captures dynamic types where an optional value can be assigned either None or a value.

- Example 1

        class Model(torch.nn.Module):
            def forward(self, x, y:Optional[Tensor]=None):
                if y is not None:
                    return x + y
                return x

    Corresponding TorchScript graph:

        Graph(
            %self : __torch__.Model,
            %x.1 : Tensor,
            %y.1 : Tensor?
        ):
            %11 : int = prim::Constant[value=1]()
            %4 : None = prim::Constant()
            %5 : bool = aten::__isnot__(%y.1, %4)
            %6 : Tensor = prim::If(%5)
                block0():
                    %y.4 : Tensor = prim::unchecked_cast(%y.1)
                    %12 : Tensor = aten::add(%x.1, %y.4, %11)
                -> (%12)
                block1():
                -> (%x.1)
            return (%6)

    ONNX graph:

        Graph(
            %x.1 : Float(2, 3),
            %y.1 : Float(2, 3)
        ):
            %2 : Bool(1) = onnx::OptionalHasElement(%y.1)
            %5 : Float(2, 3) = onnx::If(%2)
                block0():
                    %3 : Float(2, 3) = onnx::OptionalGetElement(%y.1)
                    %4 : Float(2, 3) = onnx::Add(%x.1, %3)
                -> (%4)
                block1():
                    %x.2 : Float(2, 3) = onnx::Identity(%x.1)
                -> (%x.2)
            return (%5)

- Example 2

        class Model(torch.nn.Module):
            def forward(
                    self,
                    src_tokens,
                    return_all_hiddens=torch.tensor([False]),
            ):
                encoder_states: Optional[Tensor] = None
                if return_all_hiddens:
                    encoder_states = src_tokens

                return src_tokens, encoder_states

    Corresponding TorchScript graph:

        Graph(
            %src_tokens.1 : Float(3, 2, 4,),
            %return_all_hiddens.1 : Bool(1)
        ):
            %3 : None = prim::Constant()
            %encoder_states : Tensor? = prim::If(%return_all_hiddens.1)
                block0():
                -> (%src_tokens.1)
                block1():
                -> (%3)
            return (%src_tokens.1, %encoder_states)

    ONNX graph:

        Graph(
            %src_tokens.1 : Float(3, 2, 4),
            %return_all_hiddens.1 : Bool(1)
        ):
            %2 : Float(3, 2, 4) = onnx::Optional[type=tensor(float)]()
            %3 : Float(3, 2, 4) = onnx::If(%return_all_hiddens.1)
                block0():
                -> (%src_tokens.1)
                block1():
                -> (%2)
            return (%3)