• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Trackable data structures."""
2# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import collections
21import copy
22import operator
23import sys
24
25import six
26try:
27  import wrapt
28except ImportError:
29  # Fall back to the build-time dependency if the system package is not available.
30  from .....third_party import wrapt
31
32from tensorflow.python.eager import def_function
33from tensorflow.python.eager import function as defun
34from tensorflow.python.ops import variables
35from tensorflow.python.saved_model import revived_types
36from tensorflow.python.training.tracking import base
37from tensorflow.python.training.tracking import layer_utils
38from tensorflow.python.util.compat import collections_abc
39
40
41class NoDependency(object):
42  """Allows attribute assignment to `Trackable` objects with no dependency.
43
44  Example usage:
45  ```python
46  obj = Trackable()
47  obj.has_dependency = tf.Variable(0., name="dep")
48  obj.no_dependency = NoDependency(tf.Variable(1., name="nodep"))
49  assert obj.no_dependency.name == "nodep:0"
50  ```
51
52  `obj` in this example has a dependency on the variable "dep", and both
53  attributes contain un-wrapped `Variable` objects.
54
55  `NoDependency` also works with `tf.keras.Model`, but only for checkpoint
56  dependencies: wrapping a `Layer` in `NoDependency` will assign the (unwrapped)
57  `Layer` to the attribute without a checkpoint dependency, but the `Model` will
58  still track the `Layer` (so it will appear in `Model.layers`, and its
59  variables will appear in `Model.variables`).
60  """
61
62  def __init__(self, value):
63    self.value = value
64
65
66def _should_wrap_tuple(t):
67  """Determine if a tuple has any trackable components."""
68  for element in t:
69    if isinstance(element, NoDependency):
70      return True  # We should remove the NoDependency object from the tuple.
71    if isinstance(element, base.Trackable):
72      return True
73    if _wrap_or_unwrap(element) is not element:
74      return True
75  # There are no trackable elements or data structures. Tuples are immutable, so
76  # mutation isn't a concern. Don't wrap.
77  return False
78
79
80def _wrap_or_unwrap(value):
81  """Wraps basic data structures, unwraps NoDependency objects."""
82  # pylint: disable=unidiomatic-typecheck
83  # Exact type checking to avoid mucking up custom logic in list/dict
84  # subclasses, e.g. collections.Counter.
85  if isinstance(value, NoDependency):
86    return value.value
87  if isinstance(value, base.Trackable):
88    return value  # Skip conversion for already trackable objects.
89  elif type(value) == dict:
90    return _DictWrapper(value)
91  elif type(value) == collections.OrderedDict:
92    return _DictWrapper(value)
93  elif type(value) == list:
94    return ListWrapper(value)
95  elif isinstance(value, tuple) and _should_wrap_tuple(value):
96    # There are trackable elements or data structures. Wrap the tuple.
97    return _TupleWrapper(value)
98  else:
99    return value
100  # pylint: enable=unidiomatic-typecheck
101
102
103def sticky_attribute_assignment(trackable, name, value):
104  """Adds dependencies, generally called from __setattr__.
105
106  This behavior is shared between Trackable and Model.
107
108  Respects NoDependency indicators, but otherwise makes trackable objects
109  out of common data structures and tracks objects by their attribute names.
110
111  Args:
112    trackable: The object to add dependencies to (generally the one having
113      an attribute assigned).
114    name: The attribute name being assigned.
115    value: The value being assigned. Not necessarily a trackable object.
116
117  Returns:
118    The value which should be stored in the attribute (unwrapped from a
119    NoDependency object if necessary).
120  """
121  if isinstance(value, NoDependency):
122    add_dependency = False
123  else:
124    add_dependency = True
125  value = _wrap_or_unwrap(value)
126  if not add_dependency:
127    return value
128  if isinstance(value, base.Trackable):
129    trackable._track_trackable(  # pylint: disable=protected-access
130        value, name=name,
131        # Allow the user to switch the Trackable which is tracked by this
132        # name, since assigning a new variable to an attribute has
133        # historically been fine (e.g. Adam did this).
134        overwrite=True)
135  return value
136
137
138class _UntrackableError(ValueError):
139
140  def __init__(self, value):  # pylint: disable=super-init-not-called
141    self._value = value
142
143  def __str__(self):
144    return (("Only trackable objects (such as Layers or Optimizers) may be "
145             "stored in a List object. Got %s, which does not inherit from "
146             "Trackable.") % (self._value,))
147
148
149class TrackableDataStructure(base.Trackable):
150  """Base class for data structures which contain trackable objects."""
151
152  def __init__(self):
153    # Attributes prefixed with "_self_" for compatibility with
154    # wrapt.ObjectProxy. All additional attrs MUST conform to this pattern, as
155    # extending `__slots__` on a subclass of ObjectProxy breaks in a variety of
156    # ways.
157    self._self_trainable = True
158    self._self_extra_variables = []
159    self._self_attribute_sentinel = layer_utils.AttributeSentinel(True)
160
161  @property
162  def _attribute_sentinel(self):
163    return self._self_attribute_sentinel
164
165  @property
166  def trainable(self):
167    return self._self_trainable
168
169  @trainable.setter
170  def trainable(self, value):
171    self._self_trainable = value
172
173  def _track_value(self, value, name):
174    """Add a dependency on `value`."""
175    value = sticky_attribute_assignment(
176        trackable=self, value=value, name=name)
177    if isinstance(value, variables.Variable):
178      self._self_extra_variables.append(value)
179    if not isinstance(value, base.Trackable):
180      raise _UntrackableError(value)
181    if hasattr(value, "_use_resource_variables"):
182      # In subclassed models, legacy layers (tf.layers) must always use
183      # resource variables.
184      value._use_resource_variables = True  # pylint: disable=protected-access
185    value_attribute_sentinel = getattr(value, "_attribute_sentinel", None)
186    if value_attribute_sentinel:
187      value_attribute_sentinel.add_parent(self._attribute_sentinel)
188    return value
189
190  @property
191  def _values(self):
192    """An iterable/sequence which may contain trackable objects."""
193    raise NotImplementedError("Abstract method")
194
195  @property
196  def _layers(self):
197    """All Layers and Layer containers, including empty containers."""
198    # Filter objects on demand so that wrapper objects use values from the thing
199    # they're wrapping if out of sync.
200    collected = []
201    for obj in self._values:
202      if (isinstance(obj, TrackableDataStructure)
203          or layer_utils.is_layer(obj)
204          or layer_utils.has_weights(obj)):
205        collected.append(obj)
206    return collected
207
208  @property
209  def layers(self):
210    return list(layer_utils.filter_empty_layer_containers(self._layers))
211
212  @property
213  def trainable_weights(self):
214    return layer_utils.gather_trainable_weights(
215        trainable=self.trainable,
216        sub_layers=self._layers,
217        extra_variables=self._self_extra_variables)
218
219  @property
220  def non_trainable_weights(self):
221    return layer_utils.gather_non_trainable_weights(
222        trainable=self.trainable,
223        sub_layers=self._layers,
224        extra_variables=self._self_extra_variables)
225
226  @property
227  def weights(self):
228    return self.trainable_weights + self.non_trainable_weights
229
230  @property
231  def trainable_variables(self):
232    return self.trainable_weights
233
234  @property
235  def non_trainable_variables(self):
236    return self.non_trainable_weights
237
238  @property
239  def variables(self):
240    return self.weights
241
242  @property
243  def updates(self):
244    """Aggregate updates from any `Layer` instances."""
245    # Updates and conditional losses are forwarded as-is rather than being
246    # filtered based on inputs, since this is just a container and won't ever
247    # have any inputs.
248    aggregated = []
249    for layer in self.layers:
250      if hasattr(layer, "updates"):
251        aggregated += layer.updates
252    return aggregated
253
254  @property
255  def losses(self):
256    """Aggregate losses from any `Layer` instances."""
257    aggregated = []
258    for layer in self.layers:
259      if hasattr(layer, "losses"):
260        aggregated += layer.losses
261    return aggregated
262
263  def __hash__(self):
264    # Support object-identity hashing, so these structures can be used as keys
265    # in sets/dicts.
266    return id(self)
267
268  def __eq__(self, other):
269    # Similar to Tensors, trackable data structures use object-identity
270    # equality to support set/dict membership.
271    return self is other
272
273
274class List(TrackableDataStructure, collections_abc.Sequence):
275  """An append-only sequence type which is trackable.
276
277  Maintains checkpoint dependencies on its contents (which must also be
278  trackable), and forwards any `Layer` metadata such as updates and losses.
279
280  Note that `List` is purely a container. It lets a `tf.keras.Model` or
281  other trackable object know about its contents, but does not call any
282  `Layer` instances which are added to it. To indicate a sequence of `Layer`
283  instances which should be called sequentially, use `tf.keras.Sequential`.
284
285  Example usage:
286  ```python
287  class HasList(tf.keras.Model):
288
289    def __init__(self):
290      super(HasList, self).__init__()
291      self.layer_list = tf.contrib.checkpoint.List([layers.Dense(3)])
292      self.layer_list.append(layers.Dense(4))
293
294    def call(self, x):
295      aggregation = 0.
296      for l in self.layer_list:
297        x = l(x)
298        aggregation += tf.reduce_sum(x)
299      return aggregation
300  ```
301
302  This kind of wrapping is necessary because `Trackable` objects do not
303  (yet) deeply inspect regular Python data structures, so for example assigning
304  a regular list (`self.layer_list = [layers.Dense(3)]`) does not create a
305  checkpoint dependency and does not add the `Layer` instance's weights to its
306  parent `Model`.
307  """
308
309  def __init__(self, *args, **kwargs):
310    """Construct a new sequence. Arguments are passed to `list()`."""
311    super(List, self).__init__()
312    self._storage = self._make_storage(*args, **kwargs)
313    for index, element in enumerate(self._storage):
314      self._storage[index] = self._track_value(
315          element, name=self._name_element(index))
316
317  def copy(self):
318    return type(self)(copy.copy(self._storage))
319
320  def __copy__(self):
321    return self.copy()
322
323  def __deepcopy__(self, memo):
324    return type(self)(copy.deepcopy(self._storage, memo))
325
326  def _make_storage(self, *args, **kwargs):
327    """Determines the backing storage (overridden in subclasses)."""
328    return list(*args, **kwargs)
329
330  def _name_element(self, index):
331    return "%d" % (index,)
332
333  @property
334  def _values(self):
335    """Collect values for TrackableDataStructure."""
336    return self
337
338  def append(self, value):
339    """Add a new trackable value."""
340    value = self._track_value(value, self._name_element(len(self._storage)))
341    self._storage.append(value)
342
343  def extend(self, values):
344    """Add a sequence of trackable values."""
345    for value in values:
346      self.append(value)
347
348  def __iadd__(self, values):
349    self.extend(values)
350    return self
351
352  def __add__(self, other):
353    return self._storage + getattr(other, "_storage", other)
354
355  def __imul__(self, y):
356    if y <= 0:
357      raise ValueError(
358          "List only supports append, multiplying in place by %d removes "
359          "elements." % y)
360
361    n = len(self._storage)
362    for _ in range(y - 1):
363      for i in range(n):
364        self.append(self._storage[i])
365
366    return self
367
368  def __mul__(self, n):
369    return self._storage * n
370
371  def __rmul__(self, n):
372    return self * n
373
374  def __radd__(self, other):
375    return other + self._storage
376
377  def __getitem__(self, key):
378    return self._storage[key]
379
380  def __getslice__(self, i, j):
381    return self._storage[slice(i, j)]
382
383  def __len__(self):
384    return len(self._storage)
385
386  def __repr__(self):
387    return "List(%s)" % (repr(self._storage),)
388
389  def __sizeof__(self):
390    return super(List, self).__sizeof__() + sys.getsizeof(self._storage)
391
392
393# TODO(tomhennigan) Update to collections.UserList?
394# TODO(allenl): Try switching this to wrapt.ObjectProxy again when we drop
395# Python 3.4 support (may still be tricky).
396class ListWrapper(
397    List,
398    collections_abc.MutableSequence,
399    # Shadowed, but there for isinstance checks.
400    list):
401  """Wraps the built-in `list` to support restore-on-create for variables.
402
403  Unlike `List`, this sequence type is mutable in the same ways built-in lists
404  are. Instead of throwing an error immediately like `List`, it records
405  problematic mutations (e.g. assigning a new element to a position already
406  occupied, meaning both elements get the same names at different times) and
407  refuses to save.
408
409  On assignment to an attribute of a Model or Trackable object, Python
410  lists are replaced with ListWrapper. Wrapping a list in a
411  `tf.contrib.checkpoint.NoDependency` object prevents this.
412  """
413
414  def __init__(self, wrapped_list):
415    """Construct a new list wrapper.
416
417    Args:
418      wrapped_list: The initial value of the data structure. A shallow copy may
419        be maintained for error checking. `wrapped_list` itself should not be
420        modified directly after constructing the `ListWrapper`, and if changes
421        are detected the `ListWrapper` will throw an exception on save.
422    """
423    # Monotonic flags which indicate this object would not be restored properly,
424    # and therefore should throw an error on save to avoid giving the impression
425    # that restoring it will work.
426    self._non_append_mutation_value = False
427    self._external_modification_value = False
428    super(ListWrapper, self).__init__(wrapped_list)
429    self._last_wrapped_list_snapshot = list(self._storage)
430
431  @property
432  def _non_append_mutation(self):
433    return self._non_append_mutation_value
434
435  @_non_append_mutation.setter
436  def _non_append_mutation(self, value):
437    # Trackable only cares that a mutation occurred at some point; when
438    # attempting to save it checks whether a mutation occurred and the object is
439    # in a "dirty" state but otherwise the specifics of how it got to that state
440    # are ignored. By contrast, the attribute cache needs to signal the mutation
441    # immediately since a caller could query the value of an attribute (And
442    # should not hit the cached value since the mutation may have affected the
443    # result.)
444    self._attribute_sentinel.invalidate_all()
445    self._non_append_mutation_value = value
446
447  @property
448  def _external_modification(self):
449    return self._external_modification_value
450
451  @_external_modification.setter
452  def _external_modification(self, value):
453    # Invalidate for the same reason as `_non_append_mutation`
454    self._attribute_sentinel.invalidate_all()
455    self._external_modification_value = value
456
457  # pylint: disable=protected-access
458  def __copy__(self):
459    copied = super(ListWrapper, self).__copy__()
460    copied._non_append_mutation = self._non_append_mutation
461    copied._external_modification = self._external_modification
462    return copied
463
464  def __deepcopy__(self, memo):
465    copied = super(ListWrapper, self).__deepcopy__(memo)
466    copied._non_append_mutation = self._non_append_mutation
467    copied._external_modification = self._external_modification
468    return copied
469  # pylint: enable=protected-access
470
471  def __reduce_ex__(self, protocol):
472    return (self.__class__,
473            (self._storage,))
474
475  def _make_storage(self, wrapped_list):
476    """Use the user's original list for storage."""
477    return wrapped_list
478
479  def _check_external_modification(self):
480    """Checks for any changes to the wrapped list not through the wrapper."""
481    if self._external_modification or self._non_append_mutation:
482      return
483    if self._storage != self._last_wrapped_list_snapshot:
484      self._external_modification = True
485      self._last_wrapped_list_snapshot = None
486
487  def _update_snapshot(self):
488    """Acknowledges tracked changes to the wrapped list."""
489
490    # Mutation tracking for attributes reuses the same infrastructure as
491    # Trackable mutation tracking.
492    self._attribute_sentinel.invalidate_all()
493    if self._external_modification or self._non_append_mutation:
494      return
495    self._last_wrapped_list_snapshot = list(self._storage)
496
497  @property
498  def _checkpoint_dependencies(self):
499    self._check_external_modification()
500    if self._non_append_mutation:
501      raise ValueError(
502          ("Unable to save the object %s (a list wrapper constructed to track "
503           "trackable TensorFlow objects). A list element was replaced "
504           "(__setitem__, __setslice__), deleted (__delitem__, __delslice__), "
505           "or moved (sort). In order to support restoration on object "
506           "creation, tracking is exclusively for append-only data structures."
507           "\n\nIf you don't need this list checkpointed, wrap it in a "
508           "tf.contrib.checkpoint.NoDependency object; it will be "
509           "automatically un-wrapped and subsequently ignored." % (self,)))
510    if self._external_modification:
511      raise ValueError(
512          ("Unable to save the object %s (a list wrapper constructed to track "
513           "trackable TensorFlow objects). The wrapped list was modified "
514           "outside the wrapper (its final value was %s, its value when a "
515           "checkpoint dependency was added was %s), which breaks restoration "
516           "on object creation.\n\nIf you don't need this list checkpointed, "
517           "wrap it in a tf.contrib.checkpoint.NoDependency object; it will be "
518           "automatically un-wrapped and subsequently ignored." % (
519               self, self._storage, self._last_wrapped_list_snapshot)))
520    return super(ListWrapper, self)._checkpoint_dependencies
521
522  def __delitem__(self, key):
523    self._non_append_mutation = True
524    del self._storage[key]
525
526  def __setitem__(self, key, value):
527    self._check_external_modification()
528
529    if isinstance(key, slice):
530      # Note: this is quite inefficient, but the list API supports a broad range
531      # of slice setters (e.g. truncate, extend, replace) and imitating this
532      # for a range of Python versions is non-trivial.
533      storage_copy = list(self._storage)
534      self._storage[key] = value
535
536      len_before = len(storage_copy)
537      len_now = len(self._storage)
538      for i in range(max(len_before, len_now)):
539        value_now = self._storage[i] if i < len_now else None
540        value_before = storage_copy[i] if i < len_before else None
541
542        if isinstance(value_before, base.Trackable):
543          self._non_append_mutation = True
544
545        if value_now is not None and value_now != value_before:
546          self._storage[i] = self._track_value(self._storage[i],
547                                               self._name_element(i))
548
549    else:
550      if isinstance(self._storage[key], base.Trackable):
551        self._non_append_mutation = True
552      self._storage[key] = self._track_value(value, self._name_element(key))
553
554    self._update_snapshot()
555
556  def append(self, value):
557    """Add a new trackable value."""
558    self._check_external_modification()
559    super(ListWrapper, self).append(value)
560    self._update_snapshot()
561
562  def extend(self, values):
563    """Add a sequence of trackable values."""
564    self._check_external_modification()
565    super(ListWrapper, self).extend(values)
566    self._update_snapshot()
567
568  def __imul__(self, y):
569    if y <= 0:
570      self._self_non_append_mutation = True
571      self._storage *= y
572      return self
573
574    # Relies on super() calling append, which updates the snapshot.
575    return super(ListWrapper, self).__imul__(y)
576
577  def __eq__(self, other):
578    return self._storage == getattr(other, "_storage", other)
579
580  def __ne__(self, other):
581    return self._storage != getattr(other, "_storage", other)
582
583  def __lt__(self, other):
584    return self._storage < getattr(other, "_storage", other)
585
586  def __le__(self, other):
587    return self._storage <= getattr(other, "_storage", other)
588
589  def __gt__(self, other):
590    return self._storage > getattr(other, "_storage", other)
591
592  def __ge__(self, other):
593    return self._storage >= getattr(other, "_storage", other)
594
595  def __hash__(self):
596    # List wrappers need to compare like regular lists, and so like regular
597    # lists they don't belong in hash tables.
598    raise TypeError("unhashable type: 'ListWrapper'")
599
600  def insert(self, index, obj):
601    self._non_append_mutation = True
602    self._storage.insert(index, obj)
603
604  def sort(self):
605    self._non_append_mutation = True
606    self._storage.sort()
607
608  def __setslice__(self, i, j, y):
609    self.__setitem__(slice(i, j), y)
610
611  def __delslice__(self, i, j):
612    self._non_append_mutation = True
613    del self._storage[slice(i, j)]
614
615  def _track_value(self, value, name):
616    """Allows storage of non-trackable objects."""
617    try:
618      value = super(ListWrapper, self)._track_value(value=value, name=name)
619    except ValueError:
620      # Even if this value isn't trackable, we need to make sure
621      # NoDependency objects get unwrapped.
622      value = sticky_attribute_assignment(
623          trackable=self, value=value, name=name)
624    return value
625
626  def __repr__(self):
627    return "ListWrapper(%s)" % (repr(self._storage),)
628
629  def _list_functions_for_serialization(self, unused_functions):
630    return {
631        str(key): value for key, value in enumerate(self)
632        if _is_function(value)
633    }
634
635
636class Mapping(TrackableDataStructure, collections_abc.Mapping):
637  """An append-only trackable mapping data structure with string keys.
638
639  Maintains checkpoint dependencies on its contents (which must also be
640  trackable), named based on its keys.
641
642  Note that once a key has been added, it may not be deleted or replaced. If
643  names may not be unique, see `tf.contrib.checkpoint.UniqueNameTracker`.
644  """
645
646  def __init__(self, *args, **kwargs):
647    """Construct a new sequence. Arguments are passed to `dict()`."""
648    super(Mapping, self).__init__()
649    self._storage = self._make_storage(*args, **kwargs)
650    self._storage.update(
651        {key: self._track_value(
652            value, name=self._name_element(key))
653         for key, value in self._storage.items()})
654
655  def __copy__(self):
656    return type(self)(copy.copy(self._storage))
657
658  def __deepcopy__(self, memo):
659    return type(self)(copy.deepcopy(self._storage, memo))
660
661  def _make_storage(self, *args, **kwargs):
662    return dict(*args, **kwargs)
663
664  @property
665  def _values(self):
666    """Collect values for TrackableDataStructure."""
667    # Sort items deterministically by key
668    ordered = list(zip(*sorted(self.items(), key=lambda it: it[0])))
669    if ordered:
670      return ordered[1]
671    return []
672
673  def _name_element(self, key):
674    if not isinstance(key, six.string_types):
675      raise TypeError(
676          "Mapping accepts only string keys, but got a key %s."
677          % repr(key))
678    return str(key)
679
680  def __setitem__(self, key, value):
681    name = self._name_element(key)
682    value = self._track_value(value, name=name)
683    current_value = self._storage.setdefault(key, value)
684    if current_value is not value:
685      raise ValueError(
686          ("Mappings are an append-only data structure. Tried to overwrite the "
687           "key '%s' with value %s, but it already contains %s")
688          % (key, value, current_value))
689
690  def update(self, *args, **kwargs):
691    for key, value in dict(*args, **kwargs).items():
692      self[key] = value
693
694  def __getitem__(self, key):
695    return self._storage[key]
696
697  def __len__(self):
698    return len(self._storage)
699
700  def __repr__(self):
701    return "Mapping(%s)" % (repr(self._storage),)
702
703  def __iter__(self):
704    return iter(self._storage)
705
706
707class _DictWrapper(TrackableDataStructure, wrapt.ObjectProxy):
708  """Wraps built-in dicts to support restore-on-create for variables.
709
710  _DictWrapper is to Mapping as ListWrapper is to List. Unlike Mapping,
711  _DictWrapper allows non-string keys and values and arbitrary mutations (delete
712  keys, reassign values). Like ListWrapper, these mutations mean that
713  _DictWrapper will raise an exception on save.
714  """
715
716  def __init__(self, wrapped_dict=None):
717    if wrapped_dict is None:
718      # Allow zero-argument construction, e.g. from session.run's re-wrapping.
719      wrapped_dict = {}
720    if not isinstance(wrapped_dict, collections.Mapping):
721      # Allow construction from a sequence, e.g. from nest.pack_sequence_as.
722      wrapped_dict = dict(wrapped_dict)
723    wrapt.ObjectProxy.__init__(self, wrapped_dict)
724    TrackableDataStructure.__init__(self)
725    self._self_non_string_key = False
726    self._self_external_modification = False
727    self.__wrapped__.update(
728        {key: self._track_value(
729            value, name=self._name_element(key))
730         for key, value in self.__wrapped__.items()})
731    self._update_snapshot()
732
733  def __reduce_ex__(self, protocol):
734    return (self.__class__,
735            (self.__wrapped__,))
736
737  def __getattribute__(self, name):
738    if (hasattr(type(self), name)
739        and isinstance(getattr(type(self), name), property)):
740      # Bypass ObjectProxy for properties. Whether this workaround is necessary
741      # appears to depend on the Python version but not the wrapt version: 3.4
742      # in particular seems to look up properties on the wrapped object instead
743      # of the wrapper without this logic.
744      return object.__getattribute__(self, name)
745    else:
746      return super(_DictWrapper, self).__getattribute__(name)
747
748  def copy(self):
749    return copy.copy(self)
750
751  # pylint: disable=protected-access
752  def __copy__(self):
753    copied = _DictWrapper(copy.copy(self.__wrapped__))
754    copied._self_external_modification = self._self_external_modification
755    copied._self_non_string_key = self._self_non_string_key
756    return copied
757
758  def __deepcopy__(self, memo):
759    copied = _DictWrapper(copy.deepcopy(self.__wrapped__, memo))
760    copied._self_external_modification = self._self_external_modification
761    copied._self_non_string_key = self._self_non_string_key
762    return copied
763  # pylint: enable=protected-access
764
765  @property
766  def _values(self):
767    """Collect values for TrackableDataStructure."""
768    # Sort items deterministically by key
769    ordered = list(zip(*sorted(self.items(), key=lambda it: it[0])))
770    if ordered:
771      return ordered[1]
772    return []
773
774  @property
775  def _checkpoint_dependencies(self):
776    """Check that the object is saveable before listing its dependencies."""
777    self._check_self_external_modification()
778    if self._self_non_string_key:
779      raise ValueError(
780          "Unable to save the object %s (a dictionary wrapper constructed "
781          "automatically on attribute assignment). The wrapped dictionary "
782          "contains a non-string key which maps to a trackable object or "
783          "mutable data structure.\n\nIf you don't need this dictionary "
784          "checkpointed, wrap it in a tf.contrib.checkpoint.NoDependency "
785          "object; it will be automatically un-wrapped and subsequently "
786          "ignored." % (self,))
787    if self._self_external_modification:
788      raise ValueError(
789          "Unable to save the object %s (a dictionary wrapper constructed "
790          "automatically on attribute assignment). The wrapped dictionary was "
791          "modified outside the wrapper (its final value was %s, its value "
792          "when a checkpoint dependency was added was %s), which breaks "
793          "restoration on object creation.\n\nIf you don't need this "
794          "dictionary checkpointed, wrap it in a "
795          "tf.contrib.checkpoint.NoDependency object; it will be automatically "
796          "un-wrapped and subsequently ignored." % (
797              self, self, self._self_last_wrapped_dict_snapshot))
798    assert not self._dirty  # Any reason for dirtiness should have an exception.
799    return super(_DictWrapper, self)._checkpoint_dependencies
800
801  @property
802  def _dirty(self):
803    """Check if there has already been a mutation which prevents saving."""
804    return (self._self_external_modification
805            or self._self_non_string_key)
806
807  def _check_self_external_modification(self):
808    """Checks for any changes to the wrapped dict not through the wrapper."""
809    if self._dirty:
810      return
811    if self != self._self_last_wrapped_dict_snapshot:
812      self._self_external_modification = True
813      self._self_last_wrapped_dict_snapshot = None
814
815  def _update_snapshot(self):
816    """Acknowledges tracked changes to the wrapped dict."""
817    self._attribute_sentinel.invalidate_all()
818    if self._dirty:
819      return
820    self._self_last_wrapped_dict_snapshot = dict(self)
821
822  def _track_value(self, value, name):
823    """Allows storage of non-trackable objects."""
824    if isinstance(name, six.string_types):
825      string_key = True
826    else:
827      name = "-non_string_key"
828      string_key = False
829    try:
830      no_dependency = isinstance(value, NoDependency)
831      value = super(_DictWrapper, self)._track_value(value=value, name=name)
832      if not (string_key or no_dependency):
833        # A non-string key maps to a trackable value. This data structure
834        # is not saveable.
835        self._self_non_string_key = True
836      return value
837    except ValueError:
838      # Even if this value isn't trackable, we need to make sure
839      # NoDependency objects get unwrapped.
840      return sticky_attribute_assignment(
841          trackable=self, value=value, name=name)
842
843  def _name_element(self, key):
844    """Tells TrackableDataStructure to use keys as names as-is."""
845    return key
846
847  def __setitem__(self, key, value):
848    """Allow any modifications, but possibly mark the wrapper as unsaveable."""
849    self._check_self_external_modification()
850    self._maybe_initialize_trackable()
851    no_dep = isinstance(value, NoDependency)
852    if isinstance(key, six.string_types):
853      value = self._track_value(value, name=key)
854    else:
855      value = _wrap_or_unwrap(value)
856      if not no_dep and isinstance(value, base.Trackable):
857        # Non-string keys are OK as long as we have no reason to add a
858        # dependency on the value (either because the value is not
859        # trackable, or because it was wrapped in a NoDependency object).
860        self._self_non_string_key = True
861    self.__wrapped__[key] = value
862
863    self._update_snapshot()
864
865  def __delitem__(self, key):
866    self._check_self_external_modification()
867    del self.__wrapped__[key]
868    self._update_snapshot()
869
870  def __repr__(self):
871    return "DictWrapper(%s)" % (repr(self.__wrapped__),)
872
873  def __hash__(self):
874    raise TypeError("unhashable type: 'DictWrapper'")
875
876  def __eq__(self, other):
877    # Override the TrackableDataStructure "== -> is" forwarding and go back to
878    # the wrapt implementation.
879    return self.__wrapped__ == other
880
881  def update(self, *args, **kwargs):
882    for key, value in six.iteritems(dict(*args, **kwargs)):
883      self[key] = value
884
885  def _list_functions_for_serialization(self, unused_serialization_cache):
886    return {
887        key: value for key, value in self.items()
888        if _is_function(value)
889    }
890
891
892class _TupleWrapper(TrackableDataStructure, wrapt.ObjectProxy):
893  """Trackable wrapper for tuples and namedtuples."""
894
895  def __init__(self, original_wrapped_tuple=()):
896    add_dependency = []
897    substituted_wrapped_tuple = []
898    for element in original_wrapped_tuple:
899      if isinstance(element, NoDependency):
900        add_dependency.append(False)
901      else:
902        add_dependency.append(True)
903      substituted_wrapped_tuple.append(_wrap_or_unwrap(element))
904    try:
905      fields = original_wrapped_tuple._fields
906    except AttributeError:
907      # Not a namedtuple
908      is_namedtuple = False
909    else:
910      is_namedtuple = True
911    original_type = type(original_wrapped_tuple)
912    # Flag to poison saving if we can't re-construct a namedtupled because its
913    # __new__ takes different keyword arguments than its _fields.
914    self._self_tuple_is_constructable = True
915    if is_namedtuple:
916      try:
917        # NamedTuples take N arguments, unlike tuple which takes a sequence.
918        substituted_wrapped_tuple = original_type(
919            **dict(zip(fields, substituted_wrapped_tuple)))
920      except TypeError:
921        wrapt.ObjectProxy.__init__(self, original_wrapped_tuple)
922        TrackableDataStructure.__init__(self)
923        self._self_tuple_is_constructable = False
924        return
925    else:
926      substituted_wrapped_tuple = original_type(substituted_wrapped_tuple)
927    wrapt.ObjectProxy.__init__(self, substituted_wrapped_tuple)
928    TrackableDataStructure.__init__(self)
929
930    if is_namedtuple:
931      # For namedtuples, also track by names for compatibility with
932      # dictionaries.
933      for name, should_depend, element in zip(
934          fields, add_dependency, substituted_wrapped_tuple):
935        if should_depend:
936          self._track_value(element, name=name)
937
938    # Track by index as well, for compatibility with lists.
939    for index, (should_depend, element) in enumerate(
940        zip(add_dependency, substituted_wrapped_tuple)):
941      if should_depend:
942        self._track_value(element, name="%d" % (index,))
943
944  @property
945  def _values(self):
946    """Collect values for TrackableDataStructure."""
947    return self
948
949  def _track_value(self, value, name):
950    """Allows storage of non-trackable objects."""
951    try:
952      value = super(_TupleWrapper, self)._track_value(value=value, name=name)
953    except ValueError:
954      # Even if this value isn't trackable, we need to make sure
955      # NoDependency objects get unwrapped.
956      value = sticky_attribute_assignment(
957          trackable=self, value=value, name=name)
958    return value
959
960  def __repr__(self):
961    return "_TupleWrapper(%s)" % (repr(self.__wrapped__),)
962
963  def __hash__(self):
964    # Override the TrackableDataStructure hash forwarding and go back to
965    # the wrapt implementation.
966    return hash(self.__wrapped__)
967
968  def __eq__(self, other):
969    # Override the TrackableDataStructure "== -> is" forwarding and go back to
970    # the wrapt implementation.
971    return self.__wrapped__ == other
972
973  def __copy__(self):
974    return _TupleWrapper(copy.copy(self.__wrapped__))
975
976  def __deepcopy__(self, memo):
977    return _TupleWrapper(copy.deepcopy(self.__wrapped__, memo))
978
979  def __reduce_ex__(self, protocol):
980    return (self.__class__,
981            (self.__wrapped__,))
982
983  # imul and iadd are the only tuple-relevant in-place operators. They need to
984  # be special-cased to avoid mutating the original proxy object.
985  def __imul__(self, y):
986    """Avoid running self.__wrapped__ *= y, which mutates `self`."""
987    return self.__wrapped__ * y
988
989  def __iadd__(self, y):
990    """Avoid running self.__wrapped__ += y, which mutates `self`."""
991    return self.__wrapped__ + y
992
993  @property
994  def _checkpoint_dependencies(self):
995    if not self._self_tuple_is_constructable:
996      raise ValueError(
997          ("Unable to save because the namedtuple {} is not constructable from "
998           "its _fields (i.e. __new__ is overridden). Expected keyword "
999           "arguments {}. If you do not need to save this object, consider "
1000           "wrapping it in a custom object that does not inherit from tuple.")
1001          .format(self.__wrapped__, self.__wrapped__._fields))
1002    return super(_TupleWrapper, self)._checkpoint_dependencies
1003
1004  def __getattribute__(self, name):
1005    if (hasattr(type(self), name)
1006        and isinstance(getattr(type(self), name), property)):
1007      # Bypass ObjectProxy for properties. Whether this workaround is necessary
1008      # appears to depend on the Python version but not the wrapt version: 3.4
1009      # in particular seems to look up properties on the wrapped object instead
1010      # of the wrapper without this logic.
1011      return object.__getattribute__(self, name)
1012    else:
1013      return super(_TupleWrapper, self).__getattribute__(name)
1014
1015
1016def _is_function(x):
1017  return isinstance(x, (def_function.Function, defun.ConcreteFunction))
1018
1019
1020revived_types.register_revived_type(
1021    "trackable_dict_wrapper",
1022    lambda obj: isinstance(obj, _DictWrapper),
1023    versions=[revived_types.VersionedTypeRegistration(
1024        # Standard dependencies are enough to reconstruct the trackable
1025        # items in dictionaries, so we don't need to save any extra information.
1026        object_factory=lambda proto: _DictWrapper({}),
1027        version=1,
1028        min_producer_version=1,
1029        min_consumer_version=1,
1030        setter=operator.setitem)])
1031
1032
1033def _set_list_item(list_object, index_string, value):
1034  item_index = int(index_string)
1035  if len(list_object) <= item_index:
1036    list_object.extend([None] * (1 + item_index - len(list_object)))
1037  list_object[item_index] = value
1038
1039
1040revived_types.register_revived_type(
1041    "trackable_list_wrapper",
1042    lambda obj: isinstance(obj, ListWrapper),
1043    versions=[revived_types.VersionedTypeRegistration(
1044        object_factory=lambda proto: ListWrapper([]),
1045        version=1,
1046        min_producer_version=1,
1047        min_consumer_version=1,
1048        setter=_set_list_item)])
1049
1050
1051def _set_tuple_item(list_object, index_string, value):
1052  try:
1053    item_index = int(index_string)
1054  except ValueError:
1055    # Ignore namedtuple fields.
1056    return
1057  if len(list_object) <= item_index:
1058    list_object.extend([None] * (1 + item_index - len(list_object)))
1059  list_object[item_index] = value
1060
1061
1062# Revive tuples as lists so we can append any dependencies during loading.
1063revived_types.register_revived_type(
1064    "trackable_tuple_wrapper",
1065    lambda obj: isinstance(obj, _TupleWrapper),
1066    versions=[revived_types.VersionedTypeRegistration(
1067        object_factory=lambda proto: ListWrapper([]),
1068        version=1,
1069        min_producer_version=1,
1070        min_consumer_version=1,
1071        setter=_set_tuple_item)])
1072