• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2from typing import Callable, TypeVar
3
4from torch.utils.data.datapipes._decorator import functional_datapipe
5from torch.utils.data.datapipes.datapipe import MapDataPipe
6from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
7
8
9__all__ = ["MapperMapDataPipe", "default_fn"]
10
11
12_T_co = TypeVar("_T_co", covariant=True)
13
14
15# Default function to return each item directly
16# In order to keep datapipe picklable, eliminates the usage
17# of python lambda function
18def default_fn(data):
19    return data
20
21
22@functional_datapipe("map")
23class MapperMapDataPipe(MapDataPipe[_T_co]):
24    r"""
25    Apply the input function over each item from the source DataPipe (functional name: ``map``).
26
27    The function can be any regular Python function or partial object. Lambda
28    function is not recommended as it is not supported by pickle.
29
30    Args:
31        datapipe: Source MapDataPipe
32        fn: Function being applied to each item
33
34    Example:
35        >>> # xdoctest: +SKIP
36        >>> from torchdata.datapipes.map import SequenceWrapper, Mapper
37        >>> def add_one(x):
38        ...     return x + 1
39        >>> dp = SequenceWrapper(range(10))
40        >>> map_dp_1 = dp.map(add_one)
41        >>> list(map_dp_1)
42        [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
43        >>> map_dp_2 = Mapper(dp, lambda x: x + 1)
44        >>> list(map_dp_2)
45        [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
46    """
47
48    datapipe: MapDataPipe
49    fn: Callable
50
51    def __init__(
52        self,
53        datapipe: MapDataPipe,
54        fn: Callable = default_fn,
55    ) -> None:
56        super().__init__()
57        self.datapipe = datapipe
58        _check_unpickable_fn(fn)
59        self.fn = fn  # type: ignore[assignment]
60
61    def __len__(self) -> int:
62        return len(self.datapipe)
63
64    def __getitem__(self, index) -> _T_co:
65        return self.fn(self.datapipe[index])
66