File: predictor_exporter.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (265 lines) | stat: -rw-r--r-- 10,029 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
## @package predictor_exporter
# Module caffe2.python.predictor.predictor_exporter





from caffe2.proto import caffe2_pb2
from caffe2.proto import metanet_pb2
from caffe2.python import workspace, core, scope
from caffe2.python.predictor_constants import predictor_constants
import caffe2.python.predictor.serde as serde
import caffe2.python.predictor.predictor_py_utils as utils
from builtins import bytes
import collections


def get_predictor_exporter_helper(submodelNetName):
    """ constracting stub for the PredictorExportMeta
        Only used to construct names to subfields,
        such as calling to predict_net_name
        Args:
            submodelNetName - name of the model
    """
    stub_net = core.Net(submodelNetName)
    pred_meta = PredictorExportMeta(predict_net=stub_net,
                                    parameters=[],
                                    inputs=[],
                                    outputs=[],
                                    shapes=None,
                                    name=submodelNetName,
                                    extra_init_net=None)
    return pred_meta


# pyre-fixme[13]: Pyre can't detect the attribute initialization via cls.super() here
class PredictorExportMeta(collections.namedtuple(
    'PredictorExportMeta',
        'predict_net, parameters, inputs, outputs, shapes, name, '
        'extra_init_net, global_init_net, net_type, num_workers, trainer_prefix')):
    """
    Metadata to be used for serializaing a net.

    parameters, inputs, outputs could be either BlobReference or blob's names

    predict_net can be either core.Net, NetDef, PlanDef or object

    Override the named tuple to provide optional name parameter.
    name will be used to identify multiple prediction nets.

    net_type is the type field in caffe2 NetDef - can be 'simple', 'dag', etc.

    num_workers specifies for net type 'dag' how many threads should run ops

    trainer_prefix specifies the type of trainer.

    extra_init_net gets appended to pred_init_net, useful for thread local init

    global_init_net gets appended to global_init_net, useful for global init
    on a shared across threads parameter workspace
    (in a case of multi-threaded inference)

    """
    def __new__(
        cls,
        predict_net,
        parameters,
        inputs,
        outputs,
        shapes=None,
        name="",
        extra_init_net=None,
        global_init_net=None,
        net_type=None,
        num_workers=None,
        trainer_prefix=None,
    ):
        inputs = [str(i) for i in inputs]
        outputs = [str(o) for o in outputs]
        assert len(set(inputs)) == len(inputs), (
            "All inputs to the predictor should be unique")
        parameters = [str(p) for p in parameters]
        assert set(parameters).isdisjoint(inputs), (
            "Parameters and inputs are required to be disjoint. "
            "Intersection: {}".format(set(parameters).intersection(inputs)))
        assert set(parameters).isdisjoint(outputs), (
            "Parameters and outputs are required to be disjoint. "
            "Intersection: {}".format(set(parameters).intersection(outputs)))
        shapes = shapes or {}

        if predict_net is not None:
            if isinstance(predict_net, (core.Net, core.Plan)):
                predict_net = predict_net.Proto()

            assert isinstance(predict_net, (caffe2_pb2.NetDef, caffe2_pb2.PlanDef))
        return super(PredictorExportMeta, cls).__new__(
            cls, predict_net, parameters, inputs, outputs, shapes, name,
            extra_init_net, global_init_net, net_type, num_workers, trainer_prefix)

    def inputs_name(self):
        return utils.get_comp_name(predictor_constants.INPUTS_BLOB_TYPE,
                                   self.name)

    def outputs_name(self):
        return utils.get_comp_name(predictor_constants.OUTPUTS_BLOB_TYPE,
                                   self.name)

    def parameters_name(self):
        return utils.get_comp_name(predictor_constants.PARAMETERS_BLOB_TYPE,
                                   self.name)

    def global_init_name(self):
        return utils.get_comp_name(predictor_constants.GLOBAL_INIT_NET_TYPE,
                                   self.name)

    def predict_init_name(self):
        return utils.get_comp_name(predictor_constants.PREDICT_INIT_NET_TYPE,
                                   self.name)

    def predict_net_name(self):
        return utils.get_comp_name(predictor_constants.PREDICT_NET_TYPE,
                                   self.name)

    def train_init_plan_name(self):
        plan_name = utils.get_comp_name(predictor_constants.TRAIN_INIT_PLAN_TYPE,
                                   self.name)
        return self.trainer_prefix + '_' + plan_name \
            if self.trainer_prefix else plan_name

    def train_plan_name(self):
        plan_name = utils.get_comp_name(predictor_constants.TRAIN_PLAN_TYPE,
                                   self.name)
        return self.trainer_prefix + '_' + plan_name \
            if self.trainer_prefix else plan_name


