1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""This API defines FeatureColumn abstraction.""" 16 17# This file was originally under tf/python/feature_column, and was moved to 18# Keras package in order to remove the reverse dependency from TF to Keras. 19 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24 25from tensorflow.python.feature_column import feature_column_v2 26from tensorflow.python.keras.engine.base_layer import Layer 27from tensorflow.python.keras.utils import generic_utils 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import variable_scope 30 31 32class _BaseFeaturesLayer(Layer): 33 """Base class for DenseFeatures and SequenceFeatures. 34 35 Defines common methods and helpers. 36 37 Args: 38 feature_columns: An iterable containing the FeatureColumns to use as 39 inputs to your model. 40 expected_column_type: Expected class for provided feature columns. 41 trainable: Boolean, whether the layer's variables will be updated via 42 gradient descent during training. 43 name: Name to give to the DenseFeatures. 44 **kwargs: Keyword arguments to construct a layer. 45 46 Raises: 47 ValueError: if an item in `feature_columns` doesn't match 48 `expected_column_type`. 49 """ 50 51 def __init__(self, 52 feature_columns, 53 expected_column_type, 54 trainable, 55 name, 56 partitioner=None, 57 **kwargs): 58 super(_BaseFeaturesLayer, self).__init__( 59 name=name, trainable=trainable, **kwargs) 60 self._feature_columns = feature_column_v2._normalize_feature_columns( # pylint: disable=protected-access 61 feature_columns) 62 self._state_manager = feature_column_v2._StateManagerImpl( # pylint: disable=protected-access 63 self, self.trainable) 64 self._partitioner = partitioner 65 for column in self._feature_columns: 66 if not isinstance(column, expected_column_type): 67 raise ValueError( 68 'Items of feature_columns must be a {}. ' 69 'You can wrap a categorical column with an ' 70 'embedding_column or indicator_column. Given: {}'.format( 71 expected_column_type, column)) 72 73 def build(self, _): 74 for column in self._feature_columns: 75 with variable_scope.variable_scope( 76 self.name, partitioner=self._partitioner): 77 with variable_scope.variable_scope( 78 feature_column_v2._sanitize_column_name_for_variable_scope( # pylint: disable=protected-access 79 column.name)): 80 column.create_state(self._state_manager) 81 super(_BaseFeaturesLayer, self).build(None) 82 83 def _output_shape(self, input_shape, num_elements): 84 """Computes expected output shape of the layer or a column's dense tensor. 85 86 Args: 87 input_shape: Tensor or array with batch shape. 88 num_elements: Size of the last dimension of the output. 89 90 Returns: 91 Tuple with output shape. 92 """ 93 raise NotImplementedError('Calling an abstract method.') 94 95 def compute_output_shape(self, input_shape): 96 total_elements = 0 97 for column in self._feature_columns: 98 total_elements += column.variable_shape.num_elements() 99 return self._target_shape(input_shape, total_elements) 100 101 def _process_dense_tensor(self, column, tensor): 102 """Reshapes the dense tensor output of a column based on expected shape. 103 104 Args: 105 column: A DenseColumn or SequenceDenseColumn object. 106 tensor: A dense tensor obtained from the same column. 107 108 Returns: 109 Reshaped dense tensor. 110 """ 111 num_elements = column.variable_shape.num_elements() 112 target_shape = self._target_shape(array_ops.shape(tensor), num_elements) 113 return array_ops.reshape(tensor, shape=target_shape) 114 115 def _verify_and_concat_tensors(self, output_tensors): 116 """Verifies and concatenates the dense output of several columns.""" 117 feature_column_v2._verify_static_batch_size_equality( # pylint: disable=protected-access 118 output_tensors, self._feature_columns) 119 return array_ops.concat(output_tensors, -1) 120 121 def get_config(self): 122 # Import here to avoid circular imports. 123 from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top 124 column_configs = serialization.serialize_feature_columns( 125 self._feature_columns) 126 config = {'feature_columns': column_configs} 127 config['partitioner'] = generic_utils.serialize_keras_object( 128 self._partitioner) 129 130 base_config = super( # pylint: disable=bad-super-call 131 _BaseFeaturesLayer, self).get_config() 132 return dict(list(base_config.items()) + list(config.items())) 133 134 @classmethod 135 def from_config(cls, config, custom_objects=None): 136 # Import here to avoid circular imports. 137 from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top 138 config_cp = config.copy() 139 config_cp['feature_columns'] = serialization.deserialize_feature_columns( 140 config['feature_columns'], custom_objects=custom_objects) 141 config_cp['partitioner'] = generic_utils.deserialize_keras_object( 142 config['partitioner'], custom_objects) 143 144 return cls(**config_cp) 145