1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
|
import time
from typing import Any, List
import torch.utils.data.backward_compatibility
import torch.utils.data.graph_settings
from torch.utils.data import DataLoader, IterDataPipe, communication
from torch.utils.data.datapipes.iter import IterableWrapper
__all__ = [
"DataLoader2",
]
class _ThreadingDataLoader2:
def __init__(self, datapipe, num_workers=0, collate_fn=None):
self.threads = []
self.datapipes = []
self.collate_fn = collate_fn
for worker_id in range(num_workers):
(thread, req_queue, res_queue, thread_localdatapipe) = communication.eventloop.SpawnThreadForDataPipeline(datapipe)
torch.utils.data.graph_settings.apply_sharding(thread_localdatapipe, num_workers, worker_id)
thread.start()
self.threads.append((thread, req_queue, res_queue)) # These queues are independent
local_datapipe = communication.iter.QueueWrapper(
communication.protocol.IterDataPipeQueueProtocolClient(req_queue, res_queue))
self.datapipes.append(local_datapipe)
def __iter__(self):
not_available = False
forever = True
exclude_datapipes: List[Any] = []
while len(exclude_datapipes) < len(self.datapipes):
for dp in self.datapipes:
if dp not in exclude_datapipes:
try:
value = dp.nonblocking_next()
yield value
except StopIteration:
exclude_datapipes.append(dp)
except communication.iter.NotAvailable:
not_available = True
if not_available:
time.sleep(0.001)
def __del__(self):
self._cleanup_all_threads()
def _cleanup_all_threads(self):
def clean_me(thread, req_queue, res_queue):
req_queue.put(communication.messages.TerminateRequest())
_ = res_queue.get()
thread.join()
for thread, req_queue, res_queue in self.threads:
clean_me(thread, req_queue, res_queue)
class DataLoader2:
def __new__(cls,
dataset,
batch_size=1,
shuffle=None,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
*,
prefetch_factor=2,
persistent_workers=False,
batch_outside_worker=False,
parallelism_mode='mp'):
if isinstance(dataset, IterDataPipe):
data_loader: Any = None
if batch_sampler is not None:
raise Exception(
'batch_sampler is not yet supported by DataPipes')
if sampler is not None:
raise Exception(
'sampler is not yet supported by DataPipes')
datapipe = dataset
datapipe = torch.utils.data.graph_settings.apply_shuffle_settings(datapipe, shuffle=shuffle) # type: ignore[assignment]
if batch_outside_worker and pin_memory:
raise Exception(
'pin_memory is not yet compatible with batch_outside_worker')
if not batch_outside_worker:
if batch_size is not None:
datapipe = datapipe.batch(batch_size, drop_last=drop_last)
if collate_fn is None:
collate_fn = torch.utils.data._utils.collate.default_collate
# Note: It is safe to pass shuffle=True to the old DataLoader, as shuffle does nothing
# for Iterable, but required to set Pipes correctly.
data_loader = DataLoader(datapipe,
batch_size=None, # Replaced by .batch DataPipe
shuffle=shuffle,
sampler=None,
batch_sampler=None,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=False, # Replaced by .batch DataPipe
timeout=timeout,
worker_init_fn=worker_init_fn,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers)
elif parallelism_mode == 'thread':
if collate_fn is not None and not batch_outside_worker:
datapipe = datapipe.map(collate_fn)
if pin_memory:
raise Exception(
'pin_memory is not yet supported by DataPipes with Threading')
if worker_init_fn is not None:
raise Exception(
'worker_init_fn is not yet supported by DataPipes with Threading')
data_loader = _ThreadingDataLoader2(datapipe,
num_workers=num_workers,
collate_fn=collate_fn)
else:
raise Exception('Unsupported parallelism mode', parallelism_mode)
if not batch_outside_worker:
return data_loader
else:
if collate_fn is None:
collate_fn = torch.utils.data._utils.collate.default_collate
datapipe = IterableWrapper(data_loader).batch(
batch_size, drop_last=drop_last).map(collate_fn)
return datapipe
else:
if parallelism_mode == 'thread':
raise Exception(
'thread parallelism mode is not supported for old DataSets')
return DataLoader(dataset,
batch_size=batch_size,
shuffle=shuffle,
sampler=sampler,
batch_sampler=batch_sampler,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=drop_last,
timeout=timeout,
worker_init_fn=worker_init_fn,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers)
|