1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
|
import random
import torch
from torch.utils.data import Sampler, SequentialSampler
from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data.datapipes.datapipe import IterDataPipe
from typing import Dict, Iterator, List, Optional, Sized, Tuple, Type, TypeVar
__all__ = [
"SamplerIterDataPipe",
"ShufflerIterDataPipe",
]
T_co = TypeVar('T_co', covariant=True)
class SamplerIterDataPipe(IterDataPipe[T_co]):
r"""
Generates sample elements using the provided ``Sampler`` (defaults to :class:`SequentialSampler`).
Args:
datapipe: IterDataPipe to sample from
sampler: Sampler class to generate sample elements from input DataPipe.
Default is :class:`SequentialSampler` for IterDataPipe
"""
datapipe: IterDataPipe
sampler: Sampler
def __init__(self,
datapipe: IterDataPipe,
sampler: Type[Sampler] = SequentialSampler,
sampler_args: Optional[Tuple] = None,
sampler_kwargs: Optional[Dict] = None
) -> None:
assert isinstance(datapipe, Sized), \
"Sampler class requires input datapipe implemented `__len__`"
super().__init__()
self.datapipe = datapipe
self.sampler_args = () if sampler_args is None else sampler_args
self.sampler_kwargs = {} if sampler_kwargs is None else sampler_kwargs
# https://github.com/python/mypy/pull/9629 will solve
self.sampler = sampler(data_source=self.datapipe, *self.sampler_args, **self.sampler_kwargs) # type: ignore[misc]
def __iter__(self) -> Iterator[T_co]:
return iter(self.sampler)
def __len__(self) -> int:
# Dataset has been tested as `Sized`
if isinstance(self.sampler, Sized) and len(self.sampler) >= 0:
return len(self.sampler)
raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
@functional_datapipe('shuffle')
class ShufflerIterDataPipe(IterDataPipe[T_co]):
r"""
Shuffles the input DataPipe with a buffer (functional name: ``shuffle``). The buffer
with ``buffer_size`` is filled with elements from the datapipe first. Then,
each item will be yielded from the buffer by reservoir sampling via iterator.
``buffer_size`` is required to be larger than ``0``. For ``buffer_size == 1``, the
datapipe is not shuffled. In order to fully shuffle all elements from datapipe,
``buffer_size`` is required to be greater than or equal to the size of datapipe.
When it is used with :class:`torch.utils.data.DataLoader`, the methods to
set up random seed are different based on :attr:`num_workers`.
For single-process mode (:attr:`num_workers == 0`), the random seed is set before
the :class:`~torch.utils.data.DataLoader` in the main process. For multi-process
mode (:attr:`num_worker > 0`), `worker_init_fn` is used to set up a random seed
for each worker process.
Args:
datapipe: The IterDataPipe being shuffled
buffer_size: The buffer size for shuffling (default to ``10000``)
unbatch_level: Specifies if it is necessary to unbatch source data before
applying the shuffle
Example:
>>> # xdoctest: +SKIP
>>> from torchdata.datapipes.iter import IterableWrapper
>>> dp = IterableWrapper(range(10))
>>> shuffle_dp = dp.shuffle()
>>> list(shuffle_dp)
[0, 4, 1, 6, 3, 2, 9, 5, 7, 8]
"""
datapipe: IterDataPipe[T_co]
buffer_size: int
_buffer: List[T_co]
_enabled: bool
_seed: Optional[int]
_rng: random.Random
def __init__(self,
datapipe: IterDataPipe[T_co],
*,
buffer_size: int = 10000,
unbatch_level: int = 0
) -> None:
super().__init__()
# TODO: Performance optimization
# buffer can be a fixed size and remove expensive `append()` and `len()` operations
self._buffer: List[T_co] = []
assert buffer_size > 0, "buffer_size should be larger than 0"
if unbatch_level == 0:
self.datapipe = datapipe
else:
self.datapipe = datapipe.unbatch(unbatch_level=unbatch_level)
self.buffer_size = buffer_size
self._enabled = True
self._seed = None
self._rng = random.Random()
def set_shuffle(self, shuffle=True):
self._enabled = shuffle
return self
def set_seed(self, seed: int):
self._seed = seed
return self
def __iter__(self) -> Iterator[T_co]:
if not self._enabled:
for x in self.datapipe:
yield x
else:
for x in self.datapipe:
if len(self._buffer) == self.buffer_size:
idx = self._rng.randint(0, len(self._buffer) - 1)
val, self._buffer[idx] = self._buffer[idx], x
yield val
else:
self._buffer.append(x)
while self._buffer:
idx = self._rng.randint(0, len(self._buffer) - 1)
yield self._buffer.pop(idx)
def __len__(self) -> int:
if isinstance(self.datapipe, Sized):
return len(self.datapipe)
raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
def reset(self) -> None:
self._buffer = []
if self._enabled:
if self._seed is None:
self._seed = int(torch.empty((), dtype=torch.int64).random_().item())
self._rng.seed(self._seed)
self._seed = None
def __getstate__(self):
state = (
self.datapipe,
self.buffer_size,
self._enabled,
self._seed,
self._buffer,
self._rng.getstate(),
self._valid_iterator_id,
self._number_of_samples_yielded,
)
if IterDataPipe.getstate_hook is not None:
return IterDataPipe.getstate_hook(state)
return state
def __setstate__(self, state):
(
self.datapipe,
self.buffer_size,
self._enabled,
self._seed,
self._buffer,
rng_state,
self._valid_iterator_id,
self._number_of_samples_yielded,
) = state
self._rng = random.Random()
self._rng.setstate(rng_state)
def __del__(self):
self._buffer.clear()
|