• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Helper classes that list&validate all attributes to serialize to SavedModel."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import abc
22
23from tensorflow.python.keras.saving.saved_model import json_utils
24from tensorflow.python.keras.saving.saved_model import utils
25from tensorflow.python.training.tracking import tracking
26
27
28class SavedModelSaver(object, metaclass=abc.ABCMeta):
29  """Saver defining the methods and properties used to serialize Keras objects.
30  """
31
32  def __init__(self, obj):
33    self.obj = obj
34
35  @abc.abstractproperty
36  def object_identifier(self):
37    """String stored in object identifier field in the SavedModel proto.
38
39    Returns:
40      A string with the object identifier, which is used at load time.
41    """
42    raise NotImplementedError
43
44  @property
45  def tracking_metadata(self):
46    """String stored in metadata field in the SavedModel proto.
47
48    Returns:
49      A serialized JSON storing information necessary for recreating this layer.
50    """
51    # TODO(kathywu): check that serialized JSON can be loaded (e.g., if an
52    # object is in the python property)
53    return json_utils.Encoder().encode(self.python_properties)
54
55  def list_extra_dependencies_for_serialization(self, serialization_cache):
56    """Lists extra dependencies to serialize to SavedModel.
57
58    By overriding this method, extra dependencies can be attached to the
59    serialized Layer. For example, this is used to save the list of `variables`
60    and `trainable_variables`, which are python properties in a Layer object,
61    but are represented as a static list in the SavedModel.
62
63    Args:
64      serialization_cache: A dictionary shared between all objects in the same
65        object graph. This object is passed to both
66        `_list_extra_dependencies_for_serialization` and
67        `_list_functions_for_serialization`.
68
69    Returns:
70      A dictionary mapping attribute names to trackable objects. The entire list
71      of attributes are listed in the `saved_model._LayerAttributes` class.
72    """
73    if not utils.should_save_traces():
74      return {}
75
76    return self.objects_to_serialize(serialization_cache)
77
78  def list_functions_for_serialization(self, serialization_cache):
79    """Lists extra functions to serialize to the SavedModel.
80
81    Args:
82      serialization_cache: Dictionary passed to all objects in the same object
83        graph during serialization.
84
85    Returns:
86        A dictionary mapping attribute names to `Function` or
87        `ConcreteFunction`.
88    """
89    if not utils.should_save_traces():
90      return {}
91
92    fns = self.functions_to_serialize(serialization_cache)
93
94    # The parent AutoTrackable class saves all user-defined tf.functions, and
95    # returns them in _list_functions_for_serialization(). Add these functions
96    # to the dict.
97    fns.update(
98        tracking.AutoTrackable._list_functions_for_serialization(  # pylint:disable=protected-access
99            self.obj, serialization_cache))
100    return fns
101
102  @abc.abstractproperty
103  def python_properties(self):
104    """Returns dictionary of python properties to save in the metadata.
105
106    This dictionary must be serializable and deserializable to/from JSON.
107
108    When loading, the items in this dict are used to initialize the object and
109    define attributes in the revived object.
110    """
111    raise NotImplementedError
112
113  @abc.abstractmethod
114  def objects_to_serialize(self, serialization_cache):
115    """Returns dictionary of extra checkpointable objects to serialize.
116
117    See `functions_to_serialize` for an explanation of this function's
118    effects.
119
120    Args:
121      serialization_cache: Dictionary passed to all objects in the same object
122        graph during serialization.
123
124    Returns:
125        A dictionary mapping attribute names to checkpointable objects.
126    """
127    raise NotImplementedError
128
129  @abc.abstractmethod
130  def functions_to_serialize(self, serialization_cache):
131    """Returns extra functions to include when serializing a Keras object.
132
133    Normally, when calling exporting an object to SavedModel, only the
134    functions and objects defined by the user are saved. For example:
135
136    ```
137    obj = tf.Module()
138    obj.v = tf.Variable(1.)
139
140    @tf.function
141    def foo(...): ...
142
143    obj.foo = foo
144
145    w = tf.Variable(1.)
146
147    tf.saved_model.save(obj, 'path/to/saved/model')
148    loaded = tf.saved_model.load('path/to/saved/model')
149
150    loaded.v  # Variable with the same value as obj.v
151    loaded.foo  # Equivalent to obj.foo
152    loaded.w  # AttributeError
153    ```
154
155    Assigning trackable objects to attributes creates a graph, which is used for
156    both checkpointing and SavedModel serialization.
157
158    When the graph generated from attribute tracking is insufficient, extra
159    objects and functions may be added at serialization time. For example,
160    most models do not have their call function wrapped with a @tf.function
161    decorator. This results in `model.call` not being saved. Since Keras objects
162    should be revivable from the SavedModel format, the call function is added
163    as an extra function to serialize.
164
165    This function and `objects_to_serialize` is called multiple times when
166    exporting to SavedModel. Please use the cache to avoid generating new
167    functions and objects. A fresh cache is created for each SavedModel export.
168
169    Args:
170      serialization_cache: Dictionary passed to all objects in the same object
171        graph during serialization.
172
173    Returns:
174        A dictionary mapping attribute names to `Function` or
175        `ConcreteFunction`.
176    """
177    raise NotImplementedError
178