1"""Dependency tracking for trackable objects.""" 2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import contextlib 21import copy 22import warnings 23 24from absl import logging 25import six 26 27from tensorflow.python.eager import context 28from tensorflow.python.eager import def_function 29from tensorflow.python.eager import function as defun 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import ops 32from tensorflow.python.training.tracking import base 33from tensorflow.python.training.tracking import data_structures 34from tensorflow.python.util import tf_contextlib 35from tensorflow.python.util.tf_export import tf_export 36 37 38# global _RESOURCE_TRACKER_STACK 39_RESOURCE_TRACKER_STACK = [] 40 41 42@tf_export("__internal__.tracking.AutoTrackable", v1=[]) 43class AutoTrackable(base.Trackable): 44 """Manages dependencies on other objects. 45 46 `Trackable` objects may have dependencies: other `Trackable` objects 47 which should be saved if the object declaring the dependency is saved. A 48 correctly saveable program has a dependency graph such that if changing a 49 global variable affects an object (e.g. changes the behavior of any of its 50 methods) then there is a chain of dependencies from the influenced object to 51 the variable. 52 53 Dependency edges have names, and are created implicitly when a 54 `Trackable` object is assigned to an attribute of another 55 `Trackable` object. For example: 56 57 ``` 58 obj = Trackable() 59 obj.v = ResourceVariable(0.) 60 ``` 61 62 The `Trackable` object `obj` now has a dependency named "v" on a 63 variable. 64 65 `Trackable` objects may specify `Tensor`s to be saved and restored 66 directly (e.g. a `Variable` indicating how to save itself) rather than through 67 dependencies on other objects. See 68 `Trackable._gather_saveables_for_checkpoint` for details. 69 """ 70 71 def __setattr__(self, name, value): 72 """Support self.foo = trackable syntax.""" 73 try: 74 if getattr(self, name) is value: 75 # Short circuit for `self.$x = self.$x`. 76 return 77 except AttributeError: 78 pass 79 80 if getattr(self, "_self_setattr_tracking", True): 81 value = data_structures.sticky_attribute_assignment( 82 trackable=self, value=value, name=name) 83 super(AutoTrackable, self).__setattr__(name, value) 84 85 def __delattr__(self, name): 86 self._delete_tracking(name) 87 super(AutoTrackable, self).__delattr__(name) 88 89 def _no_dependency(self, value): 90 """Override to allow TrackableBase to disable dependency tracking.""" 91 return data_structures.NoDependency(value) 92 93 def _list_functions_for_serialization(self, unused_serialization_cache): 94 """Return a dict of `Function`s of a trackable.""" 95 functions = {} 96 for attribute_name in dir(self): 97 # We get the attributes, suppressing warnings and exceptions. 98 logging_verbosity = logging.get_verbosity() 99 try: 100 logging.set_verbosity(logging.FATAL) 101 with warnings.catch_warnings(): 102 warnings.simplefilter("ignore") 103 attribute_value = getattr(self, attribute_name, None) 104 except Exception: # pylint: disable=broad-except 105 # We really don't want to throw an exception just because some object's 106 # attribute accessor is broken. 107 attribute_value = None 108 finally: 109 # We reset the verbosity setting in a `finally` block, to make 110 # sure it always happens, even if we make the exception catching above 111 # be less broad. 112 logging.set_verbosity(logging_verbosity) 113 if isinstance(attribute_value, (def_function.Function, 114 defun.ConcreteFunction)): 115 functions[attribute_name] = attribute_value 116 return functions 117 118 def _delete_tracking(self, name): 119 """Removes the tracking of name.""" 120 self._maybe_initialize_trackable() 121 if name in self._unconditional_dependency_names: 122 del self._unconditional_dependency_names[name] 123 for index, (dep_name, _) in enumerate( 124 self._unconditional_checkpoint_dependencies): 125 if dep_name == name: 126 del self._unconditional_checkpoint_dependencies[index] 127 break 128 129 130class ResourceTracker(object): 131 """An object that tracks a list of resources.""" 132 133 __slots__ = ["_resources"] 134 135 def __init__(self): 136 self._resources = [] 137 138 @property 139 def resources(self): 140 return self._resources 141 142 def add_resource(self, resource): 143 self._resources.append(resource) 144 145 146@tf_contextlib.contextmanager 147def resource_tracker_scope(resource_tracker): 148 """A context to manage resource trackers. 149 150 Use this in order to collect up all resources created within a block of code. 151 Example usage: 152 153 ```python 154 resource_tracker = ResourceTracker() 155 with resource_tracker_scope(resource_tracker): 156 resource = TrackableResource() 157 158 assert resource_tracker.resources == [resource] 159 160 Args: 161 resource_tracker: The passed in ResourceTracker object 162 163 Yields: 164 A scope in which the resource_tracker is active. 165 """ 166 global _RESOURCE_TRACKER_STACK 167 old = list(_RESOURCE_TRACKER_STACK) 168 _RESOURCE_TRACKER_STACK.append(resource_tracker) 169 try: 170 yield 171 finally: 172 _RESOURCE_TRACKER_STACK = old 173 174 175def _make_getter(captured_getter, captured_previous): 176 """To avoid capturing loop variables.""" 177 178 def getter(*args, **kwargs): 179 return captured_getter(captured_previous, *args, **kwargs) 180 181 return getter 182 183 184class ResourceMetaclass(type): 185 """Metaclass for CapturableResource.""" 186 187 def __call__(cls, *args, **kwargs): 188 189 def default_resource_creator(next_creator, *a, **kw): 190 assert next_creator is None 191 obj = cls.__new__(cls, *a, **kw) 192 obj.__init__(*a, **kw) 193 return obj 194 195 previous_getter = lambda *a, **kw: default_resource_creator(None, *a, **kw) 196 resource_creator_stack = ops.get_default_graph()._resource_creator_stack 197 for getter in resource_creator_stack[cls._resource_type()]: 198 previous_getter = _make_getter(getter, previous_getter) 199 200 return previous_getter(*args, **kwargs) 201 202 203class CapturableResource(six.with_metaclass(ResourceMetaclass, base.Trackable)): 204 """Holds a Tensor which a tf.function can capture. 205 206 `CapturableResource`s are discovered by traversing the graph of object 207 attributes, e.g. during `tf.saved_model.save`. They are excluded from the 208 scope-based tracking of `TrackableResource`; generally things that require 209 initialization should inherit from `TrackableResource` instead of 210 `CapturableResource` directly. 211 """ 212 213 def __init__(self, device=""): 214 """Initialize the `CapturableResource`. 215 216 Args: 217 device: A string indicating a required placement for this resource, 218 e.g. "CPU" if this resource must be created on a CPU device. A blank 219 device allows the user to place resource creation, so generally this 220 should be blank unless the resource only makes sense on one device. 221 """ 222 self._resource_handle = None 223 self._resource_device = device 224 self._self_destruction_context = ( 225 context.eager_mode if context.executing_eagerly() 226 else ops.get_default_graph().as_default) 227 228 @classmethod 229 def _resource_type(cls): 230 return cls.__name__ 231 232 @property 233 def _destruction_context(self): 234 return getattr(self, "_self_destruction_context", 235 # no-op context 236 contextlib.suppress) 237 238 @_destruction_context.setter 239 def _destruction_context(self, destruction_context): 240 self._self_destruction_context = destruction_context 241 242 def _create_resource(self): 243 """A function that creates a resource handle.""" 244 raise NotImplementedError("TrackableResource._create_resource not " 245 "implemented.") 246 247 def _initialize(self): 248 """A function that initializes the resource. Optional.""" 249 pass 250 251 def _destroy_resource(self): 252 """A function that destroys the resource. Optional.""" 253 pass 254 255 @property 256 def resource_handle(self): 257 """Returns the resource handle associated with this Resource.""" 258 if self._resource_handle is None: 259 with ops.device(self._resource_device): 260 self._resource_handle = self._create_resource() 261 return self._resource_handle 262 263 def _map_resources(self, _): 264 """For implementing `Trackable`.""" 265 new_obj = copy.copy(self) 266 # pylint: disable=protected-access 267 with ops.device(self._resource_device): 268 new_resource = new_obj._create_resource() 269 new_obj._resource_handle = new_resource 270 # pylint: enable=protected-access 271 obj_map = {self: new_obj} 272 resource_map = {self.resource_handle: new_resource} 273 return obj_map, resource_map 274 275 def _list_functions_for_serialization(self, unused_functions): 276 @def_function.function(input_signature=[], autograph=False) 277 def _creator(): 278 resource = self._create_resource() 279 return resource 280 281 @def_function.function(input_signature=[], autograph=False) 282 def _initializer(): 283 self._initialize() 284 return 1 # Dummy return 285 286 @def_function.function(input_signature=[], autograph=False) 287 def _destroyer(): 288 self._destroy_resource() 289 return 1 # Dummy return 290 291 return { 292 "_create_resource": _creator, 293 "_initialize": _initializer, 294 "_destroy_resource": _destroyer, 295 } 296 297 def __del__(self): 298 try: 299 # Outer race condition: on program exit, the destruction context may be 300 # deleted before this __del__ is called. At this point we can safely 301 # exit without calling _destroy_resource() and let Python handle things. 302 with self._destruction_context(): 303 # Inner race condition: possible between this and `ScopedTFFunction` 304 # whereby if an entire garbage collection chain containing both 305 # objects is moved to unreachable during the same garbage collection 306 # cycle, the __del__ for `ScopedTFFunction` can be collected before 307 # this method is called. In that case, we can't do much but 308 # continue. 309 self._destroy_resource() 310 except Exception: # pylint: disable=broad-except 311 # Silence all error logs that occur when attempting to destroy this 312 # resource. 313 pass 314 315 316@tf_export("saved_model.experimental.TrackableResource") 317class TrackableResource(CapturableResource): 318 """Holds a Tensor which a tf.function can capture. 319 320 A TrackableResource is most useful for stateful Tensors that require 321 initialization, such as `tf.lookup.StaticHashTable`. `TrackableResource`s 322 are discovered by traversing the graph of object attributes, e.g. during 323 `tf.saved_model.save`. 324 325 A TrackableResource has three methods to override: 326 327 * `_create_resource` should create the resource tensor handle. 328 * `_initialize` should initialize the resource held at `self.resource_handle`. 329 * `_destroy_resource` is called upon a `TrackableResource`'s destruction 330 and should decrement the resource's ref count. For most resources, this 331 should be done with a call to `tf.raw_ops.DestroyResourceOp`. 332 333 Example usage: 334 335 >>> class DemoResource(tf.saved_model.experimental.TrackableResource): 336 ... def __init__(self): 337 ... super().__init__() 338 ... self._initialize() 339 ... def _create_resource(self): 340 ... return tf.raw_ops.VarHandleOp(dtype=tf.float32, shape=[2]) 341 ... def _initialize(self): 342 ... tf.raw_ops.AssignVariableOp( 343 ... resource=self.resource_handle, value=tf.ones([2])) 344 ... def _destroy_resource(self): 345 ... tf.raw_ops.DestroyResourceOp(resource=self.resource_handle) 346 >>> class DemoModule(tf.Module): 347 ... def __init__(self): 348 ... self.resource = DemoResource() 349 ... def increment(self, tensor): 350 ... return tensor + tf.raw_ops.ReadVariableOp( 351 ... resource=self.resource.resource_handle, dtype=tf.float32) 352 >>> demo = DemoModule() 353 >>> demo.increment([5, 1]) 354 <tf.Tensor: shape=(2,), dtype=float32, numpy=array([6., 2.], dtype=float32)> 355 """ 356 357 def __init__(self, device=""): 358 """Initialize the `TrackableResource`. 359 360 Args: 361 device: A string indicating a required placement for this resource, 362 e.g. "CPU" if this resource must be created on a CPU device. A blank 363 device allows the user to place resource creation, so generally this 364 should be blank unless the resource only makes sense on one device. 365 """ 366 global _RESOURCE_TRACKER_STACK 367 for resource_tracker in _RESOURCE_TRACKER_STACK: 368 resource_tracker.add_resource(self) 369 super(TrackableResource, self).__init__(device=device) 370 371 372@tf_export("saved_model.Asset") 373class Asset(base.Trackable): 374 """Represents a file asset to hermetically include in a SavedModel. 375 376 A SavedModel can include arbitrary files, called assets, that are needed 377 for its use. For example a vocabulary file used initialize a lookup table. 378 379 When a trackable object is exported via `tf.saved_model.save()`, all the 380 `Asset`s reachable from it are copied into the SavedModel assets directory. 381 Upon loading, the assets and the serialized functions that depend on them 382 will refer to the correct filepaths inside the SavedModel directory. 383 384 Example: 385 386 ``` 387 filename = tf.saved_model.Asset("file.txt") 388 389 @tf.function(input_signature=[]) 390 def func(): 391 return tf.io.read_file(filename) 392 393 trackable_obj = tf.train.Checkpoint() 394 trackable_obj.func = func 395 trackable_obj.filename = filename 396 tf.saved_model.save(trackable_obj, "/tmp/saved_model") 397 398 # The created SavedModel is hermetic, it does not depend on 399 # the original file and can be moved to another path. 400 tf.io.gfile.remove("file.txt") 401 tf.io.gfile.rename("/tmp/saved_model", "/tmp/new_location") 402 403 reloaded_obj = tf.saved_model.load("/tmp/new_location") 404 print(reloaded_obj.func()) 405 ``` 406 407 Attributes: 408 asset_path: A 0-D `tf.string` tensor with path to the asset. 409 """ 410 411 def __init__(self, path): 412 """Record the full path to the asset.""" 413 # The init_scope prevents functions from capturing `path` in an 414 # initialization graph, since it is transient and should not end up in a 415 # serialized function body. 416 with ops.init_scope(), ops.device("CPU"): 417 self._path = ops.convert_to_tensor( 418 path, dtype=dtypes.string, name="asset_path") 419 420 @property 421 def asset_path(self): 422 """Fetch the current asset path.""" 423 return self._path 424 425 426ops.register_tensor_conversion_function( 427 Asset, lambda asset, **kw: ops.convert_to_tensor(asset.asset_path, **kw)) 428