1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
|
# mypy: allow-untyped-defs
r"""Contains definitions of the methods used by the _BaseDataLoaderIter to put fetched tensors into pinned memory.
These **needs** to be in global scope since Py2 doesn't support serializing
static methods.
"""
import collections
import copy
import queue
import torch
from torch._utils import ExceptionWrapper
from . import MP_STATUS_CHECK_INTERVAL
def _pin_memory_loop(in_queue, out_queue, device_id, done_event, device):
# This setting is thread local, and prevents the copy in pin_memory from
# consuming all CPU cores.
torch.set_num_threads(1)
torch.multiprocessing._set_thread_name("pt_data_pin")
if device == "cuda":
torch.cuda.set_device(device_id)
elif device == "xpu":
torch.xpu.set_device(device_id) # type: ignore[attr-defined]
elif device == torch._C._get_privateuse1_backend_name():
custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name())
custom_device_mod.set_device(device_id)
def do_one_step():
try:
r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
except queue.Empty:
return
idx, data = r
if not done_event.is_set() and not isinstance(data, ExceptionWrapper):
try:
data = pin_memory(data, device)
except Exception:
data = ExceptionWrapper(
where=f"in pin memory thread for device {device_id}"
)
r = (idx, data)
while not done_event.is_set():
try:
out_queue.put(r, timeout=MP_STATUS_CHECK_INTERVAL)
break
except queue.Full:
continue
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
# logic of this function.
while not done_event.is_set():
# Make sure that we don't preserve any object from one iteration
# to the next
do_one_step()
def pin_memory(data, device=None):
if isinstance(data, torch.Tensor):
return data.pin_memory(device)
elif isinstance(data, (str, bytes)):
return data
elif isinstance(data, collections.abc.Mapping):
try:
if isinstance(data, collections.abc.MutableMapping):
# The sequence type may have extra properties, so we can't just
# use `type(data)(...)` to create the new sequence.
# Create a clone and update it if the sequence type is mutable.
clone = copy.copy(data)
clone.update(
{k: pin_memory(sample, device) for k, sample in data.items()}
)
return clone
else:
return type(data)({k: pin_memory(sample, device) for k, sample in data.items()}) # type: ignore[call-arg]
except TypeError:
# The mapping type may not support `copy()` / `update(mapping)`
# or `__init__(iterable)`.
return {k: pin_memory(sample, device) for k, sample in data.items()}
elif isinstance(data, tuple) and hasattr(data, "_fields"): # namedtuple
return type(data)(*(pin_memory(sample, device) for sample in data))
elif isinstance(data, tuple):
return [
pin_memory(sample, device) for sample in data
] # Backwards compatibility.
elif isinstance(data, collections.abc.Sequence):
try:
if isinstance(data, collections.abc.MutableSequence):
# The sequence type may have extra properties, so we can't just
# use `type(data)(...)` to create the new sequence.
# Create a clone and update it if the sequence type is mutable.
clone = copy.copy(data) # type: ignore[arg-type]
for i, item in enumerate(data):
clone[i] = pin_memory(item, device)
return clone
return type(data)([pin_memory(sample, device) for sample in data]) # type: ignore[call-arg]
except TypeError:
# The sequence type may not support `copy()` / `__setitem__(index, item)`
# or `__init__(iterable)` (e.g., `range`).
return [pin_memory(sample, device) for sample in data]
elif hasattr(data, "pin_memory"):
return data.pin_memory()
else:
return data
|