• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utils for creating and loading the Layer metadata for SavedModel.
16
17These are required to retain the original format of the build input shape, since
18layers and models may have different build behaviors depending on if the shape
19is a list, tuple, or TensorShape. For example, Network.build() will create
20separate inputs if the given input_shape is a list, and will create a single
21input if the given shape is a tuple.
22"""
23
24import collections
25import enum
26import json
27import numpy as np
28import wrapt
29
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import tensor_shape
32from tensorflow.python.framework import type_spec
33
34
35class Encoder(json.JSONEncoder):
36  """JSON encoder and decoder that handles TensorShapes and tuples."""
37
38  def default(self, obj):  # pylint: disable=method-hidden
39    """Encodes objects for types that aren't handled by the default encoder."""
40    if isinstance(obj, tensor_shape.TensorShape):
41      items = obj.as_list() if obj.rank is not None else None
42      return {'class_name': 'TensorShape', 'items': items}
43    return get_json_type(obj)
44
45  def encode(self, obj):
46    return super(Encoder, self).encode(_encode_tuple(obj))
47
48
49def _encode_tuple(x):
50  if isinstance(x, tuple):
51    return {'class_name': '__tuple__',
52            'items': tuple(_encode_tuple(i) for i in x)}
53  elif isinstance(x, list):
54    return [_encode_tuple(i) for i in x]
55  elif isinstance(x, dict):
56    return {key: _encode_tuple(value) for key, value in x.items()}
57  else:
58    return x
59
60
61def decode(json_string):
62  return json.loads(json_string, object_hook=_decode_helper)
63
64
65def _decode_helper(obj):
66  """A decoding helper that is TF-object aware."""
67  if isinstance(obj, dict) and 'class_name' in obj:
68    if obj['class_name'] == 'TensorShape':
69      return tensor_shape.TensorShape(obj['items'])
70    elif obj['class_name'] == 'TypeSpec':
71      return type_spec.lookup(obj['type_spec'])._deserialize(  # pylint: disable=protected-access
72          _decode_helper(obj['serialized']))
73    elif obj['class_name'] == '__tuple__':
74      return tuple(_decode_helper(i) for i in obj['items'])
75    elif obj['class_name'] == '__ellipsis__':
76      return Ellipsis
77  return obj
78
79
80def get_json_type(obj):
81  """Serializes any object to a JSON-serializable structure.
82
83  Args:
84      obj: the object to serialize
85
86  Returns:
87      JSON-serializable structure representing `obj`.
88
89  Raises:
90      TypeError: if `obj` cannot be serialized.
91  """
92  # if obj is a serializable Keras class instance
93  # e.g. optimizer, layer
94  if hasattr(obj, 'get_config'):
95    return {'class_name': obj.__class__.__name__, 'config': obj.get_config()}
96
97  # if obj is any numpy type
98  if type(obj).__module__ == np.__name__:
99    if isinstance(obj, np.ndarray):
100      return obj.tolist()
101    else:
102      return obj.item()
103
104  # misc functions (e.g. loss function)
105  if callable(obj):
106    return obj.__name__
107
108  # if obj is a python 'type'
109  if type(obj).__name__ == type.__name__:
110    return obj.__name__
111
112  if isinstance(obj, tensor_shape.Dimension):
113    return obj.value
114
115  if isinstance(obj, tensor_shape.TensorShape):
116    return obj.as_list()
117
118  if isinstance(obj, dtypes.DType):
119    return obj.name
120
121  if isinstance(obj, collections.abc.Mapping):
122    return dict(obj)
123
124  if obj is Ellipsis:
125    return {'class_name': '__ellipsis__'}
126
127  if isinstance(obj, wrapt.ObjectProxy):
128    return obj.__wrapped__
129
130  if isinstance(obj, type_spec.TypeSpec):
131    try:
132      type_spec_name = type_spec.get_name(type(obj))
133      return {'class_name': 'TypeSpec', 'type_spec': type_spec_name,
134              'serialized': obj._serialize()}  # pylint: disable=protected-access
135    except ValueError:
136      raise ValueError('Unable to serialize {} to JSON, because the TypeSpec '
137                       'class {} has not been registered.'
138                       .format(obj, type(obj)))
139
140  if isinstance(obj, enum.Enum):
141    return obj.value
142
143  raise TypeError('Not JSON Serializable:', obj)
144