• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Methods to allow pandas.DataFrame (deprecated).
17
18This module and all its submodules are deprecated. See
19[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
20for migration instructions.
21"""
22
23from __future__ import absolute_import
24from __future__ import division
25from __future__ import print_function
26
27from tensorflow.python.estimator.inputs.pandas_io import pandas_input_fn as core_pandas_input_fn
28from tensorflow.python.util.deprecation import deprecated
29
30try:
31  # pylint: disable=g-import-not-at-top
32  import pandas as pd
33  HAS_PANDAS = True
34except IOError:
35  # Pandas writes a temporary file during import. If it fails, don't use pandas.
36  HAS_PANDAS = False
37except ImportError:
38  HAS_PANDAS = False
39
40PANDAS_DTYPES = {
41    'int8': 'int',
42    'int16': 'int',
43    'int32': 'int',
44    'int64': 'int',
45    'uint8': 'int',
46    'uint16': 'int',
47    'uint32': 'int',
48    'uint64': 'int',
49    'float16': 'float',
50    'float32': 'float',
51    'float64': 'float',
52    'bool': 'i'
53}
54
55
56@deprecated(None, 'Please use tf.compat.v1.estimator.inputs.pandas_input_fn')
57def pandas_input_fn(x,
58                    y=None,
59                    batch_size=128,
60                    num_epochs=1,
61                    shuffle=True,
62                    queue_capacity=1000,
63                    num_threads=1,
64                    target_column='target'):
65  """This input_fn diffs from the core version with default `shuffle`."""
66  return core_pandas_input_fn(x=x,
67                              y=y,
68                              batch_size=batch_size,
69                              shuffle=shuffle,
70                              num_epochs=num_epochs,
71                              queue_capacity=queue_capacity,
72                              num_threads=num_threads,
73                              target_column=target_column)
74
75
76@deprecated(None, 'Please access pandas data directly.')
77def extract_pandas_data(data):
78  """Extract data from pandas.DataFrame for predictors.
79
80  Given a DataFrame, will extract the values and cast them to float. The
81  DataFrame is expected to contain values of type int, float or bool.
82
83  Args:
84    data: `pandas.DataFrame` containing the data to be extracted.
85
86  Returns:
87    A numpy `ndarray` of the DataFrame's values as floats.
88
89  Raises:
90    ValueError: if data contains types other than int, float or bool.
91  """
92  if not isinstance(data, pd.DataFrame):
93    return data
94
95  bad_data = [column for column in data
96              if data[column].dtype.name not in PANDAS_DTYPES]
97
98  if not bad_data:
99    return data.values.astype('float')
100  else:
101    error_report = [("'" + str(column) + "' type='" +
102                     data[column].dtype.name + "'") for column in bad_data]
103    raise ValueError('Data types for extracting pandas data must be int, '
104                     'float, or bool. Found: ' + ', '.join(error_report))
105
106
107@deprecated(None, 'Please access pandas data directly.')
108def extract_pandas_matrix(data):
109  """Extracts numpy matrix from pandas DataFrame.
110
111  Args:
112    data: `pandas.DataFrame` containing the data to be extracted.
113
114  Returns:
115    A numpy `ndarray` of the DataFrame's values.
116  """
117  if not isinstance(data, pd.DataFrame):
118    return data
119
120  return data.as_matrix()
121
122
123@deprecated(None, 'Please access pandas data directly.')
124def extract_pandas_labels(labels):
125  """Extract data from pandas.DataFrame for labels.
126
127  Args:
128    labels: `pandas.DataFrame` or `pandas.Series` containing one column of
129      labels to be extracted.
130
131  Returns:
132    A numpy `ndarray` of labels from the DataFrame.
133
134  Raises:
135    ValueError: if more than one column is found or type is not int, float or
136      bool.
137  """
138  if isinstance(labels,
139                pd.DataFrame):  # pandas.Series also belongs to DataFrame
140    if len(labels.columns) > 1:
141      raise ValueError('Only one column for labels is allowed.')
142
143    bad_data = [column for column in labels
144                if labels[column].dtype.name not in PANDAS_DTYPES]
145    if not bad_data:
146      return labels.values
147    else:
148      error_report = ["'" + str(column) + "' type="
149                      + str(labels[column].dtype.name) for column in bad_data]
150      raise ValueError('Data types for extracting labels must be int, '
151                       'float, or bool. Found: ' + ', '.join(error_report))
152  else:
153    return labels
154