1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""FeatureColumn serialization, deserialization logic.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import six 22 23from tensorflow.python.feature_column import feature_column_v2 as fc_lib 24from tensorflow.python.feature_column import sequence_feature_column as sfc_lib 25from tensorflow.python.ops import init_ops 26from tensorflow.python.util.lazy_loader import LazyLoader 27 28# Prevent circular dependencies with Keras serialization. 29generic_utils = LazyLoader( 30 'generic_utils', globals(), 31 'tensorflow.python.keras.utils.generic_utils') 32 33_FEATURE_COLUMNS = [ 34 fc_lib.BucketizedColumn, fc_lib.CrossedColumn, fc_lib.EmbeddingColumn, 35 fc_lib.HashedCategoricalColumn, fc_lib.IdentityCategoricalColumn, 36 fc_lib.IndicatorColumn, fc_lib.NumericColumn, 37 fc_lib.SequenceCategoricalColumn, fc_lib.SequenceDenseColumn, 38 fc_lib.SharedEmbeddingColumn, fc_lib.VocabularyFileCategoricalColumn, 39 fc_lib.VocabularyListCategoricalColumn, fc_lib.WeightedCategoricalColumn, 40 init_ops.TruncatedNormal, sfc_lib.SequenceNumericColumn 41] 42 43 44def serialize_feature_column(fc): 45 """Serializes a FeatureColumn or a raw string key. 46 47 This method should only be used to serialize parent FeatureColumns when 48 implementing FeatureColumn.get_config(), else serialize_feature_columns() 49 is preferable. 50 51 This serialization also keeps information of the FeatureColumn class, so 52 deserialization is possible without knowing the class type. For example: 53 54 a = numeric_column('x') 55 a.get_config() gives: 56 { 57 'key': 'price', 58 'shape': (1,), 59 'default_value': None, 60 'dtype': 'float32', 61 'normalizer_fn': None 62 } 63 While serialize_feature_column(a) gives: 64 { 65 'class_name': 'NumericColumn', 66 'config': { 67 'key': 'price', 68 'shape': (1,), 69 'default_value': None, 70 'dtype': 'float32', 71 'normalizer_fn': None 72 } 73 } 74 75 Args: 76 fc: A FeatureColumn or raw feature key string. 77 78 Returns: 79 Keras serialization for FeatureColumns, leaves string keys unaffected. 80 81 Raises: 82 ValueError if called with input that is not string or FeatureColumn. 83 """ 84 if isinstance(fc, six.string_types): 85 return fc 86 elif isinstance(fc, fc_lib.FeatureColumn): 87 return generic_utils.serialize_keras_class_and_config( 88 fc.__class__.__name__, fc.get_config()) # pylint: disable=protected-access 89 else: 90 raise ValueError('Instance: {} is not a FeatureColumn'.format(fc)) 91 92 93def deserialize_feature_column(config, 94 custom_objects=None, 95 columns_by_name=None): 96 """Deserializes a `config` generated with `serialize_feature_column`. 97 98 This method should only be used to deserialize parent FeatureColumns when 99 implementing FeatureColumn.from_config(), else deserialize_feature_columns() 100 is preferable. Returns a FeatureColumn for this config. 101 TODO(b/118939620): Simplify code if Keras utils support object deduping. 102 103 Args: 104 config: A Dict with the serialization of feature columns acquired by 105 `serialize_feature_column`, or a string representing a raw column. 106 custom_objects: A Dict from custom_object name to the associated keras 107 serializable objects (FeatureColumns, classes or functions). 108 columns_by_name: A Dict[String, FeatureColumn] of existing columns in order 109 to avoid duplication. 110 111 Raises: 112 ValueError if `config` has invalid format (e.g: expected keys missing, 113 or refers to unknown classes). 114 115 Returns: 116 A FeatureColumn corresponding to the input `config`. 117 """ 118 if isinstance(config, six.string_types): 119 return config 120 # A dict from class_name to class for all FeatureColumns in this module. 121 # FeatureColumns not part of the module can be passed as custom_objects. 122 module_feature_column_classes = { 123 cls.__name__: cls for cls in _FEATURE_COLUMNS} 124 if columns_by_name is None: 125 columns_by_name = {} 126 127 (cls, 128 cls_config) = generic_utils.class_and_config_for_serialized_keras_object( 129 config, 130 module_objects=module_feature_column_classes, 131 custom_objects=custom_objects, 132 printable_module_name='feature_column_v2') 133 134 if not issubclass(cls, fc_lib.FeatureColumn): 135 raise ValueError( 136 'Expected FeatureColumn class, instead found: {}'.format(cls)) 137 138 # Always deserialize the FeatureColumn, in order to get the name. 139 new_instance = cls.from_config( # pylint: disable=protected-access 140 cls_config, 141 custom_objects=custom_objects, 142 columns_by_name=columns_by_name) 143 144 # If the name already exists, re-use the column from columns_by_name, 145 # (new_instance remains unused). 146 return columns_by_name.setdefault( 147 _column_name_with_class_name(new_instance), new_instance) 148 149 150def serialize_feature_columns(feature_columns): 151 """Serializes a list of FeatureColumns. 152 153 Returns a list of Keras-style config dicts that represent the input 154 FeatureColumns and can be used with `deserialize_feature_columns` for 155 reconstructing the original columns. 156 157 Args: 158 feature_columns: A list of FeatureColumns. 159 160 Returns: 161 Keras serialization for the list of FeatureColumns. 162 163 Raises: 164 ValueError if called with input that is not a list of FeatureColumns. 165 """ 166 return [serialize_feature_column(fc) for fc in feature_columns] 167 168 169def deserialize_feature_columns(configs, custom_objects=None): 170 """Deserializes a list of FeatureColumns configs. 171 172 Returns a list of FeatureColumns given a list of config dicts acquired by 173 `serialize_feature_columns`. 174 175 Args: 176 configs: A list of Dicts with the serialization of feature columns acquired 177 by `serialize_feature_columns`. 178 custom_objects: A Dict from custom_object name to the associated keras 179 serializable objects (FeatureColumns, classes or functions). 180 181 Returns: 182 FeatureColumn objects corresponding to the input configs. 183 184 Raises: 185 ValueError if called with input that is not a list of FeatureColumns. 186 """ 187 columns_by_name = {} 188 return [ 189 deserialize_feature_column(c, custom_objects, columns_by_name) 190 for c in configs 191 ] 192 193 194def _column_name_with_class_name(fc): 195 """Returns a unique name for the feature column used during deduping. 196 197 Without this two FeatureColumns that have the same name and where 198 one wraps the other, such as an IndicatorColumn wrapping a 199 SequenceCategoricalColumn, will fail to deserialize because they will have the 200 same name in colums_by_name, causing the wrong column to be returned. 201 202 Args: 203 fc: A FeatureColumn. 204 205 Returns: 206 A unique name as a string. 207 """ 208 return fc.__class__.__name__ + ':' + fc.name 209