File: backend_rep.py

package info (click to toggle)
pytorch 1.7.1-7
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 80,340 kB
  • sloc: cpp: 670,830; python: 343,991; ansic: 67,845; asm: 5,503; sh: 2,924; java: 2,888; xml: 266; makefile: 244; ruby: 148; yacc: 144; objc: 51; lex: 44
file content (65 lines) | stat: -rw-r--r-- 2,830 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# @package onnx
# Module caffe2.python.onnx.backend_rep





from caffe2.python import core
from caffe2.proto import caffe2_pb2
from onnx.backend.base import BackendRep, namedtupledict

class Caffe2Rep(BackendRep):
    def __init__(self, init_net, predict_net, workspace, uninitialized):
        super(Caffe2Rep, self).__init__()
        self.init_net = init_net
        self.predict_net = predict_net
        self.workspace = workspace
        # The list of uninitialized external_inputs in workspace, we need this to
        # pair the name with given sequence inputs.
        self.uninitialized = uninitialized
        self.nets_created = False
        self.ran_init_net = False

    @property
    def _name_scope(self):
        if self.predict_net.device_option.device_type == caffe2_pb2.CUDA:
            return 'gpu_{}'.format(self.predict_net.device_option.device_id)
        return ''

    def run(self, inputs, **kwargs):
        super(Caffe2Rep, self).run(inputs, **kwargs)
        with core.DeviceScope(self.predict_net.device_option):
            if isinstance(inputs, dict):
                with core.NameScope(self._name_scope):
                    for key, value in inputs.items():
                        self.workspace.FeedBlob(key, value)
            elif isinstance(inputs, list) or isinstance(inputs, tuple):
                if len(self.uninitialized) != len(inputs):
                    raise RuntimeError('Expected {} values for uninitialized '
                                       'graph inputs ({}), but got {}.'.format(
                                           len(self.uninitialized),
                                           ', '.join(self.uninitialized),
                                           len(inputs)))
                for i, value in enumerate(inputs):
                    # namescope already baked into protobuf
                    self.workspace.FeedBlob(self.uninitialized[i], value)
            else:
                # single input
                self.workspace.FeedBlob(self.uninitialized[0], inputs)
            if not self.nets_created:
                self.workspace.CreateNet(self.init_net)
                self.workspace.CreateNet(self.predict_net)
                self.nets_created = True
            if not self.ran_init_net:
                self.workspace.RunNet(self.init_net.name)
                self.ran_init_net = True
            self.workspace.RunNet(self.predict_net.name)
        output_values = []
        for name in self.predict_net.external_output:
            try:
                output_values.append(self.workspace.FetchBlob(name))
            except Exception:
                output_values.append(self.workspace.FetchInt8Blob(name))
        return namedtupledict('Outputs',
                              self.predict_net.external_output)(*output_values)