• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2import random
3
4from torch.utils.data.datapipes._decorator import functional_datapipe
5from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper
6from torch.utils.data.datapipes.datapipe import DFIterDataPipe, IterDataPipe
7
8
9__all__ = [
10    "ConcatDataFramesPipe",
11    "DataFramesAsTuplesPipe",
12    "ExampleAggregateAsDataFrames",
13    "FilterDataFramesPipe",
14    "PerRowDataFramesPipe",
15    "ShuffleDataFramesPipe",
16]
17
18
19@functional_datapipe("_dataframes_as_tuples")
20class DataFramesAsTuplesPipe(IterDataPipe):
21    def __init__(self, source_datapipe):
22        self.source_datapipe = source_datapipe
23
24    def __iter__(self):
25        for df in self.source_datapipe:
26            # for record in df.to_records(index=False):
27            yield from df_wrapper.iterate(df)
28
29
30@functional_datapipe("_dataframes_per_row", enable_df_api_tracing=True)
31class PerRowDataFramesPipe(DFIterDataPipe):
32    def __init__(self, source_datapipe):
33        self.source_datapipe = source_datapipe
34
35    def __iter__(self):
36        for df in self.source_datapipe:
37            # TODO(VitalyFedyunin): Replacing with TorchArrow only API, as we are dropping pandas as followup
38            for i in range(len(df)):
39                yield df[i : i + 1]
40
41
42@functional_datapipe("_dataframes_concat", enable_df_api_tracing=True)
43class ConcatDataFramesPipe(DFIterDataPipe):
44    def __init__(self, source_datapipe, batch=3):
45        self.source_datapipe = source_datapipe
46        self.n_batch = batch
47
48    def __iter__(self):
49        buffer = []
50        for df in self.source_datapipe:
51            buffer.append(df)
52            if len(buffer) == self.n_batch:
53                yield df_wrapper.concat(buffer)
54                buffer = []
55        if len(buffer):
56            yield df_wrapper.concat(buffer)
57
58
59@functional_datapipe("_dataframes_shuffle", enable_df_api_tracing=True)
60class ShuffleDataFramesPipe(DFIterDataPipe):
61    def __init__(self, source_datapipe):
62        self.source_datapipe = source_datapipe
63
64    def __iter__(self):
65        size = None
66        all_buffer = []
67        for df in self.source_datapipe:
68            if size is None:
69                size = df_wrapper.get_len(df)
70            for i in range(df_wrapper.get_len(df)):
71                all_buffer.append(df_wrapper.get_item(df, i))
72        random.shuffle(all_buffer)
73        buffer = []
74        for df in all_buffer:
75            buffer.append(df)
76            if len(buffer) == size:
77                yield df_wrapper.concat(buffer)
78                buffer = []
79        if len(buffer):
80            yield df_wrapper.concat(buffer)
81
82
83@functional_datapipe("_dataframes_filter", enable_df_api_tracing=True)
84class FilterDataFramesPipe(DFIterDataPipe):
85    def __init__(self, source_datapipe, filter_fn):
86        self.source_datapipe = source_datapipe
87        self.filter_fn = filter_fn
88
89    def __iter__(self):
90        size = None
91        all_buffer = []
92        filter_res = []
93        for df in self.source_datapipe:
94            if size is None:
95                size = len(df.index)
96            for i in range(len(df.index)):
97                all_buffer.append(df[i : i + 1])
98                filter_res.append(self.filter_fn(df.iloc[i]))
99
100        buffer = []
101        for df, res in zip(all_buffer, filter_res):
102            if res:
103                buffer.append(df)
104                if len(buffer) == size:
105                    yield df_wrapper.concat(buffer)
106                    buffer = []
107        if len(buffer):
108            yield df_wrapper.concat(buffer)
109
110
111@functional_datapipe("_to_dataframes_pipe", enable_df_api_tracing=True)
112class ExampleAggregateAsDataFrames(DFIterDataPipe):
113    def __init__(self, source_datapipe, dataframe_size=10, columns=None):
114        self.source_datapipe = source_datapipe
115        self.columns = columns
116        self.dataframe_size = dataframe_size
117
118    def _as_list(self, item):
119        try:
120            return list(item)
121        except (
122            Exception
123        ):  # TODO(VitalyFedyunin): Replace with better iterable exception
124            return [item]
125
126    def __iter__(self):
127        aggregate = []
128        for item in self.source_datapipe:
129            aggregate.append(self._as_list(item))
130            if len(aggregate) == self.dataframe_size:
131                yield df_wrapper.create_dataframe(aggregate, columns=self.columns)
132                aggregate = []
133        if len(aggregate) > 0:
134            yield df_wrapper.create_dataframe(aggregate, columns=self.columns)
135