• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Handles types registrations for tf.saved_model.load."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.core.framework import versions_pb2
22from tensorflow.core.protobuf import saved_object_graph_pb2
23from tensorflow.python.util.tf_export import tf_export
24
25
26@tf_export("__internal__.saved_model.load.VersionedTypeRegistration", v1=[])
27class VersionedTypeRegistration(object):
28  """Holds information about one version of a revived type."""
29
30  def __init__(self, object_factory, version, min_producer_version,
31               min_consumer_version, bad_consumers=None, setter=setattr):
32    """Identify a revived type version.
33
34    Args:
35      object_factory: A callable which takes a SavedUserObject proto and returns
36        a trackable object. Dependencies are added later via `setter`.
37      version: An integer, the producer version of this wrapper type. When
38        making incompatible changes to a wrapper, add a new
39        `VersionedTypeRegistration` with an incremented `version`. The most
40        recent version will be saved, and all registrations with a matching
41        identifier will be searched for the highest compatible version to use
42        when loading.
43      min_producer_version: The minimum producer version number required to use
44        this `VersionedTypeRegistration` when loading a proto.
45      min_consumer_version: `VersionedTypeRegistration`s with a version number
46        less than `min_consumer_version` will not be used to load a proto saved
47        with this object. `min_consumer_version` should be set to the lowest
48        version number which can successfully load protos saved by this
49        object. If no matching registration is available on load, the object
50        will be revived with a generic trackable type.
51
52        `min_consumer_version` and `bad_consumers` are a blunt tool, and using
53        them will generally break forward compatibility: previous versions of
54        TensorFlow will revive newly saved objects as opaque trackable
55        objects rather than wrapped objects. When updating wrappers, prefer
56        saving new information but preserving compatibility with previous
57        wrapper versions. They are, however, useful for ensuring that
58        previously-released buggy wrapper versions degrade gracefully rather
59        than throwing exceptions when presented with newly-saved SavedModels.
60      bad_consumers: A list of consumer versions which are incompatible (in
61        addition to any version less than `min_consumer_version`).
62      setter: A callable with the same signature as `setattr` to use when adding
63        dependencies to generated objects.
64    """
65    self.setter = setter
66    self.identifier = None  # Set after registration
67    self._object_factory = object_factory
68    self.version = version
69    self._min_consumer_version = min_consumer_version
70    self._min_producer_version = min_producer_version
71    if bad_consumers is None:
72      bad_consumers = []
73    self._bad_consumers = bad_consumers
74
75  def to_proto(self):
76    """Create a SavedUserObject proto."""
77    # For now wrappers just use dependencies to save their state, so the
78    # SavedUserObject doesn't depend on the object being saved.
79    # TODO(allenl): Add a wrapper which uses its own proto.
80    return saved_object_graph_pb2.SavedUserObject(
81        identifier=self.identifier,
82        version=versions_pb2.VersionDef(
83            producer=self.version,
84            min_consumer=self._min_consumer_version,
85            bad_consumers=self._bad_consumers))
86
87  def from_proto(self, proto):
88    """Recreate a trackable object from a SavedUserObject proto."""
89    return self._object_factory(proto)
90
91  def should_load(self, proto):
92    """Checks if this object should load the SavedUserObject `proto`."""
93    if proto.identifier != self.identifier:
94      return False
95    if self.version < proto.version.min_consumer:
96      return False
97    if proto.version.producer < self._min_producer_version:
98      return False
99    for bad_version in proto.version.bad_consumers:
100      if self.version == bad_version:
101        return False
102    return True
103
104
105# string identifier -> (predicate, [VersionedTypeRegistration])
106_REVIVED_TYPE_REGISTRY = {}
107_TYPE_IDENTIFIERS = []
108
109
110@tf_export("__internal__.saved_model.load.register_revived_type", v1=[])
111def register_revived_type(identifier, predicate, versions):
112  """Register a type for revived objects.
113
114  Args:
115    identifier: A unique string identifying this class of objects.
116    predicate: A Boolean predicate for this registration. Takes a
117      trackable object as an argument. If True, `type_registration` may be
118      used to save and restore the object.
119    versions: A list of `VersionedTypeRegistration` objects.
120  """
121  # Keep registrations in order of version. We always use the highest matching
122  # version (respecting the min consumer version and bad consumers).
123  versions.sort(key=lambda reg: reg.version, reverse=True)
124  if not versions:
125    raise AssertionError("Need at least one version of a registered type.")
126  version_numbers = set()
127  for registration in versions:
128    # Copy over the identifier for use in generating protos
129    registration.identifier = identifier
130    if registration.version in version_numbers:
131      raise AssertionError(
132          f"Got multiple registrations with version {registration.version} for "
133          f"type {identifier}.")
134    version_numbers.add(registration.version)
135
136  _REVIVED_TYPE_REGISTRY[identifier] = (predicate, versions)
137  _TYPE_IDENTIFIERS.append(identifier)
138
139
140def serialize(obj):
141  """Create a SavedUserObject from a trackable object."""
142  for identifier in _TYPE_IDENTIFIERS:
143    predicate, versions = _REVIVED_TYPE_REGISTRY[identifier]
144    if predicate(obj):
145      # Always uses the most recent version to serialize.
146      return versions[0].to_proto()
147  return None
148
149
150def deserialize(proto):
151  """Create a trackable object from a SavedUserObject proto.
152
153  Args:
154    proto: A SavedUserObject to deserialize.
155
156  Returns:
157    A tuple of (trackable, assignment_fn) where assignment_fn has the same
158    signature as setattr and should be used to add dependencies to
159    `trackable` when they are available.
160  """
161  _, type_registrations = _REVIVED_TYPE_REGISTRY.get(
162      proto.identifier, (None, None))
163  if type_registrations is not None:
164    for type_registration in type_registrations:
165      if type_registration.should_load(proto):
166        return (type_registration.from_proto(proto), type_registration.setter)
167  return None
168
169
170@tf_export("__internal__.saved_model.load.registered_identifiers", v1=[])
171def registered_identifiers():
172  """Return all the current registered revived object identifiers.
173
174  Returns:
175    A set of strings.
176  """
177  return _REVIVED_TYPE_REGISTRY.keys()
178
179
180@tf_export("__internal__.saved_model.load.get_setter", v1=[])
181def get_setter(proto):
182  """Gets the registered setter function for the SavedUserObject proto.
183
184  See VersionedTypeRegistration for info about the setter function.
185
186  Args:
187    proto: SavedUserObject proto
188
189  Returns:
190    setter function
191  """
192  _, type_registrations = _REVIVED_TYPE_REGISTRY.get(
193      proto.identifier, (None, None))
194  if type_registrations is not None:
195    for type_registration in type_registrations:
196      if type_registration.should_load(proto):
197        return type_registration.setter
198  return None
199