• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2from typing import Any, Optional
3
4
5_pandas: Any = None
6_WITH_PANDAS: Optional[bool] = None
7
8
9def _try_import_pandas() -> bool:
10    try:
11        import pandas  # type: ignore[import]
12
13        global _pandas
14        _pandas = pandas
15        return True
16    except ImportError:
17        return False
18
19
20# pandas used only for prototyping, will be shortly replaced with TorchArrow
21def _with_pandas() -> bool:
22    global _WITH_PANDAS
23    if _WITH_PANDAS is None:
24        _WITH_PANDAS = _try_import_pandas()
25    return _WITH_PANDAS
26
27
28class PandasWrapper:
29    @classmethod
30    def create_dataframe(cls, data, columns):
31        if not _with_pandas():
32            raise RuntimeError("DataFrames prototype requires pandas to function")
33        return _pandas.DataFrame(data, columns=columns)  # type: ignore[union-attr]
34
35    @classmethod
36    def is_dataframe(cls, data):
37        if not _with_pandas():
38            return False
39        return isinstance(data, _pandas.core.frame.DataFrame)  # type: ignore[union-attr]
40
41    @classmethod
42    def is_column(cls, data):
43        if not _with_pandas():
44            return False
45        return isinstance(data, _pandas.core.series.Series)  # type: ignore[union-attr]
46
47    @classmethod
48    def iterate(cls, data):
49        if not _with_pandas():
50            raise RuntimeError("DataFrames prototype requires pandas to function")
51        yield from data.itertuples(index=False)
52
53    @classmethod
54    def concat(cls, buffer):
55        if not _with_pandas():
56            raise RuntimeError("DataFrames prototype requires pandas to function")
57        return _pandas.concat(buffer)  # type: ignore[union-attr]
58
59    @classmethod
60    def get_item(cls, data, idx):
61        if not _with_pandas():
62            raise RuntimeError("DataFrames prototype requires pandas to function")
63        return data[idx : idx + 1]
64
65    @classmethod
66    def get_len(cls, df):
67        if not _with_pandas():
68            raise RuntimeError("DataFrames prototype requires pandas to function")
69        return len(df.index)
70
71    @classmethod
72    def get_columns(cls, df):
73        if not _with_pandas():
74            raise RuntimeError("DataFrames prototype requires pandas to function")
75        return list(df.columns.values.tolist())
76
77
78# When you build own implementation just override it with dataframe_wrapper.set_df_wrapper(new_wrapper_class)
79default_wrapper = PandasWrapper
80
81
82def get_df_wrapper():
83    return default_wrapper
84
85
86def set_df_wrapper(wrapper):
87    global default_wrapper
88    default_wrapper = wrapper
89
90
91def create_dataframe(data, columns=None):
92    wrapper = get_df_wrapper()
93    return wrapper.create_dataframe(data, columns)
94
95
96def is_dataframe(data):
97    wrapper = get_df_wrapper()
98    return wrapper.is_dataframe(data)
99
100
101def get_columns(data):
102    wrapper = get_df_wrapper()
103    return wrapper.get_columns(data)
104
105
106def is_column(data):
107    wrapper = get_df_wrapper()
108    return wrapper.is_column(data)
109
110
111def concat(buffer):
112    wrapper = get_df_wrapper()
113    return wrapper.concat(buffer)
114
115
116def iterate(data):
117    wrapper = get_df_wrapper()
118    return wrapper.iterate(data)
119
120
121def get_item(data, idx):
122    wrapper = get_df_wrapper()
123    return wrapper.get_item(data, idx)
124
125
126def get_len(df):
127    wrapper = get_df_wrapper()
128    return wrapper.get_len(df)
129