1"""Dependency tracking for trackable objects.""" 2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import copy 21import warnings 22 23from absl import logging 24 25from tensorflow.python.eager import context 26from tensorflow.python.eager import def_function 27from tensorflow.python.eager import function as defun 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import ops 30from tensorflow.python.training.tracking import base 31from tensorflow.python.training.tracking import data_structures 32from tensorflow.python.util import tf_contextlib 33from tensorflow.python.util.tf_export import tf_export 34 35 36# global _RESOURCE_TRACKER_STACK 37_RESOURCE_TRACKER_STACK = [] 38 39 40class NotTrackable(object): 41 """Marks instances of child classes as unsaveable using an object-based API. 42 43 Useful for marking objects which would otherwise look trackable because 44 of inheritance (e.g. through `Layer`) as not trackable. Inheriting from 45 `NotTrackable` does not prevent an object from being assigned to any 46 attributes, but will throw an error on save/restore. 47 """ 48 pass 49 50 51@tf_export("__internal__.tracking.AutoTrackable", v1=[]) 52class AutoTrackable(base.Trackable): 53 """Manages dependencies on other objects. 54 55 `Trackable` objects may have dependencies: other `Trackable` objects 56 which should be saved if the object declaring the dependency is saved. A 57 correctly saveable program has a dependency graph such that if changing a 58 global variable affects an object (e.g. changes the behavior of any of its 59 methods) then there is a chain of dependencies from the influenced object to 60 the variable. 61 62 Dependency edges have names, and are created implicitly when a 63 `Trackable` object is assigned to an attribute of another 64 `Trackable` object. For example: 65 66 ``` 67 obj = Trackable() 68 obj.v = ResourceVariable(0.) 69 ``` 70 71 The `Trackable` object `obj` now has a dependency named "v" on a 72 variable. 73 74 `Trackable` objects may specify `Tensor`s to be saved and restored 75 directly (e.g. a `Variable` indicating how to save itself) rather than through 76 dependencies on other objects. See 77 `Trackable._gather_saveables_for_checkpoint` for details. 78 """ 79 80 def __setattr__(self, name, value): 81 """Support self.foo = trackable syntax.""" 82 try: 83 if getattr(self, name) is value: 84 # Short circuit for `self.$x = self.$x`. 85 return 86 except AttributeError: 87 pass 88 89 if getattr(self, "_self_setattr_tracking", True): 90 value = data_structures.sticky_attribute_assignment( 91 trackable=self, value=value, name=name) 92 super(AutoTrackable, self).__setattr__(name, value) 93 94 def __delattr__(self, name): 95 self._maybe_initialize_trackable() 96 delete_tracking(self, name) 97 super(AutoTrackable, self).__delattr__(name) 98 99 def _no_dependency(self, value): 100 """Override to allow TrackableBase to disable dependency tracking.""" 101 return data_structures.NoDependency(value) 102 103 def _list_functions_for_serialization(self, unused_serialization_cache): 104 """Return a dict of `Function`s of a trackable.""" 105 functions = {} 106 for attribute_name in dir(self): 107 # We get the attributes, suppressing warnings and exceptions. 108 logging_verbosity = logging.get_verbosity() 109 try: 110 logging.set_verbosity(logging.FATAL) 111 with warnings.catch_warnings(): 112 warnings.simplefilter("ignore") 113 attribute_value = getattr(self, attribute_name, None) 114 except Exception: # pylint: disable=broad-except 115 # We really don't want to throw an exception just because some object's 116 # attribute accessor is broken. 117 attribute_value = None 118 finally: 119 # We reset the verbosity setting in a `finally` block, to make 120 # sure it always happens, even if we make the exception catching above 121 # be less broad. 122 logging.set_verbosity(logging_verbosity) 123 if isinstance(attribute_value, (def_function.Function, 124 defun.ConcreteFunction)): 125 functions[attribute_name] = attribute_value 126 return functions 127 128 129def delete_tracking(obj, name): 130 """Removes the tracking of name from object.""" 131 # pylint: disable=protected-access 132 if name in obj._unconditional_dependency_names: 133 del obj._unconditional_dependency_names[name] 134 for index, (dep_name, _) in enumerate( 135 obj._unconditional_checkpoint_dependencies): 136 if dep_name == name: 137 del obj._unconditional_checkpoint_dependencies[index] 138 break 139 # pylint: enable=protected-access 140 141 142class ResourceTracker(object): 143 """An object that tracks a list of resources.""" 144 145 __slots__ = ["_resources"] 146 147 def __init__(self): 148 self._resources = [] 149 150 @property 151 def resources(self): 152 return self._resources 153 154 def add_resource(self, resource): 155 self._resources.append(resource) 156 157 158@tf_contextlib.contextmanager 159def resource_tracker_scope(resource_tracker): 160 """A context to manage resource trackers. 161 162 Use this in order to collect up all resources created within a block of code. 163 Example usage: 164 165 ```python 166 resource_tracker = ResourceTracker() 167 with resource_tracker_scope(resource_tracker): 168 resource = TrackableResource() 169 170 assert resource_tracker.resources == [resource] 171 172 Args: 173 resource_tracker: The passed in ResourceTracker object 174 175 Yields: 176 A scope in which the resource_tracker is active. 177 """ 178 global _RESOURCE_TRACKER_STACK 179 old = list(_RESOURCE_TRACKER_STACK) 180 _RESOURCE_TRACKER_STACK.append(resource_tracker) 181 try: 182 yield 183 finally: 184 _RESOURCE_TRACKER_STACK = old 185 186 187class CapturableResourceDeleter(object): 188 """Deleter to destroy CapturableResource without overriding its __del__().""" 189 190 __slots__ = ["_destruction_context", "_destroy_resource"] 191 192 def __init__(self, destroy_resource_fn=None): 193 if destroy_resource_fn: 194 self._destroy_resource = destroy_resource_fn 195 self._destruction_context = ( 196 context.eager_mode if context.executing_eagerly() 197 else ops.get_default_graph().as_default) 198 else: 199 self._destroy_resource = None 200 201 def destroy_resource(self): 202 if self._destroy_resource: 203 return self._destroy_resource() 204 205 def __del__(self): 206 if self._destroy_resource: 207 with self._destruction_context(): 208 self._destroy_resource() 209 210 211class CapturableResource(base.Trackable): 212 """Holds a Tensor which a tf.function can capture. 213 214 `CapturableResource`s are discovered by traversing the graph of object 215 attributes, e.g. during `tf.saved_model.save`. They are excluded from the 216 scope-based tracking of `TrackableResource`; generally things that require 217 initialization should inherit from `TrackableResource` instead of 218 `CapturableResource` directly. 219 """ 220 221 def __init__(self, device="", deleter=None): 222 """Initialize the `CapturableResource`. 223 224 Args: 225 device: A string indicating a required placement for this resource, 226 e.g. "CPU" if this resource must be created on a CPU device. A blank 227 device allows the user to place resource creation, so generally this 228 should be blank unless the resource only makes sense on one device. 229 deleter: A CapturableResourceDeleter that will destroy the created 230 resource during destruction. 231 """ 232 self._resource_handle = None 233 self._resource_device = device 234 self._resource_deleter = deleter or CapturableResourceDeleter() 235 236 def _create_resource(self): 237 """A function that creates a resource handle.""" 238 raise NotImplementedError("TrackableResource._create_resource not " 239 "implemented.") 240 241 def _initialize(self): 242 """A function that initializes the resource. Optional.""" 243 pass 244 245 @property 246 def resource_handle(self): 247 """Returns the resource handle associated with this Resource.""" 248 if self._resource_handle is None: 249 with ops.device(self._resource_device): 250 self._resource_handle = self._create_resource() 251 return self._resource_handle 252 253 def _map_resources(self, _): 254 """For implementing `Trackable`.""" 255 new_obj = copy.copy(self) 256 # pylint: disable=protected-access 257 with ops.device(self._resource_device): 258 new_resource = new_obj._create_resource() 259 new_obj._resource_handle = new_resource 260 # pylint: enable=protected-access 261 obj_map = {self: new_obj} 262 resource_map = {self.resource_handle: new_resource} 263 return obj_map, resource_map 264 265 def _list_functions_for_serialization(self, unused_functions): 266 @def_function.function(input_signature=[], autograph=False) 267 def _creator(): 268 resource = self._create_resource() 269 return resource 270 271 @def_function.function(input_signature=[], autograph=False) 272 def _initializer(): 273 self._initialize() 274 return 1 # Dummy return 275 276 @def_function.function(input_signature=[], autograph=False) 277 def _destroyer(): 278 self._resource_deleter.destroy_resource() 279 return 1 # Dummy return 280 281 return { 282 "_create_resource": _creator, 283 "_initialize": _initializer, 284 "_destroy_resource": _destroyer, 285 } 286 287 288class TrackableResource(CapturableResource): 289 """Adds scope tracking to CapturableResource.""" 290 291 def __init__(self, device="", deleter=None): 292 """Initialize the `TrackableResource`. 293 294 Args: 295 device: A string indicating a required placement for this resource, 296 e.g. "CPU" if this resource must be created on a CPU device. A blank 297 device allows the user to place resource creation, so generally this 298 should be blank unless the resource only makes sense on one device. 299 deleter: A CapturableResourceDeleter that will destroy the created 300 resource during destruction. 301 """ 302 global _RESOURCE_TRACKER_STACK 303 for resource_tracker in _RESOURCE_TRACKER_STACK: 304 resource_tracker.add_resource(self) 305 super(TrackableResource, self).__init__(device=device, deleter=deleter) 306 307 308@tf_export("saved_model.Asset") 309class Asset(base.Trackable): 310 """Represents a file asset to hermetically include in a SavedModel. 311 312 A SavedModel can include arbitrary files, called assets, that are needed 313 for its use. For example a vocabulary file used initialize a lookup table. 314 315 When a trackable object is exported via `tf.saved_model.save()`, all the 316 `Asset`s reachable from it are copied into the SavedModel assets directory. 317 Upon loading, the assets and the serialized functions that depend on them 318 will refer to the correct filepaths inside the SavedModel directory. 319 320 Example: 321 322 ``` 323 filename = tf.saved_model.Asset("file.txt") 324 325 @tf.function(input_signature=[]) 326 def func(): 327 return tf.io.read_file(filename) 328 329 trackable_obj = tf.train.Checkpoint() 330 trackable_obj.func = func 331 trackable_obj.filename = filename 332 tf.saved_model.save(trackable_obj, "/tmp/saved_model") 333 334 # The created SavedModel is hermetic, it does not depend on 335 # the original file and can be moved to another path. 336 tf.io.gfile.remove("file.txt") 337 tf.io.gfile.rename("/tmp/saved_model", "/tmp/new_location") 338 339 reloaded_obj = tf.saved_model.load("/tmp/new_location") 340 print(reloaded_obj.func()) 341 ``` 342 343 Attributes: 344 asset_path: A 0-D `tf.string` tensor with path to the asset. 345 """ 346 347 def __init__(self, path): 348 """Record the full path to the asset.""" 349 # The init_scope prevents functions from capturing `path` in an 350 # initialization graph, since it is transient and should not end up in a 351 # serialized function body. 352 with ops.init_scope(), ops.device("CPU"): 353 self._path = ops.convert_to_tensor( 354 path, dtype=dtypes.string, name="asset_path") 355 356 @property 357 def asset_path(self): 358 """Fetch the current asset path.""" 359 return self._path 360 361 362ops.register_tensor_conversion_function( 363 Asset, lambda asset, **kw: ops.convert_to_tensor(asset.asset_path, **kw)) 364