• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2import random
3from typing import Iterator, List, Optional, TypeVar
4
5import torch
6from torch.utils.data.datapipes.datapipe import IterDataPipe, MapDataPipe
7
8
9__all__ = ["ShufflerIterDataPipe"]
10
11
12_T_co = TypeVar("_T_co", covariant=True)
13
14
15# @functional_datapipe('shuffle')
16class ShufflerIterDataPipe(IterDataPipe[_T_co]):
17    r"""
18    Shuffle the input MapDataPipe via its indices (functional name: ``shuffle``).
19
20    When it is used with :class:`~torch.utils.data.DataLoader`, the methods to
21    set up random seed are different based on :attr:`num_workers`.
22
23    For single-process mode (:attr:`num_workers == 0`), the random seed is set before
24    the :class:`~torch.utils.data.DataLoader` in the main process. For multi-process
25    mode (:attr:`num_worker > 0`), ``worker_init_fn`` is used to set up a random seed
26    for each worker process.
27
28    Args:
29        datapipe: MapDataPipe being shuffled
30        indices: a list of indices of the MapDataPipe. If not provided, we assume it uses 0-based indexing
31
32    Example:
33        >>> # xdoctest: +SKIP
34        >>> from torchdata.datapipes.map import SequenceWrapper
35        >>> dp = SequenceWrapper(range(10))
36        >>> shuffle_dp = dp.shuffle().set_seed(0)
37        >>> list(shuffle_dp)
38        [7, 8, 1, 5, 3, 4, 2, 0, 9, 6]
39        >>> list(shuffle_dp)
40        [6, 1, 9, 5, 2, 4, 7, 3, 8, 0]
41        >>> # Reset seed for Shuffler
42        >>> shuffle_dp = shuffle_dp.set_seed(0)
43        >>> list(shuffle_dp)
44        [7, 8, 1, 5, 3, 4, 2, 0, 9, 6]
45
46    Note:
47        Even thought this ``shuffle`` operation takes a ``MapDataPipe`` as the input, it would return an
48        ``IterDataPipe`` rather than a ``MapDataPipe``, because ``MapDataPipe`` should be non-sensitive to
49        the order of data order for the sake of random reads, but ``IterDataPipe`` depends on the order
50        of data during data-processing.
51    """
52
53    datapipe: MapDataPipe[_T_co]
54    _enabled: bool
55    _seed: Optional[int]
56    _rng: random.Random
57
58    def __init__(
59        self,
60        datapipe: MapDataPipe[_T_co],
61        *,
62        indices: Optional[List] = None,
63    ) -> None:
64        super().__init__()
65        self.datapipe = datapipe
66        self.indices = list(range(len(datapipe))) if indices is None else indices
67        self._enabled = True
68        self._seed = None
69        self._rng = random.Random()
70        self._shuffled_indices: List = self.indices
71
72    def set_shuffle(self, shuffle=True):
73        self._enabled = shuffle
74        return self
75
76    def set_seed(self, seed: int):
77        self._seed = seed
78        return self
79
80    def __iter__(self) -> Iterator[_T_co]:
81        if not self._enabled:
82            for idx in self.indices:
83                yield self.datapipe[idx]
84        else:
85            while self._shuffled_indices:
86                idx = self._shuffled_indices.pop()
87                yield self.datapipe[idx]
88
89    def reset(self) -> None:
90        if self._enabled and self._seed is None:
91            self._seed = int(torch.empty((), dtype=torch.int64).random_().item())
92        self._rng.seed(self._seed)
93        self._seed = None
94        self._shuffled_indices = self._rng.sample(self.indices, len(self.indices))
95
96    def __len__(self) -> int:
97        return len(self.datapipe)
98
99    def __getstate__(self):
100        state = (
101            self.datapipe,
102            self.indices,
103            self._enabled,
104            self._seed,
105            self._rng.getstate(),
106            self._shuffled_indices,
107            self._valid_iterator_id,
108            self._number_of_samples_yielded,
109        )
110        if IterDataPipe.getstate_hook is not None:
111            return IterDataPipe.getstate_hook(state)
112        return state
113
114    def __setstate__(self, state):
115        (
116            self.datapipe,
117            self.indices,
118            self._enabled,
119            self._seed,
120            rng_state,
121            self._shuffled_indices,
122            self._valid_iterator_id,
123            self._number_of_samples_yielded,
124        ) = state
125        self._rng = random.Random()
126        self._rng.setstate(rng_state)
127
128
129MapDataPipe.register_datapipe_as_function("shuffle", ShufflerIterDataPipe)
130