File: protocol.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (205 lines) | stat: -rw-r--r-- 8,258 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
from torch.utils.data import communication


class Protocol(object):
    __slots__ = ('request_queue', 'response_queue')

    def __init__(self, request_queue, response_queue):
        self.request_queue = request_queue
        self.response_queue = response_queue


class ProtocolClient(Protocol):
    """
        ProtocolClient takes charge of putting requests into req_queue and returning results from res_queue.
    """
    _req_sent = None

    def __init__(self, request_queue, response_queue):
        self.request_queue = request_queue
        self.response_queue = response_queue
        self._req_sent = None

    def can_take_request(self):
        return self._req_sent is None

    def waiting_for_response(self):
        return self._req_sent is not None

    def request_sent(self, request=True):
        if not self.can_take_request():
            raise Exception('Protocol only supports one request in the Queue')
        self._req_sent = request

    def request_served(self, result=None):
        if not self.waiting_for_response():
            raise Exception(
                'Expected no peding requests, but something got served', result)
        self._req_sent = None


class ProtocolServer(Protocol):
    """
        ProtocolServer takes charge of getting requests from req_queue and fetching data from source datapipe.
    """
    _req_received = None

    def __init__(self, request_queue, response_queue):
        self.request_queue = request_queue
        self.response_queue = response_queue
        self._req_received = None

    def have_pending_request(self):
        return self._req_received is not None

    def get_new_request(self, block=False):
        if self.have_pending_request():
            raise Exception(
                'Trying to get next request, while having one unserved')
        try:
            response = self.request_queue.get(block=block)
        except Exception as e:  # TODO: Catch only timeout exceptions
            raise EmptyQueue('queue is empty')
        self._req_received = response
        return response
        # TODO: Validate supported requests

    def response_terminate(self):
        if not self.have_pending_request():
            raise Exception("Attempting to reply with pending request")
        if not isinstance(self._req_received, communication.messages.TerminateRequest):
            raise Exception(
                "Replaying with terminate status to other type of message")
        self.response_queue.put(communication.messages.TerminateResponse())
        self._req_received = None


class MapDataPipeQueueProtocolServer(ProtocolServer):
    def response_item(self, key, value):
        if not self.have_pending_request():
            raise Exception("Attempting to reply with pending request")
        self.response_queue.put(communication.messages.GetItemResponse(key, value))
        self._req_received = None

    def response_len(self, size):
        if not self.have_pending_request():
            raise Exception("Attempting to reply with pending request")
        self.response_queue.put(communication.messages.LenResponse(size))
        self._req_received = None

    def response_index_out_of_bound(self):
        if not self.have_pending_request():
            raise Exception("Attempting to reply with pending request")
        self.response_queue.put(communication.messages.StopIterationResponse())
        self._req_received = None

class MapDataPipeQueueProtocolClient(ProtocolClient):
    def request_len(self):
        if not self.can_take_request():
            raise Exception('Can not request len while we are still waiting response for previous request')
        request = communication.messages.LenRequest()
        self.request_queue.put(request)
        self.request_sent(request)

    def request_item(self, index):
        if not self.can_take_request():
            raise Exception('Can not request item while we are still waiting response for previous request')
        request = communication.messages.GetItemRequest(index)
        self.request_queue.put(request)
        self.request_sent(request)

    def get_response_len(self, block=False, timeout=None):
        if not self.waiting_for_response():
            raise Exception('Can not expect any response without submitted request')
        try:
            response = self.response_queue.get(block=block, timeout=timeout)
        except TimeoutError:
            raise EmptyQueue('queue is empty')
        self.request_served(response)
        if not isinstance(response, communication.messages.LenResponse):
            raise Exception('Invalid response received')
        return response

    def get_response_item(self, block=False, timeout=None):
        if not self.waiting_for_response():
            raise Exception('Can not expect any response without submitted request')
        try:
            response = self.response_queue.get(block=block, timeout=timeout)
        except TimeoutError:
            raise EmptyQueue('queue is empty')
        self.request_served(response)
        # if not isinstance(response, communication.messages.GetItemResponse):
        #     raise Exception('Invalid response received')
        return response


class EmptyQueue(Exception):
    pass


class IterDataPipeQueueProtocolServer(ProtocolServer):
    def response_reset_iterator(self):
        if not self.have_pending_request():
            raise Exception("Attempting to reply with pending request")
        if not isinstance(self._req_received, communication.messages.ResetIteratorRequest):
            raise Exception(
                "Replaying with reset status to other type of message")
        self.response_queue.put(communication.messages.ResetIteratorResponse())
        self._req_received = None

    def response_next(self, value):
        if not self.have_pending_request():
            raise Exception("Attempting to reply with pending request")
        self.response_queue.put(communication.messages.GetNextResponse(value))
        self._req_received = None

    def response_stop_iteration(self):
        if not self.have_pending_request():
            raise Exception("Attempting to reply with pending request")
        self.response_queue.put(communication.messages.StopIterationResponse())
        self._req_received = None

    def response_invalid_state(self):
        if not self.have_pending_request():
            raise Exception("Attempting to reply with pending request")
        self.response_queue.put(communication.messages.InvalidStateResponse())
        self._req_received = None


class IterDataPipeQueueProtocolClient(ProtocolClient):
    def request_reset_iterator(self):
        if not self.can_take_request():
            raise Exception('Can not reset while we are still waiting response for previous request')
        request = communication.messages.ResetIteratorRequest()
        self.request_queue.put(request)
        self.request_sent(request)

    def request_next(self):
        if not self.can_take_request():
            raise Exception('Can not request next item while we are still waiting response for previous request')
        request = communication.messages.GetNextRequest()
        self.request_queue.put(request)
        self.request_sent(request)

    def get_response_reset_iterator(self, block=False):
        try:
            response = self.response_queue.get(block=block)
        except Exception as e:  # TODO: Catch only timeout exceptions
            raise EmptyQueue('queue is empty')
        self.request_served(response)

        if not isinstance(response, communication.messages.ResetIteratorResponse):
            raise Exception('Invalid response received')

    def get_response_next(self, block=False, timeout=None):
        if not self.waiting_for_response():
            raise Exception(
                'Can not expect any response without submitted request')
        try:
            response = self.response_queue.get(block=block, timeout=timeout)
        except Exception as e:  # TODO: Catch only timeout exceptions
            raise EmptyQueue('queue is empty')
        self.request_served(response)

        # TODO(VitalyFedyunin): Add possible response types validation here
        return response