• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Dependency tracking for trackable objects."""
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import contextlib
21import copy
22import warnings
23
24from absl import logging
25import six
26
27from tensorflow.python.eager import context
28from tensorflow.python.eager import def_function
29from tensorflow.python.eager import function as defun
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import ops
32from tensorflow.python.training.tracking import base
33from tensorflow.python.training.tracking import data_structures
34from tensorflow.python.util import tf_contextlib
35from tensorflow.python.util.tf_export import tf_export
36
37
38# global _RESOURCE_TRACKER_STACK
39_RESOURCE_TRACKER_STACK = []
40
41
42@tf_export("__internal__.tracking.AutoTrackable", v1=[])
43class AutoTrackable(base.Trackable):
44  """Manages dependencies on other objects.
45
46  `Trackable` objects may have dependencies: other `Trackable` objects
47  which should be saved if the object declaring the dependency is saved. A
48  correctly saveable program has a dependency graph such that if changing a
49  global variable affects an object (e.g. changes the behavior of any of its
50  methods) then there is a chain of dependencies from the influenced object to
51  the variable.
52
53  Dependency edges have names, and are created implicitly when a
54  `Trackable` object is assigned to an attribute of another
55  `Trackable` object. For example:
56
57  ```
58  obj = Trackable()
59  obj.v = ResourceVariable(0.)
60  ```
61
62  The `Trackable` object `obj` now has a dependency named "v" on a
63  variable.
64
65  `Trackable` objects may specify `Tensor`s to be saved and restored
66  directly (e.g. a `Variable` indicating how to save itself) rather than through
67  dependencies on other objects. See
68  `Trackable._gather_saveables_for_checkpoint` for details.
69  """
70
71  def __setattr__(self, name, value):
72    """Support self.foo = trackable syntax."""
73    try:
74      if getattr(self, name) is value:
75        # Short circuit for `self.$x = self.$x`.
76        return
77    except AttributeError:
78      pass
79
80    if getattr(self, "_self_setattr_tracking", True):
81      value = data_structures.sticky_attribute_assignment(
82          trackable=self, value=value, name=name)
83    super(AutoTrackable, self).__setattr__(name, value)
84
85  def __delattr__(self, name):
86    self._delete_tracking(name)
87    super(AutoTrackable, self).__delattr__(name)
88
89  def _no_dependency(self, value):
90    """Override to allow TrackableBase to disable dependency tracking."""
91    return data_structures.NoDependency(value)
92
93  def _list_functions_for_serialization(self, unused_serialization_cache):
94    """Return a dict of `Function`s of a trackable."""
95    functions = {}
96    for attribute_name in dir(self):
97      # We get the attributes, suppressing warnings and exceptions.
98      logging_verbosity = logging.get_verbosity()
99      try:
100        logging.set_verbosity(logging.FATAL)
101        with warnings.catch_warnings():
102          warnings.simplefilter("ignore")
103          attribute_value = getattr(self, attribute_name, None)
104      except Exception:  # pylint: disable=broad-except
105        # We really don't want to throw an exception just because some object's
106        # attribute accessor is broken.
107        attribute_value = None
108      finally:
109        # We reset the verbosity setting in a `finally` block, to make
110        # sure it always happens, even if we make the exception catching above
111        # be less broad.
112        logging.set_verbosity(logging_verbosity)
113      if isinstance(attribute_value, (def_function.Function,
114                                      defun.ConcreteFunction)):
115        functions[attribute_name] = attribute_value
116    return functions
117
118  def _delete_tracking(self, name):
119    """Removes the tracking of name."""
120    self._maybe_initialize_trackable()
121    if name in self._unconditional_dependency_names:
122      del self._unconditional_dependency_names[name]
123      for index, (dep_name, _) in enumerate(
124          self._unconditional_checkpoint_dependencies):
125        if dep_name == name:
126          del self._unconditional_checkpoint_dependencies[index]
127          break
128
129
130class ResourceTracker(object):
131  """An object that tracks a list of resources."""
132
133  __slots__ = ["_resources"]
134
135  def __init__(self):
136    self._resources = []
137
138  @property
139  def resources(self):
140    return self._resources
141
142  def add_resource(self, resource):
143    self._resources.append(resource)
144
145
146@tf_contextlib.contextmanager
147def resource_tracker_scope(resource_tracker):
148  """A context to manage resource trackers.
149
150  Use this in order to collect up all resources created within a block of code.
151  Example usage:
152
153  ```python
154  resource_tracker = ResourceTracker()
155  with resource_tracker_scope(resource_tracker):
156    resource = TrackableResource()
157
158  assert resource_tracker.resources == [resource]
159
160  Args:
161    resource_tracker: The passed in ResourceTracker object
162
163  Yields:
164    A scope in which the resource_tracker is active.
165  """
166  global _RESOURCE_TRACKER_STACK
167  old = list(_RESOURCE_TRACKER_STACK)
168  _RESOURCE_TRACKER_STACK.append(resource_tracker)
169  try:
170    yield
171  finally:
172    _RESOURCE_TRACKER_STACK = old
173
174
175def _make_getter(captured_getter, captured_previous):
176  """To avoid capturing loop variables."""
177
178  def getter(*args, **kwargs):
179    return captured_getter(captured_previous, *args, **kwargs)
180
181  return getter
182
183
184class ResourceMetaclass(type):
185  """Metaclass for CapturableResource."""
186
187  def __call__(cls, *args, **kwargs):
188
189    def default_resource_creator(next_creator, *a, **kw):
190      assert next_creator is None
191      obj = cls.__new__(cls, *a, **kw)
192      obj.__init__(*a, **kw)
193      return obj
194
195    previous_getter = lambda *a, **kw: default_resource_creator(None, *a, **kw)
196    resource_creator_stack = ops.get_default_graph()._resource_creator_stack
197    for getter in resource_creator_stack[cls._resource_type()]:
198      previous_getter = _make_getter(getter, previous_getter)
199
200    return previous_getter(*args, **kwargs)
201
202
203class CapturableResource(six.with_metaclass(ResourceMetaclass, base.Trackable)):
204  """Holds a Tensor which a tf.function can capture.
205
206  `CapturableResource`s are discovered by traversing the graph of object
207  attributes, e.g. during `tf.saved_model.save`. They are excluded from the
208  scope-based tracking of `TrackableResource`; generally things that require
209  initialization should inherit from `TrackableResource` instead of
210  `CapturableResource` directly.
211  """
212
213  def __init__(self, device=""):
214    """Initialize the `CapturableResource`.
215
216    Args:
217      device: A string indicating a required placement for this resource,
218        e.g. "CPU" if this resource must be created on a CPU device. A blank
219        device allows the user to place resource creation, so generally this
220        should be blank unless the resource only makes sense on one device.
221    """
222    self._resource_handle = None
223    self._resource_device = device
224    self._self_destruction_context = (
225        context.eager_mode if context.executing_eagerly()
226        else ops.get_default_graph().as_default)
227
228  @classmethod
229  def _resource_type(cls):
230    return cls.__name__
231
232  @property
233  def _destruction_context(self):
234    return getattr(self, "_self_destruction_context",
235                   # no-op context
236                   contextlib.suppress)
237
238  @_destruction_context.setter
239  def _destruction_context(self, destruction_context):
240    self._self_destruction_context = destruction_context
241
242  def _create_resource(self):
243    """A function that creates a resource handle."""
244    raise NotImplementedError("TrackableResource._create_resource not "
245                              "implemented.")
246
247  def _initialize(self):
248    """A function that initializes the resource. Optional."""
249    pass
250
251  def _destroy_resource(self):
252    """A function that destroys the resource. Optional."""
253    pass
254
255  @property
256  def resource_handle(self):
257    """Returns the resource handle associated with this Resource."""
258    if self._resource_handle is None:
259      with ops.device(self._resource_device):
260        self._resource_handle = self._create_resource()
261    return self._resource_handle
262
263  def _map_resources(self, _):
264    """For implementing `Trackable`."""
265    new_obj = copy.copy(self)
266    # pylint: disable=protected-access
267    with ops.device(self._resource_device):
268      new_resource = new_obj._create_resource()
269    new_obj._resource_handle = new_resource
270    # pylint: enable=protected-access
271    obj_map = {self: new_obj}
272    resource_map = {self.resource_handle: new_resource}
273    return obj_map, resource_map
274
275  def _list_functions_for_serialization(self, unused_functions):
276    @def_function.function(input_signature=[], autograph=False)
277    def _creator():
278      resource = self._create_resource()
279      return resource
280
281    @def_function.function(input_signature=[], autograph=False)
282    def _initializer():
283      self._initialize()
284      return 1  # Dummy return
285
286    @def_function.function(input_signature=[], autograph=False)
287    def _destroyer():
288      self._destroy_resource()
289      return 1  # Dummy return
290
291    return {
292        "_create_resource": _creator,
293        "_initialize": _initializer,
294        "_destroy_resource": _destroyer,
295    }
296
297  def __del__(self):
298    try:
299      # Outer race condition: on program exit, the destruction context may be
300      # deleted before this __del__ is called. At this point we can safely
301      # exit without calling _destroy_resource() and let Python handle things.
302      with self._destruction_context():
303        # Inner race condition: possible between this and `ScopedTFFunction`
304        # whereby if an entire garbage collection chain containing both
305        # objects is moved to unreachable during the same garbage collection
306        # cycle, the __del__ for `ScopedTFFunction` can be collected before
307        # this method is called. In that case, we can't do much but
308        # continue.
309        self._destroy_resource()
310    except Exception:  # pylint: disable=broad-except
311      # Silence all error logs that occur when attempting to destroy this
312      # resource.
313      pass
314
315
316@tf_export("saved_model.experimental.TrackableResource")
317class TrackableResource(CapturableResource):
318  """Holds a Tensor which a tf.function can capture.
319
320  A TrackableResource is most useful for stateful Tensors that require
321  initialization, such as `tf.lookup.StaticHashTable`. `TrackableResource`s
322  are discovered by traversing the graph of object attributes, e.g. during
323  `tf.saved_model.save`.
324
325  A TrackableResource has three methods to override:
326
327  * `_create_resource` should create the resource tensor handle.
328  * `_initialize` should initialize the resource held at `self.resource_handle`.
329  * `_destroy_resource` is called upon a `TrackableResource`'s destruction
330    and should decrement the resource's ref count. For most resources, this
331    should be done with a call to `tf.raw_ops.DestroyResourceOp`.
332
333  Example usage:
334
335  >>> class DemoResource(tf.saved_model.experimental.TrackableResource):
336  ...   def __init__(self):
337  ...     super().__init__()
338  ...     self._initialize()
339  ...   def _create_resource(self):
340  ...     return tf.raw_ops.VarHandleOp(dtype=tf.float32, shape=[2])
341  ...   def _initialize(self):
342  ...     tf.raw_ops.AssignVariableOp(
343  ...         resource=self.resource_handle, value=tf.ones([2]))
344  ...   def _destroy_resource(self):
345  ...     tf.raw_ops.DestroyResourceOp(resource=self.resource_handle)
346  >>> class DemoModule(tf.Module):
347  ...   def __init__(self):
348  ...     self.resource = DemoResource()
349  ...   def increment(self, tensor):
350  ...     return tensor + tf.raw_ops.ReadVariableOp(
351  ...         resource=self.resource.resource_handle, dtype=tf.float32)
352  >>> demo = DemoModule()
353  >>> demo.increment([5, 1])
354  <tf.Tensor: shape=(2,), dtype=float32, numpy=array([6., 2.], dtype=float32)>
355  """
356
357  def __init__(self, device=""):
358    """Initialize the `TrackableResource`.
359
360    Args:
361      device: A string indicating a required placement for this resource,
362        e.g. "CPU" if this resource must be created on a CPU device. A blank
363        device allows the user to place resource creation, so generally this
364        should be blank unless the resource only makes sense on one device.
365    """
366    global _RESOURCE_TRACKER_STACK
367    for resource_tracker in _RESOURCE_TRACKER_STACK:
368      resource_tracker.add_resource(self)
369    super(TrackableResource, self).__init__(device=device)
370
371
372@tf_export("saved_model.Asset")
373class Asset(base.Trackable):
374  """Represents a file asset to hermetically include in a SavedModel.
375
376  A SavedModel can include arbitrary files, called assets, that are needed
377  for its use. For example a vocabulary file used initialize a lookup table.
378
379  When a trackable object is exported via `tf.saved_model.save()`, all the
380  `Asset`s reachable from it are copied into the SavedModel assets directory.
381  Upon loading, the assets and the serialized functions that depend on them
382  will refer to the correct filepaths inside the SavedModel directory.
383
384  Example:
385
386  ```
387  filename = tf.saved_model.Asset("file.txt")
388
389  @tf.function(input_signature=[])
390  def func():
391    return tf.io.read_file(filename)
392
393  trackable_obj = tf.train.Checkpoint()
394  trackable_obj.func = func
395  trackable_obj.filename = filename
396  tf.saved_model.save(trackable_obj, "/tmp/saved_model")
397
398  # The created SavedModel is hermetic, it does not depend on
399  # the original file and can be moved to another path.
400  tf.io.gfile.remove("file.txt")
401  tf.io.gfile.rename("/tmp/saved_model", "/tmp/new_location")
402
403  reloaded_obj = tf.saved_model.load("/tmp/new_location")
404  print(reloaded_obj.func())
405  ```
406
407  Attributes:
408    asset_path: A 0-D `tf.string` tensor with path to the asset.
409  """
410
411  def __init__(self, path):
412    """Record the full path to the asset."""
413    # The init_scope prevents functions from capturing `path` in an
414    # initialization graph, since it is transient and should not end up in a
415    # serialized function body.
416    with ops.init_scope(), ops.device("CPU"):
417      self._path = ops.convert_to_tensor(
418          path, dtype=dtypes.string, name="asset_path")
419
420  @property
421  def asset_path(self):
422    """Fetch the current asset path."""
423    return self._path
424
425
426ops.register_tensor_conversion_function(
427    Asset, lambda asset, **kw: ops.convert_to_tensor(asset.asset_path, **kw))
428