1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
|
import logging
from caffe2.python import schema
from caffe2.python.layers.layers import (
InstantiationContext,
ModelLayer,
)
logger = logging.getLogger(__name__)
class SelectRecordByContext(ModelLayer):
"""
Allowing model to follow different paths for each instantiation context and
join later at some point. The implementation use `Alias` because schema
sometimes clone fields internally so we need static blob name for output
"""
def __init__(
self,
model,
input_record,
name='select_record_by_context',
check_field_metas=True,
use_copy=False,
default_output_record_field=None,
**kwargs
):
super(SelectRecordByContext, self).__init__(model, name, input_record,
**kwargs)
assert isinstance(input_record, schema.Struct)
assert len(input_record) > 1
self.use_copy = use_copy
self.default_output_record = (
input_record[default_output_record_field]
if (default_output_record_field is not None) else None
)
ref_record = input_record[0]
for record in input_record:
assert schema.equal_schemas(record, ref_record,
check_field_metas=check_field_metas)
self.output_schema = schema.NewRecord(model.net, ref_record)
def _set_output_blobs(self, net, context):
record = self.input_record.get(context, self.default_output_record)
assert record is not None, (
"{} context is not in input record without providing default"
" output".format(context)
)
for in_blob, out_blob in zip(
record.field_blobs(), self.output_schema.field_blobs()
):
if self.use_copy:
net.Copy(in_blob, out_blob)
else:
net.Alias(in_blob, out_blob)
def add_ops(self, net):
self._set_output_blobs(net, InstantiationContext.PREDICTION)
def add_eval_ops(self, net):
self._set_output_blobs(net, InstantiationContext.EVAL)
def add_train_ops(self, net):
self._set_output_blobs(net, InstantiationContext.TRAINING)
def add_ops_to_accumulate_pred(self, net):
self._set_output_blobs(net, InstantiationContext.ACCUMULATE_PRED)
|