1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Helper classes that list&validate all attributes to serialize to SavedModel.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import abc 22 23from tensorflow.python.keras.saving.saved_model import json_utils 24from tensorflow.python.keras.saving.saved_model import utils 25from tensorflow.python.training.tracking import tracking 26 27 28class SavedModelSaver(object, metaclass=abc.ABCMeta): 29 """Saver defining the methods and properties used to serialize Keras objects. 30 """ 31 32 def __init__(self, obj): 33 self.obj = obj 34 35 @abc.abstractproperty 36 def object_identifier(self): 37 """String stored in object identifier field in the SavedModel proto. 38 39 Returns: 40 A string with the object identifier, which is used at load time. 41 """ 42 raise NotImplementedError 43 44 @property 45 def tracking_metadata(self): 46 """String stored in metadata field in the SavedModel proto. 47 48 Returns: 49 A serialized JSON storing information necessary for recreating this layer. 50 """ 51 # TODO(kathywu): check that serialized JSON can be loaded (e.g., if an 52 # object is in the python property) 53 return json_utils.Encoder().encode(self.python_properties) 54 55 def list_extra_dependencies_for_serialization(self, serialization_cache): 56 """Lists extra dependencies to serialize to SavedModel. 57 58 By overriding this method, extra dependencies can be attached to the 59 serialized Layer. For example, this is used to save the list of `variables` 60 and `trainable_variables`, which are python properties in a Layer object, 61 but are represented as a static list in the SavedModel. 62 63 Args: 64 serialization_cache: A dictionary shared between all objects in the same 65 object graph. This object is passed to both 66 `_list_extra_dependencies_for_serialization` and 67 `_list_functions_for_serialization`. 68 69 Returns: 70 A dictionary mapping attribute names to trackable objects. The entire list 71 of attributes are listed in the `saved_model._LayerAttributes` class. 72 """ 73 if not utils.should_save_traces(): 74 return {} 75 76 return self.objects_to_serialize(serialization_cache) 77 78 def list_functions_for_serialization(self, serialization_cache): 79 """Lists extra functions to serialize to the SavedModel. 80 81 Args: 82 serialization_cache: Dictionary passed to all objects in the same object 83 graph during serialization. 84 85 Returns: 86 A dictionary mapping attribute names to `Function` or 87 `ConcreteFunction`. 88 """ 89 if not utils.should_save_traces(): 90 return {} 91 92 fns = self.functions_to_serialize(serialization_cache) 93 94 # The parent AutoTrackable class saves all user-defined tf.functions, and 95 # returns them in _list_functions_for_serialization(). Add these functions 96 # to the dict. 97 fns.update( 98 tracking.AutoTrackable._list_functions_for_serialization( # pylint:disable=protected-access 99 self.obj, serialization_cache)) 100 return fns 101 102 @abc.abstractproperty 103 def python_properties(self): 104 """Returns dictionary of python properties to save in the metadata. 105 106 This dictionary must be serializable and deserializable to/from JSON. 107 108 When loading, the items in this dict are used to initialize the object and 109 define attributes in the revived object. 110 """ 111 raise NotImplementedError 112 113 @abc.abstractmethod 114 def objects_to_serialize(self, serialization_cache): 115 """Returns dictionary of extra checkpointable objects to serialize. 116 117 See `functions_to_serialize` for an explanation of this function's 118 effects. 119 120 Args: 121 serialization_cache: Dictionary passed to all objects in the same object 122 graph during serialization. 123 124 Returns: 125 A dictionary mapping attribute names to checkpointable objects. 126 """ 127 raise NotImplementedError 128 129 @abc.abstractmethod 130 def functions_to_serialize(self, serialization_cache): 131 """Returns extra functions to include when serializing a Keras object. 132 133 Normally, when calling exporting an object to SavedModel, only the 134 functions and objects defined by the user are saved. For example: 135 136 ``` 137 obj = tf.Module() 138 obj.v = tf.Variable(1.) 139 140 @tf.function 141 def foo(...): ... 142 143 obj.foo = foo 144 145 w = tf.Variable(1.) 146 147 tf.saved_model.save(obj, 'path/to/saved/model') 148 loaded = tf.saved_model.load('path/to/saved/model') 149 150 loaded.v # Variable with the same value as obj.v 151 loaded.foo # Equivalent to obj.foo 152 loaded.w # AttributeError 153 ``` 154 155 Assigning trackable objects to attributes creates a graph, which is used for 156 both checkpointing and SavedModel serialization. 157 158 When the graph generated from attribute tracking is insufficient, extra 159 objects and functions may be added at serialization time. For example, 160 most models do not have their call function wrapped with a @tf.function 161 decorator. This results in `model.call` not being saved. Since Keras objects 162 should be revivable from the SavedModel format, the call function is added 163 as an extra function to serialize. 164 165 This function and `objects_to_serialize` is called multiple times when 166 exporting to SavedModel. Please use the cache to avoid generating new 167 functions and objects. A fresh cache is created for each SavedModel export. 168 169 Args: 170 serialization_cache: Dictionary passed to all objects in the same object 171 graph during serialization. 172 173 Returns: 174 A dictionary mapping attribute names to `Function` or 175 `ConcreteFunction`. 176 """ 177 raise NotImplementedError 178