1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
|
## @package last_n_window_collector
# Module caffe2.python.layers.last_n_window_collector
from caffe2.python import core, schema
from caffe2.python.layers.layers import ModelLayer
class LastNWindowCollector(ModelLayer):
"""
Collect last-N samples from input record. If you have complex data,
use PackRecords to pack it before using this layer.
This layer is not thread safe.
"""
def __init__(self, model, input_record, num_to_collect,
name='last_n_window_collector', **kwargs):
super(LastNWindowCollector, self).__init__(
model, name, input_record, **kwargs)
assert num_to_collect > 0
self.num_to_collect = num_to_collect
assert isinstance(input_record, schema.Scalar), \
"Got {!r}".format(input_record)
self.last_n = self.create_param(param_name='last_n',
shape=[0],
initializer=('ConstantFill', {}),
optimizer=model.NoOptim)
self.next_blob = self.create_param(
param_name='next',
shape=[],
initializer=('ConstantFill',
{'value': 0, 'dtype': core.DataType.INT32}),
optimizer=model.NoOptim
)
self.mutex = self.create_param(
param_name='mutex',
shape=[],
initializer=('CreateMutex',),
optimizer=model.NoOptim,
)
self.num_visited_blob = self.create_param(
param_name='num_visited',
shape=[],
initializer=('ConstantFill', {
'value': 0,
'dtype': core.DataType.INT64,
}),
optimizer=model.NoOptim,
)
self.output_schema = schema.Struct(
(
'last_n',
schema.from_blob_list(input_record, [self.last_n])
),
('num_visited', schema.Scalar(blob=self.num_visited_blob)),
('mutex', schema.Scalar(blob=self.mutex)),
)
def add_ops(self, net):
net.LastNWindowCollector(
[self.last_n, self.next_blob, self.input_record(), self.mutex,
self.num_visited_blob],
[self.last_n, self.next_blob, self.num_visited_blob],
num_to_collect=self.num_to_collect,
)
|