File: queue_util.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (136 lines) | stat: -rw-r--r-- 4,459 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
## @package queue_util
# Module caffe2.python.queue_util





from caffe2.python import core, dataio
from caffe2.python.task import TaskGroup

import logging


logger = logging.getLogger(__name__)


class _QueueReader(dataio.Reader):
    def __init__(self, wrapper, num_dequeue_records=1):
        assert wrapper.schema is not None, (
            'Queue needs a schema in order to be read from.')
        dataio.Reader.__init__(self, wrapper.schema())
        self._wrapper = wrapper
        self._num_dequeue_records = num_dequeue_records

    def setup_ex(self, init_net, exit_net):
        exit_net.CloseBlobsQueue([self._wrapper.queue()], 0)

    def read_ex(self, local_init_net, local_finish_net):
        self._wrapper._new_reader(local_init_net)
        dequeue_net = core.Net('dequeue')
        fields, status_blob = dequeue(
            dequeue_net,
            self._wrapper.queue(),
            len(self.schema().field_names()),
            field_names=self.schema().field_names(),
            num_records=self._num_dequeue_records)
        return [dequeue_net], status_blob, fields

    def read(self, net):
        net, _, fields = self.read_ex(net, None)
        return net, fields


class _QueueWriter(dataio.Writer):
    def __init__(self, wrapper):
        self._wrapper = wrapper

    def setup_ex(self, init_net, exit_net):
        exit_net.CloseBlobsQueue([self._wrapper.queue()], 0)

    def write_ex(self, fields, local_init_net, local_finish_net, status):
        self._wrapper._new_writer(self.schema(), local_init_net)
        enqueue_net = core.Net('enqueue')
        enqueue(enqueue_net, self._wrapper.queue(), fields, status)
        return [enqueue_net]


class QueueWrapper(dataio.Pipe):
    def __init__(self, handler, schema=None, num_dequeue_records=1):
        dataio.Pipe.__init__(self, schema, TaskGroup.LOCAL_SETUP)
        self._queue = handler
        self._num_dequeue_records = num_dequeue_records

    def reader(self):
        return _QueueReader(
            self, num_dequeue_records=self._num_dequeue_records)

    def writer(self):
        return _QueueWriter(self)

    def queue(self):
        return self._queue


class Queue(QueueWrapper):
    def __init__(self, capacity, schema=None, name='queue',
                 num_dequeue_records=1):
        # find a unique blob name for the queue
        net = core.Net(name)
        queue_blob = net.AddExternalInput(net.NextName('handler'))
        QueueWrapper.__init__(
            self, queue_blob, schema, num_dequeue_records=num_dequeue_records)
        self.capacity = capacity
        self._setup_done = False

    def setup(self, global_init_net):
        assert self._schema, 'This queue does not have a schema.'
        self._setup_done = True
        global_init_net.CreateBlobsQueue(
            [],
            [self._queue],
            capacity=self.capacity,
            num_blobs=len(self._schema.field_names()),
            field_names=self._schema.field_names())


def enqueue(net, queue, data_blobs, status=None):
    if status is None:
        status = net.NextName('status')
    # Enqueueing moved the data into the queue;
    # duplication will result in data corruption
    queue_blobs = []
    for blob in data_blobs:
        if blob not in queue_blobs:
            queue_blobs.append(blob)
        else:
            logger.warning("Need to copy blob {} to enqueue".format(blob))
            queue_blobs.append(net.Copy(blob))
    results = net.SafeEnqueueBlobs([queue] + queue_blobs, queue_blobs + [status])
    return results[-1]


def dequeue(net, queue, num_blobs, status=None, field_names=None,
            num_records=1):
    if field_names is not None:
        assert len(field_names) == num_blobs
        data_names = [net.NextName(name) for name in field_names]
    else:
        data_names = [net.NextName('data', i) for i in range(num_blobs)]
    if status is None:
        status = net.NextName('status')
    results = net.SafeDequeueBlobs(
        queue, data_names + [status], num_records=num_records)
    results = list(results)
    status_blob = results.pop(-1)
    return results, status_blob


def close_queue(step, *queues):
    close_net = core.Net("close_queue_net")
    for queue in queues:
        close_net.CloseBlobsQueue([queue], 0)
    close_step = core.execution_step("%s_step" % str(close_net), close_net)
    return core.execution_step(
        "%s_wraper_step" % str(close_net),
        [step, close_step])