• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Dependency tracking for trackable objects."""
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import copy
21import warnings
22
23from absl import logging
24
25from tensorflow.python.eager import context
26from tensorflow.python.eager import def_function
27from tensorflow.python.eager import function as defun
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import ops
30from tensorflow.python.training.tracking import base
31from tensorflow.python.training.tracking import data_structures
32from tensorflow.python.util import tf_contextlib
33from tensorflow.python.util.tf_export import tf_export
34
35
36# global _RESOURCE_TRACKER_STACK
37_RESOURCE_TRACKER_STACK = []
38
39
40class NotTrackable(object):
41  """Marks instances of child classes as unsaveable using an object-based API.
42
43  Useful for marking objects which would otherwise look trackable because
44  of inheritance (e.g. through `Layer`) as not trackable. Inheriting from
45  `NotTrackable` does not prevent an object from being assigned to any
46  attributes, but will throw an error on save/restore.
47  """
48  pass
49
50
51@tf_export("__internal__.tracking.AutoTrackable", v1=[])
52class AutoTrackable(base.Trackable):
53  """Manages dependencies on other objects.
54
55  `Trackable` objects may have dependencies: other `Trackable` objects
56  which should be saved if the object declaring the dependency is saved. A
57  correctly saveable program has a dependency graph such that if changing a
58  global variable affects an object (e.g. changes the behavior of any of its
59  methods) then there is a chain of dependencies from the influenced object to
60  the variable.
61
62  Dependency edges have names, and are created implicitly when a
63  `Trackable` object is assigned to an attribute of another
64  `Trackable` object. For example:
65
66  ```
67  obj = Trackable()
68  obj.v = ResourceVariable(0.)
69  ```
70
71  The `Trackable` object `obj` now has a dependency named "v" on a
72  variable.
73
74  `Trackable` objects may specify `Tensor`s to be saved and restored
75  directly (e.g. a `Variable` indicating how to save itself) rather than through
76  dependencies on other objects. See
77  `Trackable._gather_saveables_for_checkpoint` for details.
78  """
79
80  def __setattr__(self, name, value):
81    """Support self.foo = trackable syntax."""
82    try:
83      if getattr(self, name) is value:
84        # Short circuit for `self.$x = self.$x`.
85        return
86    except AttributeError:
87      pass
88
89    if getattr(self, "_self_setattr_tracking", True):
90      value = data_structures.sticky_attribute_assignment(
91          trackable=self, value=value, name=name)
92    super(AutoTrackable, self).__setattr__(name, value)
93
94  def __delattr__(self, name):
95    self._maybe_initialize_trackable()
96    delete_tracking(self, name)
97    super(AutoTrackable, self).__delattr__(name)
98
99  def _no_dependency(self, value):
100    """Override to allow TrackableBase to disable dependency tracking."""
101    return data_structures.NoDependency(value)
102
103  def _list_functions_for_serialization(self, unused_serialization_cache):
104    """Return a dict of `Function`s of a trackable."""
105    functions = {}
106    for attribute_name in dir(self):
107      # We get the attributes, suppressing warnings and exceptions.
108      logging_verbosity = logging.get_verbosity()
109      try:
110        logging.set_verbosity(logging.FATAL)
111        with warnings.catch_warnings():
112          warnings.simplefilter("ignore")
113          attribute_value = getattr(self, attribute_name, None)
114      except Exception:  # pylint: disable=broad-except
115        # We really don't want to throw an exception just because some object's
116        # attribute accessor is broken.
117        attribute_value = None
118      finally:
119        # We reset the verbosity setting in a `finally` block, to make
120        # sure it always happens, even if we make the exception catching above
121        # be less broad.
122        logging.set_verbosity(logging_verbosity)
123      if isinstance(attribute_value, (def_function.Function,
124                                      defun.ConcreteFunction)):
125        functions[attribute_name] = attribute_value
126    return functions
127
128
129def delete_tracking(obj, name):
130  """Removes the tracking of name from object."""
131  # pylint: disable=protected-access
132  if name in obj._unconditional_dependency_names:
133    del obj._unconditional_dependency_names[name]
134    for index, (dep_name, _) in enumerate(
135        obj._unconditional_checkpoint_dependencies):
136      if dep_name == name:
137        del obj._unconditional_checkpoint_dependencies[index]
138        break
139  # pylint: enable=protected-access
140
141
142class ResourceTracker(object):
143  """An object that tracks a list of resources."""
144
145  __slots__ = ["_resources"]
146
147  def __init__(self):
148    self._resources = []
149
150  @property
151  def resources(self):
152    return self._resources
153
154  def add_resource(self, resource):
155    self._resources.append(resource)
156
157
158@tf_contextlib.contextmanager
159def resource_tracker_scope(resource_tracker):
160  """A context to manage resource trackers.
161
162  Use this in order to collect up all resources created within a block of code.
163  Example usage:
164
165  ```python
166  resource_tracker = ResourceTracker()
167  with resource_tracker_scope(resource_tracker):
168    resource = TrackableResource()
169
170  assert resource_tracker.resources == [resource]
171
172  Args:
173    resource_tracker: The passed in ResourceTracker object
174
175  Yields:
176    A scope in which the resource_tracker is active.
177  """
178  global _RESOURCE_TRACKER_STACK
179  old = list(_RESOURCE_TRACKER_STACK)
180  _RESOURCE_TRACKER_STACK.append(resource_tracker)
181  try:
182    yield
183  finally:
184    _RESOURCE_TRACKER_STACK = old
185
186
187class CapturableResourceDeleter(object):
188  """Deleter to destroy CapturableResource without overriding its __del__()."""
189
190  __slots__ = ["_destruction_context", "_destroy_resource"]
191
192  def __init__(self, destroy_resource_fn=None):
193    if destroy_resource_fn:
194      self._destroy_resource = destroy_resource_fn
195      self._destruction_context = (
196          context.eager_mode if context.executing_eagerly()
197          else ops.get_default_graph().as_default)
198    else:
199      self._destroy_resource = None
200
201  def destroy_resource(self):
202    if self._destroy_resource:
203      return self._destroy_resource()
204
205  def __del__(self):
206    if self._destroy_resource:
207      with self._destruction_context():
208        self._destroy_resource()
209
210
211class CapturableResource(base.Trackable):
212  """Holds a Tensor which a tf.function can capture.
213
214  `CapturableResource`s are discovered by traversing the graph of object
215  attributes, e.g. during `tf.saved_model.save`. They are excluded from the
216  scope-based tracking of `TrackableResource`; generally things that require
217  initialization should inherit from `TrackableResource` instead of
218  `CapturableResource` directly.
219  """
220
221  def __init__(self, device="", deleter=None):
222    """Initialize the `CapturableResource`.
223
224    Args:
225      device: A string indicating a required placement for this resource,
226        e.g. "CPU" if this resource must be created on a CPU device. A blank
227        device allows the user to place resource creation, so generally this
228        should be blank unless the resource only makes sense on one device.
229      deleter: A CapturableResourceDeleter that will destroy the created
230        resource during destruction.
231    """
232    self._resource_handle = None
233    self._resource_device = device
234    self._resource_deleter = deleter or CapturableResourceDeleter()
235
236  def _create_resource(self):
237    """A function that creates a resource handle."""
238    raise NotImplementedError("TrackableResource._create_resource not "
239                              "implemented.")
240
241  def _initialize(self):
242    """A function that initializes the resource. Optional."""
243    pass
244
245  @property
246  def resource_handle(self):
247    """Returns the resource handle associated with this Resource."""
248    if self._resource_handle is None:
249      with ops.device(self._resource_device):
250        self._resource_handle = self._create_resource()
251    return self._resource_handle
252
253  def _map_resources(self, _):
254    """For implementing `Trackable`."""
255    new_obj = copy.copy(self)
256    # pylint: disable=protected-access
257    with ops.device(self._resource_device):
258      new_resource = new_obj._create_resource()
259    new_obj._resource_handle = new_resource
260    # pylint: enable=protected-access
261    obj_map = {self: new_obj}
262    resource_map = {self.resource_handle: new_resource}
263    return obj_map, resource_map
264
265  def _list_functions_for_serialization(self, unused_functions):
266    @def_function.function(input_signature=[], autograph=False)
267    def _creator():
268      resource = self._create_resource()
269      return resource
270
271    @def_function.function(input_signature=[], autograph=False)
272    def _initializer():
273      self._initialize()
274      return 1  # Dummy return
275
276    @def_function.function(input_signature=[], autograph=False)
277    def _destroyer():
278      self._resource_deleter.destroy_resource()
279      return 1  # Dummy return
280
281    return {
282        "_create_resource": _creator,
283        "_initialize": _initializer,
284        "_destroy_resource": _destroyer,
285    }
286
287
288class TrackableResource(CapturableResource):
289  """Adds scope tracking to CapturableResource."""
290
291  def __init__(self, device="", deleter=None):
292    """Initialize the `TrackableResource`.
293
294    Args:
295      device: A string indicating a required placement for this resource,
296        e.g. "CPU" if this resource must be created on a CPU device. A blank
297        device allows the user to place resource creation, so generally this
298        should be blank unless the resource only makes sense on one device.
299      deleter: A CapturableResourceDeleter that will destroy the created
300        resource during destruction.
301    """
302    global _RESOURCE_TRACKER_STACK
303    for resource_tracker in _RESOURCE_TRACKER_STACK:
304      resource_tracker.add_resource(self)
305    super(TrackableResource, self).__init__(device=device, deleter=deleter)
306
307
308@tf_export("saved_model.Asset")
309class Asset(base.Trackable):
310  """Represents a file asset to hermetically include in a SavedModel.
311
312  A SavedModel can include arbitrary files, called assets, that are needed
313  for its use. For example a vocabulary file used initialize a lookup table.
314
315  When a trackable object is exported via `tf.saved_model.save()`, all the
316  `Asset`s reachable from it are copied into the SavedModel assets directory.
317  Upon loading, the assets and the serialized functions that depend on them
318  will refer to the correct filepaths inside the SavedModel directory.
319
320  Example:
321
322  ```
323  filename = tf.saved_model.Asset("file.txt")
324
325  @tf.function(input_signature=[])
326  def func():
327    return tf.io.read_file(filename)
328
329  trackable_obj = tf.train.Checkpoint()
330  trackable_obj.func = func
331  trackable_obj.filename = filename
332  tf.saved_model.save(trackable_obj, "/tmp/saved_model")
333
334  # The created SavedModel is hermetic, it does not depend on
335  # the original file and can be moved to another path.
336  tf.io.gfile.remove("file.txt")
337  tf.io.gfile.rename("/tmp/saved_model", "/tmp/new_location")
338
339  reloaded_obj = tf.saved_model.load("/tmp/new_location")
340  print(reloaded_obj.func())
341  ```
342
343  Attributes:
344    asset_path: A 0-D `tf.string` tensor with path to the asset.
345  """
346
347  def __init__(self, path):
348    """Record the full path to the asset."""
349    # The init_scope prevents functions from capturing `path` in an
350    # initialization graph, since it is transient and should not end up in a
351    # serialized function body.
352    with ops.init_scope(), ops.device("CPU"):
353      self._path = ops.convert_to_tensor(
354          path, dtype=dtypes.string, name="asset_path")
355
356  @property
357  def asset_path(self):
358    """Fetch the current asset path."""
359    return self._path
360
361
362ops.register_tensor_conversion_function(
363    Asset, lambda asset, **kw: ops.convert_to_tensor(asset.asset_path, **kw))
364