File: _patch_torch.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (297 lines) | stat: -rw-r--r-- 9,294 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
"""Importing this patches torch._C classes to add ONNX conveniences."""
import numbers
import re
from typing import Any, Iterable, Tuple, Union

import torch
from torch import _C
from torch._C import _onnx as _C_onnx

# Import utils to get _params_dict because it is a global that is accessed by c++ code
from torch.onnx import _deprecation, utils
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import _beartype

_ATTR_PATTERN = re.compile("^(.+)_(([ifstgz])|(ty))$")


# TODO(#78694): Remove this file after PyTorch 1.14.
# All functions in this file are deprecated and should not be used


@_deprecation.deprecated(
    "1.13",
    "1.14",
    "note 'g.op()' is to be removed from torch.Graph. Please open a"
    " GitHub issue if you need this functionality.",
)
@_beartype.beartype
def _graph_op(
    g: _C.Graph,
    opname: str,
    *raw_args: Union[torch.Tensor, _C.Value],
    outputs: int = 1,
    **kwargs,
) -> Union[_C.Value, Tuple[_C.Value, ...]]:
    r"""Creates an ONNX operator "opname", taking "args" as inputs and attributes "kwargs".

    The set of operators and the inputs/attributes they take
    is documented at https://github.com/onnx/onnx/blob/master/docs/Operators.md

    This function is monkey-patched onto Graph.

    Args:
        g: The Torch graph.
        opname: The ONNX operator name, e.g., `Abs` or `Add`, or an operator qualified
            with a namespace, e.g., `aten::add`.
        raw_args: The inputs to the operator; usually provided
            as arguments to the `symbolic` definition.
        outputs: The number of outputs this operator returns.
            By default an operator is assumed to return a single output.
            If `outputs` is greater than one, this functions returns a tuple
            of output `Node`, representing each output of the ONNX operator
            in positional.
        kwargs: The attributes of the ONNX operator, whose keys are named
            according to the following convention: `alpha_f` indicates
            the `alpha` attribute with type `f`.  The valid type specifiers are
            `f` (float), `i` (int), `s` (string) or `t` (Tensor).  An attribute
            specified with type float accepts either a single float, or a
            list of floats (e.g., you would say `dims_i` for a `dims` attribute
            that takes a list of integers).

    Returns:
        The node representing the single output of this operator (see the `outputs`
        keyword argument for multi-return nodes).
    """
    # Filter out None attributes, this can be convenient client side because
    # now they can pass through None attributes, and have them not show up
    kwargs = {k: v for k, v in kwargs.items() if v is not None}

    args = [_const_if_tensor(g, arg) for arg in raw_args]

    if "::" in opname:
        namespace, op = opname.split("::")
    else:
        namespace = "onnx"
        op = opname

    n = g.insertNode(_new_node(g, namespace, op, outputs, *args, **kwargs))

    if GLOBALS.onnx_shape_inference:
        _C._jit_pass_onnx_node_shape_type_inference(
            n, utils._params_dict, GLOBALS.export_onnx_opset_version
        )

    if outputs == 1:
        return n.output()
    return tuple(n.outputs())


@_beartype.beartype
def _const_if_tensor(g: _C.Graph, arg):
    if arg is None:
        return arg
    if isinstance(arg, _C.Value):
        return arg
    return _graph_op(g, "Constant", value_z=arg)


@_deprecation.deprecated(
    "1.13",
    "1.14",
    "note 'g.at()' is to be removed from torch.Graph. Please open a"
    " GitHub issue if you need this functionality.",
)
# Generate an ONNX ATen op node.
@_beartype.beartype
def _aten_op(g: _C.Graph, operator: str, *args, overload_name: str = "", **kwargs):
    return _graph_op(
        g,
        "aten::ATen",
        *args,
        operator_s=operator,
        overload_name_s=overload_name,
        **kwargs,
    )


@_deprecation.deprecated(
    "1.13",
    "1.14",
    "note 'b.op()' is to be removed from torch.Block. Please open a"
    " GitHub issue if you need this functionality.",
)
@_beartype.beartype
def _block_op(block: _C.Block, opname: str, *args: _C.Value, **kwargs):
    if "::" in opname:
        namespace, op = opname.split("::")
    else:
        namespace = "onnx"
        op = opname

    n = block.addNode(f"{namespace}::{op}", args)
    aten = namespace == "aten"
    skip_attrs = {"inplace", "aten"}
    for k, v in sorted(kwargs.items()):
        if k in skip_attrs:
            continue
        _add_attribute(n, k, v, aten=aten)
    outputs = tuple(n.outputs())
    if len(outputs) == 1:
        return n.output()
    return outputs


