1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218
|
## @package predictor_py_utils
# Module caffe2.python.predictor.predictor_py_utils
from caffe2.python import core, scope
def create_predict_net(predictor_export_meta):
"""
Return the input prediction net.
"""
# Construct a new net to clear the existing settings.
net = core.Net(predictor_export_meta.predict_net.name or "predict")
net.Proto().op.extend(predictor_export_meta.predict_net.op)
net.Proto().partition_info.extend(predictor_export_meta.predict_net.partition_info)
net.Proto().external_input.extend(
predictor_export_meta.inputs + predictor_export_meta.parameters
)
net.Proto().external_output.extend(predictor_export_meta.outputs)
net.Proto().arg.extend(predictor_export_meta.predict_net.arg)
if predictor_export_meta.net_type is not None:
net.Proto().type = predictor_export_meta.net_type
if predictor_export_meta.num_workers is not None:
net.Proto().num_workers = predictor_export_meta.num_workers
return net.Proto()
def create_predict_init_net(ws, predictor_export_meta):
"""
Return an initialization net that zero-fill all the input and
output blobs, using the shapes from the provided workspace. This is
necessary as there is no shape inference functionality in Caffe2.
"""
net = core.Net("predict-init")
def zero_fill(blob):
shape = predictor_export_meta.shapes.get(blob)
if shape is None:
if blob not in ws.blobs:
raise Exception(
"{} not in workspace but needed for shape: {}".format(
blob, ws.blobs
)
)
shape = ws.blobs[blob].fetch().shape
# Explicitly null-out the scope so users (e.g. PredictorGPU)
# can control (at a Net-global level) the DeviceOption of
# these filling operators.
with scope.EmptyDeviceScope():
net.ConstantFill([], blob, shape=shape, value=0.0)
external_blobs = predictor_export_meta.inputs + predictor_export_meta.outputs
for blob in external_blobs:
zero_fill(blob)
net.Proto().external_input.extend(external_blobs)
if predictor_export_meta.extra_init_net:
net.AppendNet(predictor_export_meta.extra_init_net)
# Add the model_id in the predict_net to the init_net
AddModelIdArg(predictor_export_meta, net.Proto())
return net.Proto()
def get_comp_name(string, name):
if name:
return string + "_" + name
return string
def to_first_match_dict(kv_list):
"""
Construct dict from kv_list
"""
d = {}
for item in kv_list:
if item.key not in d:
d[item.key] = item.value
return d
def _ProtoMapGet(field, key):
"""
Given the key, get the value of the repeated field.
Helper function used by protobuf since it doesn't have map construct
"""
for v in field:
if v.key == key:
return v.value
return None
def GetPlan(meta_net_def, key):
return _ProtoMapGet(meta_net_def.plans, key)
def GetPlanOriginal(meta_net_def, key):
return _ProtoMapGet(meta_net_def.plans, key)
def GetBlobs(meta_net_def, key):
blobs = _ProtoMapGet(meta_net_def.blobs, key)
if blobs is None:
return []
return blobs
def GetBlobsByTypePrefix(meta_net_def, blob_type_prefix):
blob_map = {}
for b in meta_net_def.blobs:
if b.key.startswith(blob_type_prefix):
for blob in b.value:
if blob not in blob_map:
blob_map[blob] = len(blob_map)
return sorted(blob_map, key=lambda blob: blob_map[blob])
def GetNet(meta_net_def, key):
return _ProtoMapGet(meta_net_def.nets, key)
def GetNetOriginal(meta_net_def, key):
return _ProtoMapGet(meta_net_def.nets, key)
def GetApplicationSpecificInfo(meta_net_def, key):
return _ProtoMapGet(meta_net_def.applicationSpecificInfo, key)
def GetApplicationSpecificInfoDict(meta_net_def):
return to_first_match_dict(meta_net_def.applicationSpecificInfo)
def AddBlobs(meta_net_def, blob_name, blob_def):
blobs = _ProtoMapGet(meta_net_def.blobs, blob_name)
if blobs is None:
blobs = meta_net_def.blobs.add()
blobs.key = blob_name
blobs = blobs.value
for blob in blob_def:
blobs.append(blob)
def ReplaceBlobs(meta_net_def, blob_name, blob_def):
blobs = _ProtoMapGet(meta_net_def.blobs, blob_name)
assert blobs is not None, "The blob_name:{} does not exist".format(blob_name)
del blobs[:]
for blob in blob_def:
blobs.append(blob)
def AddPlan(meta_net_def, plan_name, plan_def):
meta_net_def.plans.add(key=plan_name, value=plan_def)
def AddNet(meta_net_def, net_name, net_def):
meta_net_def.nets.add(key=net_name, value=net_def)
def SetBlobsOrder(meta_net_def, blobs_order):
for blob in blobs_order:
meta_net_def.blobsOrder.append(blob)
def SetPreLoadBlobs(meta_net_def, pre_load_blobs):
for blob in pre_load_blobs:
meta_net_def.preLoadBlobs.append(blob)
def SetRequestOnlyEmbeddings(meta_net_def, request_only_embeddings):
for blob in request_only_embeddings:
meta_net_def.requestOnlyEmbeddings.append(blob)
def GetBlobsOrder(meta_net_def):
return meta_net_def.blobsOrder
def SetTensorBoundShapes(meta_net_def, tensor_bound_shapes):
meta_net_def.tensorBoundShapes.CopyFrom(tensor_bound_shapes)
def SetAOTConfig(meta_net_def, aot_config):
meta_net_def.aotConfig.CopyFrom(aot_config)
def GetArgumentByName(net_def, arg_name):
for arg in net_def.arg:
if arg.name == arg_name:
return arg
return None
def AddModelIdArg(meta_net_def, net_def):
"""Takes the model_id from the predict_net of meta_net_def (if it is
populated) and adds it to the net_def passed in. This is intended to be
called on init_nets, as their model_id is not populated by default, but
should be the same as that of the predict_net
"""
# Get model_id from the predict_net, assuming it's an integer
model_id = GetArgumentByName(meta_net_def.predict_net, "model_id")
if model_id is None:
return
model_id = model_id.i
# If there's another model_id on the net, replace it with the new one
old_id = GetArgumentByName(net_def, "model_id")
if old_id is not None:
old_id.i = model_id
return
# Add as an integer argument, this is also assumed above
arg = net_def.arg.add()
arg.name = "model_id"
arg.i = model_id
|