1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
|
from torch.utils.data.datapipes._hook_iterator import _SnapshotState
from torch.utils.data.datapipes.datapipe import IterDataPipe
from torch.utils.data.graph_settings import apply_random_seed
# TODO: Caveats
# 1. Caller (either the ReadingService or DataLoader) must pass in the initial RNG
# 2. `in_batch_shuffle` and `bucketbatch` are not compatible with this because they currently
# lack the option to `set_seed`.
def _simple_graph_snapshot_restoration(datapipe: IterDataPipe, n_iterations: int, rng=None) -> None:
r"""
This function will restore a snapshot by fast-forwarding the given DataPipe by ``n_iterations``,
and in the process, fast-forward its parent DataPipes as well at the cost of re-doing every computation.
For instance, applying this function to the final DataPipe of a graph will restore the snapshot
(via fast-forward) every DataPipe within the graph.
After you deserialize a DataPipe, you can use its `_number_of_samples_yielded` attribute as the input
to this function to forward the DataPipe.
A DataPipe cannot be restored twice in a row unless there is an iteration started between the restoration
attempts.
Note:
This is the simplest but least efficient way to fast-forward a DataPipe. Usage of other fast-forwarding
methods (custom ones if necessary) are recommended.
Args:
datapipe: IterDataPipe to be fast-forwarded
n_iterations: number of iterations to fast-forward
rng: ``Optional[torch.Generator]``. If not ``None``, this RNG will be used for shuffling. The generator
should be in its `initial` state as it was first passed into ``DataLoader`` or ``ReadingService``.
"""
if datapipe._snapshot_state == _SnapshotState.Restored:
raise RuntimeError(
"Snapshot restoration cannot be applied. You can only restore simple snapshot to the graph "
"if your graph has not been restored.")
# For this snapshot restoration function, we want the DataPipe to be at its initial state prior to
# simple fast-forwarding. Therefore, we need to call `reset` twice, because if `SnapshotState` is `Restored`,
# the first reset will not actually reset.
datapipe.reset() # This ensures `SnapshotState` is `Iterating` by this point, even if it was `Restored`.
apply_random_seed(datapipe, rng)
remainder = n_iterations
it = iter(datapipe) # This always reset the DataPipe if it hasn't already.
while remainder > 0:
try:
next(it)
remainder -= 1
except StopIteration:
raise RuntimeError(f"Fast-forward {datapipe} by {n_iterations} iterations "
"exceeds the number of samples available.")
datapipe._fast_forward_iterator = it
# While the DataPipe has `_fast_forward_iterator`, `next()` will get result from there instead of elsewhere.
# This will prevent the DataPipe from resetting in the `iter()` call
# If another DataPipe is consuming it, it won't have to start over again
datapipe._snapshot_state = _SnapshotState.Restored
|