1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""FeatureColumn serialization, deserialization logic.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import six 22 23from tensorflow.python.feature_column import feature_column_v2 as fc_lib 24from tensorflow.python.feature_column import sequence_feature_column as sfc_lib 25from tensorflow.python.ops import init_ops 26from tensorflow.python.util import tf_decorator 27from tensorflow.python.util import tf_inspect 28 29 30_FEATURE_COLUMNS = [ 31 fc_lib.BucketizedColumn, fc_lib.CrossedColumn, fc_lib.EmbeddingColumn, 32 fc_lib.HashedCategoricalColumn, fc_lib.IdentityCategoricalColumn, 33 fc_lib.IndicatorColumn, fc_lib.NumericColumn, 34 fc_lib.SequenceCategoricalColumn, fc_lib.SequenceDenseColumn, 35 fc_lib.SharedEmbeddingColumn, fc_lib.VocabularyFileCategoricalColumn, 36 fc_lib.VocabularyListCategoricalColumn, fc_lib.WeightedCategoricalColumn, 37 init_ops.TruncatedNormal, sfc_lib.SequenceNumericColumn 38] 39 40 41def serialize_feature_column(fc): 42 """Serializes a FeatureColumn or a raw string key. 43 44 This method should only be used to serialize parent FeatureColumns when 45 implementing FeatureColumn.get_config(), else serialize_feature_columns() 46 is preferable. 47 48 This serialization also keeps information of the FeatureColumn class, so 49 deserialization is possible without knowing the class type. For example: 50 51 a = numeric_column('x') 52 a.get_config() gives: 53 { 54 'key': 'price', 55 'shape': (1,), 56 'default_value': None, 57 'dtype': 'float32', 58 'normalizer_fn': None 59 } 60 While serialize_feature_column(a) gives: 61 { 62 'class_name': 'NumericColumn', 63 'config': { 64 'key': 'price', 65 'shape': (1,), 66 'default_value': None, 67 'dtype': 'float32', 68 'normalizer_fn': None 69 } 70 } 71 72 Args: 73 fc: A FeatureColumn or raw feature key string. 74 75 Returns: 76 Keras serialization for FeatureColumns, leaves string keys unaffected. 77 78 Raises: 79 ValueError if called with input that is not string or FeatureColumn. 80 """ 81 if isinstance(fc, six.string_types): 82 return fc 83 elif isinstance(fc, fc_lib.FeatureColumn): 84 return {'class_name': fc.__class__.__name__, 'config': fc.get_config()} 85 else: 86 raise ValueError('Instance: {} is not a FeatureColumn'.format(fc)) 87 88 89def deserialize_feature_column(config, 90 custom_objects=None, 91 columns_by_name=None): 92 """Deserializes a `config` generated with `serialize_feature_column`. 93 94 This method should only be used to deserialize parent FeatureColumns when 95 implementing FeatureColumn.from_config(), else deserialize_feature_columns() 96 is preferable. Returns a FeatureColumn for this config. 97 TODO(b/118939620): Simplify code if Keras utils support object deduping. 98 99 Args: 100 config: A Dict with the serialization of feature columns acquired by 101 `serialize_feature_column`, or a string representing a raw column. 102 custom_objects: A Dict from custom_object name to the associated keras 103 serializable objects (FeatureColumns, classes or functions). 104 columns_by_name: A Dict[String, FeatureColumn] of existing columns in order 105 to avoid duplication. 106 107 Raises: 108 ValueError if `config` has invalid format (e.g: expected keys missing, 109 or refers to unknown classes). 110 111 Returns: 112 A FeatureColumn corresponding to the input `config`. 113 """ 114 if isinstance(config, six.string_types): 115 return config 116 # A dict from class_name to class for all FeatureColumns in this module. 117 # FeatureColumns not part of the module can be passed as custom_objects. 118 module_feature_column_classes = { 119 cls.__name__: cls for cls in _FEATURE_COLUMNS} 120 if columns_by_name is None: 121 columns_by_name = {} 122 123 (cls, 124 cls_config) = _class_and_config_for_serialized_keras_object( 125 config, 126 module_objects=module_feature_column_classes, 127 custom_objects=custom_objects, 128 printable_module_name='feature_column_v2') 129 130 if not issubclass(cls, fc_lib.FeatureColumn): 131 raise ValueError( 132 'Expected FeatureColumn class, instead found: {}'.format(cls)) 133 134 # Always deserialize the FeatureColumn, in order to get the name. 135 new_instance = cls.from_config( # pylint: disable=protected-access 136 cls_config, 137 custom_objects=custom_objects, 138 columns_by_name=columns_by_name) 139 140 # If the name already exists, re-use the column from columns_by_name, 141 # (new_instance remains unused). 142 return columns_by_name.setdefault( 143 _column_name_with_class_name(new_instance), new_instance) 144 145 146def serialize_feature_columns(feature_columns): 147 """Serializes a list of FeatureColumns. 148 149 Returns a list of Keras-style config dicts that represent the input 150 FeatureColumns and can be used with `deserialize_feature_columns` for 151 reconstructing the original columns. 152 153 Args: 154 feature_columns: A list of FeatureColumns. 155 156 Returns: 157 Keras serialization for the list of FeatureColumns. 158 159 Raises: 160 ValueError if called with input that is not a list of FeatureColumns. 161 """ 162 return [serialize_feature_column(fc) for fc in feature_columns] 163 164 165def deserialize_feature_columns(configs, custom_objects=None): 166 """Deserializes a list of FeatureColumns configs. 167 168 Returns a list of FeatureColumns given a list of config dicts acquired by 169 `serialize_feature_columns`. 170 171 Args: 172 configs: A list of Dicts with the serialization of feature columns acquired 173 by `serialize_feature_columns`. 174 custom_objects: A Dict from custom_object name to the associated keras 175 serializable objects (FeatureColumns, classes or functions). 176 177 Returns: 178 FeatureColumn objects corresponding to the input configs. 179 180 Raises: 181 ValueError if called with input that is not a list of FeatureColumns. 182 """ 183 columns_by_name = {} 184 return [ 185 deserialize_feature_column(c, custom_objects, columns_by_name) 186 for c in configs 187 ] 188 189 190def _column_name_with_class_name(fc): 191 """Returns a unique name for the feature column used during deduping. 192 193 Without this two FeatureColumns that have the same name and where 194 one wraps the other, such as an IndicatorColumn wrapping a 195 SequenceCategoricalColumn, will fail to deserialize because they will have the 196 same name in columns_by_name, causing the wrong column to be returned. 197 198 Args: 199 fc: A FeatureColumn. 200 201 Returns: 202 A unique name as a string. 203 """ 204 return fc.__class__.__name__ + ':' + fc.name 205 206 207def _serialize_keras_object(instance): 208 """Serialize a Keras object into a JSON-compatible representation.""" 209 _, instance = tf_decorator.unwrap(instance) 210 if instance is None: 211 return None 212 213 if hasattr(instance, 'get_config'): 214 name = instance.__class__.__name__ 215 config = instance.get_config() 216 serialization_config = {} 217 for key, item in config.items(): 218 if isinstance(item, six.string_types): 219 serialization_config[key] = item 220 continue 221 222 # Any object of a different type needs to be converted to string or dict 223 # for serialization (e.g. custom functions, custom classes) 224 try: 225 serialized_item = _serialize_keras_object(item) 226 if isinstance(serialized_item, dict) and not isinstance(item, dict): 227 serialized_item['__passive_serialization__'] = True 228 serialization_config[key] = serialized_item 229 except ValueError: 230 serialization_config[key] = item 231 232 return {'class_name': name, 'config': serialization_config} 233 if hasattr(instance, '__name__'): 234 return instance.__name__ 235 raise ValueError('Cannot serialize', instance) 236 237 238def _deserialize_keras_object(identifier, 239 module_objects=None, 240 custom_objects=None, 241 printable_module_name='object'): 242 """Turns the serialized form of a Keras object back into an actual object.""" 243 if identifier is None: 244 return None 245 246 if isinstance(identifier, dict): 247 # In this case we are dealing with a Keras config dictionary. 248 config = identifier 249 (cls, cls_config) = _class_and_config_for_serialized_keras_object( 250 config, module_objects, custom_objects, printable_module_name) 251 252 if hasattr(cls, 'from_config'): 253 arg_spec = tf_inspect.getfullargspec(cls.from_config) 254 custom_objects = custom_objects or {} 255 256 if 'custom_objects' in arg_spec.args: 257 return cls.from_config( 258 cls_config, 259 custom_objects=dict( 260 list(custom_objects.items()))) 261 return cls.from_config(cls_config) 262 else: 263 # Then `cls` may be a function returning a class. 264 # in this case by convention `config` holds 265 # the kwargs of the function. 266 custom_objects = custom_objects or {} 267 return cls(**cls_config) 268 elif isinstance(identifier, six.string_types): 269 object_name = identifier 270 if custom_objects and object_name in custom_objects: 271 obj = custom_objects.get(object_name) 272 else: 273 obj = module_objects.get(object_name) 274 if obj is None: 275 raise ValueError( 276 'Unknown ' + printable_module_name + ': ' + object_name) 277 # Classes passed by name are instantiated with no args, functions are 278 # returned as-is. 279 if tf_inspect.isclass(obj): 280 return obj() 281 return obj 282 elif tf_inspect.isfunction(identifier): 283 # If a function has already been deserialized, return as is. 284 return identifier 285 else: 286 raise ValueError('Could not interpret serialized %s: %s' % 287 (printable_module_name, identifier)) 288 289 290def _class_and_config_for_serialized_keras_object( 291 config, 292 module_objects=None, 293 custom_objects=None, 294 printable_module_name='object'): 295 """Returns the class name and config for a serialized keras object.""" 296 if (not isinstance(config, dict) or 'class_name' not in config or 297 'config' not in config): 298 raise ValueError('Improper config format: ' + str(config)) 299 300 class_name = config['class_name'] 301 cls = _get_registered_object(class_name, custom_objects=custom_objects, 302 module_objects=module_objects) 303 if cls is None: 304 raise ValueError('Unknown ' + printable_module_name + ': ' + class_name) 305 306 cls_config = config['config'] 307 308 deserialized_objects = {} 309 for key, item in cls_config.items(): 310 if isinstance(item, dict) and '__passive_serialization__' in item: 311 deserialized_objects[key] = _deserialize_keras_object( 312 item, 313 module_objects=module_objects, 314 custom_objects=custom_objects, 315 printable_module_name='config_item') 316 elif (isinstance(item, six.string_types) and 317 tf_inspect.isfunction(_get_registered_object(item, custom_objects))): 318 # Handle custom functions here. When saving functions, we only save the 319 # function's name as a string. If we find a matching string in the custom 320 # objects during deserialization, we convert the string back to the 321 # original function. 322 # Note that a potential issue is that a string field could have a naming 323 # conflict with a custom function name, but this should be a rare case. 324 # This issue does not occur if a string field has a naming conflict with 325 # a custom object, since the config of an object will always be a dict. 326 deserialized_objects[key] = _get_registered_object(item, custom_objects) 327 for key, item in deserialized_objects.items(): 328 cls_config[key] = deserialized_objects[key] 329 330 return (cls, cls_config) 331 332 333def _get_registered_object(name, custom_objects=None, module_objects=None): 334 if custom_objects and name in custom_objects: 335 return custom_objects[name] 336 elif module_objects and name in module_objects: 337 return module_objects[name] 338 return None 339 340