1"""Dependency tracking for trackable objects.""" 2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import functools 21import weakref 22 23from tensorflow.python.eager import context 24from tensorflow.python.eager import def_function 25from tensorflow.python.eager import function as defun 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.training.tracking import base 29from tensorflow.python.training.tracking import data_structures 30from tensorflow.python.util import tf_contextlib 31from tensorflow.python.util.tf_export import tf_export 32 33 34# global _RESOURCE_TRACKER_STACK 35_RESOURCE_TRACKER_STACK = [] 36 37 38class NotTrackable(object): 39 """Marks instances of child classes as unsaveable using an object-based API. 40 41 Useful for marking objects which would otherwise look trackable because 42 of inheritance (e.g. through `Layer`) as not trackable. Inheriting from 43 `NotTrackable` does not prevent an object from being assigned to any 44 attributes, but will throw an error on save/restore. 45 """ 46 pass 47 48 49class AutoTrackable(base.Trackable): 50 """Manages dependencies on other objects. 51 52 `Trackable` objects may have dependencies: other `Trackable` objects 53 which should be saved if the object declaring the dependency is saved. A 54 correctly saveable program has a dependency graph such that if changing a 55 global variable affects an object (e.g. changes the behavior of any of its 56 methods) then there is a chain of dependencies from the influenced object to 57 the variable. 58 59 Dependency edges have names, and are created implicitly when a 60 `Trackable` object is assigned to an attribute of another 61 `Trackable` object. For example: 62 63 ``` 64 obj = Trackable() 65 obj.v = ResourceVariable(0.) 66 ``` 67 68 The `Trackable` object `obj` now has a dependency named "v" on a 69 variable. 70 71 `Trackable` objects may specify `Tensor`s to be saved and restored 72 directly (e.g. a `Variable` indicating how to save itself) rather than through 73 dependencies on other objects. See 74 `Trackable._gather_saveables_for_checkpoint` for details. 75 """ 76 77 def __setattr__(self, name, value): 78 """Support self.foo = trackable syntax.""" 79 try: 80 if getattr(self, name) is value: 81 # Short circuit for `self.$x = self.$x`. 82 return 83 except AttributeError: 84 pass 85 86 if getattr(self, "_self_setattr_tracking", True): 87 value = data_structures.sticky_attribute_assignment( 88 trackable=self, value=value, name=name) 89 super(AutoTrackable, self).__setattr__(name, value) 90 91 def __delattr__(self, name): 92 self._maybe_initialize_trackable() 93 delete_tracking(self, name) 94 super(AutoTrackable, self).__delattr__(name) 95 96 def _no_dependency(self, value): 97 """Override to allow TrackableBase to disable dependency tracking.""" 98 return data_structures.NoDependency(value) 99 100 def _list_functions_for_serialization(self, unused_serialization_cache): 101 """Return a dict of `Function`s of a trackable.""" 102 functions = {} 103 for attribute_name in dir(self): 104 try: 105 attribute_value = getattr(self, attribute_name, None) 106 except Exception: # pylint: disable=broad-except 107 # We really don't want to throw an exception just because some object's 108 # attribute accessor is broken. 109 attribute_value = None 110 if isinstance(attribute_value, (def_function.Function, 111 defun.ConcreteFunction)): 112 functions[attribute_name] = attribute_value 113 return functions 114 115 116def delete_tracking(obj, name): 117 """Removes the tracking of name from object.""" 118 # pylint: disable=protected-access 119 if name in obj._unconditional_dependency_names: 120 del obj._unconditional_dependency_names[name] 121 for index, (dep_name, _) in enumerate( 122 obj._unconditional_checkpoint_dependencies): 123 if dep_name == name: 124 del obj._unconditional_checkpoint_dependencies[index] 125 break 126 # pylint: enable=protected-access 127 128 129class ResourceTracker(object): 130 """An object that tracks a list of resources.""" 131 132 def __init__(self): 133 self._resources = [] 134 135 @property 136 def resources(self): 137 return self._resources 138 139 def add_resource(self, resource): 140 self._resources.append(resource) 141 142 143@tf_contextlib.contextmanager 144def resource_tracker_scope(resource_tracker): 145 """A context to manage resource trackers. 146 147 Use this in order to collect up all resources created within a block of code. 148 Example usage: 149 150 ```python 151 resource_tracker = ResourceTracker() 152 with resource_tracker_scope(resource_tracker): 153 resource = TrackableResource() 154 155 assert resource_tracker.resources == [resource] 156 157 Args: 158 resource_tracker: The passed in ResourceTracker object 159 160 Yields: 161 A scope in which the resource_tracker is active. 162 """ 163 global _RESOURCE_TRACKER_STACK 164 old = list(_RESOURCE_TRACKER_STACK) 165 _RESOURCE_TRACKER_STACK.append(resource_tracker) 166 try: 167 yield 168 finally: 169 _RESOURCE_TRACKER_STACK = old 170 171 172class CapturableResourceDeleter(object): 173 """Deleter to destroy CapturableResource without overriding its __del__().""" 174 175 def __init__(self, destroy_resource_fn=None): 176 if destroy_resource_fn: 177 self._destroy_resource = destroy_resource_fn 178 self._destruction_context = ( 179 context.eager_mode if context.executing_eagerly() 180 else ops.get_default_graph().as_default) 181 else: 182 self._destroy_resource = None 183 184 def destroy_resource(self): 185 if self._destroy_resource: 186 return self._destroy_resource() 187 188 def __del__(self): 189 if self._destroy_resource: 190 with self._destruction_context(): 191 self._destroy_resource() 192 193 194class CapturableResource(base.Trackable): 195 """Holds a Tensor which a tf.function can capture. 196 197 `CapturableResource`s are discovered by traversing the graph of object 198 attributes, e.g. during `tf.saved_model.save`. They are excluded from the 199 scope-based tracking of `TrackableResource`; generally things that require 200 initialization should inherit from `TrackableResource` instead of 201 `CapturableResource` directly. 202 """ 203 204 def __init__(self, device="", deleter=None): 205 """Initialize the `CapturableResource`. 206 207 Args: 208 device: A string indicating a required placement for this resource, 209 e.g. "CPU" if this resource must be created on a CPU device. A blank 210 device allows the user to place resource creation, so generally this 211 should be blank unless the resource only makes sense on one device. 212 deleter: A CapturableResourceDeleter that will destroy the created 213 resource during destruction. 214 """ 215 self._resource_handle = None 216 self._resource_device = device 217 self._resource_deleter = deleter or CapturableResourceDeleter() 218 219 def _create_resource(self): 220 """A function that creates a resource handle.""" 221 raise NotImplementedError("TrackableResource._create_resource not " 222 "implemented.") 223 224 def _initialize(self): 225 """A function that initializes the resource. Optional.""" 226 pass 227 228 @property 229 def resource_handle(self): 230 """Returns the resource handle associated with this Resource.""" 231 if self._resource_handle is None: 232 with ops.device(self._resource_device): 233 self._resource_handle = self._create_resource() 234 return self._resource_handle 235 236 def _list_functions_for_serialization(self, unused_functions): 237 @def_function.function(input_signature=[], autograph=False) 238 def _creator(): 239 resource = self._create_resource() 240 return resource 241 242 @def_function.function(input_signature=[], autograph=False) 243 def _initializer(): 244 self._initialize() 245 return 1 # Dummy return 246 247 @def_function.function(input_signature=[], autograph=False) 248 def _destroyer(): 249 self._resource_deleter.destroy_resource() 250 return 1 # Dummy return 251 252 return { 253 "_create_resource": _creator, 254 "_initialize": _initializer, 255 "_destroy_resource": _destroyer, 256 } 257 258 259class TrackableResource(CapturableResource): 260 """Adds scope tracking to CapturableResource.""" 261 262 def __init__(self, device="", deleter=None): 263 """Initialize the `TrackableResource`. 264 265 Args: 266 device: A string indicating a required placement for this resource, 267 e.g. "CPU" if this resource must be created on a CPU device. A blank 268 device allows the user to place resource creation, so generally this 269 should be blank unless the resource only makes sense on one device. 270 deleter: A CapturableResourceDeleter that will destroy the created 271 resource during destruction. 272 """ 273 global _RESOURCE_TRACKER_STACK 274 for resource_tracker in _RESOURCE_TRACKER_STACK: 275 resource_tracker.add_resource(self) 276 super(TrackableResource, self).__init__(device=device, deleter=deleter) 277 278 279@tf_export("saved_model.Asset") 280class Asset(base.Trackable): 281 """Represents a file asset to hermetically include in a SavedModel. 282 283 A SavedModel can include arbitrary files, called assets, that are needed 284 for its use. For example a vocabulary file used initialize a lookup table. 285 286 When a trackable object is exported via `tf.saved_model.save()`, all the 287 `Asset`s reachable from it are copied into the SavedModel assets directory. 288 Upon loading, the assets and the serialized functions that depend on them 289 will refer to the correct filepaths inside the SavedModel directory. 290 291 Example: 292 293 ``` 294 filename = tf.saved_model.Asset("file.txt") 295 296 @tf.function(input_signature=[]) 297 def func(): 298 return tf.io.read_file(filename) 299 300 trackable_obj = tf.train.Checkpoint() 301 trackable_obj.func = func 302 trackable_obj.filename = filename 303 tf.saved_model.save(trackable_obj, "/tmp/saved_model") 304 305 # The created SavedModel is hermetic, it does not depend on 306 # the original file and can be moved to another path. 307 tf.io.gfile.remove("file.txt") 308 tf.io.gfile.rename("/tmp/saved_model", "/tmp/new_location") 309 310 reloaded_obj = tf.saved_model.load("/tmp/new_location") 311 print(reloaded_obj.func()) 312 ``` 313 314 Attributes: 315 asset_path: A 0-D `tf.string` tensor with path to the asset. 316 """ 317 318 def __init__(self, path): 319 """Record the full path to the asset.""" 320 # The init_scope prevents functions from capturing `path` in an 321 # initialization graph, since it is transient and should not end up in a 322 # serialized function body. 323 with ops.init_scope(), ops.device("CPU"): 324 self._path = ops.convert_to_tensor( 325 path, dtype=dtypes.string, name="asset_path") 326 327 @property 328 def asset_path(self): 329 """Fetch the current asset path.""" 330 return self._path 331 332 333def cached_per_instance(f): 334 """Lightweight decorator for caching lazily constructed properties. 335 336 When to use: 337 This decorator provides simple caching with minimal overhead. It is designed 338 for properties which are expensive to compute and static over the life of a 339 class instance, and provides no mechanism for cache invalidation. Thus it is 340 best suited for lazily exposing derived properties of other static data. 341 342 For classes with custom getattr / setattr behavior (such as trackable 343 objects), storing cache results as object attributes is not performant. 344 Instead, a specialized cache can significantly reduce property lookup 345 overhead. (While still allowing the decorated property to be lazily computed.) 346 Consider the following class: 347 348 ``` 349 class MyClass(object): 350 def __setattr__(self, key, value): 351 # Some expensive class specific code 352 # ... 353 # ... 354 355 super(MyClass, self).__setattr__(key, value) 356 357 @property 358 def thing(self): 359 # `thing` is expensive to compute (and may not even be requested), so we 360 # want to lazily compute it and then cache it. 361 output = getattr(self, '_thing', None) 362 if output is None: 363 self._thing = output = compute_thing(self) 364 return output 365 ``` 366 367 It's also worth noting that ANY overriding of __setattr__, even something as 368 simple as: 369 ``` 370 def __setattr__(self, key, value): 371 super(MyClass, self).__setattr__(key, value) 372 ``` 373 374 Slows down attribute assignment by nearly 10x. 375 376 By contrast, replacing the definition of `thing` with the following sidesteps 377 the expensive __setattr__ altogether: 378 379 ''' 380 @property 381 @tracking.cached_per_instance 382 def thing(self): 383 # `thing` is expensive to compute (and may not even be requested), so we 384 # want to lazily compute it and then cache it. 385 return compute_thing(self) 386 ''' 387 388 Performance: 389 The overhead for this decorator is ~0.4 us / call. A much lower overhead 390 implementation (~0.085 us / call) can be achieved by using a custom dict type: 391 392 ``` 393 def dict_based_cache(f): 394 class Cache(dict): 395 __slots__ = () 396 def __missing__(self, key): 397 self[key] = output = f(key) 398 return output 399 400 return property(Cache().__getitem__) 401 ``` 402 403 However, that implementation holds class instances as keys, and as a result 404 blocks garbage collection. (And modifying it to use weakref's as keys raises 405 the lookup overhead to ~0.4 us) As a result, the WeakKeyDictionary 406 implementation below turns out to be more prudent. 407 408 Args: 409 f: The function to cache. 410 411 Returns: 412 f decorated with simple caching behavior. 413 """ 414 415 cache = weakref.WeakKeyDictionary() 416 417 @functools.wraps(f) 418 def wrapped(item): 419 output = cache.get(item) 420 if output is None: 421 cache[item] = output = f(item) 422 return output 423 424 wrapped.cache = cache 425 return wrapped 426 427 428ops.register_tensor_conversion_function( 429 Asset, lambda asset, **kw: ops.convert_to_tensor(asset.asset_path, **kw)) 430