1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
|
import torch
import threading
import pickle
from torch.utils.data import IterDataPipe, communication, MapDataPipe
try:
import dill
# XXX: By default, dill writes the Pickler dispatch table to inject its
# own logic there. This globally affects the behavior of the standard library
# pickler for any user who transitively depends on this module!
# Undo this extension to avoid altering the behavior of the pickler globally.
dill.extend(use_dill=False)
HAS_DILL = True
except ImportError:
HAS_DILL = False
__all__ = [
"DataPipeToQueuesLoop",
"SpawnProcessForDataPipeline",
"SpawnThreadForDataPipeline",
]
def DataPipeToQueuesLoop(source_datapipe, req_queue, res_queue):
if isinstance(source_datapipe, IterDataPipe):
pipe_type = communication.iter
protocol_type = communication.protocol.IterDataPipeQueueProtocolServer
elif isinstance(source_datapipe, MapDataPipe):
pipe_type = communication.map # type: ignore[misc]
protocol_type = communication.protocol.MapDataPipeQueueProtocolServer # type: ignore[assignment]
else:
raise Exception('Only supports IterDataPipe or MapDataPipe, got', source_datapipe)
torch.set_num_threads(1)
for _ in pipe_type.DataPipeBehindQueues(source_datapipe, protocol_type(req_queue, res_queue),
blocking_request_get=True):
pass
def SpawnProcessForDataPipeline(multiprocessing_ctx, datapipe):
req_queue = multiprocessing_ctx.Queue()
res_queue = multiprocessing_ctx.Queue()
process = multiprocessing_ctx.Process(
target=DataPipeToQueuesLoop, args=(datapipe, req_queue, res_queue))
return process, req_queue, res_queue
def SpawnThreadForDataPipeline(datapipe):
r"""
Given a DataPipe, creates a copy of the DataPipe, starts a new Thread with DataPipeToQueuesLoop as target,
and return the process, req_queue, res_queue, thread_local_datapipe.
"""
req_queue = communication.queue.ThreadingQueue()
res_queue = communication.queue.ThreadingQueue()
try:
new_datapipe = pickle.loads(pickle.dumps(datapipe))
except Exception as pe:
if HAS_DILL:
try:
new_datapipe = dill.loads(dill.dumps(datapipe))
except Exception as de:
raise Exception('Unable to dill DataPipe to make thread local copy', de)
else:
raise Exception('Unable to pickle DataPipe to make thread local copy (consider installing `dill`)', pe)
process = threading.Thread(target=DataPipeToQueuesLoop, args=(
new_datapipe, req_queue, res_queue), daemon=True)
return process, req_queue, res_queue, new_datapipe
|