def prepare_prediction_net(filename, db_type, device_option=None):
    '''
    Helper function which loads all required blobs from the db
    and returns prediction net ready to be used
    '''
    metanet_def = load_from_db(filename, db_type, device_option)

    global_init_net = utils.GetNet(
        metanet_def, predictor_constants.GLOBAL_INIT_NET_TYPE)
    workspace.RunNetOnce(global_init_net)

    predict_init_net = utils.GetNet(
        metanet_def, predictor_constants.PREDICT_INIT_NET_TYPE)
    workspace.RunNetOnce(predict_init_net)

    predict_net = core.Net(
        utils.GetNet(metanet_def, predictor_constants.PREDICT_NET_TYPE))
    workspace.CreateNet(predict_net)

    return predict_net


def _global_init_net(predictor_export_meta, db_type):
    net = core.Net("global-init")
    # manifold_db does not need DBReader
    if db_type != "manifold_db":
        net.Load(
            [predictor_constants.PREDICTOR_DBREADER],
            predictor_export_meta.parameters)
        net.Proto().external_input.extend([predictor_constants.PREDICTOR_DBREADER])
        net.Proto().external_output.extend(predictor_export_meta.parameters)

    if predictor_export_meta.global_init_net:
        net.AppendNet(predictor_export_meta.global_init_net)

    # Add the model_id in the predict_net to the global_init_net
    utils.AddModelIdArg(predictor_export_meta, net.Proto())
    return net.Proto()


def get_meta_net_def(predictor_export_meta, ws=None, db_type=None):
    """
    """

    ws = ws or workspace.C.Workspace.current
    meta_net_def = metanet_pb2.MetaNetDef()

    # Predict net is the core network that we use.
    utils.AddNet(meta_net_def, predictor_export_meta.predict_init_name(),
                 utils.create_predict_init_net(ws, predictor_export_meta))
    utils.AddNet(meta_net_def, predictor_export_meta.global_init_name(),
                 _global_init_net(predictor_export_meta, db_type))
    utils.AddNet(meta_net_def, predictor_export_meta.predict_net_name(),
                 utils.create_predict_net(predictor_export_meta))
    utils.AddBlobs(meta_net_def, predictor_export_meta.parameters_name(),
                   predictor_export_meta.parameters)
    utils.AddBlobs(meta_net_def, predictor_export_meta.inputs_name(),
                   predictor_export_meta.inputs)
    utils.AddBlobs(meta_net_def, predictor_export_meta.outputs_name(),
                   predictor_export_meta.outputs)
    return meta_net_def


def set_model_info(meta_net_def, project_str, model_class_str, version):
    assert isinstance(meta_net_def, metanet_pb2.MetaNetDef)
    meta_net_def.modelInfo.project = project_str
    meta_net_def.modelInfo.modelClass = model_class_str
    meta_net_def.modelInfo.version = version


def save_to_db(db_type, db_destination, predictor_export_meta, use_ideep=False,
               *args, **kwargs):
    meta_net_def = get_meta_net_def(predictor_export_meta, db_type=db_type)
    device_type = caffe2_pb2.IDEEP if use_ideep else caffe2_pb2.CPU
    with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU)):
        workspace.FeedBlob(
            predictor_constants.META_NET_DEF,
            serde.serialize_protobuf_struct(meta_net_def)
        )

    blobs_to_save = [predictor_constants.META_NET_DEF] + \
        predictor_export_meta.parameters

    op = core.CreateOperator(
        "Save",
        blobs_to_save, [],
        device_option = core.DeviceOption(device_type),
        absolute_path=True,
        db=db_destination, db_type=db_type,
        **kwargs
    )

    workspace.RunOperatorOnce(op)


def load_from_db(filename, db_type, device_option=None, *args, **kwargs):
    # global_init_net in meta_net_def will load parameters from
    # predictor_constants.PREDICTOR_DBREADER
    create_db = core.CreateOperator(
        'CreateDB', [],
        [core.BlobReference(predictor_constants.PREDICTOR_DBREADER)],
        db=filename, db_type=db_type)
    assert workspace.RunOperatorOnce(create_db), (
        'Failed to create db {}'.format(filename))

    # predictor_constants.META_NET_DEF is always stored before the parameters
    load_meta_net_def = core.CreateOperator(
        'Load',
        [core.BlobReference(predictor_constants.PREDICTOR_DBREADER)],
        [core.BlobReference(predictor_constants.META_NET_DEF)])
    assert workspace.RunOperatorOnce(load_meta_net_def)

    blob = workspace.FetchBlob(predictor_constants.META_NET_DEF)
    meta_net_def = serde.deserialize_protobuf_struct(
        blob if isinstance(blob, bytes)
        else str(blob).encode('utf-8'),
        metanet_pb2.MetaNetDef)

    if device_option is None:
        device_option = scope.CurrentDeviceScope()

    if device_option is not None:
        # Set the device options of all loaded blobs
        for kv in meta_net_def.nets:
            net = kv.value
            for op in net.op:
                op.device_option.CopyFrom(device_option)

    return meta_net_def