1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
|
## @package queue_util
# Module caffe2.python.queue_util
from caffe2.python import core, dataio
from caffe2.python.task import TaskGroup
import logging
logger = logging.getLogger(__name__)
class _QueueReader(dataio.Reader):
def __init__(self, wrapper, num_dequeue_records=1):
assert wrapper.schema is not None, (
'Queue needs a schema in order to be read from.')
dataio.Reader.__init__(self, wrapper.schema())
self._wrapper = wrapper
self._num_dequeue_records = num_dequeue_records
def setup_ex(self, init_net, exit_net):
exit_net.CloseBlobsQueue([self._wrapper.queue()], 0)
def read_ex(self, local_init_net, local_finish_net):
self._wrapper._new_reader(local_init_net)
dequeue_net = core.Net('dequeue')
fields, status_blob = dequeue(
dequeue_net,
self._wrapper.queue(),
len(self.schema().field_names()),
field_names=self.schema().field_names(),
num_records=self._num_dequeue_records)
return [dequeue_net], status_blob, fields
def read(self, net):
net, _, fields = self.read_ex(net, None)
return net, fields
class _QueueWriter(dataio.Writer):
def __init__(self, wrapper):
self._wrapper = wrapper
def setup_ex(self, init_net, exit_net):
exit_net.CloseBlobsQueue([self._wrapper.queue()], 0)
def write_ex(self, fields, local_init_net, local_finish_net, status):
self._wrapper._new_writer(self.schema(), local_init_net)
enqueue_net = core.Net('enqueue')
enqueue(enqueue_net, self._wrapper.queue(), fields, status)
return [enqueue_net]
class QueueWrapper(dataio.Pipe):
def __init__(self, handler, schema=None, num_dequeue_records=1):
dataio.Pipe.__init__(self, schema, TaskGroup.LOCAL_SETUP)
self._queue = handler
self._num_dequeue_records = num_dequeue_records
def reader(self):
return _QueueReader(
self, num_dequeue_records=self._num_dequeue_records)
def writer(self):
return _QueueWriter(self)
def queue(self):
return self._queue
class Queue(QueueWrapper):
def __init__(self, capacity, schema=None, name='queue',
num_dequeue_records=1):
# find a unique blob name for the queue
net = core.Net(name)
queue_blob = net.AddExternalInput(net.NextName('handler'))
QueueWrapper.__init__(
self, queue_blob, schema, num_dequeue_records=num_dequeue_records)
self.capacity = capacity
self._setup_done = False
def setup(self, global_init_net):
assert self._schema, 'This queue does not have a schema.'
self._setup_done = True
global_init_net.CreateBlobsQueue(
[],
[self._queue],
capacity=self.capacity,
num_blobs=len(self._schema.field_names()),
field_names=self._schema.field_names())
def enqueue(net, queue, data_blobs, status=None):
if status is None:
status = net.NextName('status')
# Enqueueing moved the data into the queue;
# duplication will result in data corruption
queue_blobs = []
for blob in data_blobs:
if blob not in queue_blobs:
queue_blobs.append(blob)
else:
logger.warning("Need to copy blob {} to enqueue".format(blob))
queue_blobs.append(net.Copy(blob))
results = net.SafeEnqueueBlobs([queue] + queue_blobs, queue_blobs + [status])
return results[-1]
def dequeue(net, queue, num_blobs, status=None, field_names=None,
num_records=1):
if field_names is not None:
assert len(field_names) == num_blobs
data_names = [net.NextName(name) for name in field_names]
else:
data_names = [net.NextName('data', i) for i in range(num_blobs)]
if status is None:
status = net.NextName('status')
results = net.SafeDequeueBlobs(
queue, data_names + [status], num_records=num_records)
results = list(results)
status_blob = results.pop(-1)
return results, status_blob
def close_queue(step, *queues):
close_net = core.Net("close_queue_net")
for queue in queues:
close_net.CloseBlobsQueue([queue], 0)
close_step = core.execution_step("%s_step" % str(close_net), close_net)
return core.execution_step(
"%s_wraper_step" % str(close_net),
[step, close_step])
|