• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Ops for preprocessing data."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.contrib.tensor_forest.python.ops import tensor_forest_ops
21
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import sparse_tensor
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops import sparse_ops
28from tensorflow.python.platform import tf_logging as logging
29
30# Data column types for indicating categorical or other non-float values.
31DATA_FLOAT = 0
32DATA_CATEGORICAL = 1
33
34DTYPE_TO_FTYPE = {
35    dtypes.string: DATA_CATEGORICAL,
36    dtypes.int32: DATA_CATEGORICAL,
37    dtypes.int64: DATA_CATEGORICAL,
38    dtypes.float32: DATA_FLOAT,
39    dtypes.float64: DATA_FLOAT
40}
41
42
43def CastToFloat(tensor):
44  if tensor.dtype == dtypes.string:
45    return tensor_forest_ops.reinterpret_string_to_float(tensor)
46  elif tensor.dtype.is_integer:
47    return math_ops.cast(tensor, dtypes.float32)
48  else:
49    return tensor
50
51
52# TODO(gilberth): If protos are ever allowed in dynamically loaded custom
53# op libraries, convert this to a proto like a sane person.
54class TensorForestDataSpec(object):
55
56  def __init__(self):
57    self.sparse = DataColumnCollection()
58    self.dense = DataColumnCollection()
59    self.dense_features_size = 0
60
61  def SerializeToString(self):
62    return 'dense_features_size: %d dense: [%s] sparse: [%s]' % (
63        self.dense_features_size, self.dense.SerializeToString(),
64        self.sparse.SerializeToString())
65
66
67class DataColumnCollection(object):
68  """Collection of DataColumns, meant to mimic a proto repeated field."""
69
70  def __init__(self):
71    self.cols = []
72
73  def add(self):  # pylint: disable=invalid-name
74    self.cols.append(DataColumn())
75    return self.cols[-1]
76
77  def size(self):  # pylint: disable=invalid-name
78    return len(self.cols)
79
80  def SerializeToString(self):
81    ret = ''
82    for c in self.cols:
83      ret += '{%s}' % c.SerializeToString()
84    return ret
85
86
87class DataColumn(object):
88
89  def __init__(self):
90    self.name = ''
91    self.original_type = ''
92    self.size = 0
93
94  def SerializeToString(self):
95    return 'name: {0} original_type: {1} size: {2}'.format(self.name,
96                                                           self.original_type,
97                                                           self.size)
98
99
100def GetColumnName(column_key, col_num):
101  if isinstance(column_key, str):
102    return column_key
103  else:
104    return getattr(column_key, 'column_name', str(col_num))
105
106
107def ParseDataTensorOrDict(data):
108  """Return a tensor to use for input data.
109
110  The incoming features can be a dict where keys are the string names of the
111  columns, which we turn into a single 2-D tensor.
112
113  Args:
114    data: `Tensor` or `dict` of `Tensor` objects.
115
116  Returns:
117    A 2-D tensor for input to tensor_forest, a keys tensor for the
118    tf.Examples if they exist, and a list of the type of each column
119    (e.g. continuous float, categorical).
120  """
121  data_spec = TensorForestDataSpec()
122  if isinstance(data, dict):
123    dense_features_size = 0
124    dense_features = []
125    sparse_features = []
126    for k in sorted(data.keys()):
127      is_sparse = isinstance(data[k], sparse_tensor.SparseTensor)
128      if is_sparse:
129        # TODO(gilberth): support sparse continuous.
130        if data[k].dtype == dtypes.float32:
131          logging.info('TensorForest does not support sparse continuous.')
132          continue
133        elif data_spec.sparse.size() == 0:
134          col_spec = data_spec.sparse.add()
135          col_spec.original_type = DATA_CATEGORICAL
136          col_spec.name = 'all_sparse'
137          col_spec.size = -1
138        sparse_features.append(
139            sparse_tensor.SparseTensor(data[
140                k].indices, CastToFloat(data[k].values), data[k].dense_shape))
141      else:
142        col_spec = data_spec.dense.add()
143
144        col_spec.original_type = DTYPE_TO_FTYPE[data[k].dtype]
145        col_spec.name = GetColumnName(k, len(dense_features))
146        # the second dimension of get_shape should always be known.
147        shape = data[k].get_shape()
148        if len(shape) == 1:
149          col_spec.size = 1
150        else:
151          col_spec.size = shape[1].value
152
153        dense_features_size += col_spec.size
154        dense_features.append(CastToFloat(data[k]))
155
156    processed_dense_features = None
157    processed_sparse_features = None
158    if dense_features:
159      processed_dense_features = array_ops.concat(dense_features, 1)
160      data_spec.dense_features_size = dense_features_size
161    if sparse_features:
162      processed_sparse_features = sparse_ops.sparse_concat(1, sparse_features)
163    logging.info(data_spec.SerializeToString())
164    return processed_dense_features, processed_sparse_features, data_spec
165  elif isinstance(data, sparse_tensor.SparseTensor):
166    col_spec = data_spec.sparse.add()
167    col_spec.name = 'sparse_features'
168    col_spec.original_type = DTYPE_TO_FTYPE[data.dtype]
169    col_spec.size = -1
170    data_spec.dense_features_size = 0
171    return None, data, data_spec
172  else:
173    data = ops.convert_to_tensor(data)
174    col_spec = data_spec.dense.add()
175    col_spec.name = 'dense_features'
176    col_spec.original_type = DTYPE_TO_FTYPE[data.dtype]
177    col_spec.size = data.get_shape()[1]
178    data_spec.dense_features_size = col_spec.size
179    return data, None, data_spec
180
181
182def ParseLabelTensorOrDict(labels):
183  """Return a tensor to use for input labels to tensor_forest.
184
185  The incoming targets can be a dict where keys are the string names of the
186  columns, which we turn into a single 1-D tensor for classification or
187  2-D tensor for regression.
188
189  Converts sparse tensors to dense ones.
190
191  Args:
192    labels: `Tensor` or `dict` of `Tensor` objects.
193
194  Returns:
195    A 2-D tensor for labels/outputs.
196  """
197  if isinstance(labels, dict):
198    return math_ops.cast(
199        array_ops.concat(
200            [
201                sparse_ops.sparse_tensor_to_dense(
202                    labels[k], default_value=-1) if isinstance(
203                        labels, sparse_tensor.SparseTensor) else labels[k]
204                for k in sorted(labels.keys())
205            ],
206            1),
207        dtypes.float32)
208  else:
209    if isinstance(labels, sparse_tensor.SparseTensor):
210      return math_ops.cast(
211          sparse_ops.sparse_tensor_to_dense(labels, default_value=-1),
212          dtypes.float32)
213    else:
214      return math_ops.cast(labels, dtypes.float32)
215