1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Handles types registrations for tf.saved_model.load.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.core.framework import versions_pb2 22from tensorflow.core.protobuf import saved_object_graph_pb2 23from tensorflow.python.util.tf_export import tf_export 24 25 26@tf_export("__internal__.saved_model.load.VersionedTypeRegistration", v1=[]) 27class VersionedTypeRegistration(object): 28 """Holds information about one version of a revived type.""" 29 30 def __init__(self, object_factory, version, min_producer_version, 31 min_consumer_version, bad_consumers=None, setter=setattr): 32 """Identify a revived type version. 33 34 Args: 35 object_factory: A callable which takes a SavedUserObject proto and returns 36 a trackable object. Dependencies are added later via `setter`. 37 version: An integer, the producer version of this wrapper type. When 38 making incompatible changes to a wrapper, add a new 39 `VersionedTypeRegistration` with an incremented `version`. The most 40 recent version will be saved, and all registrations with a matching 41 identifier will be searched for the highest compatible version to use 42 when loading. 43 min_producer_version: The minimum producer version number required to use 44 this `VersionedTypeRegistration` when loading a proto. 45 min_consumer_version: `VersionedTypeRegistration`s with a version number 46 less than `min_consumer_version` will not be used to load a proto saved 47 with this object. `min_consumer_version` should be set to the lowest 48 version number which can successfully load protos saved by this 49 object. If no matching registration is available on load, the object 50 will be revived with a generic trackable type. 51 52 `min_consumer_version` and `bad_consumers` are a blunt tool, and using 53 them will generally break forward compatibility: previous versions of 54 TensorFlow will revive newly saved objects as opaque trackable 55 objects rather than wrapped objects. When updating wrappers, prefer 56 saving new information but preserving compatibility with previous 57 wrapper versions. They are, however, useful for ensuring that 58 previously-released buggy wrapper versions degrade gracefully rather 59 than throwing exceptions when presented with newly-saved SavedModels. 60 bad_consumers: A list of consumer versions which are incompatible (in 61 addition to any version less than `min_consumer_version`). 62 setter: A callable with the same signature as `setattr` to use when adding 63 dependencies to generated objects. 64 """ 65 self.setter = setter 66 self.identifier = None # Set after registration 67 self._object_factory = object_factory 68 self.version = version 69 self._min_consumer_version = min_consumer_version 70 self._min_producer_version = min_producer_version 71 if bad_consumers is None: 72 bad_consumers = [] 73 self._bad_consumers = bad_consumers 74 75 def to_proto(self): 76 """Create a SavedUserObject proto.""" 77 # For now wrappers just use dependencies to save their state, so the 78 # SavedUserObject doesn't depend on the object being saved. 79 # TODO(allenl): Add a wrapper which uses its own proto. 80 return saved_object_graph_pb2.SavedUserObject( 81 identifier=self.identifier, 82 version=versions_pb2.VersionDef( 83 producer=self.version, 84 min_consumer=self._min_consumer_version, 85 bad_consumers=self._bad_consumers)) 86 87 def from_proto(self, proto): 88 """Recreate a trackable object from a SavedUserObject proto.""" 89 return self._object_factory(proto) 90 91 def should_load(self, proto): 92 """Checks if this object should load the SavedUserObject `proto`.""" 93 if proto.identifier != self.identifier: 94 return False 95 if self.version < proto.version.min_consumer: 96 return False 97 if proto.version.producer < self._min_producer_version: 98 return False 99 for bad_version in proto.version.bad_consumers: 100 if self.version == bad_version: 101 return False 102 return True 103 104 105# string identifier -> (predicate, [VersionedTypeRegistration]) 106_REVIVED_TYPE_REGISTRY = {} 107_TYPE_IDENTIFIERS = [] 108 109 110@tf_export("__internal__.saved_model.load.register_revived_type", v1=[]) 111def register_revived_type(identifier, predicate, versions): 112 """Register a type for revived objects. 113 114 Args: 115 identifier: A unique string identifying this class of objects. 116 predicate: A Boolean predicate for this registration. Takes a 117 trackable object as an argument. If True, `type_registration` may be 118 used to save and restore the object. 119 versions: A list of `VersionedTypeRegistration` objects. 120 """ 121 # Keep registrations in order of version. We always use the highest matching 122 # version (respecting the min consumer version and bad consumers). 123 versions.sort(key=lambda reg: reg.version, reverse=True) 124 if not versions: 125 raise AssertionError("Need at least one version of a registered type.") 126 version_numbers = set() 127 for registration in versions: 128 # Copy over the identifier for use in generating protos 129 registration.identifier = identifier 130 if registration.version in version_numbers: 131 raise AssertionError( 132 f"Got multiple registrations with version {registration.version} for " 133 f"type {identifier}.") 134 version_numbers.add(registration.version) 135 136 _REVIVED_TYPE_REGISTRY[identifier] = (predicate, versions) 137 _TYPE_IDENTIFIERS.append(identifier) 138 139 140def serialize(obj): 141 """Create a SavedUserObject from a trackable object.""" 142 for identifier in _TYPE_IDENTIFIERS: 143 predicate, versions = _REVIVED_TYPE_REGISTRY[identifier] 144 if predicate(obj): 145 # Always uses the most recent version to serialize. 146 return versions[0].to_proto() 147 return None 148 149 150def deserialize(proto): 151 """Create a trackable object from a SavedUserObject proto. 152 153 Args: 154 proto: A SavedUserObject to deserialize. 155 156 Returns: 157 A tuple of (trackable, assignment_fn) where assignment_fn has the same 158 signature as setattr and should be used to add dependencies to 159 `trackable` when they are available. 160 """ 161 _, type_registrations = _REVIVED_TYPE_REGISTRY.get( 162 proto.identifier, (None, None)) 163 if type_registrations is not None: 164 for type_registration in type_registrations: 165 if type_registration.should_load(proto): 166 return (type_registration.from_proto(proto), type_registration.setter) 167 return None 168 169 170@tf_export("__internal__.saved_model.load.registered_identifiers", v1=[]) 171def registered_identifiers(): 172 """Return all the current registered revived object identifiers. 173 174 Returns: 175 A set of strings. 176 """ 177 return _REVIVED_TYPE_REGISTRY.keys() 178 179 180@tf_export("__internal__.saved_model.load.get_setter", v1=[]) 181def get_setter(proto): 182 """Gets the registered setter function for the SavedUserObject proto. 183 184 See VersionedTypeRegistration for info about the setter function. 185 186 Args: 187 proto: SavedUserObject proto 188 189 Returns: 190 setter function 191 """ 192 _, type_registrations = _REVIVED_TYPE_REGISTRY.get( 193 proto.identifier, (None, None)) 194 if type_registrations is not None: 195 for type_registration in type_registrations: 196 if type_registration.should_load(proto): 197 return type_registration.setter 198 return None 199