1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Handles types registrations for tf.saved_model.load.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.core.framework import versions_pb2 22from tensorflow.core.protobuf import saved_object_graph_pb2 23 24 25class VersionedTypeRegistration(object): 26 """Holds information about one version of a revived type.""" 27 28 def __init__(self, object_factory, version, min_producer_version, 29 min_consumer_version, bad_consumers=None, setter=setattr): 30 """Identify a revived type version. 31 32 Args: 33 object_factory: A callable which takes a SavedUserObject proto and returns 34 a trackable object. Dependencies are added later via `setter`. 35 version: An integer, the producer version of this wrapper type. When 36 making incompatible changes to a wrapper, add a new 37 `VersionedTypeRegistration` with an incremented `version`. The most 38 recent version will be saved, and all registrations with a matching 39 identifier will be searched for the highest compatible version to use 40 when loading. 41 min_producer_version: The minimum producer version number required to use 42 this `VersionedTypeRegistration` when loading a proto. 43 min_consumer_version: `VersionedTypeRegistration`s with a version number 44 less than `min_consumer_version` will not be used to load a proto saved 45 with this object. `min_consumer_version` should be set to the lowest 46 version number which can successfully load protos saved by this 47 object. If no matching registration is available on load, the object 48 will be revived with a generic trackable type. 49 50 `min_consumer_version` and `bad_consumers` are a blunt tool, and using 51 them will generally break forward compatibility: previous versions of 52 TensorFlow will revive newly saved objects as opaque trackable 53 objects rather than wrapped objects. When updating wrappers, prefer 54 saving new information but preserving compatibility with previous 55 wrapper versions. They are, however, useful for ensuring that 56 previously-released buggy wrapper versions degrade gracefully rather 57 than throwing exceptions when presented with newly-saved SavedModels. 58 bad_consumers: A list of consumer versions which are incompatible (in 59 addition to any version less than `min_consumer_version`). 60 setter: A callable with the same signature as `setattr` to use when adding 61 dependencies to generated objects. 62 """ 63 self.setter = setter 64 self.identifier = None # Set after registration 65 self._object_factory = object_factory 66 self.version = version 67 self._min_consumer_version = min_consumer_version 68 self._min_producer_version = min_producer_version 69 if bad_consumers is None: 70 bad_consumers = [] 71 self._bad_consumers = bad_consumers 72 73 def to_proto(self): 74 """Create a SavedUserObject proto.""" 75 # For now wrappers just use dependencies to save their state, so the 76 # SavedUserObject doesn't depend on the object being saved. 77 # TODO(allenl): Add a wrapper which uses its own proto. 78 return saved_object_graph_pb2.SavedUserObject( 79 identifier=self.identifier, 80 version=versions_pb2.VersionDef( 81 producer=self.version, 82 min_consumer=self._min_consumer_version, 83 bad_consumers=self._bad_consumers)) 84 85 def from_proto(self, proto): 86 """Recreate a trackable object from a SavedUserObject proto.""" 87 return self._object_factory(proto) 88 89 def should_load(self, proto): 90 """Checks if this object should load the SavedUserObject `proto`.""" 91 if proto.identifier != self.identifier: 92 return False 93 if self.version < proto.version.min_consumer: 94 return False 95 if proto.version.producer < self._min_producer_version: 96 return False 97 for bad_version in proto.version.bad_consumers: 98 if self.version == bad_version: 99 return False 100 return True 101 102 103# string identifier -> (predicate, [VersionedTypeRegistration]) 104_REVIVED_TYPE_REGISTRY = {} 105_TYPE_IDENTIFIERS = [] 106 107 108def register_revived_type(identifier, predicate, versions): 109 """Register a type for revived objects. 110 111 Args: 112 identifier: A unique string identifying this class of objects. 113 predicate: A Boolean predicate for this registration. Takes a 114 trackable object as an argument. If True, `type_registration` may be 115 used to save and restore the object. 116 versions: A list of `VersionedTypeRegistration` objects. 117 """ 118 # Keep registrations in order of version. We always use the highest matching 119 # version (respecting the min consumer version and bad consumers). 120 versions.sort(key=lambda reg: reg.version, reverse=True) 121 if not versions: 122 raise AssertionError("Need at least one version of a registered type.") 123 version_numbers = set() 124 for registration in versions: 125 # Copy over the identifier for use in generating protos 126 registration.identifier = identifier 127 if registration.version in version_numbers: 128 raise AssertionError( 129 "Got multiple registrations with version {} for type {}".format( 130 registration.version, identifier)) 131 version_numbers.add(registration.version) 132 if identifier in _REVIVED_TYPE_REGISTRY: 133 raise AssertionError( 134 "Duplicate registrations for type {}".format(identifier)) 135 136 _REVIVED_TYPE_REGISTRY[identifier] = (predicate, versions) 137 _TYPE_IDENTIFIERS.append(identifier) 138 139 140def serialize(obj): 141 """Create a SavedUserObject from a trackable object.""" 142 for identifier in _TYPE_IDENTIFIERS: 143 predicate, versions = _REVIVED_TYPE_REGISTRY[identifier] 144 if predicate(obj): 145 # Always uses the most recent version to serialize. 146 return versions[0].to_proto() 147 return None 148 149 150def deserialize(proto): 151 """Create a trackable object from a SavedUserObject proto. 152 153 Args: 154 proto: A SavedUserObject to deserialize. 155 156 Returns: 157 A tuple of (trackable, assignment_fn) where assignment_fn has the same 158 signature as setattr and should be used to add dependencies to 159 `trackable` when they are available. 160 """ 161 _, type_registrations = _REVIVED_TYPE_REGISTRY.get( 162 proto.identifier, (None, None)) 163 if type_registrations is not None: 164 for type_registration in type_registrations: 165 if type_registration.should_load(proto): 166 return (type_registration.from_proto(proto), type_registration.setter) 167 return None 168