1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""This API defines FeatureColumn abstraction.""" 16 17# This file was originally under tf/python/feature_column, and was moved to 18# Keras package in order to remove the reverse dependency from TF to Keras. 19 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24import collections 25import re 26 27from tensorflow.python.feature_column import feature_column_v2 28from tensorflow.python.framework import tensor_shape 29from tensorflow.python.keras.engine.base_layer import Layer 30from tensorflow.python.keras.utils import generic_utils 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import variable_scope 33 34 35class _BaseFeaturesLayer(Layer): 36 """Base class for DenseFeatures and SequenceFeatures. 37 38 Defines common methods and helpers. 39 40 Args: 41 feature_columns: An iterable containing the FeatureColumns to use as 42 inputs to your model. 43 expected_column_type: Expected class for provided feature columns. 44 trainable: Boolean, whether the layer's variables will be updated via 45 gradient descent during training. 46 name: Name to give to the DenseFeatures. 47 **kwargs: Keyword arguments to construct a layer. 48 49 Raises: 50 ValueError: if an item in `feature_columns` doesn't match 51 `expected_column_type`. 52 """ 53 54 def __init__(self, 55 feature_columns, 56 expected_column_type, 57 trainable, 58 name, 59 partitioner=None, 60 **kwargs): 61 super(_BaseFeaturesLayer, self).__init__( 62 name=name, trainable=trainable, **kwargs) 63 self._feature_columns = _normalize_feature_columns( 64 feature_columns) 65 self._state_manager = feature_column_v2._StateManagerImpl( # pylint: disable=protected-access 66 self, self.trainable) 67 self._partitioner = partitioner 68 for column in self._feature_columns: 69 if not isinstance(column, expected_column_type): 70 raise ValueError( 71 'Items of feature_columns must be a {}. ' 72 'You can wrap a categorical column with an ' 73 'embedding_column or indicator_column. Given: {}'.format( 74 expected_column_type, column)) 75 76 def build(self, _): 77 for column in self._feature_columns: 78 with variable_scope.variable_scope( 79 self.name, partitioner=self._partitioner): 80 with variable_scope.variable_scope( 81 _sanitize_column_name_for_variable_scope(column.name)): 82 column.create_state(self._state_manager) 83 super(_BaseFeaturesLayer, self).build(None) 84 85 def _output_shape(self, input_shape, num_elements): 86 """Computes expected output shape of the layer or a column's dense tensor. 87 88 Args: 89 input_shape: Tensor or array with batch shape. 90 num_elements: Size of the last dimension of the output. 91 92 Returns: 93 Tuple with output shape. 94 """ 95 raise NotImplementedError('Calling an abstract method.') 96 97 def compute_output_shape(self, input_shape): 98 total_elements = 0 99 for column in self._feature_columns: 100 total_elements += column.variable_shape.num_elements() 101 return self._target_shape(input_shape, total_elements) 102 103 def _process_dense_tensor(self, column, tensor): 104 """Reshapes the dense tensor output of a column based on expected shape. 105 106 Args: 107 column: A DenseColumn or SequenceDenseColumn object. 108 tensor: A dense tensor obtained from the same column. 109 110 Returns: 111 Reshaped dense tensor. 112 """ 113 num_elements = column.variable_shape.num_elements() 114 target_shape = self._target_shape(array_ops.shape(tensor), num_elements) 115 return array_ops.reshape(tensor, shape=target_shape) 116 117 def _verify_and_concat_tensors(self, output_tensors): 118 """Verifies and concatenates the dense output of several columns.""" 119 _verify_static_batch_size_equality(output_tensors, self._feature_columns) 120 return array_ops.concat(output_tensors, -1) 121 122 def get_config(self): 123 # Import here to avoid circular imports. 124 from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top 125 column_configs = [serialization.serialize_feature_column(fc) 126 for fc in self._feature_columns] 127 config = {'feature_columns': column_configs} 128 config['partitioner'] = generic_utils.serialize_keras_object( 129 self._partitioner) 130 131 base_config = super( # pylint: disable=bad-super-call 132 _BaseFeaturesLayer, self).get_config() 133 return dict(list(base_config.items()) + list(config.items())) 134 135 @classmethod 136 def from_config(cls, config, custom_objects=None): 137 # Import here to avoid circular imports. 138 from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top 139 config_cp = config.copy() 140 columns_by_name = {} 141 config_cp['feature_columns'] = [serialization.deserialize_feature_column( 142 c, custom_objects, columns_by_name) for c in config['feature_columns']] 143 config_cp['partitioner'] = generic_utils.deserialize_keras_object( 144 config['partitioner'], custom_objects) 145 146 return cls(**config_cp) 147 148 149def _sanitize_column_name_for_variable_scope(name): 150 """Sanitizes user-provided feature names for use as variable scopes.""" 151 invalid_char = re.compile('[^A-Za-z0-9_.\\-]') 152 return invalid_char.sub('_', name) 153 154 155def _verify_static_batch_size_equality(tensors, columns): 156 """Verify equality between static batch sizes. 157 158 Args: 159 tensors: iterable of input tensors. 160 columns: Corresponding feature columns. 161 162 Raises: 163 ValueError: in case of mismatched batch sizes. 164 """ 165 expected_batch_size = None 166 for i in range(0, len(tensors)): 167 # bath_size is a Dimension object. 168 batch_size = tensor_shape.Dimension(tensor_shape.dimension_value( 169 tensors[i].shape[0])) 170 if batch_size.value is not None: 171 if expected_batch_size is None: 172 bath_size_column_index = i 173 expected_batch_size = batch_size 174 elif not expected_batch_size.is_compatible_with(batch_size): 175 raise ValueError( 176 'Batch size (first dimension) of each feature must be same. ' 177 'Batch size of columns ({}, {}): ({}, {})'.format( 178 columns[bath_size_column_index].name, columns[i].name, 179 expected_batch_size, batch_size)) 180 181 182def _normalize_feature_columns(feature_columns): 183 """Normalizes the `feature_columns` input. 184 185 This method converts the `feature_columns` to list type as best as it can. In 186 addition, verifies the type and other parts of feature_columns, required by 187 downstream library. 188 189 Args: 190 feature_columns: The raw feature columns, usually passed by users. 191 192 Returns: 193 The normalized feature column list. 194 195 Raises: 196 ValueError: for any invalid inputs, such as empty, duplicated names, etc. 197 """ 198 if isinstance(feature_columns, feature_column_v2.FeatureColumn): 199 feature_columns = [feature_columns] 200 201 if isinstance(feature_columns, collections.abc.Iterator): 202 feature_columns = list(feature_columns) 203 204 if isinstance(feature_columns, dict): 205 raise ValueError('Expected feature_columns to be iterable, found dict.') 206 207 for column in feature_columns: 208 if not isinstance(column, feature_column_v2.FeatureColumn): 209 raise ValueError('Items of feature_columns must be a FeatureColumn. ' 210 'Given (type {}): {}.'.format(type(column), column)) 211 if not feature_columns: 212 raise ValueError('feature_columns must not be empty.') 213 name_to_column = {} 214 for column in feature_columns: 215 if column.name in name_to_column: 216 raise ValueError('Duplicate feature column name found for columns: {} ' 217 'and {}. This usually means that these columns refer to ' 218 'same base feature. Either one must be discarded or a ' 219 'duplicated but renamed item must be inserted in ' 220 'features dict.'.format(column, 221 name_to_column[column.name])) 222 name_to_column[column.name] = column 223 224 return sorted(feature_columns, key=lambda x: x.name) 225