1# mypy: allow-untyped-defs 2from typing import Callable, TypeVar 3 4from torch.utils.data.datapipes._decorator import functional_datapipe 5from torch.utils.data.datapipes.datapipe import MapDataPipe 6from torch.utils.data.datapipes.utils.common import _check_unpickable_fn 7 8 9__all__ = ["MapperMapDataPipe", "default_fn"] 10 11 12_T_co = TypeVar("_T_co", covariant=True) 13 14 15# Default function to return each item directly 16# In order to keep datapipe picklable, eliminates the usage 17# of python lambda function 18def default_fn(data): 19 return data 20 21 22@functional_datapipe("map") 23class MapperMapDataPipe(MapDataPipe[_T_co]): 24 r""" 25 Apply the input function over each item from the source DataPipe (functional name: ``map``). 26 27 The function can be any regular Python function or partial object. Lambda 28 function is not recommended as it is not supported by pickle. 29 30 Args: 31 datapipe: Source MapDataPipe 32 fn: Function being applied to each item 33 34 Example: 35 >>> # xdoctest: +SKIP 36 >>> from torchdata.datapipes.map import SequenceWrapper, Mapper 37 >>> def add_one(x): 38 ... return x + 1 39 >>> dp = SequenceWrapper(range(10)) 40 >>> map_dp_1 = dp.map(add_one) 41 >>> list(map_dp_1) 42 [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 43 >>> map_dp_2 = Mapper(dp, lambda x: x + 1) 44 >>> list(map_dp_2) 45 [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 46 """ 47 48 datapipe: MapDataPipe 49 fn: Callable 50 51 def __init__( 52 self, 53 datapipe: MapDataPipe, 54 fn: Callable = default_fn, 55 ) -> None: 56 super().__init__() 57 self.datapipe = datapipe 58 _check_unpickable_fn(fn) 59 self.fn = fn # type: ignore[assignment] 60 61 def __len__(self) -> int: 62 return len(self.datapipe) 63 64 def __getitem__(self, index) -> _T_co: 65 return self.fn(self.datapipe[index]) 66