1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utils for creating and loading the Layer metadata for SavedModel. 16 17These are required to retain the original format of the build input shape, since 18layers and models may have different build behaviors depending on if the shape 19is a list, tuple, or TensorShape. For example, Network.build() will create 20separate inputs if the given input_shape is a list, and will create a single 21input if the given shape is a tuple. 22""" 23 24import collections 25import enum 26import json 27import numpy as np 28import wrapt 29 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import tensor_shape 32from tensorflow.python.framework import type_spec 33 34 35class Encoder(json.JSONEncoder): 36 """JSON encoder and decoder that handles TensorShapes and tuples.""" 37 38 def default(self, obj): # pylint: disable=method-hidden 39 """Encodes objects for types that aren't handled by the default encoder.""" 40 if isinstance(obj, tensor_shape.TensorShape): 41 items = obj.as_list() if obj.rank is not None else None 42 return {'class_name': 'TensorShape', 'items': items} 43 return get_json_type(obj) 44 45 def encode(self, obj): 46 return super(Encoder, self).encode(_encode_tuple(obj)) 47 48 49def _encode_tuple(x): 50 if isinstance(x, tuple): 51 return {'class_name': '__tuple__', 52 'items': tuple(_encode_tuple(i) for i in x)} 53 elif isinstance(x, list): 54 return [_encode_tuple(i) for i in x] 55 elif isinstance(x, dict): 56 return {key: _encode_tuple(value) for key, value in x.items()} 57 else: 58 return x 59 60 61def decode(json_string): 62 return json.loads(json_string, object_hook=_decode_helper) 63 64 65def _decode_helper(obj): 66 """A decoding helper that is TF-object aware.""" 67 if isinstance(obj, dict) and 'class_name' in obj: 68 if obj['class_name'] == 'TensorShape': 69 return tensor_shape.TensorShape(obj['items']) 70 elif obj['class_name'] == 'TypeSpec': 71 return type_spec.lookup(obj['type_spec'])._deserialize( # pylint: disable=protected-access 72 _decode_helper(obj['serialized'])) 73 elif obj['class_name'] == '__tuple__': 74 return tuple(_decode_helper(i) for i in obj['items']) 75 elif obj['class_name'] == '__ellipsis__': 76 return Ellipsis 77 return obj 78 79 80def get_json_type(obj): 81 """Serializes any object to a JSON-serializable structure. 82 83 Args: 84 obj: the object to serialize 85 86 Returns: 87 JSON-serializable structure representing `obj`. 88 89 Raises: 90 TypeError: if `obj` cannot be serialized. 91 """ 92 # if obj is a serializable Keras class instance 93 # e.g. optimizer, layer 94 if hasattr(obj, 'get_config'): 95 return {'class_name': obj.__class__.__name__, 'config': obj.get_config()} 96 97 # if obj is any numpy type 98 if type(obj).__module__ == np.__name__: 99 if isinstance(obj, np.ndarray): 100 return obj.tolist() 101 else: 102 return obj.item() 103 104 # misc functions (e.g. loss function) 105 if callable(obj): 106 return obj.__name__ 107 108 # if obj is a python 'type' 109 if type(obj).__name__ == type.__name__: 110 return obj.__name__ 111 112 if isinstance(obj, tensor_shape.Dimension): 113 return obj.value 114 115 if isinstance(obj, tensor_shape.TensorShape): 116 return obj.as_list() 117 118 if isinstance(obj, dtypes.DType): 119 return obj.name 120 121 if isinstance(obj, collections.abc.Mapping): 122 return dict(obj) 123 124 if obj is Ellipsis: 125 return {'class_name': '__ellipsis__'} 126 127 if isinstance(obj, wrapt.ObjectProxy): 128 return obj.__wrapped__ 129 130 if isinstance(obj, type_spec.TypeSpec): 131 try: 132 type_spec_name = type_spec.get_name(type(obj)) 133 return {'class_name': 'TypeSpec', 'type_spec': type_spec_name, 134 'serialized': obj._serialize()} # pylint: disable=protected-access 135 except ValueError: 136 raise ValueError('Unable to serialize {} to JSON, because the TypeSpec ' 137 'class {} has not been registered.' 138 .format(obj, type(obj))) 139 140 if isinstance(obj, enum.Enum): 141 return obj.value 142 143 raise TypeError('Not JSON Serializable:', obj) 144