1# mypy: allow-untyped-defs 2from typing import Sized, Tuple, TypeVar 3 4from torch.utils.data.datapipes._decorator import functional_datapipe 5from torch.utils.data.datapipes.datapipe import MapDataPipe 6 7 8__all__ = ["ConcaterMapDataPipe", "ZipperMapDataPipe"] 9 10_T_co = TypeVar("_T_co", covariant=True) 11 12 13@functional_datapipe("concat") 14class ConcaterMapDataPipe(MapDataPipe): 15 r""" 16 Concatenate multiple Map DataPipes (functional name: ``concat``). 17 18 The new index of is the cumulative sum of source DataPipes. 19 For example, if there are 2 source DataPipes both with length 5, 20 index 0 to 4 of the resulting `ConcatMapDataPipe` would refer to 21 elements of the first DataPipe, and 5 to 9 would refer to elements 22 of the second DataPipe. 23 24 Args: 25 datapipes: Map DataPipes being concatenated 26 27 Example: 28 >>> # xdoctest: +SKIP 29 >>> from torchdata.datapipes.map import SequenceWrapper 30 >>> dp1 = SequenceWrapper(range(3)) 31 >>> dp2 = SequenceWrapper(range(3)) 32 >>> concat_dp = dp1.concat(dp2) 33 >>> list(concat_dp) 34 [0, 1, 2, 0, 1, 2] 35 """ 36 37 datapipes: Tuple[MapDataPipe] 38 39 def __init__(self, *datapipes: MapDataPipe): 40 if len(datapipes) == 0: 41 raise ValueError("Expected at least one DataPipe, but got nothing") 42 if not all(isinstance(dp, MapDataPipe) for dp in datapipes): 43 raise TypeError("Expected all inputs to be `MapDataPipe`") 44 if not all(isinstance(dp, Sized) for dp in datapipes): 45 raise TypeError("Expected all inputs to be `Sized`") 46 self.datapipes = datapipes # type: ignore[assignment] 47 48 def __getitem__(self, index) -> _T_co: # type: ignore[type-var] 49 offset = 0 50 for dp in self.datapipes: 51 if index - offset < len(dp): 52 return dp[index - offset] 53 else: 54 offset += len(dp) 55 raise IndexError(f"Index {index} is out of range.") 56 57 def __len__(self) -> int: 58 return sum(len(dp) for dp in self.datapipes) 59 60 61@functional_datapipe("zip") 62class ZipperMapDataPipe(MapDataPipe[Tuple[_T_co, ...]]): 63 r""" 64 Aggregates elements into a tuple from each of the input DataPipes (functional name: ``zip``). 65 66 This MataPipe is out of bound as soon as the shortest input DataPipe is exhausted. 67 68 Args: 69 *datapipes: Map DataPipes being aggregated 70 71 Example: 72 >>> # xdoctest: +SKIP 73 >>> from torchdata.datapipes.map import SequenceWrapper 74 >>> dp1 = SequenceWrapper(range(3)) 75 >>> dp2 = SequenceWrapper(range(10, 13)) 76 >>> zip_dp = dp1.zip(dp2) 77 >>> list(zip_dp) 78 [(0, 10), (1, 11), (2, 12)] 79 """ 80 81 datapipes: Tuple[MapDataPipe[_T_co], ...] 82 83 def __init__(self, *datapipes: MapDataPipe[_T_co]) -> None: 84 if len(datapipes) == 0: 85 raise ValueError("Expected at least one DataPipe, but got nothing") 86 if not all(isinstance(dp, MapDataPipe) for dp in datapipes): 87 raise TypeError("Expected all inputs to be `MapDataPipe`") 88 if not all(isinstance(dp, Sized) for dp in datapipes): 89 raise TypeError("Expected all inputs to be `Sized`") 90 self.datapipes = datapipes 91 92 def __getitem__(self, index) -> Tuple[_T_co, ...]: 93 res = [] 94 for dp in self.datapipes: 95 try: 96 res.append(dp[index]) 97 except IndexError as e: 98 raise IndexError( 99 f"Index {index} is out of range for one of the input MapDataPipes {dp}." 100 ) from e 101 return tuple(res) 102 103 def __len__(self) -> int: 104 return min(len(dp) for dp in self.datapipes) 105