1# mypy: allow-untyped-defs 2import copy 3import warnings 4 5from torch.utils.data.datapipes.datapipe import MapDataPipe 6 7 8__all__ = ["SequenceWrapperMapDataPipe"] 9 10 11class SequenceWrapperMapDataPipe(MapDataPipe): 12 r""" 13 Wraps a sequence object into a MapDataPipe. 14 15 Args: 16 sequence: Sequence object to be wrapped into an MapDataPipe 17 deepcopy: Option to deepcopy input sequence object 18 19 .. note:: 20 If ``deepcopy`` is set to False explicitly, users should ensure 21 that data pipeline doesn't contain any in-place operations over 22 the iterable instance, in order to prevent data inconsistency 23 across iterations. 24 25 Example: 26 >>> # xdoctest: +SKIP 27 >>> from torchdata.datapipes.map import SequenceWrapper 28 >>> dp = SequenceWrapper(range(10)) 29 >>> list(dp) 30 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 31 >>> dp = SequenceWrapper({'a': 100, 'b': 200, 'c': 300, 'd': 400}) 32 >>> dp['a'] 33 100 34 """ 35 36 def __init__(self, sequence, deepcopy=True): 37 if deepcopy: 38 try: 39 self.sequence = copy.deepcopy(sequence) 40 except TypeError: 41 warnings.warn( 42 "The input sequence can not be deepcopied, " 43 "please be aware of in-place modification would affect source data" 44 ) 45 self.sequence = sequence 46 else: 47 self.sequence = sequence 48 49 def __getitem__(self, index): 50 return self.sequence[index] 51 52 def __len__(self): 53 return len(self.sequence) 54