1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
|
import time
import types
from torch.utils.data import IterDataPipe, communication
DEFAULT_NON_BLOCKING_SLEEP = 0.001
__all__ = [
"DataPipeBehindQueues",
"EnsureNonBlockingDataPipe",
"InvalidStateResetRequired",
"NonBlocking",
"NotAvailable",
"QueueWrapper",
"default_not_available_hook",
]
def default_not_available_hook():
time.sleep(DEFAULT_NON_BLOCKING_SLEEP)
class NotAvailable(Exception):
pass
class InvalidStateResetRequired(Exception):
"""
Returned by DataPipe when it is expecting to get reset request,
for example RouterDataPipe expecting all workers to request reset'
"""
pass
class NonBlocking(IterDataPipe):
not_available_hook = default_not_available_hook
def __iter__(self):
self.reset_iterator()
return self
def __next__(self):
while True:
try:
return self.nonblocking_next()
except StopIteration:
raise StopIteration
except NotAvailable:
if NonBlocking.not_available_hook is not None:
NonBlocking.not_available_hook()
def nonblocking_next(self):
raise NotImplementedError(
"nonblocking_next is not implemented for %s" % self.__class__)
def reset_iterator(self):
raise NotImplementedError(
"reset_iterator is not implemented for %s" % self.__class__)
@staticmethod
def register_not_available_hook(hook_function):
NonBlocking.not_available_hook = hook_function
def EnsureNonBlockingDataPipe(validated_datapipe):
if not isinstance(validated_datapipe, IterDataPipe):
raise Exception('Not Iterable DataPipe ' +
str(validated_datapipe.__class__))
if isinstance(validated_datapipe, NonBlocking):
return validated_datapipe
if not hasattr(validated_datapipe, '_as_iterator'):
validated_datapipe._as_iterator = None # type: ignore[attr-defined]
if not hasattr(validated_datapipe, 'nonblocking_next'):
def nonblocking_next(self):
if self._as_iterator is None:
self._as_iterator = iter(self)
return next(self._as_iterator)
validated_datapipe.nonblocking_next = types.MethodType( # type: ignore[attr-defined]
nonblocking_next, validated_datapipe)
if not hasattr(validated_datapipe, 'reset_iterator'):
def reset_iterator(self):
self._as_iterator = None
validated_datapipe.reset_iterator = types.MethodType( # type: ignore[attr-defined]
reset_iterator, validated_datapipe)
return validated_datapipe
def DataPipeBehindQueues(source_datapipe, protocol, full_stop=False, blocking_request_get=False):
"""
Indefinitely iterates over req_queue and passing values from source_datapipe to res_queue
If raise_stop is true, raises exception when StopIteration received from the source_datapipe
"""
if not isinstance(protocol, communication.protocol.IterDataPipeQueueProtocolServer):
raise Exception('Expecting IterDataPipeQueueProtocolServer, got', protocol)
source_datapipe = EnsureNonBlockingDataPipe(source_datapipe)
forever = True
while forever:
try:
# Non-blocking call is Extremely slow here for python.mp, need to figure out a good workaround
request = protocol.get_new_request(block=blocking_request_get)
except communication.protocol.EmptyQueue:
yield True
continue
if isinstance(request, communication.messages.ResetIteratorRequest):
source_datapipe.reset_iterator()
protocol.response_reset_iterator()
elif isinstance(request, communication.messages.TerminateRequest):
forever = False
protocol.response_terminate()
elif isinstance(request, communication.messages.GetNextRequest):
while forever:
try:
value = source_datapipe.nonblocking_next()
except NotAvailable:
yield True
continue
except StopIteration:
protocol.response_stop_iteration()
if full_stop:
forever = False
else:
yield True
break
except InvalidStateResetRequired:
protocol.response_invalid_state()
if full_stop:
forever = False
else:
yield True
break
protocol.response_next(value)
yield True # Returns control
break
else:
raise Exception('Unrecognized type of request received', request)
class QueueWrapper(NonBlocking):
"""
Creates iter.DataPipe which reads data from the DataLoader.Queue
"""
def __init__(self, protocol, response_wait_time=0.00001):
if not isinstance(protocol, communication.protocol.IterDataPipeQueueProtocolClient):
raise Exception('Got', protocol)
self.protocol = protocol
self.counter = 0
self._stop_iteration = False
self._response_wait_time = response_wait_time
def reset_iterator(self):
self._stop_iteration = False
self.counter = 0
self.protocol.request_reset_iterator()
while True:
try:
self.protocol.get_response_reset_iterator()
break
except communication.protocol.EmptyQueue:
if NonBlocking.not_available_hook is not None:
NonBlocking.not_available_hook()
def nonblocking_next(self):
if self._stop_iteration:
raise Exception(
'`next` or `nonblocking_next` called after receiving StopIteration')
if self.protocol.can_take_request():
self.protocol.request_next()
try:
response = self.protocol.get_response_next(block=True, timeout=self._response_wait_time)
except communication.protocol.EmptyQueue:
raise NotAvailable
if isinstance(response, communication.messages.StopIterationResponse):
self._stop_iteration = True
raise StopIteration
if isinstance(response, communication.messages.InvalidStateResponse):
raise NotAvailable
return response.value
|