@_beartype.beartype
def _new_node(
    g: _C.Graph, namespace: str, op: str, outputs: int, *args: _C.Value, **kwargs
) -> _C.Node:
    """Creates a new node in the graph.

    Args:
        g: The graph to create the operator on.
        namespace: The namespace of the operator. E.g., "aten", "onnx".
        op: The name of the operator to create.
        outputs: The number of the outputs of the node.

    Returns:
        The new node.
    """
    aten = namespace == "aten"
    node = g.create(f"{namespace}::{op}", args, outputs)
    skip_attrs = {"inplace", "aten"}
    for k, v in sorted(kwargs.items()):
        if k in skip_attrs:
            continue
        _add_attribute(node, k, v, aten=aten)
    return node


@_beartype.beartype
def _is_onnx_list(value):
    return (
        not isinstance(value, torch._six.string_classes)
        and not isinstance(value, torch.Tensor)
        and isinstance(value, Iterable)
    )


@_beartype.beartype
def _scalar(x: torch.Tensor):
    """Convert a scalar tensor into a Python value."""
    assert x.numel() == 1
    return x[0]


@_beartype.beartype
def _is_caffe2_aten_fallback() -> bool:
    return (
        GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
        and _C_onnx._CAFFE2_ATEN_FALLBACK
    )


@_beartype.beartype
def _add_attribute(node: _C.Node, key: str, value: Any, aten: bool):
    r"""Initializes the right attribute based on type of value."""
    m = _ATTR_PATTERN.match(key)
    if m is None:
        raise ValueError(
            f"Invalid attribute specifier '{key}' names "
            "must be suffixed with type, e.g. 'dim_i' or 'dims_i'"
        )
    name, kind = m.group(1), m.group(2)
    if _is_onnx_list(value):
        kind += "s"

    if aten and _is_caffe2_aten_fallback():
        if isinstance(value, torch.Tensor):
            # Caffe2 proto does not support tensor attribute.
            if value.numel() > 1:
                raise ValueError("Should not pass tensor attribute")
            value = _scalar(value)
            if isinstance(value, float):
                kind = "f"
            else:
                kind = "i"
    return getattr(node, f"{kind}_")(name, value)


# TODO(#76254): Remove the deprecated function.
@_deprecation.deprecated(
    "1.13", "1.14", "Use 'g.op()' to create a constant node instead."
)
@_beartype.beartype
def _graph_constant(
    g,
    value,
    dims,
    type_: str,
    *args,
    **kwargs,
):
    """This helper function can create either constant tensor or constant scalar.

    If dims is None or 0 or [0], generate a 0-d tensor (scalar).
    """
    assert isinstance(value, numbers.Number)
    assert type_ is not None
    isscalar = False
    if dims is None or dims == 0 or set(dims) == {0}:
        dims = [1]
        isscalar = True
    type_ = type_.lower()
    tensor: Union[
        torch.CharTensor,
        torch.ShortTensor,
        torch.IntTensor,
        torch.LongTensor,
        torch.HalfTensor,
        torch.FloatTensor,
        torch.DoubleTensor,
    ]
    if type_ == "char":
        tensor = torch.CharTensor(*dims)
    elif type_ == "short":
        tensor = torch.ShortTensor(*dims)
    elif type_ == "int":
        tensor = torch.IntTensor(*dims)
    elif type_ == "long":
        tensor = torch.LongTensor(*dims)
    elif type_ == "half":
        tensor = torch.HalfTensor(*dims)
    elif type_ == "float":
        tensor = torch.FloatTensor(*dims)
    elif type_ == "double":
        tensor = torch.DoubleTensor(*dims)
    else:
        raise ValueError(
            "Unknown type, type should be one of the following strings: "
            "char, short, int, long, half, float, double"
        )
    tensor.fill_(value)  # type: ignore[call-overload]
    if isscalar:
        return g.op("Constant", *args, value_z=tensor, **kwargs)
    return g.op("Constant", *args, value_t=tensor, **kwargs)


# TODO(#76254): Remove the deprecated function.
@_deprecation.deprecated(
    "1.13",
    "1.14",
    "Internally use '_node_get' in symbolic_helper instead.",
)
def _node_getitem(self, k):
    """Gets attributes of a node which is polymorphic over return type.

    This is monkey-patched onto Node.
    """
    sel = self.kindOf(k)
    return getattr(self, sel)(k)


torch._C.Graph.op = _graph_op  # type: ignore[attr-defined]
torch._C.Graph.at = _aten_op  # type: ignore[attr-defined]
torch._C.Block.op = _block_op  # type: ignore[attr-defined]
torch._C.Graph.constant = _graph_constant  # type: ignore[attr-defined]
torch._C.Node.__getitem__ = _node_getitem  # type: ignore[attr-defined, misc, assignment]