• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Helper classes that list&validate all attributes to serialize to SavedModel."""
16
17import abc
18
19from tensorflow.python.keras.saving.saved_model import json_utils
20from tensorflow.python.keras.saving.saved_model import utils
21
22
23class SavedModelSaver(object, metaclass=abc.ABCMeta):
24  """Saver defining the methods and properties used to serialize Keras objects.
25  """
26
27  def __init__(self, obj):
28    self.obj = obj
29
30  @abc.abstractproperty
31  def object_identifier(self):
32    """String stored in object identifier field in the SavedModel proto.
33
34    Returns:
35      A string with the object identifier, which is used at load time.
36    """
37    raise NotImplementedError
38
39  @property
40  def tracking_metadata(self):
41    """String stored in metadata field in the SavedModel proto.
42
43    Returns:
44      A serialized JSON storing information necessary for recreating this layer.
45    """
46    # TODO(kathywu): check that serialized JSON can be loaded (e.g., if an
47    # object is in the python property)
48    return json_utils.Encoder().encode(self.python_properties)
49
50  def trackable_children(self, serialization_cache):
51    """Lists all Trackable children connected to this object."""
52    if not utils.should_save_traces():
53      return {}
54
55    children = self.objects_to_serialize(serialization_cache)
56    children.update(self.functions_to_serialize(serialization_cache))
57    return children
58
59  @abc.abstractproperty
60  def python_properties(self):
61    """Returns dictionary of python properties to save in the metadata.
62
63    This dictionary must be serializable and deserializable to/from JSON.
64
65    When loading, the items in this dict are used to initialize the object and
66    define attributes in the revived object.
67    """
68    raise NotImplementedError
69
70  @abc.abstractmethod
71  def objects_to_serialize(self, serialization_cache):
72    """Returns dictionary of extra checkpointable objects to serialize.
73
74    See `functions_to_serialize` for an explanation of this function's
75    effects.
76
77    Args:
78      serialization_cache: Dictionary passed to all objects in the same object
79        graph during serialization.
80
81    Returns:
82        A dictionary mapping attribute names to checkpointable objects.
83    """
84    raise NotImplementedError
85
86  @abc.abstractmethod
87  def functions_to_serialize(self, serialization_cache):
88    """Returns extra functions to include when serializing a Keras object.
89
90    Normally, when calling exporting an object to SavedModel, only the
91    functions and objects defined by the user are saved. For example:
92
93    ```
94    obj = tf.Module()
95    obj.v = tf.Variable(1.)
96
97    @tf.function
98    def foo(...): ...
99
100    obj.foo = foo
101
102    w = tf.Variable(1.)
103
104    tf.saved_model.save(obj, 'path/to/saved/model')
105    loaded = tf.saved_model.load('path/to/saved/model')
106
107    loaded.v  # Variable with the same value as obj.v
108    loaded.foo  # Equivalent to obj.foo
109    loaded.w  # AttributeError
110    ```
111
112    Assigning trackable objects to attributes creates a graph, which is used for
113    both checkpointing and SavedModel serialization.
114
115    When the graph generated from attribute tracking is insufficient, extra
116    objects and functions may be added at serialization time. For example,
117    most models do not have their call function wrapped with a @tf.function
118    decorator. This results in `model.call` not being saved. Since Keras objects
119    should be revivable from the SavedModel format, the call function is added
120    as an extra function to serialize.
121
122    This function and `objects_to_serialize` is called multiple times when
123    exporting to SavedModel. Please use the cache to avoid generating new
124    functions and objects. A fresh cache is created for each SavedModel export.
125
126    Args:
127      serialization_cache: Dictionary passed to all objects in the same object
128        graph during serialization.
129
130    Returns:
131        A dictionary mapping attribute names to `Function` or
132        `ConcreteFunction`.
133    """
134    raise NotImplementedError
135