1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
|
# mypy: allow-untyped-defs
from typing import Sized, Tuple, TypeVar
from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data.datapipes.datapipe import MapDataPipe
__all__ = ["ConcaterMapDataPipe", "ZipperMapDataPipe"]
_T_co = TypeVar("_T_co", covariant=True)
@functional_datapipe("concat")
class ConcaterMapDataPipe(MapDataPipe):
r"""
Concatenate multiple Map DataPipes (functional name: ``concat``).
The new index of is the cumulative sum of source DataPipes.
For example, if there are 2 source DataPipes both with length 5,
index 0 to 4 of the resulting `ConcatMapDataPipe` would refer to
elements of the first DataPipe, and 5 to 9 would refer to elements
of the second DataPipe.
Args:
datapipes: Map DataPipes being concatenated
Example:
>>> # xdoctest: +SKIP
>>> from torchdata.datapipes.map import SequenceWrapper
>>> dp1 = SequenceWrapper(range(3))
>>> dp2 = SequenceWrapper(range(3))
>>> concat_dp = dp1.concat(dp2)
>>> list(concat_dp)
[0, 1, 2, 0, 1, 2]
"""
datapipes: Tuple[MapDataPipe]
def __init__(self, *datapipes: MapDataPipe):
if len(datapipes) == 0:
raise ValueError("Expected at least one DataPipe, but got nothing")
if not all(isinstance(dp, MapDataPipe) for dp in datapipes):
raise TypeError("Expected all inputs to be `MapDataPipe`")
if not all(isinstance(dp, Sized) for dp in datapipes):
raise TypeError("Expected all inputs to be `Sized`")
self.datapipes = datapipes # type: ignore[assignment]
def __getitem__(self, index) -> _T_co: # type: ignore[type-var]
offset = 0
for dp in self.datapipes:
if index - offset < len(dp):
return dp[index - offset]
else:
offset += len(dp)
raise IndexError(f"Index {index} is out of range.")
def __len__(self) -> int:
return sum(len(dp) for dp in self.datapipes)
@functional_datapipe("zip")
class ZipperMapDataPipe(MapDataPipe[Tuple[_T_co, ...]]):
r"""
Aggregates elements into a tuple from each of the input DataPipes (functional name: ``zip``).
This MataPipe is out of bound as soon as the shortest input DataPipe is exhausted.
Args:
*datapipes: Map DataPipes being aggregated
Example:
>>> # xdoctest: +SKIP
>>> from torchdata.datapipes.map import SequenceWrapper
>>> dp1 = SequenceWrapper(range(3))
>>> dp2 = SequenceWrapper(range(10, 13))
>>> zip_dp = dp1.zip(dp2)
>>> list(zip_dp)
[(0, 10), (1, 11), (2, 12)]
"""
datapipes: Tuple[MapDataPipe[_T_co], ...]
def __init__(self, *datapipes: MapDataPipe[_T_co]) -> None:
if len(datapipes) == 0:
raise ValueError("Expected at least one DataPipe, but got nothing")
if not all(isinstance(dp, MapDataPipe) for dp in datapipes):
raise TypeError("Expected all inputs to be `MapDataPipe`")
if not all(isinstance(dp, Sized) for dp in datapipes):
raise TypeError("Expected all inputs to be `Sized`")
self.datapipes = datapipes
def __getitem__(self, index) -> Tuple[_T_co, ...]:
res = []
for dp in self.datapipes:
try:
res.append(dp[index])
except IndexError as e:
raise IndexError(
f"Index {index} is out of range for one of the input MapDataPipes {dp}."
) from e
return tuple(res)
def __len__(self) -> int:
return min(len(dp) for dp in self.datapipes)
|