• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""This API defines FeatureColumn abstraction."""
16
17# This file was originally under tf/python/feature_column, and was moved to
18# Keras package in order to remove the reverse dependency from TF to Keras.
19
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import collections
25import re
26
27from tensorflow.python.feature_column import feature_column_v2
28from tensorflow.python.framework import tensor_shape
29from tensorflow.python.keras.engine.base_layer import Layer
30from tensorflow.python.keras.utils import generic_utils
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import variable_scope
33
34
35class _BaseFeaturesLayer(Layer):
36  """Base class for DenseFeatures and SequenceFeatures.
37
38  Defines common methods and helpers.
39
40  Args:
41    feature_columns: An iterable containing the FeatureColumns to use as
42      inputs to your model.
43    expected_column_type: Expected class for provided feature columns.
44    trainable:  Boolean, whether the layer's variables will be updated via
45      gradient descent during training.
46    name: Name to give to the DenseFeatures.
47    **kwargs: Keyword arguments to construct a layer.
48
49  Raises:
50    ValueError: if an item in `feature_columns` doesn't match
51      `expected_column_type`.
52  """
53
54  def __init__(self,
55               feature_columns,
56               expected_column_type,
57               trainable,
58               name,
59               partitioner=None,
60               **kwargs):
61    super(_BaseFeaturesLayer, self).__init__(
62        name=name, trainable=trainable, **kwargs)
63    self._feature_columns = _normalize_feature_columns(
64        feature_columns)
65    self._state_manager = feature_column_v2._StateManagerImpl(  # pylint: disable=protected-access
66        self, self.trainable)
67    self._partitioner = partitioner
68    for column in self._feature_columns:
69      if not isinstance(column, expected_column_type):
70        raise ValueError(
71            'Items of feature_columns must be a {}. '
72            'You can wrap a categorical column with an '
73            'embedding_column or indicator_column. Given: {}'.format(
74                expected_column_type, column))
75
76  def build(self, _):
77    for column in self._feature_columns:
78      with variable_scope.variable_scope(
79          self.name, partitioner=self._partitioner):
80        with variable_scope.variable_scope(
81            _sanitize_column_name_for_variable_scope(column.name)):
82          column.create_state(self._state_manager)
83    super(_BaseFeaturesLayer, self).build(None)
84
85  def _output_shape(self, input_shape, num_elements):
86    """Computes expected output shape of the layer or a column's dense tensor.
87
88    Args:
89      input_shape: Tensor or array with batch shape.
90      num_elements: Size of the last dimension of the output.
91
92    Returns:
93      Tuple with output shape.
94    """
95    raise NotImplementedError('Calling an abstract method.')
96
97  def compute_output_shape(self, input_shape):
98    total_elements = 0
99    for column in self._feature_columns:
100      total_elements += column.variable_shape.num_elements()
101    return self._target_shape(input_shape, total_elements)
102
103  def _process_dense_tensor(self, column, tensor):
104    """Reshapes the dense tensor output of a column based on expected shape.
105
106    Args:
107      column: A DenseColumn or SequenceDenseColumn object.
108      tensor: A dense tensor obtained from the same column.
109
110    Returns:
111      Reshaped dense tensor.
112    """
113    num_elements = column.variable_shape.num_elements()
114    target_shape = self._target_shape(array_ops.shape(tensor), num_elements)
115    return array_ops.reshape(tensor, shape=target_shape)
116
117  def _verify_and_concat_tensors(self, output_tensors):
118    """Verifies and concatenates the dense output of several columns."""
119    _verify_static_batch_size_equality(output_tensors, self._feature_columns)
120    return array_ops.concat(output_tensors, -1)
121
122  def get_config(self):
123    # Import here to avoid circular imports.
124    from tensorflow.python.feature_column import serialization  # pylint: disable=g-import-not-at-top
125    column_configs = [serialization.serialize_feature_column(fc)
126                      for fc in self._feature_columns]
127    config = {'feature_columns': column_configs}
128    config['partitioner'] = generic_utils.serialize_keras_object(
129        self._partitioner)
130
131    base_config = super(  # pylint: disable=bad-super-call
132        _BaseFeaturesLayer, self).get_config()
133    return dict(list(base_config.items()) + list(config.items()))
134
135  @classmethod
136  def from_config(cls, config, custom_objects=None):
137    # Import here to avoid circular imports.
138    from tensorflow.python.feature_column import serialization  # pylint: disable=g-import-not-at-top
139    config_cp = config.copy()
140    columns_by_name = {}
141    config_cp['feature_columns'] = [serialization.deserialize_feature_column(
142        c, custom_objects, columns_by_name) for c in config['feature_columns']]
143    config_cp['partitioner'] = generic_utils.deserialize_keras_object(
144        config['partitioner'], custom_objects)
145
146    return cls(**config_cp)
147
148
149def _sanitize_column_name_for_variable_scope(name):
150  """Sanitizes user-provided feature names for use as variable scopes."""
151  invalid_char = re.compile('[^A-Za-z0-9_.\\-]')
152  return invalid_char.sub('_', name)
153
154
155def _verify_static_batch_size_equality(tensors, columns):
156  """Verify equality between static batch sizes.
157
158  Args:
159    tensors: iterable of input tensors.
160    columns: Corresponding feature columns.
161
162  Raises:
163    ValueError: in case of mismatched batch sizes.
164  """
165  expected_batch_size = None
166  for i in range(0, len(tensors)):
167    # bath_size is a Dimension object.
168    batch_size = tensor_shape.Dimension(tensor_shape.dimension_value(
169        tensors[i].shape[0]))
170    if batch_size.value is not None:
171      if expected_batch_size is None:
172        bath_size_column_index = i
173        expected_batch_size = batch_size
174      elif not expected_batch_size.is_compatible_with(batch_size):
175        raise ValueError(
176            'Batch size (first dimension) of each feature must be same. '
177            'Batch size of columns ({}, {}): ({}, {})'.format(
178                columns[bath_size_column_index].name, columns[i].name,
179                expected_batch_size, batch_size))
180
181
182def _normalize_feature_columns(feature_columns):
183  """Normalizes the `feature_columns` input.
184
185  This method converts the `feature_columns` to list type as best as it can. In
186  addition, verifies the type and other parts of feature_columns, required by
187  downstream library.
188
189  Args:
190    feature_columns: The raw feature columns, usually passed by users.
191
192  Returns:
193    The normalized feature column list.
194
195  Raises:
196    ValueError: for any invalid inputs, such as empty, duplicated names, etc.
197  """
198  if isinstance(feature_columns, feature_column_v2.FeatureColumn):
199    feature_columns = [feature_columns]
200
201  if isinstance(feature_columns, collections.abc.Iterator):
202    feature_columns = list(feature_columns)
203
204  if isinstance(feature_columns, dict):
205    raise ValueError('Expected feature_columns to be iterable, found dict.')
206
207  for column in feature_columns:
208    if not isinstance(column, feature_column_v2.FeatureColumn):
209      raise ValueError('Items of feature_columns must be a FeatureColumn. '
210                       'Given (type {}): {}.'.format(type(column), column))
211  if not feature_columns:
212    raise ValueError('feature_columns must not be empty.')
213  name_to_column = {}
214  for column in feature_columns:
215    if column.name in name_to_column:
216      raise ValueError('Duplicate feature column name found for columns: {} '
217                       'and {}. This usually means that these columns refer to '
218                       'same base feature. Either one must be discarded or a '
219                       'duplicated but renamed item must be inserted in '
220                       'features dict.'.format(column,
221                                               name_to_column[column.name]))
222    name_to_column[column.name] = column
223
224  return sorted(feature_columns, key=lambda x: x.name)
225