File: pytorch_helper.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (92 lines) | stat: -rw-r--r-- 3,381 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import io

import onnx

import torch.onnx
from caffe2.python.core import BlobReference, Net
from caffe2.python.onnx.backend import Caffe2Backend

_next_idx = 0
# Clone net takes a dict instead of a lambda
# It should probably take a lambda, it is more flexible
# We fake dict here


class _FakeDict:
    def __init__(self, fn):
        self.fn = fn

    def get(self, name, _):
        return self.fn(name)


def PyTorchModule(helper, model, sample_arguments, caffe2_inputs, prefix_name=None):
    """
    Embed an ONNX-exportable PyTorch Model into a Caffe2 model being built.

    Args:
        helper (caffe2.python.core.ModelHelder): the model helper where
            this imported network should be inserted
        model (torch.nn.Module): the model to be exported
        sample_arguments (tuple of arguments): the inputs to
            the model, e.g., such that ``model(*args)`` is a valid
            invocation of the model.  Any non-Variable arguments will
            be hard-coded into the exported model; any Variable arguments
            will become inputs of the exported model, in the order they
            occur in args.  If args is a Variable, this is equivalent
            to having called it with a 1-ary tuple of that Variable.
            (Note: passing keyword arguments to the model is not currently
            supported.  Give us a shout if you need it.)
        caffe2_inputs (list of str or caffe2.python.core.BlobReference): the
           caffe2 Blobs that should be inputs to this network. Must be
           the same length as sample_arguments
        prefix_name: prefix name to add to each member of the blob, if None then
           a fresh prefix pytorch_input_N/ is used
    Returns:
        A tuple of caffe2.python.core.BlobReference objects referring to the
        models outputs, or a single BlobReference when the model returns a single
        value.
    """
    if prefix_name is None:
        global _next_idx
        prefix_name = "pytorch_import_" + str(_next_idx) + "/"
        _next_idx += 1

    # TODO: handle the case where model cannot be exported
    # and embed as a Python op in Caffe2
    f = io.BytesIO()
    torch.onnx.export(model, sample_arguments, f, export_params=True)
    onnx_model = onnx.load(io.BytesIO(f.getvalue()))
    init_net, predict_net = Caffe2Backend.onnx_graph_to_caffe2_net(onnx_model)

    initialized = {x.name for x in onnx_model.graph.initializer}
    uninitialized_inputs = {
        x.name: i
        for i, x in enumerate(onnx_model.graph.input)
        if x.name not in initialized
    }

    if len(uninitialized_inputs) != len(caffe2_inputs):
        raise ValueError(
            "Expected {} inputs but found {}".format(
                len(uninitialized_inputs), len(caffe2_inputs)
            )
        )

    def remap_blob_name(name):
        if name in uninitialized_inputs:
            idx = uninitialized_inputs[name]
            return str(caffe2_inputs[idx])
        return prefix_name + name

    predict_net = Net(predict_net).Clone("anon", _FakeDict(remap_blob_name))
    helper.net.AppendNet(predict_net)

    init_net = Net(init_net).Clone("anon", _FakeDict(remap_blob_name))
    helper.param_init_net.AppendNet(init_net)

    results = tuple(
        BlobReference(remap_blob_name(x.name), helper.net)
        for x in onnx_model.graph.output
    )
    return results