1# mypy: allow-untyped-defs 2import random 3 4from torch.utils.data.datapipes._decorator import functional_datapipe 5from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper 6from torch.utils.data.datapipes.datapipe import DFIterDataPipe, IterDataPipe 7 8 9__all__ = [ 10 "ConcatDataFramesPipe", 11 "DataFramesAsTuplesPipe", 12 "ExampleAggregateAsDataFrames", 13 "FilterDataFramesPipe", 14 "PerRowDataFramesPipe", 15 "ShuffleDataFramesPipe", 16] 17 18 19@functional_datapipe("_dataframes_as_tuples") 20class DataFramesAsTuplesPipe(IterDataPipe): 21 def __init__(self, source_datapipe): 22 self.source_datapipe = source_datapipe 23 24 def __iter__(self): 25 for df in self.source_datapipe: 26 # for record in df.to_records(index=False): 27 yield from df_wrapper.iterate(df) 28 29 30@functional_datapipe("_dataframes_per_row", enable_df_api_tracing=True) 31class PerRowDataFramesPipe(DFIterDataPipe): 32 def __init__(self, source_datapipe): 33 self.source_datapipe = source_datapipe 34 35 def __iter__(self): 36 for df in self.source_datapipe: 37 # TODO(VitalyFedyunin): Replacing with TorchArrow only API, as we are dropping pandas as followup 38 for i in range(len(df)): 39 yield df[i : i + 1] 40 41 42@functional_datapipe("_dataframes_concat", enable_df_api_tracing=True) 43class ConcatDataFramesPipe(DFIterDataPipe): 44 def __init__(self, source_datapipe, batch=3): 45 self.source_datapipe = source_datapipe 46 self.n_batch = batch 47 48 def __iter__(self): 49 buffer = [] 50 for df in self.source_datapipe: 51 buffer.append(df) 52 if len(buffer) == self.n_batch: 53 yield df_wrapper.concat(buffer) 54 buffer = [] 55 if len(buffer): 56 yield df_wrapper.concat(buffer) 57 58 59@functional_datapipe("_dataframes_shuffle", enable_df_api_tracing=True) 60class ShuffleDataFramesPipe(DFIterDataPipe): 61 def __init__(self, source_datapipe): 62 self.source_datapipe = source_datapipe 63 64 def __iter__(self): 65 size = None 66 all_buffer = [] 67 for df in self.source_datapipe: 68 if size is None: 69 size = df_wrapper.get_len(df) 70 for i in range(df_wrapper.get_len(df)): 71 all_buffer.append(df_wrapper.get_item(df, i)) 72 random.shuffle(all_buffer) 73 buffer = [] 74 for df in all_buffer: 75 buffer.append(df) 76 if len(buffer) == size: 77 yield df_wrapper.concat(buffer) 78 buffer = [] 79 if len(buffer): 80 yield df_wrapper.concat(buffer) 81 82 83@functional_datapipe("_dataframes_filter", enable_df_api_tracing=True) 84class FilterDataFramesPipe(DFIterDataPipe): 85 def __init__(self, source_datapipe, filter_fn): 86 self.source_datapipe = source_datapipe 87 self.filter_fn = filter_fn 88 89 def __iter__(self): 90 size = None 91 all_buffer = [] 92 filter_res = [] 93 for df in self.source_datapipe: 94 if size is None: 95 size = len(df.index) 96 for i in range(len(df.index)): 97 all_buffer.append(df[i : i + 1]) 98 filter_res.append(self.filter_fn(df.iloc[i])) 99 100 buffer = [] 101 for df, res in zip(all_buffer, filter_res): 102 if res: 103 buffer.append(df) 104 if len(buffer) == size: 105 yield df_wrapper.concat(buffer) 106 buffer = [] 107 if len(buffer): 108 yield df_wrapper.concat(buffer) 109 110 111@functional_datapipe("_to_dataframes_pipe", enable_df_api_tracing=True) 112class ExampleAggregateAsDataFrames(DFIterDataPipe): 113 def __init__(self, source_datapipe, dataframe_size=10, columns=None): 114 self.source_datapipe = source_datapipe 115 self.columns = columns 116 self.dataframe_size = dataframe_size 117 118 def _as_list(self, item): 119 try: 120 return list(item) 121 except ( 122 Exception 123 ): # TODO(VitalyFedyunin): Replace with better iterable exception 124 return [item] 125 126 def __iter__(self): 127 aggregate = [] 128 for item in self.source_datapipe: 129 aggregate.append(self._as_list(item)) 130 if len(aggregate) == self.dataframe_size: 131 yield df_wrapper.create_dataframe(aggregate, columns=self.columns) 132 aggregate = [] 133 if len(aggregate) > 0: 134 yield df_wrapper.create_dataframe(aggregate, columns=self.columns) 135