1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""FeatureColumn serialization, deserialization logic.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import six 22 23from tensorflow.python.feature_column import feature_column_v2 as fc_lib 24from tensorflow.python.feature_column import sequence_feature_column as sfc_lib 25from tensorflow.python.ops import init_ops 26from tensorflow.python.util import tf_decorator 27from tensorflow.python.util import tf_inspect 28from tensorflow.python.util.tf_export import tf_export 29 30 31_FEATURE_COLUMNS = [ 32 fc_lib.BucketizedColumn, fc_lib.CrossedColumn, fc_lib.EmbeddingColumn, 33 fc_lib.HashedCategoricalColumn, fc_lib.IdentityCategoricalColumn, 34 fc_lib.IndicatorColumn, fc_lib.NumericColumn, 35 fc_lib.SequenceCategoricalColumn, fc_lib.SequenceDenseColumn, 36 fc_lib.SharedEmbeddingColumn, fc_lib.VocabularyFileCategoricalColumn, 37 fc_lib.VocabularyListCategoricalColumn, fc_lib.WeightedCategoricalColumn, 38 init_ops.TruncatedNormal, sfc_lib.SequenceNumericColumn 39] 40 41 42@tf_export('__internal__.feature_column.serialize_feature_column', v1=[]) 43def serialize_feature_column(fc): 44 """Serializes a FeatureColumn or a raw string key. 45 46 This method should only be used to serialize parent FeatureColumns when 47 implementing FeatureColumn.get_config(), else serialize_feature_columns() 48 is preferable. 49 50 This serialization also keeps information of the FeatureColumn class, so 51 deserialization is possible without knowing the class type. For example: 52 53 a = numeric_column('x') 54 a.get_config() gives: 55 { 56 'key': 'price', 57 'shape': (1,), 58 'default_value': None, 59 'dtype': 'float32', 60 'normalizer_fn': None 61 } 62 While serialize_feature_column(a) gives: 63 { 64 'class_name': 'NumericColumn', 65 'config': { 66 'key': 'price', 67 'shape': (1,), 68 'default_value': None, 69 'dtype': 'float32', 70 'normalizer_fn': None 71 } 72 } 73 74 Args: 75 fc: A FeatureColumn or raw feature key string. 76 77 Returns: 78 Keras serialization for FeatureColumns, leaves string keys unaffected. 79 80 Raises: 81 ValueError if called with input that is not string or FeatureColumn. 82 """ 83 if isinstance(fc, six.string_types): 84 return fc 85 elif isinstance(fc, fc_lib.FeatureColumn): 86 return {'class_name': fc.__class__.__name__, 'config': fc.get_config()} 87 else: 88 raise ValueError('Instance: {} is not a FeatureColumn'.format(fc)) 89 90 91@tf_export('__internal__.feature_column.deserialize_feature_column', v1=[]) 92def deserialize_feature_column(config, 93 custom_objects=None, 94 columns_by_name=None): 95 """Deserializes a `config` generated with `serialize_feature_column`. 96 97 This method should only be used to deserialize parent FeatureColumns when 98 implementing FeatureColumn.from_config(), else deserialize_feature_columns() 99 is preferable. Returns a FeatureColumn for this config. 100 101 Args: 102 config: A Dict with the serialization of feature columns acquired by 103 `serialize_feature_column`, or a string representing a raw column. 104 custom_objects: A Dict from custom_object name to the associated keras 105 serializable objects (FeatureColumns, classes or functions). 106 columns_by_name: A Dict[String, FeatureColumn] of existing columns in order 107 to avoid duplication. 108 109 Raises: 110 ValueError if `config` has invalid format (e.g: expected keys missing, 111 or refers to unknown classes). 112 113 Returns: 114 A FeatureColumn corresponding to the input `config`. 115 """ 116 # TODO(b/118939620): Simplify code if Keras utils support object deduping. 117 if isinstance(config, six.string_types): 118 return config 119 # A dict from class_name to class for all FeatureColumns in this module. 120 # FeatureColumns not part of the module can be passed as custom_objects. 121 module_feature_column_classes = { 122 cls.__name__: cls for cls in _FEATURE_COLUMNS} 123 if columns_by_name is None: 124 columns_by_name = {} 125 126 (cls, 127 cls_config) = _class_and_config_for_serialized_keras_object( 128 config, 129 module_objects=module_feature_column_classes, 130 custom_objects=custom_objects, 131 printable_module_name='feature_column_v2') 132 133 if not issubclass(cls, fc_lib.FeatureColumn): 134 raise ValueError( 135 'Expected FeatureColumn class, instead found: {}'.format(cls)) 136 137 # Always deserialize the FeatureColumn, in order to get the name. 138 new_instance = cls.from_config( # pylint: disable=protected-access 139 cls_config, 140 custom_objects=custom_objects, 141 columns_by_name=columns_by_name) 142 143 # If the name already exists, re-use the column from columns_by_name, 144 # (new_instance remains unused). 145 return columns_by_name.setdefault( 146 _column_name_with_class_name(new_instance), new_instance) 147 148 149def serialize_feature_columns(feature_columns): 150 """Serializes a list of FeatureColumns. 151 152 Returns a list of Keras-style config dicts that represent the input 153 FeatureColumns and can be used with `deserialize_feature_columns` for 154 reconstructing the original columns. 155 156 Args: 157 feature_columns: A list of FeatureColumns. 158 159 Returns: 160 Keras serialization for the list of FeatureColumns. 161 162 Raises: 163 ValueError if called with input that is not a list of FeatureColumns. 164 """ 165 return [serialize_feature_column(fc) for fc in feature_columns] 166 167 168def deserialize_feature_columns(configs, custom_objects=None): 169 """Deserializes a list of FeatureColumns configs. 170 171 Returns a list of FeatureColumns given a list of config dicts acquired by 172 `serialize_feature_columns`. 173 174 Args: 175 configs: A list of Dicts with the serialization of feature columns acquired 176 by `serialize_feature_columns`. 177 custom_objects: A Dict from custom_object name to the associated keras 178 serializable objects (FeatureColumns, classes or functions). 179 180 Returns: 181 FeatureColumn objects corresponding to the input configs. 182 183 Raises: 184 ValueError if called with input that is not a list of FeatureColumns. 185 """ 186 columns_by_name = {} 187 return [ 188 deserialize_feature_column(c, custom_objects, columns_by_name) 189 for c in configs 190 ] 191 192 193def _column_name_with_class_name(fc): 194 """Returns a unique name for the feature column used during deduping. 195 196 Without this two FeatureColumns that have the same name and where 197 one wraps the other, such as an IndicatorColumn wrapping a 198 SequenceCategoricalColumn, will fail to deserialize because they will have the 199 same name in columns_by_name, causing the wrong column to be returned. 200 201 Args: 202 fc: A FeatureColumn. 203 204 Returns: 205 A unique name as a string. 206 """ 207 return fc.__class__.__name__ + ':' + fc.name 208 209 210def _serialize_keras_object(instance): 211 """Serialize a Keras object into a JSON-compatible representation.""" 212 _, instance = tf_decorator.unwrap(instance) 213 if instance is None: 214 return None 215 216 if hasattr(instance, 'get_config'): 217 name = instance.__class__.__name__ 218 config = instance.get_config() 219 serialization_config = {} 220 for key, item in config.items(): 221 if isinstance(item, six.string_types): 222 serialization_config[key] = item 223 continue 224 225 # Any object of a different type needs to be converted to string or dict 226 # for serialization (e.g. custom functions, custom classes) 227 try: 228 serialized_item = _serialize_keras_object(item) 229 if isinstance(serialized_item, dict) and not isinstance(item, dict): 230 serialized_item['__passive_serialization__'] = True 231 serialization_config[key] = serialized_item 232 except ValueError: 233 serialization_config[key] = item 234 235 return {'class_name': name, 'config': serialization_config} 236 if hasattr(instance, '__name__'): 237 return instance.__name__ 238 raise ValueError('Cannot serialize', instance) 239 240 241def _deserialize_keras_object(identifier, 242 module_objects=None, 243 custom_objects=None, 244 printable_module_name='object'): 245 """Turns the serialized form of a Keras object back into an actual object.""" 246 if identifier is None: 247 return None 248 249 if isinstance(identifier, dict): 250 # In this case we are dealing with a Keras config dictionary. 251 config = identifier 252 (cls, cls_config) = _class_and_config_for_serialized_keras_object( 253 config, module_objects, custom_objects, printable_module_name) 254 255 if hasattr(cls, 'from_config'): 256 arg_spec = tf_inspect.getfullargspec(cls.from_config) 257 custom_objects = custom_objects or {} 258 259 if 'custom_objects' in arg_spec.args: 260 return cls.from_config( 261 cls_config, 262 custom_objects=dict( 263 list(custom_objects.items()))) 264 return cls.from_config(cls_config) 265 else: 266 # Then `cls` may be a function returning a class. 267 # in this case by convention `config` holds 268 # the kwargs of the function. 269 custom_objects = custom_objects or {} 270 return cls(**cls_config) 271 elif isinstance(identifier, six.string_types): 272 object_name = identifier 273 if custom_objects and object_name in custom_objects: 274 obj = custom_objects.get(object_name) 275 else: 276 obj = module_objects.get(object_name) 277 if obj is None: 278 raise ValueError( 279 'Unknown ' + printable_module_name + ': ' + object_name) 280 # Classes passed by name are instantiated with no args, functions are 281 # returned as-is. 282 if tf_inspect.isclass(obj): 283 return obj() 284 return obj 285 elif tf_inspect.isfunction(identifier): 286 # If a function has already been deserialized, return as is. 287 return identifier 288 else: 289 raise ValueError('Could not interpret serialized %s: %s' % 290 (printable_module_name, identifier)) 291 292 293def _class_and_config_for_serialized_keras_object( 294 config, 295 module_objects=None, 296 custom_objects=None, 297 printable_module_name='object'): 298 """Returns the class name and config for a serialized keras object.""" 299 if (not isinstance(config, dict) or 'class_name' not in config or 300 'config' not in config): 301 raise ValueError('Improper config format: ' + str(config)) 302 303 class_name = config['class_name'] 304 cls = _get_registered_object(class_name, custom_objects=custom_objects, 305 module_objects=module_objects) 306 if cls is None: 307 raise ValueError('Unknown ' + printable_module_name + ': ' + class_name) 308 309 cls_config = config['config'] 310 311 deserialized_objects = {} 312 for key, item in cls_config.items(): 313 if isinstance(item, dict) and '__passive_serialization__' in item: 314 deserialized_objects[key] = _deserialize_keras_object( 315 item, 316 module_objects=module_objects, 317 custom_objects=custom_objects, 318 printable_module_name='config_item') 319 elif (isinstance(item, six.string_types) and 320 tf_inspect.isfunction(_get_registered_object(item, custom_objects))): 321 # Handle custom functions here. When saving functions, we only save the 322 # function's name as a string. If we find a matching string in the custom 323 # objects during deserialization, we convert the string back to the 324 # original function. 325 # Note that a potential issue is that a string field could have a naming 326 # conflict with a custom function name, but this should be a rare case. 327 # This issue does not occur if a string field has a naming conflict with 328 # a custom object, since the config of an object will always be a dict. 329 deserialized_objects[key] = _get_registered_object(item, custom_objects) 330 for key, item in deserialized_objects.items(): 331 cls_config[key] = deserialized_objects[key] 332 333 return (cls, cls_config) 334 335 336def _get_registered_object(name, custom_objects=None, module_objects=None): 337 if custom_objects and name in custom_objects: 338 return custom_objects[name] 339 elif module_objects and name in module_objects: 340 return module_objects[name] 341 return None 342 343