• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2import copy
3import warnings
4
5from torch.utils.data.datapipes.datapipe import MapDataPipe
6
7
8__all__ = ["SequenceWrapperMapDataPipe"]
9
10
11class SequenceWrapperMapDataPipe(MapDataPipe):
12    r"""
13    Wraps a sequence object into a MapDataPipe.
14
15    Args:
16        sequence: Sequence object to be wrapped into an MapDataPipe
17        deepcopy: Option to deepcopy input sequence object
18
19    .. note::
20      If ``deepcopy`` is set to False explicitly, users should ensure
21      that data pipeline doesn't contain any in-place operations over
22      the iterable instance, in order to prevent data inconsistency
23      across iterations.
24
25    Example:
26        >>> # xdoctest: +SKIP
27        >>> from torchdata.datapipes.map import SequenceWrapper
28        >>> dp = SequenceWrapper(range(10))
29        >>> list(dp)
30        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
31        >>> dp = SequenceWrapper({'a': 100, 'b': 200, 'c': 300, 'd': 400})
32        >>> dp['a']
33        100
34    """
35
36    def __init__(self, sequence, deepcopy=True):
37        if deepcopy:
38            try:
39                self.sequence = copy.deepcopy(sequence)
40            except TypeError:
41                warnings.warn(
42                    "The input sequence can not be deepcopied, "
43                    "please be aware of in-place modification would affect source data"
44                )
45                self.sequence = sequence
46        else:
47            self.sequence = sequence
48
49    def __getitem__(self, index):
50        return self.sequence[index]
51
52    def __len__(self):
53        return len(self.sequence)
54