1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
|
import inspect
import functools
from enum import Enum
import torch.autograd
class _SnapshotState(Enum):
r"""
These are the snapshotting-related states that IterDataPipes can be in.
`NotStarted` - allows you to restore a snapshot and create an iterator with reset
`Restored` - cannot restore again, allows you to create an iterator without resetting the DataPipe
`Iterating` - can restore, will reset if you create a new iterator
"""
NotStarted = 0
Restored = 1
Iterating = 2
def _simplify_obj_name(obj) -> str:
"""
Simplify the display strings of objects for the purpose of rendering within DataPipe error messages.
"""
if inspect.isfunction(obj):
return obj.__name__
else:
return repr(obj)
def _generate_input_args_string(obj):
"""
Generate a string for the input arguments of an object.
"""
signature = inspect.signature(obj.__class__)
input_param_names = set()
for param_name, _ in signature.parameters.items():
input_param_names.add(param_name)
result = []
for name, obj in inspect.getmembers(obj):
if name in input_param_names:
result.append((name, _simplify_obj_name(obj)))
return ', '.join([f'{name}={value}' for name, value in result])
def _generate_iterdatapipe_msg(datapipe):
return f"{datapipe.__class__.__name__}({_generate_input_args_string(datapipe)})"
def _gen_invalid_iterdatapipe_msg(datapipe):
return ("This iterator has been invalidated because another iterator has been created "
f"from the same IterDataPipe: {_generate_iterdatapipe_msg(datapipe)}\n"
"This may be caused multiple references to the same IterDataPipe. We recommend "
"using `.fork()` if that is necessary.")
_feedback_msg = ("\nFor feedback regarding this single iterator per IterDataPipe constraint, feel free "
"to comment on this issue: https://github.com/pytorch/data/issues/45.")
def _check_iterator_valid(datapipe, iterator_id, next_method_exists=False) -> None:
r"""
Given an instance of a DataPipe and an iterator ID, check if the IDs match, and if not, raises an exception.
In the case of ChildDataPipe, the ID gets compared to the one stored in `main_datapipe` as well.
"""
if next_method_exists:
# This is the case where `IterDataPipe` has both `__iter__` and `__next__`.
# The `_valid_iterator_id` should either be never set (`None`), or set by at most one
# iterator (`0`). Otherwise, it means there are multiple iterators.
if datapipe._valid_iterator_id is not None and datapipe._valid_iterator_id != 0:
extra_msg = "\nNote that this exception is raised inside your IterDataPipe's a `__next__` method"
raise RuntimeError(_gen_invalid_iterdatapipe_msg(datapipe) + extra_msg + _feedback_msg)
elif hasattr(datapipe, "_is_child_datapipe") and datapipe._is_child_datapipe is True:
if hasattr(datapipe, "_check_valid_iterator_id"):
if not datapipe._check_valid_iterator_id(iterator_id):
raise RuntimeError("This iterator has been invalidated, because a new iterator has been created "
f"from one of the ChildDataPipes of "
f"{_generate_iterdatapipe_msg(datapipe.main_datapipe)}." + _feedback_msg)
else:
raise RuntimeError("ChildDataPipe must have method `_check_valid_iterator_id`.")
elif datapipe._valid_iterator_id != iterator_id:
raise RuntimeError(_gen_invalid_iterdatapipe_msg(datapipe) + _feedback_msg)
def _set_datapipe_valid_iterator_id(datapipe):
r"""
Given a DataPipe, updates its valid iterator ID and reset the DataPipe.
"""
if hasattr(datapipe, "_is_child_datapipe") and datapipe._is_child_datapipe is True:
if hasattr(datapipe, "_set_main_datapipe_valid_iterator_id"):
datapipe._set_main_datapipe_valid_iterator_id() # reset() is called within this method when appropriate
else:
raise RuntimeError("ChildDataPipe must have method `_set_main_datapipe_valid_iterator_id`.")
else:
if datapipe._valid_iterator_id is None:
datapipe._valid_iterator_id = 0
else:
datapipe._valid_iterator_id += 1
datapipe.reset()
return datapipe._valid_iterator_id
def hook_iterator(namespace, profile_name):
r"""
Hook that is applied to all `__iter__` of metaclass `_DataPipeMeta`. This is done for the purpose of
profiling and checking if an iterator is still valid.
"""
def profiler_record_fn_context():
return torch.autograd.profiler.record_function(profile_name)
class IteratorDecorator:
r"""
Wrap the iterator and modifying its `__next__` method. This decorator is applied to
DataPipes of which `__iter__` method is NOT a generator function. Those `__iter__`
method commonly returns `self` but not necessarily.
"""
def __init__(self, iterator, source_dp, iterator_id, has_next_method):
self.iterator = iterator
self.source_dp = source_dp
self.iterator_id = iterator_id
self._profiler_enabled = torch.autograd._profiler_enabled()
# Check if `__iter__` returns `self` and `DataPipe` has `__next__`
self.self_and_has_next_method = self.iterator is self.source_dp and has_next_method
def __iter__(self):
return self
def _get_next(self):
r"""
Return next with logic related to iterator validity, profiler, and incrementation of samples yielded.
"""
_check_iterator_valid(self.source_dp, self.iterator_id)
result = next(self.iterator)
if not self.self_and_has_next_method:
self.source_dp._number_of_samples_yielded += 1
return result
def __next__(self):
# TODO: Add try-except to in-place reduce traceback from the Exception
# See: https://github.com/pytorch/data/issues/284
if self._profiler_enabled:
with profiler_record_fn_context():
return self._get_next()
else: # Decided against using `contextlib.nullcontext` for performance reasons
return self._get_next()
def __getattr__(self, name):
return getattr(self.iterator, name)
func = namespace['__iter__']
# ``__iter__`` of IterDataPipe is a generator function
if inspect.isgeneratorfunction(func):
@functools.wraps(func)
def wrap_generator(*args, **kwargs):
gen = func(*args, **kwargs)
datapipe = args[0]
if datapipe._fast_forward_iterator:
it = datapipe._fast_forward_iterator
datapipe._fast_forward_iterator = None
datapipe._snapshot_state = _SnapshotState.Iterating
while True:
try:
yield next(it)
except StopIteration:
return
iterator_id = _set_datapipe_valid_iterator_id(datapipe) # This ID is tied to each created iterator
_profiler_enabled = torch.autograd._profiler_enabled()
try:
if _profiler_enabled:
with profiler_record_fn_context():
response = gen.send(None)
else:
response = gen.send(None)
while True:
datapipe._number_of_samples_yielded += 1
request = yield response
# Pass through here every time `__next__` is called
if _profiler_enabled:
with profiler_record_fn_context():
_check_iterator_valid(datapipe, iterator_id)
response = gen.send(request)
else: # Decided against using `contextlib.nullcontext` for performance reasons
_check_iterator_valid(datapipe, iterator_id)
response = gen.send(request)
except StopIteration as e:
return
except Exception as e:
# TODO: Simplify the traceback message to skip over `response = gen.send(None)`
# Part of https://github.com/pytorch/data/issues/284
datapipe = args[0]
msg = "thrown by __iter__ of"
single_iterator_msg = "single iterator per IterDataPipe constraint"
if hasattr(e.args, '__len__'):
full_msg = f"{msg} {datapipe.__class__.__name__}({_generate_input_args_string(datapipe)})"
if len(e.args) == 0 or not isinstance(e.args[0], str): # If an exception message doesn't exist
e.args = (f'\nThis exception is {full_msg}',)
elif msg not in e.args[0] and single_iterator_msg not in e.args[0]:
e.args = (e.args[0] + f'\nThis exception is {full_msg}',) + e.args[1:]
raise
namespace['__iter__'] = wrap_generator
else: # ``__iter__`` of IterDataPipe is NOT a generator function
# IterDataPipe is an iterator with both ``__iter__`` and ``__next__``
# And ``__iter__`` may or may not return `self`
if '__next__' in namespace: # If `__next__` exists, put a wrapper around it
next_func = namespace['__next__']
@functools.wraps(next_func)
def wrap_next(*args, **kwargs):
if torch.autograd._profiler_enabled():
with profiler_record_fn_context():
result = next_func(*args, **kwargs)
else:
result = next_func(*args, **kwargs)
datapipe = args[0]
datapipe._number_of_samples_yielded += 1
return result
namespace['__next__'] = wrap_next
# Note that if the `__next__` and `__iter__` do something completely unrelated. It may cause issue but
# the user will be violating the iterator protocol. Potential issue:
# 1. Valid iterator ID may not update or checked properly
# 2. The number of samples yielded will be miscounted
# Regardless if `__next__` exists or not, `__iter__` needs a wrapper to track the number of valid iterators
@functools.wraps(func)
def wrap_iter(*args, **kwargs):
iter_ret = func(*args, **kwargs)
datapipe = args[0]
datapipe._snapshot_state = _SnapshotState.Iterating
if datapipe._fast_forward_iterator:
iter_ret = datapipe._fast_forward_iterator
datapipe._fast_forward_iterator = None
return iter_ret
iterator_id = _set_datapipe_valid_iterator_id(datapipe) # This ID is tied to each created iterator
return IteratorDecorator(iter_ret, datapipe, iterator_id, '__next__' in namespace)
namespace['__iter__'] = wrap_iter
|