1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Helper classes that list&validate all attributes to serialize to SavedModel.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import abc 22import six 23 24from tensorflow.python.keras.saving.saved_model import json_utils 25from tensorflow.python.keras.saving.saved_model import utils 26from tensorflow.python.training.tracking import tracking 27 28 29@six.add_metaclass(abc.ABCMeta) 30class SavedModelSaver(object): 31 """Saver defining the methods and properties used to serialize Keras objects. 32 """ 33 34 def __init__(self, obj): 35 self.obj = obj 36 37 @abc.abstractproperty 38 def object_identifier(self): 39 """String stored in object identifier field in the SavedModel proto. 40 41 Returns: 42 A string with the object identifier, which is used at load time. 43 """ 44 raise NotImplementedError 45 46 @property 47 def tracking_metadata(self): 48 """String stored in metadata field in the SavedModel proto. 49 50 Returns: 51 A serialized JSON storing information necessary for recreating this layer. 52 """ 53 # TODO(kathywu): check that serialized JSON can be loaded (e.g., if an 54 # object is in the python property) 55 return json_utils.Encoder().encode(self.python_properties) 56 57 def list_extra_dependencies_for_serialization(self, serialization_cache): 58 """Lists extra dependencies to serialize to SavedModel. 59 60 By overriding this method, extra dependencies can be attached to the 61 serialized Layer. For example, this is used to save the list of `variables` 62 and `trainable_variables`, which are python properties in a Layer object, 63 but are represented as a static list in the SavedModel. 64 65 Args: 66 serialization_cache: A dictionary shared between all objects in the same 67 object graph. This object is passed to both 68 `_list_extra_dependencies_for_serialization` and 69 `_list_functions_for_serialization`. 70 71 Returns: 72 A dictionary mapping attribute names to trackable objects. The entire list 73 of attributes are listed in the `saved_model._LayerAttributes` class. 74 """ 75 if not utils.should_save_traces(): 76 return {} 77 78 return self.objects_to_serialize(serialization_cache) 79 80 def list_functions_for_serialization(self, serialization_cache): 81 """Lists extra functions to serialize to the SavedModel. 82 83 Args: 84 serialization_cache: Dictionary passed to all objects in the same object 85 graph during serialization. 86 87 Returns: 88 A dictionary mapping attribute names to `Function` or 89 `ConcreteFunction`. 90 """ 91 if not utils.should_save_traces(): 92 return {} 93 94 fns = self.functions_to_serialize(serialization_cache) 95 96 # The parent AutoTrackable class saves all user-defined tf.functions, and 97 # returns them in _list_functions_for_serialization(). Add these functions 98 # to the dict. 99 fns.update( 100 tracking.AutoTrackable._list_functions_for_serialization( # pylint:disable=protected-access 101 self.obj, serialization_cache)) 102 return fns 103 104 @abc.abstractproperty 105 def python_properties(self): 106 """Returns dictionary of python properties to save in the metadata. 107 108 This dictionary must be serializable and deserializable to/from JSON. 109 110 When loading, the items in this dict are used to initialize the object and 111 define attributes in the revived object. 112 """ 113 raise NotImplementedError 114 115 @abc.abstractmethod 116 def objects_to_serialize(self, serialization_cache): 117 """Returns dictionary of extra checkpointable objects to serialize. 118 119 See `functions_to_serialize` for an explanation of this function's 120 effects. 121 122 Args: 123 serialization_cache: Dictionary passed to all objects in the same object 124 graph during serialization. 125 126 Returns: 127 A dictionary mapping attribute names to checkpointable objects. 128 """ 129 raise NotImplementedError 130 131 @abc.abstractmethod 132 def functions_to_serialize(self, serialization_cache): 133 """Returns extra functions to include when serializing a Keras object. 134 135 Normally, when calling exporting an object to SavedModel, only the 136 functions and objects defined by the user are saved. For example: 137 138 ``` 139 obj = tf.Module() 140 obj.v = tf.Variable(1.) 141 142 @tf.function 143 def foo(...): ... 144 145 obj.foo = foo 146 147 w = tf.Variable(1.) 148 149 tf.saved_model.save(obj, 'path/to/saved/model') 150 loaded = tf.saved_model.load('path/to/saved/model') 151 152 loaded.v # Variable with the same value as obj.v 153 loaded.foo # Equivalent to obj.foo 154 loaded.w # AttributeError 155 ``` 156 157 Assigning trackable objects to attributes creates a graph, which is used for 158 both checkpointing and SavedModel serialization. 159 160 When the graph generated from attribute tracking is insufficient, extra 161 objects and functions may be added at serialization time. For example, 162 most models do not have their call function wrapped with a @tf.function 163 decorator. This results in `model.call` not being saved. Since Keras objects 164 should be revivable from the SavedModel format, the call function is added 165 as an extra function to serialize. 166 167 This function and `objects_to_serialize` is called multiple times when 168 exporting to SavedModel. Please use the cache to avoid generating new 169 functions and objects. A fresh cache is created for each SavedModel export. 170 171 Args: 172 serialization_cache: Dictionary passed to all objects in the same object 173 graph during serialization. 174 175 Returns: 176 A dictionary mapping attribute names to `Function` or 177 `ConcreteFunction`. 178 """ 179 raise NotImplementedError 180