1# mypy: allow-untyped-defs 2from typing import List, Sized, Type, TypeVar 3 4from torch.utils.data.datapipes._decorator import functional_datapipe 5from torch.utils.data.datapipes.datapipe import DataChunk, MapDataPipe 6 7 8__all__ = ["BatcherMapDataPipe"] 9 10 11_T = TypeVar("_T") 12 13 14@functional_datapipe("batch") 15class BatcherMapDataPipe(MapDataPipe[DataChunk]): 16 r""" 17 Create mini-batches of data (functional name: ``batch``). 18 19 An outer dimension will be added as ``batch_size`` if ``drop_last`` is set to ``True``, 20 or ``length % batch_size`` for the last batch if ``drop_last`` is set to ``False``. 21 22 Args: 23 datapipe: Iterable DataPipe being batched 24 batch_size: The size of each batch 25 drop_last: Option to drop the last batch if it's not full 26 27 Example: 28 >>> # xdoctest: +SKIP 29 >>> from torchdata.datapipes.map import SequenceWrapper 30 >>> dp = SequenceWrapper(range(10)) 31 >>> batch_dp = dp.batch(batch_size=2) 32 >>> list(batch_dp) 33 [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]] 34 """ 35 36 datapipe: MapDataPipe 37 batch_size: int 38 drop_last: bool 39 40 def __init__( 41 self, 42 datapipe: MapDataPipe[_T], 43 batch_size: int, 44 drop_last: bool = False, 45 wrapper_class: Type[DataChunk] = DataChunk, 46 ) -> None: 47 assert batch_size > 0, "Batch size is required to be larger than 0!" 48 super().__init__() 49 self.datapipe = datapipe 50 self.batch_size = batch_size 51 self.drop_last = drop_last 52 self.wrapper_class = wrapper_class 53 54 def __getitem__(self, index) -> DataChunk: 55 batch: List = [] 56 indices = range(index * self.batch_size, (index + 1) * self.batch_size) 57 try: 58 for i in indices: 59 batch.append(self.datapipe[i]) 60 return self.wrapper_class(batch) 61 except IndexError as e: 62 if not self.drop_last and len(batch) > 0: 63 return self.wrapper_class(batch) 64 else: 65 raise IndexError(f"Index {index} is out of bound.") from e 66 67 def __len__(self) -> int: 68 if isinstance(self.datapipe, Sized): 69 if self.drop_last: 70 return len(self.datapipe) // self.batch_size 71 else: 72 return (len(self.datapipe) + self.batch_size - 1) // self.batch_size 73 else: 74 raise TypeError(f"{type(self).__name__} instance doesn't have valid length") 75