• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""FeatureColumn serialization, deserialization logic."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import six
22
23from tensorflow.python.feature_column import feature_column_v2 as fc_lib
24from tensorflow.python.feature_column import sequence_feature_column as sfc_lib
25from tensorflow.python.ops import init_ops
26from tensorflow.python.util.lazy_loader import LazyLoader
27
28# Prevent circular dependencies with Keras serialization.
29generic_utils = LazyLoader(
30    'generic_utils', globals(),
31    'tensorflow.python.keras.utils.generic_utils')
32
33_FEATURE_COLUMNS = [
34    fc_lib.BucketizedColumn, fc_lib.CrossedColumn, fc_lib.EmbeddingColumn,
35    fc_lib.HashedCategoricalColumn, fc_lib.IdentityCategoricalColumn,
36    fc_lib.IndicatorColumn, fc_lib.NumericColumn,
37    fc_lib.SequenceCategoricalColumn, fc_lib.SequenceDenseColumn,
38    fc_lib.SharedEmbeddingColumn, fc_lib.VocabularyFileCategoricalColumn,
39    fc_lib.VocabularyListCategoricalColumn, fc_lib.WeightedCategoricalColumn,
40    init_ops.TruncatedNormal, sfc_lib.SequenceNumericColumn
41]
42
43
44def serialize_feature_column(fc):
45  """Serializes a FeatureColumn or a raw string key.
46
47  This method should only be used to serialize parent FeatureColumns when
48  implementing FeatureColumn.get_config(), else serialize_feature_columns()
49  is preferable.
50
51  This serialization also keeps information of the FeatureColumn class, so
52  deserialization is possible without knowing the class type. For example:
53
54  a = numeric_column('x')
55  a.get_config() gives:
56  {
57      'key': 'price',
58      'shape': (1,),
59      'default_value': None,
60      'dtype': 'float32',
61      'normalizer_fn': None
62  }
63  While serialize_feature_column(a) gives:
64  {
65      'class_name': 'NumericColumn',
66      'config': {
67          'key': 'price',
68          'shape': (1,),
69          'default_value': None,
70          'dtype': 'float32',
71          'normalizer_fn': None
72      }
73  }
74
75  Args:
76    fc: A FeatureColumn or raw feature key string.
77
78  Returns:
79    Keras serialization for FeatureColumns, leaves string keys unaffected.
80
81  Raises:
82    ValueError if called with input that is not string or FeatureColumn.
83  """
84  if isinstance(fc, six.string_types):
85    return fc
86  elif isinstance(fc, fc_lib.FeatureColumn):
87    return generic_utils.serialize_keras_class_and_config(
88        fc.__class__.__name__, fc.get_config())  # pylint: disable=protected-access
89  else:
90    raise ValueError('Instance: {} is not a FeatureColumn'.format(fc))
91
92
93def deserialize_feature_column(config,
94                               custom_objects=None,
95                               columns_by_name=None):
96  """Deserializes a `config` generated with `serialize_feature_column`.
97
98  This method should only be used to deserialize parent FeatureColumns when
99  implementing FeatureColumn.from_config(), else deserialize_feature_columns()
100  is preferable. Returns a FeatureColumn for this config.
101  TODO(b/118939620): Simplify code if Keras utils support object deduping.
102
103  Args:
104    config: A Dict with the serialization of feature columns acquired by
105      `serialize_feature_column`, or a string representing a raw column.
106    custom_objects: A Dict from custom_object name to the associated keras
107      serializable objects (FeatureColumns, classes or functions).
108    columns_by_name: A Dict[String, FeatureColumn] of existing columns in order
109      to avoid duplication.
110
111  Raises:
112    ValueError if `config` has invalid format (e.g: expected keys missing,
113    or refers to unknown classes).
114
115  Returns:
116    A FeatureColumn corresponding to the input `config`.
117  """
118  if isinstance(config, six.string_types):
119    return config
120  # A dict from class_name to class for all FeatureColumns in this module.
121  # FeatureColumns not part of the module can be passed as custom_objects.
122  module_feature_column_classes = {
123      cls.__name__: cls for cls in _FEATURE_COLUMNS}
124  if columns_by_name is None:
125    columns_by_name = {}
126
127  (cls,
128   cls_config) = generic_utils.class_and_config_for_serialized_keras_object(
129       config,
130       module_objects=module_feature_column_classes,
131       custom_objects=custom_objects,
132       printable_module_name='feature_column_v2')
133
134  if not issubclass(cls, fc_lib.FeatureColumn):
135    raise ValueError(
136        'Expected FeatureColumn class, instead found: {}'.format(cls))
137
138  # Always deserialize the FeatureColumn, in order to get the name.
139  new_instance = cls.from_config(  # pylint: disable=protected-access
140      cls_config,
141      custom_objects=custom_objects,
142      columns_by_name=columns_by_name)
143
144  # If the name already exists, re-use the column from columns_by_name,
145  # (new_instance remains unused).
146  return columns_by_name.setdefault(
147      _column_name_with_class_name(new_instance), new_instance)
148
149
150def serialize_feature_columns(feature_columns):
151  """Serializes a list of FeatureColumns.
152
153  Returns a list of Keras-style config dicts that represent the input
154  FeatureColumns and can be used with `deserialize_feature_columns` for
155  reconstructing the original columns.
156
157  Args:
158    feature_columns: A list of FeatureColumns.
159
160  Returns:
161    Keras serialization for the list of FeatureColumns.
162
163  Raises:
164    ValueError if called with input that is not a list of FeatureColumns.
165  """
166  return [serialize_feature_column(fc) for fc in feature_columns]
167
168
169def deserialize_feature_columns(configs, custom_objects=None):
170  """Deserializes a list of FeatureColumns configs.
171
172  Returns a list of FeatureColumns given a list of config dicts acquired by
173  `serialize_feature_columns`.
174
175  Args:
176    configs: A list of Dicts with the serialization of feature columns acquired
177      by `serialize_feature_columns`.
178    custom_objects: A Dict from custom_object name to the associated keras
179      serializable objects (FeatureColumns, classes or functions).
180
181  Returns:
182    FeatureColumn objects corresponding to the input configs.
183
184  Raises:
185    ValueError if called with input that is not a list of FeatureColumns.
186  """
187  columns_by_name = {}
188  return [
189      deserialize_feature_column(c, custom_objects, columns_by_name)
190      for c in configs
191  ]
192
193
194def _column_name_with_class_name(fc):
195  """Returns a unique name for the feature column used during deduping.
196
197  Without this two FeatureColumns that have the same name and where
198  one wraps the other, such as an IndicatorColumn wrapping a
199  SequenceCategoricalColumn, will fail to deserialize because they will have the
200  same name in colums_by_name, causing the wrong column to be returned.
201
202  Args:
203    fc: A FeatureColumn.
204
205  Returns:
206    A unique name as a string.
207  """
208  return fc.__class__.__name__ + ':' + fc.name
209