• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""FeatureColumn serialization, deserialization logic."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import six
22
23from tensorflow.python.feature_column import feature_column_v2 as fc_lib
24from tensorflow.python.feature_column import sequence_feature_column as sfc_lib
25from tensorflow.python.ops import init_ops
26from tensorflow.python.util import tf_decorator
27from tensorflow.python.util import tf_inspect
28from tensorflow.python.util.tf_export import tf_export
29
30
31_FEATURE_COLUMNS = [
32    fc_lib.BucketizedColumn, fc_lib.CrossedColumn, fc_lib.EmbeddingColumn,
33    fc_lib.HashedCategoricalColumn, fc_lib.IdentityCategoricalColumn,
34    fc_lib.IndicatorColumn, fc_lib.NumericColumn,
35    fc_lib.SequenceCategoricalColumn, fc_lib.SequenceDenseColumn,
36    fc_lib.SharedEmbeddingColumn, fc_lib.VocabularyFileCategoricalColumn,
37    fc_lib.VocabularyListCategoricalColumn, fc_lib.WeightedCategoricalColumn,
38    init_ops.TruncatedNormal, sfc_lib.SequenceNumericColumn
39]
40
41
42@tf_export('__internal__.feature_column.serialize_feature_column', v1=[])
43def serialize_feature_column(fc):
44  """Serializes a FeatureColumn or a raw string key.
45
46  This method should only be used to serialize parent FeatureColumns when
47  implementing FeatureColumn.get_config(), else serialize_feature_columns()
48  is preferable.
49
50  This serialization also keeps information of the FeatureColumn class, so
51  deserialization is possible without knowing the class type. For example:
52
53  a = numeric_column('x')
54  a.get_config() gives:
55  {
56      'key': 'price',
57      'shape': (1,),
58      'default_value': None,
59      'dtype': 'float32',
60      'normalizer_fn': None
61  }
62  While serialize_feature_column(a) gives:
63  {
64      'class_name': 'NumericColumn',
65      'config': {
66          'key': 'price',
67          'shape': (1,),
68          'default_value': None,
69          'dtype': 'float32',
70          'normalizer_fn': None
71      }
72  }
73
74  Args:
75    fc: A FeatureColumn or raw feature key string.
76
77  Returns:
78    Keras serialization for FeatureColumns, leaves string keys unaffected.
79
80  Raises:
81    ValueError if called with input that is not string or FeatureColumn.
82  """
83  if isinstance(fc, six.string_types):
84    return fc
85  elif isinstance(fc, fc_lib.FeatureColumn):
86    return {'class_name': fc.__class__.__name__, 'config': fc.get_config()}
87  else:
88    raise ValueError('Instance: {} is not a FeatureColumn'.format(fc))
89
90
91@tf_export('__internal__.feature_column.deserialize_feature_column', v1=[])
92def deserialize_feature_column(config,
93                               custom_objects=None,
94                               columns_by_name=None):
95  """Deserializes a `config` generated with `serialize_feature_column`.
96
97  This method should only be used to deserialize parent FeatureColumns when
98  implementing FeatureColumn.from_config(), else deserialize_feature_columns()
99  is preferable. Returns a FeatureColumn for this config.
100
101  Args:
102    config: A Dict with the serialization of feature columns acquired by
103      `serialize_feature_column`, or a string representing a raw column.
104    custom_objects: A Dict from custom_object name to the associated keras
105      serializable objects (FeatureColumns, classes or functions).
106    columns_by_name: A Dict[String, FeatureColumn] of existing columns in order
107      to avoid duplication.
108
109  Raises:
110    ValueError if `config` has invalid format (e.g: expected keys missing,
111    or refers to unknown classes).
112
113  Returns:
114    A FeatureColumn corresponding to the input `config`.
115  """
116  # TODO(b/118939620): Simplify code if Keras utils support object deduping.
117  if isinstance(config, six.string_types):
118    return config
119  # A dict from class_name to class for all FeatureColumns in this module.
120  # FeatureColumns not part of the module can be passed as custom_objects.
121  module_feature_column_classes = {
122      cls.__name__: cls for cls in _FEATURE_COLUMNS}
123  if columns_by_name is None:
124    columns_by_name = {}
125
126  (cls,
127   cls_config) = _class_and_config_for_serialized_keras_object(
128       config,
129       module_objects=module_feature_column_classes,
130       custom_objects=custom_objects,
131       printable_module_name='feature_column_v2')
132
133  if not issubclass(cls, fc_lib.FeatureColumn):
134    raise ValueError(
135        'Expected FeatureColumn class, instead found: {}'.format(cls))
136
137  # Always deserialize the FeatureColumn, in order to get the name.
138  new_instance = cls.from_config(  # pylint: disable=protected-access
139      cls_config,
140      custom_objects=custom_objects,
141      columns_by_name=columns_by_name)
142
143  # If the name already exists, re-use the column from columns_by_name,
144  # (new_instance remains unused).
145  return columns_by_name.setdefault(
146      _column_name_with_class_name(new_instance), new_instance)
147
148
149def serialize_feature_columns(feature_columns):
150  """Serializes a list of FeatureColumns.
151
152  Returns a list of Keras-style config dicts that represent the input
153  FeatureColumns and can be used with `deserialize_feature_columns` for
154  reconstructing the original columns.
155
156  Args:
157    feature_columns: A list of FeatureColumns.
158
159  Returns:
160    Keras serialization for the list of FeatureColumns.
161
162  Raises:
163    ValueError if called with input that is not a list of FeatureColumns.
164  """
165  return [serialize_feature_column(fc) for fc in feature_columns]
166
167
168def deserialize_feature_columns(configs, custom_objects=None):
169  """Deserializes a list of FeatureColumns configs.
170
171  Returns a list of FeatureColumns given a list of config dicts acquired by
172  `serialize_feature_columns`.
173
174  Args:
175    configs: A list of Dicts with the serialization of feature columns acquired
176      by `serialize_feature_columns`.
177    custom_objects: A Dict from custom_object name to the associated keras
178      serializable objects (FeatureColumns, classes or functions).
179
180  Returns:
181    FeatureColumn objects corresponding to the input configs.
182
183  Raises:
184    ValueError if called with input that is not a list of FeatureColumns.
185  """
186  columns_by_name = {}
187  return [
188      deserialize_feature_column(c, custom_objects, columns_by_name)
189      for c in configs
190  ]
191
192
193def _column_name_with_class_name(fc):
194  """Returns a unique name for the feature column used during deduping.
195
196  Without this two FeatureColumns that have the same name and where
197  one wraps the other, such as an IndicatorColumn wrapping a
198  SequenceCategoricalColumn, will fail to deserialize because they will have the
199  same name in columns_by_name, causing the wrong column to be returned.
200
201  Args:
202    fc: A FeatureColumn.
203
204  Returns:
205    A unique name as a string.
206  """
207  return fc.__class__.__name__ + ':' + fc.name
208
209
210def _serialize_keras_object(instance):
211  """Serialize a Keras object into a JSON-compatible representation."""
212  _, instance = tf_decorator.unwrap(instance)
213  if instance is None:
214    return None
215
216  if hasattr(instance, 'get_config'):
217    name = instance.__class__.__name__
218    config = instance.get_config()
219    serialization_config = {}
220    for key, item in config.items():
221      if isinstance(item, six.string_types):
222        serialization_config[key] = item
223        continue
224
225      # Any object of a different type needs to be converted to string or dict
226      # for serialization (e.g. custom functions, custom classes)
227      try:
228        serialized_item = _serialize_keras_object(item)
229        if isinstance(serialized_item, dict) and not isinstance(item, dict):
230          serialized_item['__passive_serialization__'] = True
231        serialization_config[key] = serialized_item
232      except ValueError:
233        serialization_config[key] = item
234
235    return {'class_name': name, 'config': serialization_config}
236  if hasattr(instance, '__name__'):
237    return instance.__name__
238  raise ValueError('Cannot serialize', instance)
239
240
241def _deserialize_keras_object(identifier,
242                              module_objects=None,
243                              custom_objects=None,
244                              printable_module_name='object'):
245  """Turns the serialized form of a Keras object back into an actual object."""
246  if identifier is None:
247    return None
248
249  if isinstance(identifier, dict):
250    # In this case we are dealing with a Keras config dictionary.
251    config = identifier
252    (cls, cls_config) = _class_and_config_for_serialized_keras_object(
253        config, module_objects, custom_objects, printable_module_name)
254
255    if hasattr(cls, 'from_config'):
256      arg_spec = tf_inspect.getfullargspec(cls.from_config)
257      custom_objects = custom_objects or {}
258
259      if 'custom_objects' in arg_spec.args:
260        return cls.from_config(
261            cls_config,
262            custom_objects=dict(
263                list(custom_objects.items())))
264      return cls.from_config(cls_config)
265    else:
266      # Then `cls` may be a function returning a class.
267      # in this case by convention `config` holds
268      # the kwargs of the function.
269      custom_objects = custom_objects or {}
270      return cls(**cls_config)
271  elif isinstance(identifier, six.string_types):
272    object_name = identifier
273    if custom_objects and object_name in custom_objects:
274      obj = custom_objects.get(object_name)
275    else:
276      obj = module_objects.get(object_name)
277      if obj is None:
278        raise ValueError(
279            'Unknown ' + printable_module_name + ': ' + object_name)
280    # Classes passed by name are instantiated with no args, functions are
281    # returned as-is.
282    if tf_inspect.isclass(obj):
283      return obj()
284    return obj
285  elif tf_inspect.isfunction(identifier):
286    # If a function has already been deserialized, return as is.
287    return identifier
288  else:
289    raise ValueError('Could not interpret serialized %s: %s' %
290                     (printable_module_name, identifier))
291
292
293def _class_and_config_for_serialized_keras_object(
294    config,
295    module_objects=None,
296    custom_objects=None,
297    printable_module_name='object'):
298  """Returns the class name and config for a serialized keras object."""
299  if (not isinstance(config, dict) or 'class_name' not in config or
300      'config' not in config):
301    raise ValueError('Improper config format: ' + str(config))
302
303  class_name = config['class_name']
304  cls = _get_registered_object(class_name, custom_objects=custom_objects,
305                               module_objects=module_objects)
306  if cls is None:
307    raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
308
309  cls_config = config['config']
310
311  deserialized_objects = {}
312  for key, item in cls_config.items():
313    if isinstance(item, dict) and '__passive_serialization__' in item:
314      deserialized_objects[key] = _deserialize_keras_object(
315          item,
316          module_objects=module_objects,
317          custom_objects=custom_objects,
318          printable_module_name='config_item')
319    elif (isinstance(item, six.string_types) and
320          tf_inspect.isfunction(_get_registered_object(item, custom_objects))):
321      # Handle custom functions here. When saving functions, we only save the
322      # function's name as a string. If we find a matching string in the custom
323      # objects during deserialization, we convert the string back to the
324      # original function.
325      # Note that a potential issue is that a string field could have a naming
326      # conflict with a custom function name, but this should be a rare case.
327      # This issue does not occur if a string field has a naming conflict with
328      # a custom object, since the config of an object will always be a dict.
329      deserialized_objects[key] = _get_registered_object(item, custom_objects)
330  for key, item in deserialized_objects.items():
331    cls_config[key] = deserialized_objects[key]
332
333  return (cls, cls_config)
334
335
336def _get_registered_object(name, custom_objects=None, module_objects=None):
337  if custom_objects and name in custom_objects:
338    return custom_objects[name]
339  elif module_objects and name in module_objects:
340    return module_objects[name]
341  return None
342
343