• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Trackable data structures."""
2# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import collections
21import copy
22import operator
23import sys
24
25import six
26try:
27  import wrapt
28except ImportError:
29  # Fall back to the build-time dependency if the system package is not available.
30  from .....third_party import wrapt  # pylint: disable=relative-beyond-top-level
31
32from tensorflow.python.eager import def_function
33from tensorflow.python.eager import function as defun
34from tensorflow.python.ops import variables
35from tensorflow.python.saved_model import revived_types
36from tensorflow.python.training.tracking import base
37from tensorflow.python.training.tracking import layer_utils
38from tensorflow.python.util import lazy_loader
39from tensorflow.python.util.compat import collections_abc
40from tensorflow.python.util.tf_export import tf_export
41
42
43module = lazy_loader.LazyLoader(
44    "module", globals(), "tensorflow.python.module.module")
45
46
47class NoDependency(object):
48  """Allows attribute assignment to `Trackable` objects with no dependency.
49
50  Example usage:
51  ```python
52  obj = Trackable()
53  obj.has_dependency = tf.Variable(0., name="dep")
54  obj.no_dependency = NoDependency(tf.Variable(1., name="nodep"))
55  assert obj.no_dependency.name == "nodep:0"
56  ```
57
58  `obj` in this example has a dependency on the variable "dep", and both
59  attributes contain un-wrapped `Variable` objects.
60
61  `NoDependency` also works with `tf.keras.Model`, but only for checkpoint
62  dependencies: wrapping a `Layer` in `NoDependency` will assign the (unwrapped)
63  `Layer` to the attribute without a checkpoint dependency, but the `Model` will
64  still track the `Layer` (so it will appear in `Model.layers`, and its
65  variables will appear in `Model.variables`).
66  """
67
68  __slots__ = ["value"]
69
70  def __init__(self, value):
71    self.value = value
72
73
74def _should_wrap_tuple(t):
75  """Determine if a tuple has any trackable components."""
76  # pylint: disable=unidiomatic-typecheck
77  # Exact type checking to avoid mucking up custom logic in list/dict
78  # subclasses, e.g. collections.Counter.
79  for element in t:
80    if isinstance(element, NoDependency):
81      return True  # We should remove the NoDependency object from the tuple.
82    if isinstance(element, base.Trackable):
83      return True
84    if type(element) == dict:
85      return True
86    if type(element) == collections.OrderedDict:
87      return True
88    if type(element) == list:
89      return True
90    if isinstance(element, tuple) and _should_wrap_tuple(element):
91      return True
92  # There are no trackable elements or data structures. Tuples are immutable, so
93  # mutation isn't a concern. Don't wrap.
94  return False
95  # pylint: enable=unidiomatic-typecheck
96
97
98@tf_export("__internal__.tracking.wrap", v1=[])
99def wrap_or_unwrap(value):
100  """Wraps input value into trackable data structures.
101
102  This is mostly useful for containers like list, dict, etc, which could contain
103  trackable objects in it. Wrapped data structure will be tracked when
104  associated with a `tf.Module`, so that save model/checkpoint can properly
105  track the dependency.
106
107  It will also unwrap NoDependency objects.
108
109  Args:
110    value: the input object to be wrapped.
111
112  Returns:
113    Wrapped trackable data structure.
114  """
115  # pylint: disable=unidiomatic-typecheck
116  # Exact type checking to avoid mucking up custom logic in list/dict
117  # subclasses, e.g. collections.Counter.
118  if isinstance(value, NoDependency):
119    return value.value
120  if isinstance(value, base.Trackable):
121    return value  # Skip conversion for already trackable objects.
122  elif type(value) == dict:
123    return _DictWrapper(value)
124  elif type(value) == collections.OrderedDict:
125    return _DictWrapper(value)
126  elif type(value) == list:
127    return ListWrapper(value)
128  elif isinstance(value, tuple) and _should_wrap_tuple(value):
129    # There are trackable elements or data structures. Wrap the tuple.
130    return _TupleWrapper(value)
131  else:
132    return value
133  # pylint: enable=unidiomatic-typecheck
134
135
136@tf_export("__internal__.tracking.sticky_attribute_assignment", v1=[])
137def sticky_attribute_assignment(trackable, name, value):
138  """Adds dependencies, generally called from __setattr__.
139
140  This behavior is shared between Trackable and Model.
141
142  Respects NoDependency indicators, but otherwise makes trackable objects
143  out of common data structures and tracks objects by their attribute names.
144
145  Args:
146    trackable: The object to add dependencies to (generally the one having
147      an attribute assigned).
148    name: The attribute name being assigned.
149    value: The value being assigned. Not necessarily a trackable object.
150
151  Returns:
152    The value which should be stored in the attribute (unwrapped from a
153    NoDependency object if necessary).
154  """
155  if isinstance(value, NoDependency):
156    add_dependency = False
157  else:
158    add_dependency = True
159  value = wrap_or_unwrap(value)
160  if not add_dependency:
161    return value
162  if isinstance(value, base.Trackable):
163    trackable._track_trackable(  # pylint: disable=protected-access
164        value, name=name,
165        # Allow the user to switch the Trackable which is tracked by this
166        # name, since assigning a new variable to an attribute has
167        # historically been fine (e.g. Adam did this).
168        overwrite=True)
169  return value
170
171
172class _UntrackableError(ValueError):
173
174  def __init__(self, value):  # pylint: disable=super-init-not-called
175    self._value = value
176
177  def __str__(self):
178    return (("Only trackable objects (such as Layers or Optimizers) may be "
179             "stored in a List object. Got %s, which does not inherit from "
180             "Trackable.") % (self._value,))
181
182
183@tf_export("__internal__.tracking.TrackableDataStructure", v1=[])
184class TrackableDataStructure(base.Trackable):
185  """Base class for data structures which contain trackable objects."""
186
187  def __init__(self):
188    # Attributes prefixed with "_self_" for compatibility with
189    # wrapt.ObjectProxy. All additional attrs MUST conform to this pattern, as
190    # extending `__slots__` on a subclass of ObjectProxy breaks in a variety of
191    # ways.
192    self._self_trainable = True
193    self._self_extra_variables = []
194    self._self_attribute_sentinel = layer_utils.AttributeSentinel(True)
195
196  @property
197  def _attribute_sentinel(self):
198    return self._self_attribute_sentinel
199
200  @property
201  def trainable(self):
202    return self._self_trainable
203
204  @trainable.setter
205  def trainable(self, value):
206    self._self_trainable = value
207
208  def _track_value(self, value, name):
209    """Add a dependency on `value`."""
210    value = sticky_attribute_assignment(
211        trackable=self, value=value, name=name)
212    if isinstance(value, variables.Variable):
213      self._self_extra_variables.append(value)
214    if not isinstance(value, base.Trackable):
215      raise _UntrackableError(value)
216    if hasattr(value, "_use_resource_variables"):
217      # In subclassed models, legacy layers (tf.layers) must always use
218      # resource variables.
219      value._use_resource_variables = True  # pylint: disable=protected-access
220    value_attribute_sentinel = getattr(value, "_attribute_sentinel", None)
221    if value_attribute_sentinel:
222      value_attribute_sentinel.add_parent(self._attribute_sentinel)
223    return value
224
225  @property
226  def _values(self):
227    """An iterable/sequence which may contain trackable objects."""
228    raise NotImplementedError("Abstract method")
229
230  @property
231  def _layers(self):
232    """All Layers and Layer containers, including empty containers."""
233    # Filter objects on demand so that wrapper objects use values from the thing
234    # they're wrapping if out of sync.
235    collected = []
236    for obj in self._values:
237      if (isinstance(obj, TrackableDataStructure)
238          or layer_utils.is_layer(obj)
239          or layer_utils.has_weights(obj)):
240        collected.append(obj)
241    return collected
242
243  @property
244  def layers(self):
245    return list(layer_utils.filter_empty_layer_containers(self._layers))
246
247  @property
248  def trainable_weights(self):
249    if not self._self_trainable:
250      return []
251    trainable_variables = []
252    for obj in self._values:
253      if isinstance(obj, (TrackableDataStructure, module.Module)):
254        trainable_variables += obj.trainable_variables
255    trainable_extra_variables = [
256        v for v in self._self_extra_variables if v.trainable
257    ]
258    return trainable_variables + trainable_extra_variables
259
260  @property
261  def non_trainable_weights(self):
262    trainable_extra_variables = [
263        v for v in self._self_extra_variables if v.trainable
264    ]
265    non_trainable_extra_variables = [
266        v for v in self._self_extra_variables if not v.trainable
267    ]
268    non_trainable_variables = []
269    for obj in self._values:
270      if isinstance(obj, (TrackableDataStructure, module.Module)):
271        non_trainable_variables += obj.non_trainable_variables
272
273    if not self._self_trainable:
274      # Return order is all trainable vars, then all non-trainable vars.
275      trainable_variables = []
276      for obj in self._values:
277        if isinstance(obj, (TrackableDataStructure, module.Module)):
278          trainable_variables += obj.trainable_variables
279
280      non_trainable_variables = (
281          trainable_variables + trainable_extra_variables +
282          non_trainable_variables + non_trainable_extra_variables)
283    else:
284      non_trainable_variables = (
285          non_trainable_variables + non_trainable_extra_variables)
286
287    return non_trainable_variables
288
289  @property
290  def weights(self):
291    return self.trainable_weights + self.non_trainable_weights
292
293  @property
294  def trainable_variables(self):
295    return self.trainable_weights
296
297  @property
298  def non_trainable_variables(self):
299    return self.non_trainable_weights
300
301  @property
302  def variables(self):
303    return self.weights
304
305  @property
306  def updates(self):
307    """Aggregate updates from any `Layer` instances."""
308    # Updates and conditional losses are forwarded as-is rather than being
309    # filtered based on inputs, since this is just a container and won't ever
310    # have any inputs.
311    aggregated = []
312    for layer in self.layers:
313      if hasattr(layer, "updates"):
314        aggregated += layer.updates
315    return aggregated
316
317  @property
318  def losses(self):
319    """Aggregate losses from any `Layer` instances."""
320    aggregated = []
321    for layer in self.layers:
322      if hasattr(layer, "losses"):
323        aggregated += layer.losses
324    return aggregated
325
326  def __hash__(self):
327    # Support object-identity hashing, so these structures can be used as keys
328    # in sets/dicts.
329    return id(self)
330
331  def __eq__(self, other):
332    # Similar to Tensors, trackable data structures use object-identity
333    # equality to support set/dict membership.
334    return self is other
335
336
337class List(TrackableDataStructure, collections_abc.Sequence):
338  """An append-only sequence type which is trackable.
339
340  Maintains checkpoint dependencies on its contents (which must also be
341  trackable), and forwards any `Layer` metadata such as updates and losses.
342
343  Note that `List` is purely a container. It lets a `tf.keras.Model` or
344  other trackable object know about its contents, but does not call any
345  `Layer` instances which are added to it. To indicate a sequence of `Layer`
346  instances which should be called sequentially, use `tf.keras.Sequential`.
347
348  Example usage:
349  ```python
350  class HasList(tf.keras.Model):
351
352    def __init__(self):
353      super(HasList, self).__init__()
354      self.layer_list = List([layers.Dense(3)])
355      self.layer_list.append(layers.Dense(4))
356
357    def call(self, x):
358      aggregation = 0.
359      for l in self.layer_list:
360        x = l(x)
361        aggregation += tf.reduce_sum(x)
362      return aggregation
363  ```
364
365  This kind of wrapping is necessary because `Trackable` objects do not
366  (yet) deeply inspect regular Python data structures, so for example assigning
367  a regular list (`self.layer_list = [layers.Dense(3)]`) does not create a
368  checkpoint dependency and does not add the `Layer` instance's weights to its
369  parent `Model`.
370  """
371
372  def __init__(self, *args, **kwargs):
373    """Construct a new sequence. Arguments are passed to `list()`."""
374    super(List, self).__init__()
375    self._storage = self._make_storage(*args, **kwargs)
376    for index, element in enumerate(self._storage):
377      self._storage[index] = self._track_value(
378          element, name=self._name_element(index))
379
380  def copy(self):
381    return type(self)(copy.copy(self._storage))
382
383  def __copy__(self):
384    return self.copy()
385
386  def __deepcopy__(self, memo):
387    return type(self)(copy.deepcopy(self._storage, memo))
388
389  def _make_storage(self, *args, **kwargs):
390    """Determines the backing storage (overridden in subclasses)."""
391    return list(*args, **kwargs)
392
393  def _name_element(self, index):
394    return "%d" % (index,)
395
396  @property
397  def _values(self):
398    """Collect values for TrackableDataStructure."""
399    return self
400
401  def append(self, value):
402    """Add a new trackable value."""
403    value = self._track_value(value, self._name_element(len(self._storage)))
404    self._storage.append(value)
405
406  def extend(self, values):
407    """Add a sequence of trackable values."""
408    for value in values:
409      self.append(value)
410
411  def __iadd__(self, values):
412    self.extend(values)
413    return self
414
415  def __add__(self, other):
416    return self._storage + getattr(other, "_storage", other)
417
418  def __imul__(self, y):
419    if y <= 0:
420      raise ValueError(
421          "List only supports append, multiplying in place by %d removes "
422          "elements." % y)
423
424    n = len(self._storage)
425    for _ in range(y - 1):
426      for i in range(n):
427        self.append(self._storage[i])
428
429    return self
430
431  def __mul__(self, n):
432    return self._storage * n
433
434  def __rmul__(self, n):
435    return self * n
436
437  def __radd__(self, other):
438    return other + self._storage
439
440  def __getitem__(self, key):
441    return self._storage[key]
442
443  def __getslice__(self, i, j):
444    return self._storage[slice(i, j)]
445
446  def __len__(self):
447    return len(self._storage)
448
449  def __repr__(self):
450    return "List(%s)" % (repr(self._storage),)
451
452  def __sizeof__(self):
453    return super(List, self).__sizeof__() + sys.getsizeof(self._storage)
454
455
456# TODO(tomhennigan) Update to collections.UserList?
457# TODO(allenl): Try switching this to wrapt.ObjectProxy again when we drop
458# Python 3.4 support (may still be tricky).
459class ListWrapper(
460    List,
461    collections_abc.MutableSequence,
462    # Shadowed, but there for isinstance checks.
463    list):
464  """Wraps the built-in `list` to support restore-on-create for variables.
465
466  Unlike `List`, this sequence type is mutable in the same ways built-in lists
467  are. Instead of throwing an error immediately like `List`, it records
468  problematic mutations (e.g. assigning a new element to a position already
469  occupied, meaning both elements get the same names at different times) and
470  refuses to save.
471
472  On assignment to an attribute of a Model or Trackable object, Python
473  lists are replaced with ListWrapper. Wrapping a list in a
474  `NoDependency` object prevents this.
475  """
476
477  def __init__(self, wrapped_list):
478    """Construct a new list wrapper.
479
480    Args:
481      wrapped_list: The initial value of the data structure. A shallow copy may
482        be maintained for error checking. `wrapped_list` itself should not be
483        modified directly after constructing the `ListWrapper`, and if changes
484        are detected the `ListWrapper` will throw an exception on save.
485    """
486    # Monotonic flags which indicate this object would not be restored properly,
487    # and therefore should throw an error on save to avoid giving the impression
488    # that restoring it will work.
489    self._non_append_mutation_value = False
490    self._external_modification_value = False
491    super(ListWrapper, self).__init__(wrapped_list)
492    self._last_wrapped_list_snapshot = list(self._storage)
493
494  @property
495  def _non_append_mutation(self):
496    return self._non_append_mutation_value
497
498  @_non_append_mutation.setter
499  def _non_append_mutation(self, value):
500    # Trackable only cares that a mutation occurred at some point; when
501    # attempting to save it checks whether a mutation occurred and the object is
502    # in a "dirty" state but otherwise the specifics of how it got to that state
503    # are ignored. By contrast, the attribute cache needs to signal the mutation
504    # immediately since a caller could query the value of an attribute (And
505    # should not hit the cached value since the mutation may have affected the
506    # result.)
507    self._attribute_sentinel.invalidate_all()
508    self._non_append_mutation_value = value
509
510  @property
511  def _external_modification(self):
512    return self._external_modification_value
513
514  @_external_modification.setter
515  def _external_modification(self, value):
516    # Invalidate for the same reason as `_non_append_mutation`
517    self._attribute_sentinel.invalidate_all()
518    self._external_modification_value = value
519
520  # pylint: disable=protected-access
521  def __copy__(self):
522    copied = super(ListWrapper, self).__copy__()
523    copied._non_append_mutation = self._non_append_mutation
524    copied._external_modification = self._external_modification
525    return copied
526
527  def __deepcopy__(self, memo):
528    copied = super(ListWrapper, self).__deepcopy__(memo)
529    copied._non_append_mutation = self._non_append_mutation
530    copied._external_modification = self._external_modification
531    return copied
532  # pylint: enable=protected-access
533
534  def __reduce_ex__(self, protocol):
535    return (self.__class__,
536            (self._storage,))
537
538  def _make_storage(self, wrapped_list):
539    """Use the user's original list for storage."""
540    return wrapped_list
541
542  def _check_external_modification(self):
543    """Checks for any changes to the wrapped list not through the wrapper."""
544    if self._external_modification or self._non_append_mutation:
545      return
546    if self._storage != self._last_wrapped_list_snapshot:
547      self._external_modification = True
548      self._last_wrapped_list_snapshot = None
549
550  def _update_snapshot(self):
551    """Acknowledges tracked changes to the wrapped list."""
552
553    # Mutation tracking for attributes reuses the same infrastructure as
554    # Trackable mutation tracking.
555    self._attribute_sentinel.invalidate_all()
556    if self._external_modification or self._non_append_mutation:
557      return
558    self._last_wrapped_list_snapshot = list(self._storage)
559
560  @property
561  def _checkpoint_dependencies(self):
562    self._check_external_modification()
563    if self._non_append_mutation:
564      raise ValueError(
565          ("Unable to save the object %s (a list wrapper constructed to track "
566           "trackable TensorFlow objects). A list element was replaced "
567           "(__setitem__, __setslice__), deleted (__delitem__, __delslice__), "
568           "or moved (sort). In order to support restoration on object "
569           "creation, tracking is exclusively for append-only data structures."
570           "\n\nIf you don't need this list checkpointed, wrap it in a "
571           "non-trackable object; it will be subsequently ignored." % (self,)))
572    if self._external_modification:
573      raise ValueError(
574          ("Unable to save the object %s (a list wrapper constructed to track "
575           "trackable TensorFlow objects). The wrapped list was modified "
576           "outside the wrapper (its final value was %s, its value when a "
577           "checkpoint dependency was added was %s), which breaks restoration "
578           "on object creation.\n\nIf you don't need this list checkpointed, "
579           "wrap it in a NoDependency object; it will be "
580           "subsequently ignored." % (
581               self, self._storage, self._last_wrapped_list_snapshot)))
582    return super(ListWrapper, self)._checkpoint_dependencies
583
584  def _has_mutation_or_trackable(self):
585    """Short-circuits a check for trackables if there's already a mutation."""
586    if self._non_append_mutation:
587      return True
588    return any(isinstance(element, base.Trackable) for element in self._storage)
589
590  def __delitem__(self, key):
591    self._check_external_modification()
592    if self._has_mutation_or_trackable():
593      self._non_append_mutation = True
594    del self._storage[key]
595    self._update_snapshot()
596
597  def __setitem__(self, key, value):
598    self._check_external_modification()
599
600    if isinstance(key, slice):
601      # Note: this is quite inefficient, but the list API supports a broad range
602      # of slice setters (e.g. truncate, extend, replace) and imitating this
603      # for a range of Python versions is non-trivial.
604      storage_copy = list(self._storage)
605      self._storage[key] = value
606
607      len_before = len(storage_copy)
608      len_now = len(self._storage)
609      for i in range(max(len_before, len_now)):
610        value_now = self._storage[i] if i < len_now else None
611        value_before = storage_copy[i] if i < len_before else None
612
613        if isinstance(value_before, base.Trackable):
614          self._non_append_mutation = True
615
616        if value_now is not None and value_now != value_before:
617          self._storage[i] = self._track_value(self._storage[i],
618                                               self._name_element(i))
619
620    else:
621      if isinstance(self._storage[key], base.Trackable):
622        self._non_append_mutation = True
623      self._storage[key] = self._track_value(value, self._name_element(key))
624
625    self._update_snapshot()
626
627  def append(self, value):
628    """Add a new trackable value."""
629    self._check_external_modification()
630    super(ListWrapper, self).append(value)
631    self._update_snapshot()
632
633  def extend(self, values):
634    """Add a sequence of trackable values."""
635    self._check_external_modification()
636    super(ListWrapper, self).extend(values)
637    self._update_snapshot()
638
639  def __imul__(self, y):
640    if y <= 0:
641      self._check_external_modification()
642      if self._has_mutation_or_trackable():
643        self._non_append_mutation = True
644      self._storage *= y
645      self._update_snapshot()
646      return self
647
648    # Relies on super() calling append, which updates the snapshot.
649    return super(ListWrapper, self).__imul__(y)
650
651  def __eq__(self, other):
652    return self._storage == getattr(other, "_storage", other)
653
654  def __ne__(self, other):
655    return self._storage != getattr(other, "_storage", other)
656
657  def __lt__(self, other):
658    return self._storage < getattr(other, "_storage", other)
659
660  def __le__(self, other):
661    return self._storage <= getattr(other, "_storage", other)
662
663  def __gt__(self, other):
664    return self._storage > getattr(other, "_storage", other)
665
666  def __ge__(self, other):
667    return self._storage >= getattr(other, "_storage", other)
668
669  def __hash__(self):
670    # List wrappers need to compare like regular lists, and so like regular
671    # lists they don't belong in hash tables.
672    raise TypeError("unhashable type: 'ListWrapper'")
673
674  def insert(self, index, obj):
675    self._check_external_modification()
676    if (self._has_mutation_or_trackable() or isinstance(obj, base.Trackable)):
677      self._non_append_mutation = True
678    self._storage.insert(index, obj)
679    self._update_snapshot()
680
681  def sort(self):
682    self._check_external_modification()
683    if self._has_mutation_or_trackable():
684      self._non_append_mutation = True
685    self._storage.sort()
686    self._update_snapshot()
687
688  def __setslice__(self, i, j, y):
689    self.__setitem__(slice(i, j), y)
690
691  def __delslice__(self, i, j):
692    self._check_external_modification()
693    if self._has_mutation_or_trackable():
694      self._non_append_mutation = True
695    del self._storage[slice(i, j)]
696    self._update_snapshot()
697
698  def _track_value(self, value, name):
699    """Allows storage of non-trackable objects."""
700    try:
701      value = super(ListWrapper, self)._track_value(value=value, name=name)
702    except ValueError:
703      # Even if this value isn't trackable, we need to make sure
704      # NoDependency objects get unwrapped.
705      value = sticky_attribute_assignment(
706          trackable=self, value=value, name=name)
707    return value
708
709  def __repr__(self):
710    return "ListWrapper(%s)" % (repr(self._storage),)
711
712  def _list_functions_for_serialization(self, unused_functions):
713    return {
714        str(key): value for key, value in enumerate(self)
715        if _is_function(value)
716    }
717
718
719class Mapping(TrackableDataStructure, collections_abc.Mapping):
720  """An append-only trackable mapping data structure with string keys.
721
722  Maintains checkpoint dependencies on its contents (which must also be
723  trackable), named based on its keys.
724
725  Note that once a key has been added, it may not be deleted or replaced.
726  """
727
728  def __init__(self, *args, **kwargs):
729    """Construct a new sequence. Arguments are passed to `dict()`."""
730    super(Mapping, self).__init__()
731    self._storage = self._make_storage(*args, **kwargs)
732    self._storage.update(
733        {key: self._track_value(
734            value, name=self._name_element(key))
735         for key, value in self._storage.items()})
736
737  def __copy__(self):
738    return type(self)(copy.copy(self._storage))
739
740  def __deepcopy__(self, memo):
741    return type(self)(copy.deepcopy(self._storage, memo))
742
743  def _make_storage(self, *args, **kwargs):
744    return dict(*args, **kwargs)
745
746  @property
747  def _values(self):
748    """Collect values for TrackableDataStructure."""
749    # Sort items deterministically by key
750    ordered = list(zip(*sorted(self.items(), key=lambda it: it[0])))
751    if ordered:
752      return ordered[1]
753    return []
754
755  def _name_element(self, key):
756    if not isinstance(key, six.string_types):
757      raise TypeError(
758          "Mapping accepts only string keys, but got a key %s."
759          % repr(key))
760    return str(key)
761
762  def __setitem__(self, key, value):
763    name = self._name_element(key)
764    value = self._track_value(value, name=name)
765    current_value = self._storage.setdefault(key, value)
766    if current_value is not value:
767      raise ValueError(
768          ("Mappings are an append-only data structure. Tried to overwrite the "
769           "key '%s' with value %s, but it already contains %s")
770          % (key, value, current_value))
771
772  def update(self, *args, **kwargs):
773    for key, value in dict(*args, **kwargs).items():
774      self[key] = value
775
776  def __getitem__(self, key):
777    return self._storage[key]
778
779  def __len__(self):
780    return len(self._storage)
781
782  def __repr__(self):
783    return "Mapping(%s)" % (repr(self._storage),)
784
785  def __iter__(self):
786    return iter(self._storage)
787
788
789class _DictWrapper(TrackableDataStructure, wrapt.ObjectProxy):
790  """Wraps built-in dicts to support restore-on-create for variables.
791
792  _DictWrapper is to Mapping as ListWrapper is to List. Unlike Mapping,
793  _DictWrapper allows non-string keys and values and arbitrary mutations (delete
794  keys, reassign values). Like ListWrapper, these mutations mean that
795  _DictWrapper will raise an exception on save.
796  """
797
798  def __init__(self, wrapped_dict=None):
799    if wrapped_dict is None:
800      # Allow zero-argument construction, e.g. from session.run's re-wrapping.
801      wrapped_dict = {}
802    if not isinstance(wrapped_dict, collections_abc.Mapping):
803      # Allow construction from a sequence, e.g. from nest.pack_sequence_as.
804      wrapped_dict = dict(wrapped_dict)
805    wrapt.ObjectProxy.__init__(self, wrapped_dict)
806    TrackableDataStructure.__init__(self)
807    self._self_non_string_key = False
808    self._self_external_modification = False
809    self.__wrapped__.update(
810        {key: self._track_value(
811            value, name=self._name_element(key))
812         for key, value in self.__wrapped__.items()})
813    self._update_snapshot()
814
815  def __reduce_ex__(self, protocol):
816    return (self.__class__,
817            (self.__wrapped__,))
818
819  def __getattribute__(self, name):
820    if (hasattr(type(self), name)
821        and isinstance(getattr(type(self), name), property)):
822      # Bypass ObjectProxy for properties. Whether this workaround is necessary
823      # appears to depend on the Python version but not the wrapt version: 3.4
824      # in particular seems to look up properties on the wrapped object instead
825      # of the wrapper without this logic.
826      return object.__getattribute__(self, name)
827    else:
828      return super(_DictWrapper, self).__getattribute__(name)
829
830  def copy(self):
831    return copy.copy(self)
832
833  # pylint: disable=protected-access
834  def __copy__(self):
835    copied = _DictWrapper(copy.copy(self.__wrapped__))
836    copied._self_external_modification = self._self_external_modification
837    copied._self_non_string_key = self._self_non_string_key
838    return copied
839
840  def __deepcopy__(self, memo):
841    copied = _DictWrapper(copy.deepcopy(self.__wrapped__, memo))
842    copied._self_external_modification = self._self_external_modification
843    copied._self_non_string_key = self._self_non_string_key
844    return copied
845  # pylint: enable=protected-access
846
847  @property
848  def _values(self):
849    """Collect values for TrackableDataStructure."""
850    # Sort items deterministically by key
851    ordered = list(zip(*sorted(self.items(), key=lambda it: it[0])))
852    if ordered:
853      return ordered[1]
854    return []
855
856  @property
857  def _checkpoint_dependencies(self):
858    """Check that the object is saveable before listing its dependencies."""
859    self._check_self_external_modification()
860    if self._self_non_string_key:
861      raise ValueError(
862          "Unable to save the object %s (a dictionary wrapper constructed "
863          "automatically on attribute assignment). The wrapped dictionary "
864          "contains a non-string key which maps to a trackable object or "
865          "mutable data structure.\n\nIf you don't need this dictionary "
866          "checkpointed, wrap it in a non-trackable "
867          "object; it will be subsequently ignored." % (self,))
868    if self._self_external_modification:
869      raise ValueError(
870          "Unable to save the object %s (a dictionary wrapper constructed "
871          "automatically on attribute assignment). The wrapped dictionary was "
872          "modified outside the wrapper (its final value was %s, its value "
873          "when a checkpoint dependency was added was %s), which breaks "
874          "restoration on object creation.\n\nIf you don't need this "
875          "dictionary checkpointed, wrap it in a "
876          "non-trackable object; it will be subsequently ignored." % (
877              self, self, self._self_last_wrapped_dict_snapshot))
878    assert not self._dirty  # Any reason for dirtiness should have an exception.
879    return super(_DictWrapper, self)._checkpoint_dependencies
880
881  @property
882  def _dirty(self):
883    """Check if there has already been a mutation which prevents saving."""
884    return (self._self_external_modification
885            or self._self_non_string_key)
886
887  def _check_self_external_modification(self):
888    """Checks for any changes to the wrapped dict not through the wrapper."""
889    if self._dirty:
890      return
891    if self != self._self_last_wrapped_dict_snapshot:
892      self._self_external_modification = True
893      self._self_last_wrapped_dict_snapshot = None
894
895  def _update_snapshot(self):
896    """Acknowledges tracked changes to the wrapped dict."""
897    self._attribute_sentinel.invalidate_all()
898    if self._dirty:
899      return
900    self._self_last_wrapped_dict_snapshot = dict(self)
901
902  def _track_value(self, value, name):
903    """Allows storage of non-trackable objects."""
904    if isinstance(name, six.string_types):
905      string_key = True
906    else:
907      name = "-non_string_key"
908      string_key = False
909    try:
910      no_dependency = isinstance(value, NoDependency)
911      value = super(_DictWrapper, self)._track_value(value=value, name=name)
912      if not (string_key or no_dependency):
913        # A non-string key maps to a trackable value. This data structure
914        # is not saveable.
915        self._self_non_string_key = True
916      return value
917    except ValueError:
918      # Even if this value isn't trackable, we need to make sure
919      # NoDependency objects get unwrapped.
920      return sticky_attribute_assignment(
921          trackable=self, value=value, name=name)
922
923  def _name_element(self, key):
924    """Tells TrackableDataStructure to use keys as names as-is."""
925    return key
926
927  def __setitem__(self, key, value):
928    """Allow any modifications, but possibly mark the wrapper as unsaveable."""
929    self._check_self_external_modification()
930    self._maybe_initialize_trackable()
931    no_dep = isinstance(value, NoDependency)
932    if isinstance(key, six.string_types):
933      value = self._track_value(value, name=key)
934    else:
935      value = wrap_or_unwrap(value)
936      if not no_dep and isinstance(value, base.Trackable):
937        # Non-string keys are OK as long as we have no reason to add a
938        # dependency on the value (either because the value is not
939        # trackable, or because it was wrapped in a NoDependency object).
940        self._self_non_string_key = True
941    self.__wrapped__[key] = value
942
943    self._update_snapshot()
944
945  def __delitem__(self, key):
946    self._check_self_external_modification()
947    del self.__wrapped__[key]
948    self._update_snapshot()
949
950  def __repr__(self):
951    return "DictWrapper(%s)" % (repr(self.__wrapped__),)
952
953  def __hash__(self):
954    raise TypeError("unhashable type: 'DictWrapper'")
955
956  def __eq__(self, other):
957    # Override the TrackableDataStructure "== -> is" forwarding and go back to
958    # the wrapt implementation.
959    return self.__wrapped__ == other
960
961  def update(self, *args, **kwargs):
962    for key, value in six.iteritems(dict(*args, **kwargs)):
963      self[key] = value
964
965  def _list_functions_for_serialization(self, unused_serialization_cache):
966    return {
967        key: value for key, value in self.items()
968        if _is_function(value)
969    }
970
971
972class _TupleWrapper(TrackableDataStructure, wrapt.ObjectProxy):
973  """Trackable wrapper for tuples and namedtuples."""
974
975  def __init__(self, original_wrapped_tuple=()):
976    add_dependency = []
977    substituted_wrapped_tuple = []
978    for element in original_wrapped_tuple:
979      if isinstance(element, NoDependency):
980        add_dependency.append(False)
981      else:
982        add_dependency.append(True)
983      substituted_wrapped_tuple.append(wrap_or_unwrap(element))
984    try:
985      fields = original_wrapped_tuple._fields
986    except AttributeError:
987      # Not a namedtuple
988      is_namedtuple = False
989    else:
990      is_namedtuple = True
991    original_type = type(original_wrapped_tuple)
992    # Flag to poison saving if we can't re-construct a namedtupled because its
993    # __new__ takes different keyword arguments than its _fields.
994    self._self_tuple_is_constructable = True
995    if is_namedtuple:
996      try:
997        # NamedTuples take N arguments, unlike tuple which takes a sequence.
998        substituted_wrapped_tuple = original_type(
999            **dict(zip(fields, substituted_wrapped_tuple)))
1000      except TypeError:
1001        wrapt.ObjectProxy.__init__(self, original_wrapped_tuple)
1002        TrackableDataStructure.__init__(self)
1003        self._self_tuple_is_constructable = False
1004        return
1005    else:
1006      substituted_wrapped_tuple = original_type(substituted_wrapped_tuple)
1007    wrapt.ObjectProxy.__init__(self, substituted_wrapped_tuple)
1008    TrackableDataStructure.__init__(self)
1009
1010    if is_namedtuple:
1011      # For namedtuples, also track by names for compatibility with
1012      # dictionaries.
1013      for name, should_depend, element in zip(
1014          fields, add_dependency, substituted_wrapped_tuple):
1015        if should_depend:
1016          self._track_value(element, name=name)
1017
1018    # Track by index as well, for compatibility with lists.
1019    for index, (should_depend, element) in enumerate(
1020        zip(add_dependency, substituted_wrapped_tuple)):
1021      if should_depend:
1022        self._track_value(element, name="%d" % (index,))
1023
1024  @property
1025  def _values(self):
1026    """Collect values for TrackableDataStructure."""
1027    return self
1028
1029  def _track_value(self, value, name):
1030    """Allows storage of non-trackable objects."""
1031    try:
1032      value = super(_TupleWrapper, self)._track_value(value=value, name=name)
1033    except ValueError:
1034      # Even if this value isn't trackable, we need to make sure
1035      # NoDependency objects get unwrapped.
1036      value = sticky_attribute_assignment(
1037          trackable=self, value=value, name=name)
1038    return value
1039
1040  def __repr__(self):
1041    return "_TupleWrapper(%s)" % (repr(self.__wrapped__),)
1042
1043  def __hash__(self):
1044    # Override the TrackableDataStructure hash forwarding and go back to
1045    # the wrapt implementation.
1046    return hash(self.__wrapped__)
1047
1048  def __eq__(self, other):
1049    # Override the TrackableDataStructure "== -> is" forwarding and go back to
1050    # the wrapt implementation.
1051    return self.__wrapped__ == other
1052
1053  def __copy__(self):
1054    return _TupleWrapper(copy.copy(self.__wrapped__))
1055
1056  def __deepcopy__(self, memo):
1057    return _TupleWrapper(copy.deepcopy(self.__wrapped__, memo))
1058
1059  def __reduce_ex__(self, protocol):
1060    return (self.__class__,
1061            (self.__wrapped__,))
1062
1063  # imul and iadd are the only tuple-relevant in-place operators. They need to
1064  # be special-cased to avoid mutating the original proxy object.
1065  def __imul__(self, y):
1066    """Avoid running self.__wrapped__ *= y, which mutates `self`."""
1067    return self.__wrapped__ * y
1068
1069  def __iadd__(self, y):
1070    """Avoid running self.__wrapped__ += y, which mutates `self`."""
1071    return self.__wrapped__ + y
1072
1073  @property
1074  def _checkpoint_dependencies(self):
1075    if not self._self_tuple_is_constructable:
1076      raise ValueError(
1077          ("Unable to save because the namedtuple {} is not constructable from "
1078           "its _fields (i.e. __new__ is overridden). Expected keyword "
1079           "arguments {}. If you do not need to save this object, consider "
1080           "wrapping it in a custom object that does not inherit from tuple.")
1081          .format(self.__wrapped__, self.__wrapped__._fields))
1082    return super(_TupleWrapper, self)._checkpoint_dependencies
1083
1084  def __getattribute__(self, name):
1085    if name != "__wrapped__" and hasattr(self.__wrapped__, name):
1086      # Prefer attributes on the wrapped object when they conflict with
1087      # attributes on the wrapper object.
1088      return getattr(self.__wrapped__, name)
1089
1090    if (hasattr(type(self), name)
1091        and isinstance(getattr(type(self), name), property)):
1092      # Bypass ObjectProxy for properties. Whether this workaround is necessary
1093      # appears to depend on the Python version but not the wrapt version: 3.4
1094      # in particular seems to look up properties on the wrapped object instead
1095      # of the wrapper without this logic.
1096      return object.__getattribute__(self, name)
1097    else:
1098      return super(_TupleWrapper, self).__getattribute__(name)
1099
1100
1101def _is_function(x):
1102  return isinstance(x, (def_function.Function, defun.ConcreteFunction))
1103
1104
1105revived_types.register_revived_type(
1106    "trackable_dict_wrapper",
1107    lambda obj: isinstance(obj, _DictWrapper),
1108    versions=[revived_types.VersionedTypeRegistration(
1109        # Standard dependencies are enough to reconstruct the trackable
1110        # items in dictionaries, so we don't need to save any extra information.
1111        object_factory=lambda proto: _DictWrapper({}),
1112        version=1,
1113        min_producer_version=1,
1114        min_consumer_version=1,
1115        setter=operator.setitem)])
1116
1117
1118def _set_list_item(list_object, index_string, value):
1119  item_index = int(index_string)
1120  if len(list_object) <= item_index:
1121    list_object.extend([None] * (1 + item_index - len(list_object)))
1122  list_object[item_index] = value
1123
1124
1125revived_types.register_revived_type(
1126    "trackable_list_wrapper",
1127    lambda obj: isinstance(obj, ListWrapper),
1128    versions=[revived_types.VersionedTypeRegistration(
1129        object_factory=lambda proto: ListWrapper([]),
1130        version=1,
1131        min_producer_version=1,
1132        min_consumer_version=1,
1133        setter=_set_list_item)])
1134
1135
1136def _set_tuple_item(list_object, index_string, value):
1137  try:
1138    item_index = int(index_string)
1139  except ValueError:
1140    # Ignore namedtuple fields.
1141    return
1142  if len(list_object) <= item_index:
1143    list_object.extend([None] * (1 + item_index - len(list_object)))
1144  list_object[item_index] = value
1145
1146
1147# Revive tuples as lists so we can append any dependencies during loading.
1148revived_types.register_revived_type(
1149    "trackable_tuple_wrapper",
1150    lambda obj: isinstance(obj, _TupleWrapper),
1151    versions=[revived_types.VersionedTypeRegistration(
1152        object_factory=lambda proto: ListWrapper([]),
1153        version=1,
1154        min_producer_version=1,
1155        min_consumer_version=1,
1156        setter=_set_tuple_item)])
1157