1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
|
import time
import types
from torch.utils.data import communication, MapDataPipe
DEFAULT_NON_BLOCKING_SLEEP = 0.001
__all__ = [
"DataPipeBehindQueues",
"EnsureNonBlockingMapDataPipe",
"NonBlockingMap",
"NotAvailable",
"QueueWrapperForMap",
"default_not_available_hook",
]
def default_not_available_hook():
time.sleep(DEFAULT_NON_BLOCKING_SLEEP)
class NotAvailable(Exception):
pass
class NonBlockingMap(MapDataPipe):
not_available_hook = default_not_available_hook
def __getitem__(self, index):
while True:
try:
return self.nonblocking_getitem(index)
except NotAvailable:
if NonBlockingMap.not_available_hook is not None:
NonBlockingMap.not_available_hook()
def __len__(self):
try:
return self.nonblocking_len()
except NotAvailable:
if NonBlockingMap.not_available_hook is not None:
NonBlockingMap.not_available_hook()
def nonblocking_len(self):
raise NotImplementedError(
"nonblocking_len is not implemented for %s" % self.__class__)
def nonblocking_getitem(self, index):
raise NotImplementedError(
"nonblocking_getitem is not implemented for %s" % self.__class__)
@staticmethod
def register_not_available_hook(hook_function):
NonBlockingMap.not_available_hook = hook_function
def EnsureNonBlockingMapDataPipe(validated_datapipe):
if not isinstance(validated_datapipe, MapDataPipe):
raise Exception(f'Not Map DataPipe - got {validated_datapipe.__class__}')
if isinstance(validated_datapipe, NonBlockingMap):
return validated_datapipe
if not hasattr(validated_datapipe, 'nonblocking_len'):
def nonblocking_len(self):
return self.__len__()
validated_datapipe.nonblocking_len = types.MethodType( # type: ignore[attr-defined]
nonblocking_len, validated_datapipe)
if not hasattr(validated_datapipe, 'nonblocking_getitem'):
def nonblocking_getitem(self, index):
return self.__getitem__(index)
validated_datapipe.nonblocking_getitem = types.MethodType( # type: ignore[attr-defined]
nonblocking_getitem, validated_datapipe)
return validated_datapipe
def DataPipeBehindQueues(source_datapipe, protocol, full_stop=False, blocking_request_get=False):
"""
Indefinitely iterates over req_queue and passing values from source_datapipe to res_queue
If raise_stop is true, raises exception when StopIteration received from the source_datapipe
"""
if not isinstance(protocol, communication.protocol.MapDataPipeQueueProtocolServer):
raise Exception('Expecting MapDataPipeQueueProtocolServer, got', protocol)
source_datapipe = EnsureNonBlockingMapDataPipe(source_datapipe)
forever = True
while forever:
try:
# Non-blocking call is Extremely slow here for python.mp, need to figure out a good workaround
request = protocol.get_new_request(block=blocking_request_get)
except communication.protocol.EmptyQueue:
yield True
continue
if isinstance(request, communication.messages.TerminateRequest):
forever = False
protocol.response_terminate()
elif isinstance(request, communication.messages.LenRequest):
size = source_datapipe.nonblocking_len()
protocol.response_len(size)
elif isinstance(request, communication.messages.GetItemRequest):
while forever:
try:
value = source_datapipe.nonblocking_getitem(request.key)
except NotAvailable:
yield True
continue
except IndexError as e:
# Alternatively, we can just allow the underlying DataPipe to throw an exception?
protocol.response_index_out_of_bound()
if full_stop:
forever = False
else:
yield True
break
protocol.response_item(request.key, value)
yield True # Returns control
break
else:
raise Exception('Unrecognized type of request received', request)
class QueueWrapperForMap(NonBlockingMap):
"""
Creates map.DataPipe which reads data from the DataLoader.Queue
"""
def __init__(self, protocol, response_wait_time=0.00001):
if not isinstance(protocol, communication.protocol.MapDataPipeQueueProtocolClient):
raise Exception('Got', protocol)
self.protocol = protocol
self.counter = 0
self._stop_iteration = False
self._response_wait_time = response_wait_time
def nonblocking_getitem(self, index):
if self._stop_iteration:
raise Exception(
'`getitem` or `nonblocking_getitem` called after receiving StopIteration')
if self.protocol.can_take_request():
self.protocol.request_item(index)
try:
response = self.protocol.get_response_item(block=True, timeout=self._response_wait_time)
except communication.protocol.EmptyQueue:
raise NotAvailable
if isinstance(response, communication.messages.StopIterationResponse):
self._stop_iteration = True
raise IndexError(f"Index {index} is out of bound.")
return response.key, response.value
def nonblocking_len(self):
if self._stop_iteration:
raise Exception(
'`len` or `nonblocking_len` called after receiving StopIteration')
if self.protocol.can_take_request():
self.protocol.request_len()
try:
response = self.protocol.get_response_len(block=True, timeout=self._response_wait_time)
except communication.protocol.EmptyQueue:
raise NotAvailable
return response.len
|