• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2from typing import List, Sized, Type, TypeVar
3
4from torch.utils.data.datapipes._decorator import functional_datapipe
5from torch.utils.data.datapipes.datapipe import DataChunk, MapDataPipe
6
7
8__all__ = ["BatcherMapDataPipe"]
9
10
11_T = TypeVar("_T")
12
13
14@functional_datapipe("batch")
15class BatcherMapDataPipe(MapDataPipe[DataChunk]):
16    r"""
17    Create mini-batches of data (functional name: ``batch``).
18
19    An outer dimension will be added as ``batch_size`` if ``drop_last`` is set to ``True``,
20    or ``length % batch_size`` for the last batch if ``drop_last`` is set to ``False``.
21
22    Args:
23        datapipe: Iterable DataPipe being batched
24        batch_size: The size of each batch
25        drop_last: Option to drop the last batch if it's not full
26
27    Example:
28        >>> # xdoctest: +SKIP
29        >>> from torchdata.datapipes.map import SequenceWrapper
30        >>> dp = SequenceWrapper(range(10))
31        >>> batch_dp = dp.batch(batch_size=2)
32        >>> list(batch_dp)
33        [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
34    """
35
36    datapipe: MapDataPipe
37    batch_size: int
38    drop_last: bool
39
40    def __init__(
41        self,
42        datapipe: MapDataPipe[_T],
43        batch_size: int,
44        drop_last: bool = False,
45        wrapper_class: Type[DataChunk] = DataChunk,
46    ) -> None:
47        assert batch_size > 0, "Batch size is required to be larger than 0!"
48        super().__init__()
49        self.datapipe = datapipe
50        self.batch_size = batch_size
51        self.drop_last = drop_last
52        self.wrapper_class = wrapper_class
53
54    def __getitem__(self, index) -> DataChunk:
55        batch: List = []
56        indices = range(index * self.batch_size, (index + 1) * self.batch_size)
57        try:
58            for i in indices:
59                batch.append(self.datapipe[i])
60            return self.wrapper_class(batch)
61        except IndexError as e:
62            if not self.drop_last and len(batch) > 0:
63                return self.wrapper_class(batch)
64            else:
65                raise IndexError(f"Index {index} is out of bound.") from e
66
67    def __len__(self) -> int:
68        if isinstance(self.datapipe, Sized):
69            if self.drop_last:
70                return len(self.datapipe) // self.batch_size
71            else:
72                return (len(self.datapipe) + self.batch_size - 1) // self.batch_size
73        else:
74            raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
75