1# mypy: allow-untyped-defs 2import random 3from typing import Iterator, List, Optional, TypeVar 4 5import torch 6from torch.utils.data.datapipes.datapipe import IterDataPipe, MapDataPipe 7 8 9__all__ = ["ShufflerIterDataPipe"] 10 11 12_T_co = TypeVar("_T_co", covariant=True) 13 14 15# @functional_datapipe('shuffle') 16class ShufflerIterDataPipe(IterDataPipe[_T_co]): 17 r""" 18 Shuffle the input MapDataPipe via its indices (functional name: ``shuffle``). 19 20 When it is used with :class:`~torch.utils.data.DataLoader`, the methods to 21 set up random seed are different based on :attr:`num_workers`. 22 23 For single-process mode (:attr:`num_workers == 0`), the random seed is set before 24 the :class:`~torch.utils.data.DataLoader` in the main process. For multi-process 25 mode (:attr:`num_worker > 0`), ``worker_init_fn`` is used to set up a random seed 26 for each worker process. 27 28 Args: 29 datapipe: MapDataPipe being shuffled 30 indices: a list of indices of the MapDataPipe. If not provided, we assume it uses 0-based indexing 31 32 Example: 33 >>> # xdoctest: +SKIP 34 >>> from torchdata.datapipes.map import SequenceWrapper 35 >>> dp = SequenceWrapper(range(10)) 36 >>> shuffle_dp = dp.shuffle().set_seed(0) 37 >>> list(shuffle_dp) 38 [7, 8, 1, 5, 3, 4, 2, 0, 9, 6] 39 >>> list(shuffle_dp) 40 [6, 1, 9, 5, 2, 4, 7, 3, 8, 0] 41 >>> # Reset seed for Shuffler 42 >>> shuffle_dp = shuffle_dp.set_seed(0) 43 >>> list(shuffle_dp) 44 [7, 8, 1, 5, 3, 4, 2, 0, 9, 6] 45 46 Note: 47 Even thought this ``shuffle`` operation takes a ``MapDataPipe`` as the input, it would return an 48 ``IterDataPipe`` rather than a ``MapDataPipe``, because ``MapDataPipe`` should be non-sensitive to 49 the order of data order for the sake of random reads, but ``IterDataPipe`` depends on the order 50 of data during data-processing. 51 """ 52 53 datapipe: MapDataPipe[_T_co] 54 _enabled: bool 55 _seed: Optional[int] 56 _rng: random.Random 57 58 def __init__( 59 self, 60 datapipe: MapDataPipe[_T_co], 61 *, 62 indices: Optional[List] = None, 63 ) -> None: 64 super().__init__() 65 self.datapipe = datapipe 66 self.indices = list(range(len(datapipe))) if indices is None else indices 67 self._enabled = True 68 self._seed = None 69 self._rng = random.Random() 70 self._shuffled_indices: List = self.indices 71 72 def set_shuffle(self, shuffle=True): 73 self._enabled = shuffle 74 return self 75 76 def set_seed(self, seed: int): 77 self._seed = seed 78 return self 79 80 def __iter__(self) -> Iterator[_T_co]: 81 if not self._enabled: 82 for idx in self.indices: 83 yield self.datapipe[idx] 84 else: 85 while self._shuffled_indices: 86 idx = self._shuffled_indices.pop() 87 yield self.datapipe[idx] 88 89 def reset(self) -> None: 90 if self._enabled and self._seed is None: 91 self._seed = int(torch.empty((), dtype=torch.int64).random_().item()) 92 self._rng.seed(self._seed) 93 self._seed = None 94 self._shuffled_indices = self._rng.sample(self.indices, len(self.indices)) 95 96 def __len__(self) -> int: 97 return len(self.datapipe) 98 99 def __getstate__(self): 100 state = ( 101 self.datapipe, 102 self.indices, 103 self._enabled, 104 self._seed, 105 self._rng.getstate(), 106 self._shuffled_indices, 107 self._valid_iterator_id, 108 self._number_of_samples_yielded, 109 ) 110 if IterDataPipe.getstate_hook is not None: 111 return IterDataPipe.getstate_hook(state) 112 return state 113 114 def __setstate__(self, state): 115 ( 116 self.datapipe, 117 self.indices, 118 self._enabled, 119 self._seed, 120 rng_state, 121 self._shuffled_indices, 122 self._valid_iterator_id, 123 self._number_of_samples_yielded, 124 ) = state 125 self._rng = random.Random() 126 self._rng.setstate(rng_state) 127 128 129MapDataPipe.register_datapipe_as_function("shuffle", ShufflerIterDataPipe) 130