• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2from torch.utils.data.datapipes._hook_iterator import _SnapshotState
3from torch.utils.data.datapipes.datapipe import IterDataPipe
4from torch.utils.data.graph_settings import apply_random_seed
5
6
7# TODO: Caveats
8#   1. Caller (either the ReadingService or DataLoader) must pass in the initial RNG
9#   2. `in_batch_shuffle` and `bucketbatch` are not compatible with this because they currently
10#      lack the option to `set_seed`.
11def _simple_graph_snapshot_restoration(
12    datapipe: IterDataPipe, n_iterations: int, rng=None
13) -> None:
14    r"""
15    Fast-forward the given DataPipe and its parents by ``n_iterations``, re-doing computations to restore a snapshot.
16
17    For instance, applying this function to the final DataPipe of a graph will restore the snapshot
18    (via fast-forward) every DataPipe within the graph.
19
20    After you deserialize a DataPipe, you can use its `_number_of_samples_yielded` attribute as the input
21    to this function to forward the DataPipe.
22
23    A DataPipe cannot be restored twice in a row unless there is an iteration started between the restoration
24    attempts.
25
26    Note:
27        This is the simplest but least efficient way to fast-forward a DataPipe. Usage of other fast-forwarding
28        methods (custom ones if necessary) are recommended.
29
30    Args:
31        datapipe: IterDataPipe to be fast-forwarded
32        n_iterations: number of iterations to fast-forward
33        rng: ``Optional[torch.Generator]``. If not ``None``, this RNG will be used for shuffling. The generator
34            should be in its `initial` state as it was first passed into ``DataLoader`` or ``ReadingService``.
35    """
36    if datapipe._snapshot_state == _SnapshotState.Restored:
37        raise RuntimeError(
38            "Snapshot restoration cannot be applied. You can only restore simple snapshot to the graph "
39            "if your graph has not been restored."
40        )
41
42    # For this snapshot restoration function, we want the DataPipe to be at its initial state prior to
43    # simple fast-forwarding. Therefore, we need to call `reset` twice, because if `SnapshotState` is `Restored`,
44    # the first reset will not actually reset.
45    datapipe.reset()  # This ensures `SnapshotState` is `Iterating` by this point, even if it was `Restored`.
46    apply_random_seed(datapipe, rng)
47
48    remainder = n_iterations
49    it = iter(datapipe)  # This always reset the DataPipe if it hasn't already.
50    while remainder > 0:
51        try:
52            next(it)
53            remainder -= 1
54        except StopIteration as e:
55            raise RuntimeError(
56                f"Fast-forward {datapipe} by {n_iterations} iterations "
57                "exceeds the number of samples available."
58            ) from e
59    datapipe._fast_forward_iterator = it
60    # While the DataPipe has `_fast_forward_iterator`, `next()` will get result from there instead of elsewhere.
61
62    # This will prevent the DataPipe from resetting in the `iter()` call
63    # If another DataPipe is consuming it, it won't have to start over again
64    datapipe._snapshot_state = _SnapshotState.Restored
65