• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2from typing import Sized, Tuple, TypeVar
3
4from torch.utils.data.datapipes._decorator import functional_datapipe
5from torch.utils.data.datapipes.datapipe import MapDataPipe
6
7
8__all__ = ["ConcaterMapDataPipe", "ZipperMapDataPipe"]
9
10_T_co = TypeVar("_T_co", covariant=True)
11
12
13@functional_datapipe("concat")
14class ConcaterMapDataPipe(MapDataPipe):
15    r"""
16    Concatenate multiple Map DataPipes (functional name: ``concat``).
17
18    The new index of is the cumulative sum of source DataPipes.
19    For example, if there are 2 source DataPipes both with length 5,
20    index 0 to 4 of the resulting `ConcatMapDataPipe` would refer to
21    elements of the first DataPipe, and 5 to 9 would refer to elements
22    of the second DataPipe.
23
24    Args:
25        datapipes: Map DataPipes being concatenated
26
27    Example:
28        >>> # xdoctest: +SKIP
29        >>> from torchdata.datapipes.map import SequenceWrapper
30        >>> dp1 = SequenceWrapper(range(3))
31        >>> dp2 = SequenceWrapper(range(3))
32        >>> concat_dp = dp1.concat(dp2)
33        >>> list(concat_dp)
34        [0, 1, 2, 0, 1, 2]
35    """
36
37    datapipes: Tuple[MapDataPipe]
38
39    def __init__(self, *datapipes: MapDataPipe):
40        if len(datapipes) == 0:
41            raise ValueError("Expected at least one DataPipe, but got nothing")
42        if not all(isinstance(dp, MapDataPipe) for dp in datapipes):
43            raise TypeError("Expected all inputs to be `MapDataPipe`")
44        if not all(isinstance(dp, Sized) for dp in datapipes):
45            raise TypeError("Expected all inputs to be `Sized`")
46        self.datapipes = datapipes  # type: ignore[assignment]
47
48    def __getitem__(self, index) -> _T_co:  # type: ignore[type-var]
49        offset = 0
50        for dp in self.datapipes:
51            if index - offset < len(dp):
52                return dp[index - offset]
53            else:
54                offset += len(dp)
55        raise IndexError(f"Index {index} is out of range.")
56
57    def __len__(self) -> int:
58        return sum(len(dp) for dp in self.datapipes)
59
60
61@functional_datapipe("zip")
62class ZipperMapDataPipe(MapDataPipe[Tuple[_T_co, ...]]):
63    r"""
64    Aggregates elements into a tuple from each of the input DataPipes (functional name: ``zip``).
65
66    This MataPipe is out of bound as soon as the shortest input DataPipe is exhausted.
67
68    Args:
69        *datapipes: Map DataPipes being aggregated
70
71    Example:
72        >>> # xdoctest: +SKIP
73        >>> from torchdata.datapipes.map import SequenceWrapper
74        >>> dp1 = SequenceWrapper(range(3))
75        >>> dp2 = SequenceWrapper(range(10, 13))
76        >>> zip_dp = dp1.zip(dp2)
77        >>> list(zip_dp)
78        [(0, 10), (1, 11), (2, 12)]
79    """
80
81    datapipes: Tuple[MapDataPipe[_T_co], ...]
82
83    def __init__(self, *datapipes: MapDataPipe[_T_co]) -> None:
84        if len(datapipes) == 0:
85            raise ValueError("Expected at least one DataPipe, but got nothing")
86        if not all(isinstance(dp, MapDataPipe) for dp in datapipes):
87            raise TypeError("Expected all inputs to be `MapDataPipe`")
88        if not all(isinstance(dp, Sized) for dp in datapipes):
89            raise TypeError("Expected all inputs to be `Sized`")
90        self.datapipes = datapipes
91
92    def __getitem__(self, index) -> Tuple[_T_co, ...]:
93        res = []
94        for dp in self.datapipes:
95            try:
96                res.append(dp[index])
97            except IndexError as e:
98                raise IndexError(
99                    f"Index {index} is out of range for one of the input MapDataPipes {dp}."
100                ) from e
101        return tuple(res)
102
103    def __len__(self) -> int:
104        return min(len(dp) for dp in self.datapipes)
105