• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Dependency tracking for trackable objects."""
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.eager import def_function
21from tensorflow.python.eager import function as defun
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.training.tracking import base
25from tensorflow.python.training.tracking import data_structures
26from tensorflow.python.util import tf_contextlib
27
28
29# global _RESOURCE_TRACKER_STACK
30_RESOURCE_TRACKER_STACK = []
31
32
33class NotTrackable(object):
34  """Marks instances of child classes as unsaveable using an object-based API.
35
36  Useful for marking objects which would otherwise look trackable because
37  of inheritance (e.g. through `Layer`) as not trackable. Inheriting from
38  `NotTrackable` does not prevent an object from being assigned to any
39  attributes, but will throw an error on save/restore.
40  """
41  pass
42
43
44class AutoTrackable(base.Trackable):
45  """Manages dependencies on other objects.
46
47  `Trackable` objects may have dependencies: other `Trackable` objects
48  which should be saved if the object declaring the dependency is saved. A
49  correctly saveable program has a dependency graph such that if changing a
50  global variable affects an object (e.g. changes the behavior of any of its
51  methods) then there is a chain of dependencies from the influenced object to
52  the variable.
53
54  Dependency edges have names, and are created implicitly when a
55  `Trackable` object is assigned to an attribute of another
56  `Trackable` object. For example:
57
58  ```
59  obj = Trackable()
60  obj.v = ResourceVariable(0.)
61  ```
62
63  The `Trackable` object `obj` now has a dependency named "v" on a
64  variable.
65
66  `Trackable` objects may specify `Tensor`s to be saved and restored
67  directly (e.g. a `Variable` indicating how to save itself) rather than through
68  dependencies on other objects. See
69  `Trackable._gather_saveables_for_checkpoint` for details.
70  """
71
72  def __setattr__(self, name, value):
73    """Support self.foo = trackable syntax."""
74    if getattr(self, "_setattr_tracking", True):
75      value = data_structures.sticky_attribute_assignment(
76          trackable=self, value=value, name=name)
77    super(AutoTrackable, self).__setattr__(name, value)
78
79  def __delattr__(self, name):
80    self._maybe_initialize_trackable()
81    if name in self._unconditional_dependency_names:
82      del self._unconditional_dependency_names[name]
83      for index, (dep_name, _) in enumerate(
84          self._unconditional_checkpoint_dependencies):
85        if dep_name == name:
86          del self._unconditional_checkpoint_dependencies[index]
87          break
88    super(AutoTrackable, self).__delattr__(name)
89
90  def _no_dependency(self, value):
91    """Override to allow TrackableBase to disable dependency tracking."""
92    return data_structures.NoDependency(value)
93
94  def _list_functions_for_serialization(self):
95    """Return a dict of `Function`s of a trackable."""
96    functions = dict()
97    for attribute_name in dir(self):
98      try:
99        attribute_value = getattr(self, attribute_name, None)
100      except Exception:  # pylint: disable=broad-except
101        # We really don't want to throw an exception just because some object's
102        # attribute accessor is broken.
103        attribute_value = None
104      if isinstance(attribute_value, (def_function.Function,
105                                      defun.ConcreteFunction)):
106        functions[attribute_name] = attribute_value
107    return functions
108
109
110class ResourceTracker(object):
111  """An object that tracks a list of resources."""
112
113  def __init__(self):
114    self._resources = []
115
116  @property
117  def resources(self):
118    return self._resources
119
120  def add_resource(self, resource):
121    self._resources.append(resource)
122
123
124@tf_contextlib.contextmanager
125def resource_tracker_scope(resource_tracker):
126  """A context to manage resource trackers.
127
128  Use this in order to collect up all resources created within a block of code.
129  Example usage:
130
131  ```python
132  resource_tracker = ResourceTracker()
133  with resource_tracker_scope(resource_tracker):
134    resource = TrackableResource()
135
136  assert resource_tracker.resources == [resource]
137
138  Args:
139    resource_tracker: The passed in ResourceTracker object
140
141  Yields:
142    A scope in which the resource_tracker is active.
143  """
144  global _RESOURCE_TRACKER_STACK
145  old = list(_RESOURCE_TRACKER_STACK)
146  _RESOURCE_TRACKER_STACK.append(resource_tracker)
147  try:
148    yield
149  finally:
150    _RESOURCE_TRACKER_STACK = old
151
152
153class TrackableResource(base.Trackable):
154  """Base class for all resources that need to be tracked."""
155
156  def __init__(self):
157    global _RESOURCE_TRACKER_STACK
158    for resource_tracker in _RESOURCE_TRACKER_STACK:
159      resource_tracker.add_resource(self)
160
161    self._resource_handle = None
162
163  def _create_resource(self):
164    """A function that creates a resource handle."""
165    raise NotImplementedError("TrackableResource._create_resource not "
166                              "implemented.")
167
168  def _initialize(self):
169    """A function that initializes the resource. Optional."""
170    pass
171
172  @property
173  def resource_handle(self):
174    """Returns the resource handle associated with this Resource."""
175    if self._resource_handle is None:
176      self._resource_handle = self._create_resource()
177    return self._resource_handle
178
179  def _list_functions_for_serialization(self):
180    @def_function.function(input_signature=[], autograph=False)
181    def _creator():
182      resource = self._create_resource()
183      return resource
184
185    @def_function.function(input_signature=[], autograph=False)
186    def _initializer():
187      self._initialize()
188      return 1  # Dummy return
189
190    return {
191        "_create_resource": _creator,
192        "_initialize": _initializer,
193    }
194
195
196class TrackableAsset(base.Trackable):
197  """Base class for asset files which need to be tracked."""
198
199  def __init__(self, path):
200    """Record the full path to the asset."""
201    # The init_scope prevents functions from capturing `path` in an
202    # initialization graph, since it is transient and should not end up in a
203    # serialized function body.
204    with ops.init_scope():
205      self._path = ops.internal_convert_to_tensor(path, dtype=dtypes.string,
206                                                  name="asset_path")
207
208  @property
209  def asset_path(self):
210    """Fetch the current asset path."""
211    return self._path
212
213ops.register_tensor_conversion_function(
214    TrackableAsset,
215    lambda asset, **kw: ops.internal_convert_to_tensor(asset.asset_path, **kw))
216