• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Dependency tracking for trackable objects."""
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import functools
21import weakref
22
23from tensorflow.python.eager import context
24from tensorflow.python.eager import def_function
25from tensorflow.python.eager import function as defun
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.training.tracking import base
29from tensorflow.python.training.tracking import data_structures
30from tensorflow.python.util import tf_contextlib
31from tensorflow.python.util.tf_export import tf_export
32
33
34# global _RESOURCE_TRACKER_STACK
35_RESOURCE_TRACKER_STACK = []
36
37
38class NotTrackable(object):
39  """Marks instances of child classes as unsaveable using an object-based API.
40
41  Useful for marking objects which would otherwise look trackable because
42  of inheritance (e.g. through `Layer`) as not trackable. Inheriting from
43  `NotTrackable` does not prevent an object from being assigned to any
44  attributes, but will throw an error on save/restore.
45  """
46  pass
47
48
49class AutoTrackable(base.Trackable):
50  """Manages dependencies on other objects.
51
52  `Trackable` objects may have dependencies: other `Trackable` objects
53  which should be saved if the object declaring the dependency is saved. A
54  correctly saveable program has a dependency graph such that if changing a
55  global variable affects an object (e.g. changes the behavior of any of its
56  methods) then there is a chain of dependencies from the influenced object to
57  the variable.
58
59  Dependency edges have names, and are created implicitly when a
60  `Trackable` object is assigned to an attribute of another
61  `Trackable` object. For example:
62
63  ```
64  obj = Trackable()
65  obj.v = ResourceVariable(0.)
66  ```
67
68  The `Trackable` object `obj` now has a dependency named "v" on a
69  variable.
70
71  `Trackable` objects may specify `Tensor`s to be saved and restored
72  directly (e.g. a `Variable` indicating how to save itself) rather than through
73  dependencies on other objects. See
74  `Trackable._gather_saveables_for_checkpoint` for details.
75  """
76
77  def __setattr__(self, name, value):
78    """Support self.foo = trackable syntax."""
79    try:
80      if getattr(self, name) is value:
81        # Short circuit for `self.$x = self.$x`.
82        return
83    except AttributeError:
84      pass
85
86    if getattr(self, "_self_setattr_tracking", True):
87      value = data_structures.sticky_attribute_assignment(
88          trackable=self, value=value, name=name)
89    super(AutoTrackable, self).__setattr__(name, value)
90
91  def __delattr__(self, name):
92    self._maybe_initialize_trackable()
93    delete_tracking(self, name)
94    super(AutoTrackable, self).__delattr__(name)
95
96  def _no_dependency(self, value):
97    """Override to allow TrackableBase to disable dependency tracking."""
98    return data_structures.NoDependency(value)
99
100  def _list_functions_for_serialization(self, unused_serialization_cache):
101    """Return a dict of `Function`s of a trackable."""
102    functions = {}
103    for attribute_name in dir(self):
104      try:
105        attribute_value = getattr(self, attribute_name, None)
106      except Exception:  # pylint: disable=broad-except
107        # We really don't want to throw an exception just because some object's
108        # attribute accessor is broken.
109        attribute_value = None
110      if isinstance(attribute_value, (def_function.Function,
111                                      defun.ConcreteFunction)):
112        functions[attribute_name] = attribute_value
113    return functions
114
115
116def delete_tracking(obj, name):
117  """Removes the tracking of name from object."""
118  # pylint: disable=protected-access
119  if name in obj._unconditional_dependency_names:
120    del obj._unconditional_dependency_names[name]
121    for index, (dep_name, _) in enumerate(
122        obj._unconditional_checkpoint_dependencies):
123      if dep_name == name:
124        del obj._unconditional_checkpoint_dependencies[index]
125        break
126  # pylint: enable=protected-access
127
128
129class ResourceTracker(object):
130  """An object that tracks a list of resources."""
131
132  def __init__(self):
133    self._resources = []
134
135  @property
136  def resources(self):
137    return self._resources
138
139  def add_resource(self, resource):
140    self._resources.append(resource)
141
142
143@tf_contextlib.contextmanager
144def resource_tracker_scope(resource_tracker):
145  """A context to manage resource trackers.
146
147  Use this in order to collect up all resources created within a block of code.
148  Example usage:
149
150  ```python
151  resource_tracker = ResourceTracker()
152  with resource_tracker_scope(resource_tracker):
153    resource = TrackableResource()
154
155  assert resource_tracker.resources == [resource]
156
157  Args:
158    resource_tracker: The passed in ResourceTracker object
159
160  Yields:
161    A scope in which the resource_tracker is active.
162  """
163  global _RESOURCE_TRACKER_STACK
164  old = list(_RESOURCE_TRACKER_STACK)
165  _RESOURCE_TRACKER_STACK.append(resource_tracker)
166  try:
167    yield
168  finally:
169    _RESOURCE_TRACKER_STACK = old
170
171
172class CapturableResourceDeleter(object):
173  """Deleter to destroy CapturableResource without overriding its __del__()."""
174
175  def __init__(self, destroy_resource_fn=None):
176    if destroy_resource_fn:
177      self._destroy_resource = destroy_resource_fn
178      self._destruction_context = (
179          context.eager_mode if context.executing_eagerly()
180          else ops.get_default_graph().as_default)
181    else:
182      self._destroy_resource = None
183
184  def destroy_resource(self):
185    if self._destroy_resource:
186      return self._destroy_resource()
187
188  def __del__(self):
189    if self._destroy_resource:
190      with self._destruction_context():
191        self._destroy_resource()
192
193
194class CapturableResource(base.Trackable):
195  """Holds a Tensor which a tf.function can capture.
196
197  `CapturableResource`s are discovered by traversing the graph of object
198  attributes, e.g. during `tf.saved_model.save`. They are excluded from the
199  scope-based tracking of `TrackableResource`; generally things that require
200  initialization should inherit from `TrackableResource` instead of
201  `CapturableResource` directly.
202  """
203
204  def __init__(self, device="", deleter=None):
205    """Initialize the `CapturableResource`.
206
207    Args:
208      device: A string indicating a required placement for this resource,
209        e.g. "CPU" if this resource must be created on a CPU device. A blank
210        device allows the user to place resource creation, so generally this
211        should be blank unless the resource only makes sense on one device.
212      deleter: A CapturableResourceDeleter that will destroy the created
213        resource during destruction.
214    """
215    self._resource_handle = None
216    self._resource_device = device
217    self._resource_deleter = deleter or CapturableResourceDeleter()
218
219  def _create_resource(self):
220    """A function that creates a resource handle."""
221    raise NotImplementedError("TrackableResource._create_resource not "
222                              "implemented.")
223
224  def _initialize(self):
225    """A function that initializes the resource. Optional."""
226    pass
227
228  @property
229  def resource_handle(self):
230    """Returns the resource handle associated with this Resource."""
231    if self._resource_handle is None:
232      with ops.device(self._resource_device):
233        self._resource_handle = self._create_resource()
234    return self._resource_handle
235
236  def _list_functions_for_serialization(self, unused_functions):
237    @def_function.function(input_signature=[], autograph=False)
238    def _creator():
239      resource = self._create_resource()
240      return resource
241
242    @def_function.function(input_signature=[], autograph=False)
243    def _initializer():
244      self._initialize()
245      return 1  # Dummy return
246
247    @def_function.function(input_signature=[], autograph=False)
248    def _destroyer():
249      self._resource_deleter.destroy_resource()
250      return 1  # Dummy return
251
252    return {
253        "_create_resource": _creator,
254        "_initialize": _initializer,
255        "_destroy_resource": _destroyer,
256    }
257
258
259class TrackableResource(CapturableResource):
260  """Adds scope tracking to CapturableResource."""
261
262  def __init__(self, device="", deleter=None):
263    """Initialize the `TrackableResource`.
264
265    Args:
266      device: A string indicating a required placement for this resource,
267        e.g. "CPU" if this resource must be created on a CPU device. A blank
268        device allows the user to place resource creation, so generally this
269        should be blank unless the resource only makes sense on one device.
270      deleter: A CapturableResourceDeleter that will destroy the created
271        resource during destruction.
272    """
273    global _RESOURCE_TRACKER_STACK
274    for resource_tracker in _RESOURCE_TRACKER_STACK:
275      resource_tracker.add_resource(self)
276    super(TrackableResource, self).__init__(device=device, deleter=deleter)
277
278
279@tf_export("saved_model.Asset")
280class Asset(base.Trackable):
281  """Represents a file asset to hermetically include in a SavedModel.
282
283  A SavedModel can include arbitrary files, called assets, that are needed
284  for its use. For example a vocabulary file used initialize a lookup table.
285
286  When a trackable object is exported via `tf.saved_model.save()`, all the
287  `Asset`s reachable from it are copied into the SavedModel assets directory.
288  Upon loading, the assets and the serialized functions that depend on them
289  will refer to the correct filepaths inside the SavedModel directory.
290
291  Example:
292
293  ```
294  filename = tf.saved_model.Asset("file.txt")
295
296  @tf.function(input_signature=[])
297  def func():
298    return tf.io.read_file(filename)
299
300  trackable_obj = tf.train.Checkpoint()
301  trackable_obj.func = func
302  trackable_obj.filename = filename
303  tf.saved_model.save(trackable_obj, "/tmp/saved_model")
304
305  # The created SavedModel is hermetic, it does not depend on
306  # the original file and can be moved to another path.
307  tf.io.gfile.remove("file.txt")
308  tf.io.gfile.rename("/tmp/saved_model", "/tmp/new_location")
309
310  reloaded_obj = tf.saved_model.load("/tmp/new_location")
311  print(reloaded_obj.func())
312  ```
313
314  Attributes:
315    asset_path: A 0-D `tf.string` tensor with path to the asset.
316  """
317
318  def __init__(self, path):
319    """Record the full path to the asset."""
320    # The init_scope prevents functions from capturing `path` in an
321    # initialization graph, since it is transient and should not end up in a
322    # serialized function body.
323    with ops.init_scope(), ops.device("CPU"):
324      self._path = ops.convert_to_tensor(
325          path, dtype=dtypes.string, name="asset_path")
326
327  @property
328  def asset_path(self):
329    """Fetch the current asset path."""
330    return self._path
331
332
333def cached_per_instance(f):
334  """Lightweight decorator for caching lazily constructed properties.
335
336  When to use:
337  This decorator provides simple caching with minimal overhead. It is designed
338  for properties which are expensive to compute and static over the life of a
339  class instance, and provides no mechanism for cache invalidation. Thus it is
340  best suited for lazily exposing derived properties of other static data.
341
342  For classes with custom getattr / setattr behavior (such as trackable
343  objects), storing cache results as object attributes is not performant.
344  Instead, a specialized cache can significantly reduce property lookup
345  overhead. (While still allowing the decorated property to be lazily computed.)
346  Consider the following class:
347
348  ```
349  class MyClass(object):
350    def __setattr__(self, key, value):
351      # Some expensive class specific code
352      # ...
353      # ...
354
355      super(MyClass, self).__setattr__(key, value)
356
357    @property
358    def thing(self):
359      # `thing` is expensive to compute (and may not even be requested), so we
360      # want to lazily compute it and then cache it.
361      output = getattr(self, '_thing', None)
362      if output is None:
363        self._thing = output = compute_thing(self)
364      return output
365  ```
366
367  It's also worth noting that ANY overriding of __setattr__, even something as
368  simple as:
369  ```
370    def __setattr__(self, key, value):
371      super(MyClass, self).__setattr__(key, value)
372  ```
373
374  Slows down attribute assignment by nearly 10x.
375
376  By contrast, replacing the definition of `thing` with the following sidesteps
377  the expensive __setattr__ altogether:
378
379  '''
380  @property
381  @tracking.cached_per_instance
382  def thing(self):
383    # `thing` is expensive to compute (and may not even be requested), so we
384    # want to lazily compute it and then cache it.
385    return compute_thing(self)
386  '''
387
388  Performance:
389  The overhead for this decorator is ~0.4 us / call. A much lower overhead
390  implementation (~0.085 us / call) can be achieved by using a custom dict type:
391
392  ```
393  def dict_based_cache(f):
394    class Cache(dict):
395      __slots__ = ()
396      def __missing__(self, key):
397        self[key] = output = f(key)
398        return output
399
400    return property(Cache().__getitem__)
401  ```
402
403  However, that implementation holds class instances as keys, and as a result
404  blocks garbage collection. (And modifying it to use weakref's as keys raises
405  the lookup overhead to ~0.4 us) As a result, the WeakKeyDictionary
406  implementation below turns out to be more prudent.
407
408  Args:
409    f: The function to cache.
410
411  Returns:
412    f decorated with simple caching behavior.
413  """
414
415  cache = weakref.WeakKeyDictionary()
416
417  @functools.wraps(f)
418  def wrapped(item):
419    output = cache.get(item)
420    if output is None:
421      cache[item] = output = f(item)
422    return output
423
424  wrapped.cache = cache
425  return wrapped
426
427
428ops.register_tensor_conversion_function(
429    Asset, lambda asset, **kw: ops.convert_to_tensor(asset.asset_path, **kw))
430