1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
|
import copy
import warnings
from torch.utils.data.datapipes.datapipe import IterDataPipe
__all__ = ["IterableWrapperIterDataPipe", ]
class IterableWrapperIterDataPipe(IterDataPipe):
r"""
Wraps an iterable object to create an IterDataPipe.
Args:
iterable: Iterable object to be wrapped into an IterDataPipe
deepcopy: Option to deepcopy input iterable object for each
iterator. The copy is made when the first element is read in ``iter()``.
.. note::
If ``deepcopy`` is explicitly set to ``False``, users should ensure
that the data pipeline doesn't contain any in-place operations over
the iterable instance to prevent data inconsistency across iterations.
Example:
>>> # xdoctest: +SKIP
>>> from torchdata.datapipes.iter import IterableWrapper
>>> dp = IterableWrapper(range(10))
>>> list(dp)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
"""
def __init__(self, iterable, deepcopy=True):
self.iterable = iterable
self.deepcopy = deepcopy
def __iter__(self):
source_data = self.iterable
if self.deepcopy:
try:
source_data = copy.deepcopy(self.iterable)
# For the case that data cannot be deep-copied,
# all in-place operations will affect iterable variable.
# When this DataPipe is iterated second time, it will
# yield modified items.
except TypeError:
warnings.warn(
"The input iterable can not be deepcopied, "
"please be aware of in-place modification would affect source data."
)
for data in source_data:
yield data
def __len__(self):
return len(self.iterable)
|