• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Trackable data structures."""
2# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import collections
21import copy
22import operator
23import sys
24
25import six
26try:
27  import wrapt
28except ImportError:
29  # Fall back to the build-time dependency if the system package is not available.
30  from .....third_party import wrapt
31
32from tensorflow.python.eager import def_function
33from tensorflow.python.eager import function as defun
34from tensorflow.python.ops import variables
35from tensorflow.python.saved_model import revived_types
36from tensorflow.python.training.tracking import base
37from tensorflow.python.training.tracking import layer_utils
38from tensorflow.python.util import lazy_loader
39from tensorflow.python.util.compat import collections_abc
40
41module = lazy_loader.LazyLoader(
42    "module", globals(), "tensorflow.python.module.module")
43
44
45class NoDependency(object):
46  """Allows attribute assignment to `Trackable` objects with no dependency.
47
48  Example usage:
49  ```python
50  obj = Trackable()
51  obj.has_dependency = tf.Variable(0., name="dep")
52  obj.no_dependency = NoDependency(tf.Variable(1., name="nodep"))
53  assert obj.no_dependency.name == "nodep:0"
54  ```
55
56  `obj` in this example has a dependency on the variable "dep", and both
57  attributes contain un-wrapped `Variable` objects.
58
59  `NoDependency` also works with `tf.keras.Model`, but only for checkpoint
60  dependencies: wrapping a `Layer` in `NoDependency` will assign the (unwrapped)
61  `Layer` to the attribute without a checkpoint dependency, but the `Model` will
62  still track the `Layer` (so it will appear in `Model.layers`, and its
63  variables will appear in `Model.variables`).
64  """
65
66  __slots__ = ["value"]
67
68  def __init__(self, value):
69    self.value = value
70
71
72def _should_wrap_tuple(t):
73  """Determine if a tuple has any trackable components."""
74  for element in t:
75    if isinstance(element, NoDependency):
76      return True  # We should remove the NoDependency object from the tuple.
77    if isinstance(element, base.Trackable):
78      return True
79    if wrap_or_unwrap(element) is not element:
80      return True
81  # There are no trackable elements or data structures. Tuples are immutable, so
82  # mutation isn't a concern. Don't wrap.
83  return False
84
85
86def wrap_or_unwrap(value):
87  """Wraps basic data structures, unwraps NoDependency objects."""
88  # pylint: disable=unidiomatic-typecheck
89  # Exact type checking to avoid mucking up custom logic in list/dict
90  # subclasses, e.g. collections.Counter.
91  if isinstance(value, NoDependency):
92    return value.value
93  if isinstance(value, base.Trackable):
94    return value  # Skip conversion for already trackable objects.
95  elif type(value) == dict:
96    return _DictWrapper(value)
97  elif type(value) == collections.OrderedDict:
98    return _DictWrapper(value)
99  elif type(value) == list:
100    return ListWrapper(value)
101  elif isinstance(value, tuple) and _should_wrap_tuple(value):
102    # There are trackable elements or data structures. Wrap the tuple.
103    return _TupleWrapper(value)
104  else:
105    return value
106  # pylint: enable=unidiomatic-typecheck
107
108
109def sticky_attribute_assignment(trackable, name, value):
110  """Adds dependencies, generally called from __setattr__.
111
112  This behavior is shared between Trackable and Model.
113
114  Respects NoDependency indicators, but otherwise makes trackable objects
115  out of common data structures and tracks objects by their attribute names.
116
117  Args:
118    trackable: The object to add dependencies to (generally the one having
119      an attribute assigned).
120    name: The attribute name being assigned.
121    value: The value being assigned. Not necessarily a trackable object.
122
123  Returns:
124    The value which should be stored in the attribute (unwrapped from a
125    NoDependency object if necessary).
126  """
127  if isinstance(value, NoDependency):
128    add_dependency = False
129  else:
130    add_dependency = True
131  value = wrap_or_unwrap(value)
132  if not add_dependency:
133    return value
134  if isinstance(value, base.Trackable):
135    trackable._track_trackable(  # pylint: disable=protected-access
136        value, name=name,
137        # Allow the user to switch the Trackable which is tracked by this
138        # name, since assigning a new variable to an attribute has
139        # historically been fine (e.g. Adam did this).
140        overwrite=True)
141  return value
142
143
144class _UntrackableError(ValueError):
145
146  def __init__(self, value):  # pylint: disable=super-init-not-called
147    self._value = value
148
149  def __str__(self):
150    return (("Only trackable objects (such as Layers or Optimizers) may be "
151             "stored in a List object. Got %s, which does not inherit from "
152             "Trackable.") % (self._value,))
153
154
155class TrackableDataStructure(base.Trackable):
156  """Base class for data structures which contain trackable objects."""
157
158  def __init__(self):
159    # Attributes prefixed with "_self_" for compatibility with
160    # wrapt.ObjectProxy. All additional attrs MUST conform to this pattern, as
161    # extending `__slots__` on a subclass of ObjectProxy breaks in a variety of
162    # ways.
163    self._self_trainable = True
164    self._self_extra_variables = []
165    self._self_attribute_sentinel = layer_utils.AttributeSentinel(True)
166
167  @property
168  def _attribute_sentinel(self):
169    return self._self_attribute_sentinel
170
171  @property
172  def trainable(self):
173    return self._self_trainable
174
175  @trainable.setter
176  def trainable(self, value):
177    self._self_trainable = value
178
179  def _track_value(self, value, name):
180    """Add a dependency on `value`."""
181    value = sticky_attribute_assignment(
182        trackable=self, value=value, name=name)
183    if isinstance(value, variables.Variable):
184      self._self_extra_variables.append(value)
185    if not isinstance(value, base.Trackable):
186      raise _UntrackableError(value)
187    if hasattr(value, "_use_resource_variables"):
188      # In subclassed models, legacy layers (tf.layers) must always use
189      # resource variables.
190      value._use_resource_variables = True  # pylint: disable=protected-access
191    value_attribute_sentinel = getattr(value, "_attribute_sentinel", None)
192    if value_attribute_sentinel:
193      value_attribute_sentinel.add_parent(self._attribute_sentinel)
194    return value
195
196  @property
197  def _values(self):
198    """An iterable/sequence which may contain trackable objects."""
199    raise NotImplementedError("Abstract method")
200
201  @property
202  def _layers(self):
203    """All Layers and Layer containers, including empty containers."""
204    # Filter objects on demand so that wrapper objects use values from the thing
205    # they're wrapping if out of sync.
206    collected = []
207    for obj in self._values:
208      if (isinstance(obj, TrackableDataStructure)
209          or layer_utils.is_layer(obj)
210          or layer_utils.has_weights(obj)):
211        collected.append(obj)
212    return collected
213
214  @property
215  def layers(self):
216    return list(layer_utils.filter_empty_layer_containers(self._layers))
217
218  @property
219  def trainable_weights(self):
220    if not self._self_trainable:
221      return []
222    trainable_variables = []
223    for obj in self._values:
224      if isinstance(obj, (TrackableDataStructure, module.Module)):
225        trainable_variables += obj.trainable_variables
226    trainable_extra_variables = [
227        v for v in self._self_extra_variables if v.trainable
228    ]
229    return trainable_variables + trainable_extra_variables
230
231  @property
232  def non_trainable_weights(self):
233    trainable_extra_variables = [
234        v for v in self._self_extra_variables if v.trainable
235    ]
236    non_trainable_extra_variables = [
237        v for v in self._self_extra_variables if not v.trainable
238    ]
239    non_trainable_variables = []
240    for obj in self._values:
241      if isinstance(obj, (TrackableDataStructure, module.Module)):
242        non_trainable_variables += obj.non_trainable_variables
243
244    if not self._self_trainable:
245      # Return order is all trainable vars, then all non-trainable vars.
246      trainable_variables = []
247      for obj in self._values:
248        if isinstance(obj, (TrackableDataStructure, module.Module)):
249          trainable_variables += obj.trainable_variables
250
251      non_trainable_variables = (
252          trainable_variables + trainable_extra_variables +
253          non_trainable_variables + non_trainable_extra_variables)
254    else:
255      non_trainable_variables = (
256          non_trainable_variables + non_trainable_extra_variables)
257
258    return non_trainable_variables
259
260  @property
261  def weights(self):
262    return self.trainable_weights + self.non_trainable_weights
263
264  @property
265  def trainable_variables(self):
266    return self.trainable_weights
267
268  @property
269  def non_trainable_variables(self):
270    return self.non_trainable_weights
271
272  @property
273  def variables(self):
274    return self.weights
275
276  @property
277  def updates(self):
278    """Aggregate updates from any `Layer` instances."""
279    # Updates and conditional losses are forwarded as-is rather than being
280    # filtered based on inputs, since this is just a container and won't ever
281    # have any inputs.
282    aggregated = []
283    for layer in self.layers:
284      if hasattr(layer, "updates"):
285        aggregated += layer.updates
286    return aggregated
287
288  @property
289  def losses(self):
290    """Aggregate losses from any `Layer` instances."""
291    aggregated = []
292    for layer in self.layers:
293      if hasattr(layer, "losses"):
294        aggregated += layer.losses
295    return aggregated
296
297  def __hash__(self):
298    # Support object-identity hashing, so these structures can be used as keys
299    # in sets/dicts.
300    return id(self)
301
302  def __eq__(self, other):
303    # Similar to Tensors, trackable data structures use object-identity
304    # equality to support set/dict membership.
305    return self is other
306
307
308class List(TrackableDataStructure, collections_abc.Sequence):
309  """An append-only sequence type which is trackable.
310
311  Maintains checkpoint dependencies on its contents (which must also be
312  trackable), and forwards any `Layer` metadata such as updates and losses.
313
314  Note that `List` is purely a container. It lets a `tf.keras.Model` or
315  other trackable object know about its contents, but does not call any
316  `Layer` instances which are added to it. To indicate a sequence of `Layer`
317  instances which should be called sequentially, use `tf.keras.Sequential`.
318
319  Example usage:
320  ```python
321  class HasList(tf.keras.Model):
322
323    def __init__(self):
324      super(HasList, self).__init__()
325      self.layer_list = List([layers.Dense(3)])
326      self.layer_list.append(layers.Dense(4))
327
328    def call(self, x):
329      aggregation = 0.
330      for l in self.layer_list:
331        x = l(x)
332        aggregation += tf.reduce_sum(x)
333      return aggregation
334  ```
335
336  This kind of wrapping is necessary because `Trackable` objects do not
337  (yet) deeply inspect regular Python data structures, so for example assigning
338  a regular list (`self.layer_list = [layers.Dense(3)]`) does not create a
339  checkpoint dependency and does not add the `Layer` instance's weights to its
340  parent `Model`.
341  """
342
343  def __init__(self, *args, **kwargs):
344    """Construct a new sequence. Arguments are passed to `list()`."""
345    super(List, self).__init__()
346    self._storage = self._make_storage(*args, **kwargs)
347    for index, element in enumerate(self._storage):
348      self._storage[index] = self._track_value(
349          element, name=self._name_element(index))
350
351  def copy(self):
352    return type(self)(copy.copy(self._storage))
353
354  def __copy__(self):
355    return self.copy()
356
357  def __deepcopy__(self, memo):
358    return type(self)(copy.deepcopy(self._storage, memo))
359
360  def _make_storage(self, *args, **kwargs):
361    """Determines the backing storage (overridden in subclasses)."""
362    return list(*args, **kwargs)
363
364  def _name_element(self, index):
365    return "%d" % (index,)
366
367  @property
368  def _values(self):
369    """Collect values for TrackableDataStructure."""
370    return self
371
372  def append(self, value):
373    """Add a new trackable value."""
374    value = self._track_value(value, self._name_element(len(self._storage)))
375    self._storage.append(value)
376
377  def extend(self, values):
378    """Add a sequence of trackable values."""
379    for value in values:
380      self.append(value)
381
382  def __iadd__(self, values):
383    self.extend(values)
384    return self
385
386  def __add__(self, other):
387    return self._storage + getattr(other, "_storage", other)
388
389  def __imul__(self, y):
390    if y <= 0:
391      raise ValueError(
392          "List only supports append, multiplying in place by %d removes "
393          "elements." % y)
394
395    n = len(self._storage)
396    for _ in range(y - 1):
397      for i in range(n):
398        self.append(self._storage[i])
399
400    return self
401
402  def __mul__(self, n):
403    return self._storage * n
404
405  def __rmul__(self, n):
406    return self * n
407
408  def __radd__(self, other):
409    return other + self._storage
410
411  def __getitem__(self, key):
412    return self._storage[key]
413
414  def __getslice__(self, i, j):
415    return self._storage[slice(i, j)]
416
417  def __len__(self):
418    return len(self._storage)
419
420  def __repr__(self):
421    return "List(%s)" % (repr(self._storage),)
422
423  def __sizeof__(self):
424    return super(List, self).__sizeof__() + sys.getsizeof(self._storage)
425
426
427# TODO(tomhennigan) Update to collections.UserList?
428# TODO(allenl): Try switching this to wrapt.ObjectProxy again when we drop
429# Python 3.4 support (may still be tricky).
430class ListWrapper(
431    List,
432    collections_abc.MutableSequence,
433    # Shadowed, but there for isinstance checks.
434    list):
435  """Wraps the built-in `list` to support restore-on-create for variables.
436
437  Unlike `List`, this sequence type is mutable in the same ways built-in lists
438  are. Instead of throwing an error immediately like `List`, it records
439  problematic mutations (e.g. assigning a new element to a position already
440  occupied, meaning both elements get the same names at different times) and
441  refuses to save.
442
443  On assignment to an attribute of a Model or Trackable object, Python
444  lists are replaced with ListWrapper. Wrapping a list in a
445  `NoDependency` object prevents this.
446  """
447
448  def __init__(self, wrapped_list):
449    """Construct a new list wrapper.
450
451    Args:
452      wrapped_list: The initial value of the data structure. A shallow copy may
453        be maintained for error checking. `wrapped_list` itself should not be
454        modified directly after constructing the `ListWrapper`, and if changes
455        are detected the `ListWrapper` will throw an exception on save.
456    """
457    # Monotonic flags which indicate this object would not be restored properly,
458    # and therefore should throw an error on save to avoid giving the impression
459    # that restoring it will work.
460    self._non_append_mutation_value = False
461    self._external_modification_value = False
462    super(ListWrapper, self).__init__(wrapped_list)
463    self._last_wrapped_list_snapshot = list(self._storage)
464
465  @property
466  def _non_append_mutation(self):
467    return self._non_append_mutation_value
468
469  @_non_append_mutation.setter
470  def _non_append_mutation(self, value):
471    # Trackable only cares that a mutation occurred at some point; when
472    # attempting to save it checks whether a mutation occurred and the object is
473    # in a "dirty" state but otherwise the specifics of how it got to that state
474    # are ignored. By contrast, the attribute cache needs to signal the mutation
475    # immediately since a caller could query the value of an attribute (And
476    # should not hit the cached value since the mutation may have affected the
477    # result.)
478    self._attribute_sentinel.invalidate_all()
479    self._non_append_mutation_value = value
480
481  @property
482  def _external_modification(self):
483    return self._external_modification_value
484
485  @_external_modification.setter
486  def _external_modification(self, value):
487    # Invalidate for the same reason as `_non_append_mutation`
488    self._attribute_sentinel.invalidate_all()
489    self._external_modification_value = value
490
491  # pylint: disable=protected-access
492  def __copy__(self):
493    copied = super(ListWrapper, self).__copy__()
494    copied._non_append_mutation = self._non_append_mutation
495    copied._external_modification = self._external_modification
496    return copied
497
498  def __deepcopy__(self, memo):
499    copied = super(ListWrapper, self).__deepcopy__(memo)
500    copied._non_append_mutation = self._non_append_mutation
501    copied._external_modification = self._external_modification
502    return copied
503  # pylint: enable=protected-access
504
505  def __reduce_ex__(self, protocol):
506    return (self.__class__,
507            (self._storage,))
508
509  def _make_storage(self, wrapped_list):
510    """Use the user's original list for storage."""
511    return wrapped_list
512
513  def _check_external_modification(self):
514    """Checks for any changes to the wrapped list not through the wrapper."""
515    if self._external_modification or self._non_append_mutation:
516      return
517    if self._storage != self._last_wrapped_list_snapshot:
518      self._external_modification = True
519      self._last_wrapped_list_snapshot = None
520
521  def _update_snapshot(self):
522    """Acknowledges tracked changes to the wrapped list."""
523
524    # Mutation tracking for attributes reuses the same infrastructure as
525    # Trackable mutation tracking.
526    self._attribute_sentinel.invalidate_all()
527    if self._external_modification or self._non_append_mutation:
528      return
529    self._last_wrapped_list_snapshot = list(self._storage)
530
531  @property
532  def _checkpoint_dependencies(self):
533    self._check_external_modification()
534    if self._non_append_mutation:
535      raise ValueError(
536          ("Unable to save the object %s (a list wrapper constructed to track "
537           "trackable TensorFlow objects). A list element was replaced "
538           "(__setitem__, __setslice__), deleted (__delitem__, __delslice__), "
539           "or moved (sort). In order to support restoration on object "
540           "creation, tracking is exclusively for append-only data structures."
541           "\n\nIf you don't need this list checkpointed, wrap it in a "
542           "non-trackable object; it will be subsequently ignored." % (self,)))
543    if self._external_modification:
544      raise ValueError(
545          ("Unable to save the object %s (a list wrapper constructed to track "
546           "trackable TensorFlow objects). The wrapped list was modified "
547           "outside the wrapper (its final value was %s, its value when a "
548           "checkpoint dependency was added was %s), which breaks restoration "
549           "on object creation.\n\nIf you don't need this list checkpointed, "
550           "wrap it in a NoDependency object; it will be "
551           "subsequently ignored." % (
552               self, self._storage, self._last_wrapped_list_snapshot)))
553    return super(ListWrapper, self)._checkpoint_dependencies
554
555  def _has_mutation_or_trackable(self):
556    """Short-circuits a check for trackables if there's already a mutation."""
557    if self._non_append_mutation:
558      return True
559    return any(isinstance(element, base.Trackable) for element in self._storage)
560
561  def __delitem__(self, key):
562    self._check_external_modification()
563    if self._has_mutation_or_trackable():
564      self._non_append_mutation = True
565    del self._storage[key]
566    self._update_snapshot()
567
568  def __setitem__(self, key, value):
569    self._check_external_modification()
570
571    if isinstance(key, slice):
572      # Note: this is quite inefficient, but the list API supports a broad range
573      # of slice setters (e.g. truncate, extend, replace) and imitating this
574      # for a range of Python versions is non-trivial.
575      storage_copy = list(self._storage)
576      self._storage[key] = value
577
578      len_before = len(storage_copy)
579      len_now = len(self._storage)
580      for i in range(max(len_before, len_now)):
581        value_now = self._storage[i] if i < len_now else None
582        value_before = storage_copy[i] if i < len_before else None
583
584        if isinstance(value_before, base.Trackable):
585          self._non_append_mutation = True
586
587        if value_now is not None and value_now != value_before:
588          self._storage[i] = self._track_value(self._storage[i],
589                                               self._name_element(i))
590
591    else:
592      if isinstance(self._storage[key], base.Trackable):
593        self._non_append_mutation = True
594      self._storage[key] = self._track_value(value, self._name_element(key))
595
596    self._update_snapshot()
597
598  def append(self, value):
599    """Add a new trackable value."""
600    self._check_external_modification()
601    super(ListWrapper, self).append(value)
602    self._update_snapshot()
603
604  def extend(self, values):
605    """Add a sequence of trackable values."""
606    self._check_external_modification()
607    super(ListWrapper, self).extend(values)
608    self._update_snapshot()
609
610  def __imul__(self, y):
611    if y <= 0:
612      self._check_external_modification()
613      if self._has_mutation_or_trackable():
614        self._non_append_mutation = True
615      self._storage *= y
616      self._update_snapshot()
617      return self
618
619    # Relies on super() calling append, which updates the snapshot.
620    return super(ListWrapper, self).__imul__(y)
621
622  def __eq__(self, other):
623    return self._storage == getattr(other, "_storage", other)
624
625  def __ne__(self, other):
626    return self._storage != getattr(other, "_storage", other)
627
628  def __lt__(self, other):
629    return self._storage < getattr(other, "_storage", other)
630
631  def __le__(self, other):
632    return self._storage <= getattr(other, "_storage", other)
633
634  def __gt__(self, other):
635    return self._storage > getattr(other, "_storage", other)
636
637  def __ge__(self, other):
638    return self._storage >= getattr(other, "_storage", other)
639
640  def __hash__(self):
641    # List wrappers need to compare like regular lists, and so like regular
642    # lists they don't belong in hash tables.
643    raise TypeError("unhashable type: 'ListWrapper'")
644
645  def insert(self, index, obj):
646    self._check_external_modification()
647    if (self._has_mutation_or_trackable() or isinstance(obj, base.Trackable)):
648      self._non_append_mutation = True
649    self._storage.insert(index, obj)
650    self._update_snapshot()
651
652  def sort(self):
653    self._check_external_modification()
654    if self._has_mutation_or_trackable():
655      self._non_append_mutation = True
656    self._storage.sort()
657    self._update_snapshot()
658
659  def __setslice__(self, i, j, y):
660    self.__setitem__(slice(i, j), y)
661
662  def __delslice__(self, i, j):
663    self._check_external_modification()
664    if self._has_mutation_or_trackable():
665      self._non_append_mutation = True
666    del self._storage[slice(i, j)]
667    self._update_snapshot()
668
669  def _track_value(self, value, name):
670    """Allows storage of non-trackable objects."""
671    try:
672      value = super(ListWrapper, self)._track_value(value=value, name=name)
673    except ValueError:
674      # Even if this value isn't trackable, we need to make sure
675      # NoDependency objects get unwrapped.
676      value = sticky_attribute_assignment(
677          trackable=self, value=value, name=name)
678    return value
679
680  def __repr__(self):
681    return "ListWrapper(%s)" % (repr(self._storage),)
682
683  def _list_functions_for_serialization(self, unused_functions):
684    return {
685        str(key): value for key, value in enumerate(self)
686        if _is_function(value)
687    }
688
689
690class Mapping(TrackableDataStructure, collections_abc.Mapping):
691  """An append-only trackable mapping data structure with string keys.
692
693  Maintains checkpoint dependencies on its contents (which must also be
694  trackable), named based on its keys.
695
696  Note that once a key has been added, it may not be deleted or replaced.
697  """
698
699  def __init__(self, *args, **kwargs):
700    """Construct a new sequence. Arguments are passed to `dict()`."""
701    super(Mapping, self).__init__()
702    self._storage = self._make_storage(*args, **kwargs)
703    self._storage.update(
704        {key: self._track_value(
705            value, name=self._name_element(key))
706         for key, value in self._storage.items()})
707
708  def __copy__(self):
709    return type(self)(copy.copy(self._storage))
710
711  def __deepcopy__(self, memo):
712    return type(self)(copy.deepcopy(self._storage, memo))
713
714  def _make_storage(self, *args, **kwargs):
715    return dict(*args, **kwargs)
716
717  @property
718  def _values(self):
719    """Collect values for TrackableDataStructure."""
720    # Sort items deterministically by key
721    ordered = list(zip(*sorted(self.items(), key=lambda it: it[0])))
722    if ordered:
723      return ordered[1]
724    return []
725
726  def _name_element(self, key):
727    if not isinstance(key, six.string_types):
728      raise TypeError(
729          "Mapping accepts only string keys, but got a key %s."
730          % repr(key))
731    return str(key)
732
733  def __setitem__(self, key, value):
734    name = self._name_element(key)
735    value = self._track_value(value, name=name)
736    current_value = self._storage.setdefault(key, value)
737    if current_value is not value:
738      raise ValueError(
739          ("Mappings are an append-only data structure. Tried to overwrite the "
740           "key '%s' with value %s, but it already contains %s")
741          % (key, value, current_value))
742
743  def update(self, *args, **kwargs):
744    for key, value in dict(*args, **kwargs).items():
745      self[key] = value
746
747  def __getitem__(self, key):
748    return self._storage[key]
749
750  def __len__(self):
751    return len(self._storage)
752
753  def __repr__(self):
754    return "Mapping(%s)" % (repr(self._storage),)
755
756  def __iter__(self):
757    return iter(self._storage)
758
759
760class _DictWrapper(TrackableDataStructure, wrapt.ObjectProxy):
761  """Wraps built-in dicts to support restore-on-create for variables.
762
763  _DictWrapper is to Mapping as ListWrapper is to List. Unlike Mapping,
764  _DictWrapper allows non-string keys and values and arbitrary mutations (delete
765  keys, reassign values). Like ListWrapper, these mutations mean that
766  _DictWrapper will raise an exception on save.
767  """
768
769  def __init__(self, wrapped_dict=None):
770    if wrapped_dict is None:
771      # Allow zero-argument construction, e.g. from session.run's re-wrapping.
772      wrapped_dict = {}
773    if not isinstance(wrapped_dict, collections_abc.Mapping):
774      # Allow construction from a sequence, e.g. from nest.pack_sequence_as.
775      wrapped_dict = dict(wrapped_dict)
776    wrapt.ObjectProxy.__init__(self, wrapped_dict)
777    TrackableDataStructure.__init__(self)
778    self._self_non_string_key = False
779    self._self_external_modification = False
780    self.__wrapped__.update(
781        {key: self._track_value(
782            value, name=self._name_element(key))
783         for key, value in self.__wrapped__.items()})
784    self._update_snapshot()
785
786  def __reduce_ex__(self, protocol):
787    return (self.__class__,
788            (self.__wrapped__,))
789
790  def __getattribute__(self, name):
791    if (hasattr(type(self), name)
792        and isinstance(getattr(type(self), name), property)):
793      # Bypass ObjectProxy for properties. Whether this workaround is necessary
794      # appears to depend on the Python version but not the wrapt version: 3.4
795      # in particular seems to look up properties on the wrapped object instead
796      # of the wrapper without this logic.
797      return object.__getattribute__(self, name)
798    else:
799      return super(_DictWrapper, self).__getattribute__(name)
800
801  def copy(self):
802    return copy.copy(self)
803
804  # pylint: disable=protected-access
805  def __copy__(self):
806    copied = _DictWrapper(copy.copy(self.__wrapped__))
807    copied._self_external_modification = self._self_external_modification
808    copied._self_non_string_key = self._self_non_string_key
809    return copied
810
811  def __deepcopy__(self, memo):
812    copied = _DictWrapper(copy.deepcopy(self.__wrapped__, memo))
813    copied._self_external_modification = self._self_external_modification
814    copied._self_non_string_key = self._self_non_string_key
815    return copied
816  # pylint: enable=protected-access
817
818  @property
819  def _values(self):
820    """Collect values for TrackableDataStructure."""
821    # Sort items deterministically by key
822    ordered = list(zip(*sorted(self.items(), key=lambda it: it[0])))
823    if ordered:
824      return ordered[1]
825    return []
826
827  @property
828  def _checkpoint_dependencies(self):
829    """Check that the object is saveable before listing its dependencies."""
830    self._check_self_external_modification()
831    if self._self_non_string_key:
832      raise ValueError(
833          "Unable to save the object %s (a dictionary wrapper constructed "
834          "automatically on attribute assignment). The wrapped dictionary "
835          "contains a non-string key which maps to a trackable object or "
836          "mutable data structure.\n\nIf you don't need this dictionary "
837          "checkpointed, wrap it in a non-trackable "
838          "object; it will be subsequently ignored." % (self,))
839    if self._self_external_modification:
840      raise ValueError(
841          "Unable to save the object %s (a dictionary wrapper constructed "
842          "automatically on attribute assignment). The wrapped dictionary was "
843          "modified outside the wrapper (its final value was %s, its value "
844          "when a checkpoint dependency was added was %s), which breaks "
845          "restoration on object creation.\n\nIf you don't need this "
846          "dictionary checkpointed, wrap it in a "
847          "non-trackable object; it will be subsequently ignored." % (
848              self, self, self._self_last_wrapped_dict_snapshot))
849    assert not self._dirty  # Any reason for dirtiness should have an exception.
850    return super(_DictWrapper, self)._checkpoint_dependencies
851
852  @property
853  def _dirty(self):
854    """Check if there has already been a mutation which prevents saving."""
855    return (self._self_external_modification
856            or self._self_non_string_key)
857
858  def _check_self_external_modification(self):
859    """Checks for any changes to the wrapped dict not through the wrapper."""
860    if self._dirty:
861      return
862    if self != self._self_last_wrapped_dict_snapshot:
863      self._self_external_modification = True
864      self._self_last_wrapped_dict_snapshot = None
865
866  def _update_snapshot(self):
867    """Acknowledges tracked changes to the wrapped dict."""
868    self._attribute_sentinel.invalidate_all()
869    if self._dirty:
870      return
871    self._self_last_wrapped_dict_snapshot = dict(self)
872
873  def _track_value(self, value, name):
874    """Allows storage of non-trackable objects."""
875    if isinstance(name, six.string_types):
876      string_key = True
877    else:
878      name = "-non_string_key"
879      string_key = False
880    try:
881      no_dependency = isinstance(value, NoDependency)
882      value = super(_DictWrapper, self)._track_value(value=value, name=name)
883      if not (string_key or no_dependency):
884        # A non-string key maps to a trackable value. This data structure
885        # is not saveable.
886        self._self_non_string_key = True
887      return value
888    except ValueError:
889      # Even if this value isn't trackable, we need to make sure
890      # NoDependency objects get unwrapped.
891      return sticky_attribute_assignment(
892          trackable=self, value=value, name=name)
893
894  def _name_element(self, key):
895    """Tells TrackableDataStructure to use keys as names as-is."""
896    return key
897
898  def __setitem__(self, key, value):
899    """Allow any modifications, but possibly mark the wrapper as unsaveable."""
900    self._check_self_external_modification()
901    self._maybe_initialize_trackable()
902    no_dep = isinstance(value, NoDependency)
903    if isinstance(key, six.string_types):
904      value = self._track_value(value, name=key)
905    else:
906      value = wrap_or_unwrap(value)
907      if not no_dep and isinstance(value, base.Trackable):
908        # Non-string keys are OK as long as we have no reason to add a
909        # dependency on the value (either because the value is not
910        # trackable, or because it was wrapped in a NoDependency object).
911        self._self_non_string_key = True
912    self.__wrapped__[key] = value
913
914    self._update_snapshot()
915
916  def __delitem__(self, key):
917    self._check_self_external_modification()
918    del self.__wrapped__[key]
919    self._update_snapshot()
920
921  def __repr__(self):
922    return "DictWrapper(%s)" % (repr(self.__wrapped__),)
923
924  def __hash__(self):
925    raise TypeError("unhashable type: 'DictWrapper'")
926
927  def __eq__(self, other):
928    # Override the TrackableDataStructure "== -> is" forwarding and go back to
929    # the wrapt implementation.
930    return self.__wrapped__ == other
931
932  def update(self, *args, **kwargs):
933    for key, value in six.iteritems(dict(*args, **kwargs)):
934      self[key] = value
935
936  def _list_functions_for_serialization(self, unused_serialization_cache):
937    return {
938        key: value for key, value in self.items()
939        if _is_function(value)
940    }
941
942
943class _TupleWrapper(TrackableDataStructure, wrapt.ObjectProxy):
944  """Trackable wrapper for tuples and namedtuples."""
945
946  def __init__(self, original_wrapped_tuple=()):
947    add_dependency = []
948    substituted_wrapped_tuple = []
949    for element in original_wrapped_tuple:
950      if isinstance(element, NoDependency):
951        add_dependency.append(False)
952      else:
953        add_dependency.append(True)
954      substituted_wrapped_tuple.append(wrap_or_unwrap(element))
955    try:
956      fields = original_wrapped_tuple._fields
957    except AttributeError:
958      # Not a namedtuple
959      is_namedtuple = False
960    else:
961      is_namedtuple = True
962    original_type = type(original_wrapped_tuple)
963    # Flag to poison saving if we can't re-construct a namedtupled because its
964    # __new__ takes different keyword arguments than its _fields.
965    self._self_tuple_is_constructable = True
966    if is_namedtuple:
967      try:
968        # NamedTuples take N arguments, unlike tuple which takes a sequence.
969        substituted_wrapped_tuple = original_type(
970            **dict(zip(fields, substituted_wrapped_tuple)))
971      except TypeError:
972        wrapt.ObjectProxy.__init__(self, original_wrapped_tuple)
973        TrackableDataStructure.__init__(self)
974        self._self_tuple_is_constructable = False
975        return
976    else:
977      substituted_wrapped_tuple = original_type(substituted_wrapped_tuple)
978    wrapt.ObjectProxy.__init__(self, substituted_wrapped_tuple)
979    TrackableDataStructure.__init__(self)
980
981    if is_namedtuple:
982      # For namedtuples, also track by names for compatibility with
983      # dictionaries.
984      for name, should_depend, element in zip(
985          fields, add_dependency, substituted_wrapped_tuple):
986        if should_depend:
987          self._track_value(element, name=name)
988
989    # Track by index as well, for compatibility with lists.
990    for index, (should_depend, element) in enumerate(
991        zip(add_dependency, substituted_wrapped_tuple)):
992      if should_depend:
993        self._track_value(element, name="%d" % (index,))
994
995  @property
996  def _values(self):
997    """Collect values for TrackableDataStructure."""
998    return self
999
1000  def _track_value(self, value, name):
1001    """Allows storage of non-trackable objects."""
1002    try:
1003      value = super(_TupleWrapper, self)._track_value(value=value, name=name)
1004    except ValueError:
1005      # Even if this value isn't trackable, we need to make sure
1006      # NoDependency objects get unwrapped.
1007      value = sticky_attribute_assignment(
1008          trackable=self, value=value, name=name)
1009    return value
1010
1011  def __repr__(self):
1012    return "_TupleWrapper(%s)" % (repr(self.__wrapped__),)
1013
1014  def __hash__(self):
1015    # Override the TrackableDataStructure hash forwarding and go back to
1016    # the wrapt implementation.
1017    return hash(self.__wrapped__)
1018
1019  def __eq__(self, other):
1020    # Override the TrackableDataStructure "== -> is" forwarding and go back to
1021    # the wrapt implementation.
1022    return self.__wrapped__ == other
1023
1024  def __copy__(self):
1025    return _TupleWrapper(copy.copy(self.__wrapped__))
1026
1027  def __deepcopy__(self, memo):
1028    return _TupleWrapper(copy.deepcopy(self.__wrapped__, memo))
1029
1030  def __reduce_ex__(self, protocol):
1031    return (self.__class__,
1032            (self.__wrapped__,))
1033
1034  # imul and iadd are the only tuple-relevant in-place operators. They need to
1035  # be special-cased to avoid mutating the original proxy object.
1036  def __imul__(self, y):
1037    """Avoid running self.__wrapped__ *= y, which mutates `self`."""
1038    return self.__wrapped__ * y
1039
1040  def __iadd__(self, y):
1041    """Avoid running self.__wrapped__ += y, which mutates `self`."""
1042    return self.__wrapped__ + y
1043
1044  @property
1045  def _checkpoint_dependencies(self):
1046    if not self._self_tuple_is_constructable:
1047      raise ValueError(
1048          ("Unable to save because the namedtuple {} is not constructable from "
1049           "its _fields (i.e. __new__ is overridden). Expected keyword "
1050           "arguments {}. If you do not need to save this object, consider "
1051           "wrapping it in a custom object that does not inherit from tuple.")
1052          .format(self.__wrapped__, self.__wrapped__._fields))
1053    return super(_TupleWrapper, self)._checkpoint_dependencies
1054
1055  def __getattribute__(self, name):
1056    if (hasattr(type(self), name)
1057        and isinstance(getattr(type(self), name), property)):
1058      # Bypass ObjectProxy for properties. Whether this workaround is necessary
1059      # appears to depend on the Python version but not the wrapt version: 3.4
1060      # in particular seems to look up properties on the wrapped object instead
1061      # of the wrapper without this logic.
1062      return object.__getattribute__(self, name)
1063    else:
1064      return super(_TupleWrapper, self).__getattribute__(name)
1065
1066
1067def _is_function(x):
1068  return isinstance(x, (def_function.Function, defun.ConcreteFunction))
1069
1070
1071revived_types.register_revived_type(
1072    "trackable_dict_wrapper",
1073    lambda obj: isinstance(obj, _DictWrapper),
1074    versions=[revived_types.VersionedTypeRegistration(
1075        # Standard dependencies are enough to reconstruct the trackable
1076        # items in dictionaries, so we don't need to save any extra information.
1077        object_factory=lambda proto: _DictWrapper({}),
1078        version=1,
1079        min_producer_version=1,
1080        min_consumer_version=1,
1081        setter=operator.setitem)])
1082
1083
1084def _set_list_item(list_object, index_string, value):
1085  item_index = int(index_string)
1086  if len(list_object) <= item_index:
1087    list_object.extend([None] * (1 + item_index - len(list_object)))
1088  list_object[item_index] = value
1089
1090
1091revived_types.register_revived_type(
1092    "trackable_list_wrapper",
1093    lambda obj: isinstance(obj, ListWrapper),
1094    versions=[revived_types.VersionedTypeRegistration(
1095        object_factory=lambda proto: ListWrapper([]),
1096        version=1,
1097        min_producer_version=1,
1098        min_consumer_version=1,
1099        setter=_set_list_item)])
1100
1101
1102def _set_tuple_item(list_object, index_string, value):
1103  try:
1104    item_index = int(index_string)
1105  except ValueError:
1106    # Ignore namedtuple fields.
1107    return
1108  if len(list_object) <= item_index:
1109    list_object.extend([None] * (1 + item_index - len(list_object)))
1110  list_object[item_index] = value
1111
1112
1113# Revive tuples as lists so we can append any dependencies during loading.
1114revived_types.register_revived_type(
1115    "trackable_tuple_wrapper",
1116    lambda obj: isinstance(obj, _TupleWrapper),
1117    versions=[revived_types.VersionedTypeRegistration(
1118        object_factory=lambda proto: ListWrapper([]),
1119        version=1,
1120        min_producer_version=1,
1121        min_consumer_version=1,
1122        setter=_set_tuple_item)])
1123