1# mypy: allow-untyped-defs 2from torch.utils.data.datapipes._hook_iterator import _SnapshotState 3from torch.utils.data.datapipes.datapipe import IterDataPipe 4from torch.utils.data.graph_settings import apply_random_seed 5 6 7# TODO: Caveats 8# 1. Caller (either the ReadingService or DataLoader) must pass in the initial RNG 9# 2. `in_batch_shuffle` and `bucketbatch` are not compatible with this because they currently 10# lack the option to `set_seed`. 11def _simple_graph_snapshot_restoration( 12 datapipe: IterDataPipe, n_iterations: int, rng=None 13) -> None: 14 r""" 15 Fast-forward the given DataPipe and its parents by ``n_iterations``, re-doing computations to restore a snapshot. 16 17 For instance, applying this function to the final DataPipe of a graph will restore the snapshot 18 (via fast-forward) every DataPipe within the graph. 19 20 After you deserialize a DataPipe, you can use its `_number_of_samples_yielded` attribute as the input 21 to this function to forward the DataPipe. 22 23 A DataPipe cannot be restored twice in a row unless there is an iteration started between the restoration 24 attempts. 25 26 Note: 27 This is the simplest but least efficient way to fast-forward a DataPipe. Usage of other fast-forwarding 28 methods (custom ones if necessary) are recommended. 29 30 Args: 31 datapipe: IterDataPipe to be fast-forwarded 32 n_iterations: number of iterations to fast-forward 33 rng: ``Optional[torch.Generator]``. If not ``None``, this RNG will be used for shuffling. The generator 34 should be in its `initial` state as it was first passed into ``DataLoader`` or ``ReadingService``. 35 """ 36 if datapipe._snapshot_state == _SnapshotState.Restored: 37 raise RuntimeError( 38 "Snapshot restoration cannot be applied. You can only restore simple snapshot to the graph " 39 "if your graph has not been restored." 40 ) 41 42 # For this snapshot restoration function, we want the DataPipe to be at its initial state prior to 43 # simple fast-forwarding. Therefore, we need to call `reset` twice, because if `SnapshotState` is `Restored`, 44 # the first reset will not actually reset. 45 datapipe.reset() # This ensures `SnapshotState` is `Iterating` by this point, even if it was `Restored`. 46 apply_random_seed(datapipe, rng) 47 48 remainder = n_iterations 49 it = iter(datapipe) # This always reset the DataPipe if it hasn't already. 50 while remainder > 0: 51 try: 52 next(it) 53 remainder -= 1 54 except StopIteration as e: 55 raise RuntimeError( 56 f"Fast-forward {datapipe} by {n_iterations} iterations " 57 "exceeds the number of samples available." 58 ) from e 59 datapipe._fast_forward_iterator = it 60 # While the DataPipe has `_fast_forward_iterator`, `next()` will get result from there instead of elsewhere. 61 62 # This will prevent the DataPipe from resetting in the `iter()` call 63 # If another DataPipe is consuming it, it won't have to start over again 64 datapipe._snapshot_state = _SnapshotState.Restored 65