File: pytorch_export_helpers.py

package info (click to toggle)
onnxruntime 1.23.2%2Bdfsg-6
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 340,756 kB
  • sloc: cpp: 3,222,136; python: 188,267; ansic: 114,318; asm: 37,927; cs: 36,849; java: 10,962; javascript: 6,811; pascal: 4,126; sh: 2,996; xml: 705; objc: 281; makefile: 67
file content (131 lines) | stat: -rw-r--r-- 5,840 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import inspect
from collections import abc

import torch


def _parse_inputs_for_onnx_export(all_input_parameters, inputs, kwargs):
    # extracted from https://github.com/microsoft/onnxruntime/blob/239c6ad3f021ff7cc2e6247eb074bd4208dc11e2/orttraining/orttraining/python/training/ortmodule/_io.py#L433

    def _add_input(name, input):
        """Returns number of expanded inputs that _add_input processed"""

        if input is None:
            # Drop all None inputs and return 0.
            return 0

        num_expanded_non_none_inputs = 0
        if isinstance(input, abc.Sequence):
            # If the input is a sequence (like a list), expand the list so that
            # each element of the list is an input by itself.
            for i, val in enumerate(input):
                # Name each input with the index appended to the original name of the
                # argument.
                num_expanded_non_none_inputs += _add_input(f"{name}_{i}", val)

            # Return here since the list by itself is not a valid input.
            # All the elements of the list have already been added as inputs individually.
            return num_expanded_non_none_inputs
        elif isinstance(input, abc.Mapping):
            # If the input is a mapping (like a dict), expand the dict so that
            # each element of the dict is an input by itself.
            for key, val in input.items():
                num_expanded_non_none_inputs += _add_input(f"{name}_{key}", val)

            # Return here since the dict by itself is not a valid input.
            # All the elements of the dict have already been added as inputs individually.
            return num_expanded_non_none_inputs

        # InputInfo should contain all the names irrespective of whether they are
        # a part of the onnx graph or not.
        input_names.append(name)

        # A single input non none input was processed, return 1
        return 1

    input_names = []
    var_positional_idx = 0
    num_expanded_non_none_positional_inputs = 0

    for input_idx, input_parameter in enumerate(all_input_parameters):
        if input_parameter.kind == inspect.Parameter.VAR_POSITIONAL:
            # VAR_POSITIONAL parameter carries all *args parameters from original forward method
            for args_i in range(input_idx, len(inputs)):
                name = f"{input_parameter.name}_{var_positional_idx}"
                var_positional_idx += 1
                inp = inputs[args_i]
                num_expanded_non_none_positional_inputs += _add_input(name, inp)
        elif (
            input_parameter.kind == inspect.Parameter.POSITIONAL_ONLY
            or input_parameter.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
            or input_parameter.kind == inspect.Parameter.KEYWORD_ONLY
        ):
            # All positional non-*args and non-**kwargs are processed here
            name = input_parameter.name
            inp = None
            input_idx += var_positional_idx  # noqa: PLW2901
            is_positional = True
            if input_idx < len(inputs) and inputs[input_idx] is not None:
                inp = inputs[input_idx]
            elif name in kwargs and kwargs[name] is not None:
                inp = kwargs[name]
                is_positional = False
            num_expanded_non_none_inputs_local = _add_input(name, inp)
            if is_positional:
                num_expanded_non_none_positional_inputs += num_expanded_non_none_inputs_local
        elif input_parameter.kind == inspect.Parameter.VAR_KEYWORD:
            # **kwargs is always the last argument of forward()
            for name, inp in kwargs.items():
                if name not in input_names:
                    _add_input(name, inp)

    return input_names


def _flatten_module_input(names, args, kwargs):
    """Flatten args and kwargs in a single tuple of tensors."""
    # extracted from https://github.com/microsoft/onnxruntime/blob/239c6ad3f021ff7cc2e6247eb074bd4208dc11e2/orttraining/orttraining/python/training/ortmodule/_io.py#L110

    def is_primitive_type(value):
        return type(value) in {int, bool, float}

    def to_tensor(value):
        return torch.tensor(value)

    ret = [to_tensor(arg) if is_primitive_type(arg) else arg for arg in args]
    ret += [
        to_tensor(kwargs[name]) if is_primitive_type(kwargs[name]) else kwargs[name] for name in names if name in kwargs
    ]

    # if kwargs is empty, append an empty dictionary at the end of the sample inputs to make exporter
    # happy. This is because the exporter is confused with kwargs and dictionary inputs otherwise.
    if not kwargs:
        ret.append({})

    return tuple(ret)


def infer_input_info(module: torch.nn.Module, *inputs, **kwargs):
    """
    Infer the input names and order from the arguments used to execute a PyTorch module for usage exporting
    the model via torch.onnx.export.
    Assumes model is on CPU. Use `module.to(torch.device('cpu'))` if it isn't.

    Example usage:
    input_names, inputs_as_tuple = infer_input_info(module, ...)
    torch.onnx.export(module, inputs_as_type, 'model.onnx', input_names=input_names, output_names=[...], ...)

    :param module: Module
    :param inputs: Positional inputs
    :param kwargs: Keyword argument inputs
    :return: Tuple of ordered input names and input values. These can be used directly with torch.onnx.export as the
            `input_names` and `inputs` arguments.
    """
    module_parameters = inspect.signature(module.forward).parameters.values()
    input_names = _parse_inputs_for_onnx_export(module_parameters, inputs, kwargs)
    inputs_as_tuple = _flatten_module_input(input_names, inputs, kwargs)

    return input_names, inputs_as_tuple