1"""Dependency tracking for trackable objects.""" 2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.eager import def_function 21from tensorflow.python.eager import function as defun 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.training.tracking import base 25from tensorflow.python.training.tracking import data_structures 26from tensorflow.python.util import tf_contextlib 27 28 29# global _RESOURCE_TRACKER_STACK 30_RESOURCE_TRACKER_STACK = [] 31 32 33class NotTrackable(object): 34 """Marks instances of child classes as unsaveable using an object-based API. 35 36 Useful for marking objects which would otherwise look trackable because 37 of inheritance (e.g. through `Layer`) as not trackable. Inheriting from 38 `NotTrackable` does not prevent an object from being assigned to any 39 attributes, but will throw an error on save/restore. 40 """ 41 pass 42 43 44class AutoTrackable(base.Trackable): 45 """Manages dependencies on other objects. 46 47 `Trackable` objects may have dependencies: other `Trackable` objects 48 which should be saved if the object declaring the dependency is saved. A 49 correctly saveable program has a dependency graph such that if changing a 50 global variable affects an object (e.g. changes the behavior of any of its 51 methods) then there is a chain of dependencies from the influenced object to 52 the variable. 53 54 Dependency edges have names, and are created implicitly when a 55 `Trackable` object is assigned to an attribute of another 56 `Trackable` object. For example: 57 58 ``` 59 obj = Trackable() 60 obj.v = ResourceVariable(0.) 61 ``` 62 63 The `Trackable` object `obj` now has a dependency named "v" on a 64 variable. 65 66 `Trackable` objects may specify `Tensor`s to be saved and restored 67 directly (e.g. a `Variable` indicating how to save itself) rather than through 68 dependencies on other objects. See 69 `Trackable._gather_saveables_for_checkpoint` for details. 70 """ 71 72 def __setattr__(self, name, value): 73 """Support self.foo = trackable syntax.""" 74 if getattr(self, "_setattr_tracking", True): 75 value = data_structures.sticky_attribute_assignment( 76 trackable=self, value=value, name=name) 77 super(AutoTrackable, self).__setattr__(name, value) 78 79 def __delattr__(self, name): 80 self._maybe_initialize_trackable() 81 if name in self._unconditional_dependency_names: 82 del self._unconditional_dependency_names[name] 83 for index, (dep_name, _) in enumerate( 84 self._unconditional_checkpoint_dependencies): 85 if dep_name == name: 86 del self._unconditional_checkpoint_dependencies[index] 87 break 88 super(AutoTrackable, self).__delattr__(name) 89 90 def _no_dependency(self, value): 91 """Override to allow TrackableBase to disable dependency tracking.""" 92 return data_structures.NoDependency(value) 93 94 def _list_functions_for_serialization(self): 95 """Return a dict of `Function`s of a trackable.""" 96 functions = dict() 97 for attribute_name in dir(self): 98 try: 99 attribute_value = getattr(self, attribute_name, None) 100 except Exception: # pylint: disable=broad-except 101 # We really don't want to throw an exception just because some object's 102 # attribute accessor is broken. 103 attribute_value = None 104 if isinstance(attribute_value, (def_function.Function, 105 defun.ConcreteFunction)): 106 functions[attribute_name] = attribute_value 107 return functions 108 109 110class ResourceTracker(object): 111 """An object that tracks a list of resources.""" 112 113 def __init__(self): 114 self._resources = [] 115 116 @property 117 def resources(self): 118 return self._resources 119 120 def add_resource(self, resource): 121 self._resources.append(resource) 122 123 124@tf_contextlib.contextmanager 125def resource_tracker_scope(resource_tracker): 126 """A context to manage resource trackers. 127 128 Use this in order to collect up all resources created within a block of code. 129 Example usage: 130 131 ```python 132 resource_tracker = ResourceTracker() 133 with resource_tracker_scope(resource_tracker): 134 resource = TrackableResource() 135 136 assert resource_tracker.resources == [resource] 137 138 Args: 139 resource_tracker: The passed in ResourceTracker object 140 141 Yields: 142 A scope in which the resource_tracker is active. 143 """ 144 global _RESOURCE_TRACKER_STACK 145 old = list(_RESOURCE_TRACKER_STACK) 146 _RESOURCE_TRACKER_STACK.append(resource_tracker) 147 try: 148 yield 149 finally: 150 _RESOURCE_TRACKER_STACK = old 151 152 153class TrackableResource(base.Trackable): 154 """Base class for all resources that need to be tracked.""" 155 156 def __init__(self): 157 global _RESOURCE_TRACKER_STACK 158 for resource_tracker in _RESOURCE_TRACKER_STACK: 159 resource_tracker.add_resource(self) 160 161 self._resource_handle = None 162 163 def _create_resource(self): 164 """A function that creates a resource handle.""" 165 raise NotImplementedError("TrackableResource._create_resource not " 166 "implemented.") 167 168 def _initialize(self): 169 """A function that initializes the resource. Optional.""" 170 pass 171 172 @property 173 def resource_handle(self): 174 """Returns the resource handle associated with this Resource.""" 175 if self._resource_handle is None: 176 self._resource_handle = self._create_resource() 177 return self._resource_handle 178 179 def _list_functions_for_serialization(self): 180 @def_function.function(input_signature=[], autograph=False) 181 def _creator(): 182 resource = self._create_resource() 183 return resource 184 185 @def_function.function(input_signature=[], autograph=False) 186 def _initializer(): 187 self._initialize() 188 return 1 # Dummy return 189 190 return { 191 "_create_resource": _creator, 192 "_initialize": _initializer, 193 } 194 195 196class TrackableAsset(base.Trackable): 197 """Base class for asset files which need to be tracked.""" 198 199 def __init__(self, path): 200 """Record the full path to the asset.""" 201 # The init_scope prevents functions from capturing `path` in an 202 # initialization graph, since it is transient and should not end up in a 203 # serialized function body. 204 with ops.init_scope(): 205 self._path = ops.internal_convert_to_tensor(path, dtype=dtypes.string, 206 name="asset_path") 207 208 @property 209 def asset_path(self): 210 """Fetch the current asset path.""" 211 return self._path 212 213ops.register_tensor_conversion_function( 214 TrackableAsset, 215 lambda asset, **kw: ops.internal_convert_to_tensor(asset.asset_path, **kw)) 216