• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Trackable data structures."""
2# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import collections
21import copy
22import operator
23import sys
24
25import six
26
27from tensorflow.python.eager import def_function
28from tensorflow.python.eager import function as defun
29from tensorflow.python.ops import variables
30from tensorflow.python.saved_model import revived_types
31from tensorflow.python.training.tracking import base
32from tensorflow.python.training.tracking import layer_utils
33
34
35class NoDependency(object):
36  """Allows attribute assignment to `Trackable` objects with no dependency.
37
38  Example usage:
39  ```python
40  obj = Trackable()
41  obj.has_dependency = tf.Variable(0., name="dep")
42  obj.no_dependency = NoDependency(tf.Variable(1., name="nodep"))
43  assert obj.no_dependency.name == "nodep:0"
44  ```
45
46  `obj` in this example has a dependency on the variable "dep", and both
47  attributes contain un-wrapped `Variable` objects.
48
49  `NoDependency` also works with `tf.keras.Model`, but only for checkpoint
50  dependencies: wrapping a `Layer` in `NoDependency` will assign the (unwrapped)
51  `Layer` to the attribute without a checkpoint dependency, but the `Model` will
52  still track the `Layer` (so it will appear in `Model.layers`, and its
53  variables will appear in `Model.variables`).
54  """
55
56  def __init__(self, value):
57    self.value = value
58
59
60def _wrap_or_unwrap(value):
61  """Wraps basic data structures, unwraps NoDependency objects."""
62  # pylint: disable=unidiomatic-typecheck
63  # Exact type checking to avoid mucking up custom logic in list/dict
64  # subclasses, e.g. collections.Counter.
65  if isinstance(value, NoDependency):
66    return value.value
67  if isinstance(value, base.Trackable):
68    return value  # Skip conversion for already trackable objects.
69  elif type(value) == dict:
70    return _DictWrapper(value)
71  elif type(value) == collections.OrderedDict:
72    return _DictWrapper(value)
73  elif type(value) == list:
74    return _ListWrapper(value)
75  else:
76    return value
77  # pylint: enable=unidiomatic-typecheck
78  # TODO(allenl): Handle other common data structures. Tuples will require
79  # special casing (tuple subclasses are not weak referenceable, so replacement
80  # with a wrapper that subclasses tuple on attribute assignment works poorly,
81  # and replacement with a wrapper that isn't a tuple is also problematic),
82  # probably a tree traversal where the leaves are non-tuples(/namedtuples) to
83  # come up with names. Dictionaries should look like lists.
84
85
86def sticky_attribute_assignment(trackable, name, value):
87  """Adds dependencies, generally called from __setattr__.
88
89  This behavior is shared between Trackable and Model.
90
91  Respects NoDependency indicators, but otherwise makes trackable objects
92  out of common data structures and tracks objects by their attribute names.
93
94  Args:
95    trackable: The object to add dependencies to (generally the one having
96      an attribute assigned).
97    name: The attribute name being assigned.
98    value: The value being assigned. Not necessarily a trackable object.
99
100  Returns:
101    The value which should be stored in the attribute (unwrapped from a
102    NoDependency object if necessary).
103  """
104  if isinstance(value, NoDependency):
105    add_dependency = False
106  else:
107    add_dependency = True
108  value = _wrap_or_unwrap(value)
109  if not add_dependency:
110    return value
111  if isinstance(value, base.Trackable):
112    trackable._track_trackable(  # pylint: disable=protected-access
113        value, name=name,
114        # Allow the user to switch the Trackable which is tracked by this
115        # name, since assigning a new variable to an attribute has
116        # historically been fine (e.g. Adam did this).
117        overwrite=True)
118  return value
119
120
121class _UntrackableError(ValueError):
122
123  def __init__(self, value):  # pylint: disable=super-init-not-called
124    self._value = value
125
126  def __str__(self):
127    return (("Only trackable objects (such as Layers or Optimizers) may be "
128             "stored in a List object. Got %s, which does not inherit from "
129             "Trackable.") % (self._value,))
130
131
132class TrackableDataStructure(base.Trackable):
133  """Base class for data structures which contain trackable objects."""
134
135  def __init__(self):
136    self.trainable = True
137    self._extra_variables = []
138
139  def _track_value(self, value, name):
140    """Add a dependency on `value`."""
141    value = sticky_attribute_assignment(
142        trackable=self, value=value, name=name)
143    if isinstance(value, variables.Variable):
144      self._extra_variables.append(value)
145    if not isinstance(value, base.Trackable):
146      raise _UntrackableError(value)
147    if hasattr(value, "_use_resource_variables"):
148      # In subclassed models, legacy layers (tf.layers) must always use
149      # resource variables.
150      value._use_resource_variables = True  # pylint: disable=protected-access
151    return value
152
153  @property
154  def _values(self):
155    """An iterable/sequence which may contain trackable objects."""
156    raise NotImplementedError("Abstract method")
157
158  @property
159  def _layers(self):
160    """All Layers and Layer containers, including empty containers."""
161    # Filter objects on demand so that wrapper objects use values from the thing
162    # they're wrapping if out of sync.
163    collected = []
164    for obj in self._values:
165      if (isinstance(obj, TrackableDataStructure)
166          or layer_utils.is_layer(obj)
167          or layer_utils.has_weights(obj)):
168        collected.append(obj)
169    return collected
170
171  @property
172  def layers(self):
173    return layer_utils.filter_empty_layer_containers(self._layers)
174
175  @property
176  def trainable_weights(self):
177    return layer_utils.gather_trainable_weights(
178        trainable=self.trainable,
179        sub_layers=self._layers,
180        extra_variables=self._extra_variables)
181
182  @property
183  def non_trainable_weights(self):
184    return layer_utils.gather_non_trainable_weights(
185        trainable=self.trainable,
186        sub_layers=self._layers,
187        extra_variables=self._extra_variables)
188
189  @property
190  def weights(self):
191    return self.trainable_weights + self.non_trainable_weights
192
193  @property
194  def trainable_variables(self):
195    return self.trainable_weights
196
197  @property
198  def non_trainable_variables(self):
199    return self.non_trainable_weights
200
201  @property
202  def variables(self):
203    return self.weights
204
205  @property
206  def updates(self):
207    """Aggregate updates from any `Layer` instances."""
208    # Updates and conditional losses are forwarded as-is rather than being
209    # filtered based on inputs, since this is just a container and won't ever
210    # have any inputs.
211    aggregated = []
212    for layer in self.layers:
213      if hasattr(layer, "updates"):
214        aggregated += layer.updates
215    return aggregated
216
217  @property
218  def losses(self):
219    """Aggregate losses from any `Layer` instances."""
220    aggregated = []
221    for layer in self.layers:
222      if hasattr(layer, "losses"):
223        aggregated += layer.losses
224    return aggregated
225
226  def __hash__(self):
227    # Support object-identity hashing, so these structures can be used as keys
228    # in sets/dicts.
229    return id(self)
230
231  def __eq__(self, other):
232    # Similar to Tensors, trackable data structures use object-identity
233    # equality to support set/dict membership.
234    return self is other
235
236
237class List(TrackableDataStructure, collections.Sequence):
238  """An append-only sequence type which is trackable.
239
240  Maintains checkpoint dependencies on its contents (which must also be
241  trackable), and forwards any `Layer` metadata such as updates and losses.
242
243  Note that `List` is purely a container. It lets a `tf.keras.Model` or
244  other trackable object know about its contents, but does not call any
245  `Layer` instances which are added to it. To indicate a sequence of `Layer`
246  instances which should be called sequentially, use `tf.keras.Sequential`.
247
248  Example usage:
249  ```python
250  class HasList(tf.keras.Model):
251
252    def __init__(self):
253      super(HasList, self).__init__()
254      self.layer_list = tf.contrib.checkpoint.List([layers.Dense(3)])
255      self.layer_list.append(layers.Dense(4))
256
257    def call(self, x):
258      aggregation = 0.
259      for l in self.layer_list:
260        x = l(x)
261        aggregation += tf.reduce_sum(x)
262      return aggregation
263  ```
264
265  This kind of wrapping is necessary because `Trackable` objects do not
266  (yet) deeply inspect regular Python data structures, so for example assigning
267  a regular list (`self.layer_list = [layers.Dense(3)]`) does not create a
268  checkpoint dependency and does not add the `Layer` instance's weights to its
269  parent `Model`.
270  """
271
272  def __init__(self, *args, **kwargs):
273    """Construct a new sequence. Arguments are passed to `list()`."""
274    super(List, self).__init__()
275    self._storage = self._make_storage(*args, **kwargs)
276    for index, element in enumerate(self._storage):
277      self._storage[index] = self._track_value(
278          element, name=self._name_element(index))
279
280  def copy(self):
281    return type(self)(copy.copy(self._storage))
282
283  def __copy__(self):
284    return self.copy()
285
286  def __deepcopy__(self, memo):
287    return type(self)(copy.deepcopy(self._storage, memo))
288
289  def _make_storage(self, *args, **kwargs):
290    """Determines the backing storage (overridden in subclasses)."""
291    return list(*args, **kwargs)
292
293  def _name_element(self, index):
294    return "%d" % (index,)
295
296  @property
297  def _values(self):
298    return self
299
300  def append(self, value):
301    """Add a new trackable value."""
302    value = self._track_value(value, self._name_element(len(self._storage)))
303    self._storage.append(value)
304
305  def extend(self, values):
306    """Add a sequence of trackable values."""
307    for value in values:
308      self.append(value)
309
310  def __iadd__(self, values):
311    self.extend(values)
312    return self
313
314  def __add__(self, other):
315    return self.__class__(self._storage + getattr(other, "_storage", other))
316
317  def __imul__(self, y):
318    if y <= 0:
319      raise ValueError(
320          "List only supports append, multiplying in place by %d removes "
321          "elements." % y)
322
323    n = len(self._storage)
324    for _ in range(y - 1):
325      for i in range(n):
326        self.append(self._storage[i])
327
328    return self
329
330  def __mul__(self, n):
331    return self.__class__(self._storage * n)
332
333  def __rmul__(self, n):
334    return self * n
335
336  def __radd__(self, other):
337    return self.__class__(other) + self
338
339  def __getitem__(self, key):
340    return self._storage[key]
341
342  def __getslice__(self, i, j):
343    return self._storage[slice(i, j)]
344
345  def __len__(self):
346    return len(self._storage)
347
348  def __repr__(self):
349    return "List(%s)" % (repr(self._storage),)
350
351  def __sizeof__(self):
352    return super(List, self).__sizeof__() + sys.getsizeof(self._storage)
353
354
355# TODO(tomhennigan) Update to collections.UserList?
356class _ListWrapper(List, collections.MutableSequence,
357                   # Shadowed, but there for isinstance checks.
358                   list):
359  """Wraps the built-in `list` to support restore-on-create for variables.
360
361  Unlike `List`, this sequence type is mutable in the same ways built-in lists
362  are. Instead of throwing an error immediately like `List`, it records
363  problematic mutations (e.g. assigning a new element to a position already
364  occupied, meaning both elements get the same names at different times) and
365  refuses to save.
366
367  On assignment to an attribute of a Model or Trackable object, Python
368  lists are replaced with _ListWrapper. Wrapping a list in a
369  `tf.contrib.checkpoint.NoDependency` object prevents this.
370  """
371
372  def __init__(self, wrapped_list):
373    """Construct a new list wrapper.
374
375    Args:
376      wrapped_list: The initial value of the data structure. A shallow copy may
377        be maintained for error checking. `wrapped_list` itself should not be
378        modified directly after constructing the `_ListWrapper`, and if changes
379        are detected the `_ListWrapper` will throw an exception on save.
380    """
381    # Monotonic flags which indicate this object would not be restored properly,
382    # and therefore should throw an error on save to avoid giving the impression
383    # that restoring it will work.
384    self._non_append_mutation = False
385    self._external_modification = False
386    super(_ListWrapper, self).__init__(wrapped_list)
387    self._last_wrapped_list_snapshot = list(self._storage)
388
389  # pylint: disable=protected-access
390  def __copy__(self):
391    copied = super(_ListWrapper, self).__copy__()
392    copied._non_append_mutation = self._non_append_mutation
393    copied._external_modification = self._external_modification
394    return copied
395
396  def __deepcopy__(self, memo):
397    copied = super(_ListWrapper, self).__deepcopy__(memo)
398    copied._non_append_mutation = self._non_append_mutation
399    copied._external_modification = self._external_modification
400    return copied
401  # pylint: enable=protected-access
402
403  def _make_storage(self, wrapped_list):
404    """Use the user's original list for storage."""
405    return wrapped_list
406
407  def _check_external_modification(self):
408    """Checks for any changes to the wrapped list not through the wrapper."""
409    if self._external_modification or self._non_append_mutation:
410      return
411    if self._storage != self._last_wrapped_list_snapshot:
412      self._external_modification = True
413      self._last_wrapped_list_snapshot = None
414
415  def _update_snapshot(self):
416    """Acknowledges tracked changes to the wrapped list."""
417    if self._external_modification or self._non_append_mutation:
418      return
419    self._last_wrapped_list_snapshot = list(self._storage)
420
421  @property
422  def _checkpoint_dependencies(self):
423    self._check_external_modification()
424    if self._non_append_mutation:
425      raise ValueError(
426          ("Unable to save the object %s (a list wrapper constructed to track "
427           "trackable TensorFlow objects). A list element was replaced "
428           "(__setitem__, __setslice__), deleted (__delitem__, __delslice__), "
429           "or moved (sort). In order to support restoration on object "
430           "creation, tracking is exclusively for append-only data structures."
431           "\n\nIf you don't need this list checkpointed, wrap it in a "
432           "tf.contrib.checkpoint.NoDependency object; it will be "
433           "automatically un-wrapped and subsequently ignored." % (self,)))
434    if self._external_modification:
435      raise ValueError(
436          ("Unable to save the object %s (a list wrapper constructed to track "
437           "trackable TensorFlow objects). The wrapped list was modified "
438           "outside the wrapper (its final value was %s, its value when a "
439           "checkpoint dependency was added was %s), which breaks restoration "
440           "on object creation.\n\nIf you don't need this list checkpointed, "
441           "wrap it in a tf.contrib.checkpoint.NoDependency object; it will be "
442           "automatically un-wrapped and subsequently ignored." % (
443               self, self._storage, self._last_wrapped_list_snapshot)))
444    return super(_ListWrapper, self)._checkpoint_dependencies
445
446  def __delitem__(self, key):
447    self._non_append_mutation = True
448    del self._storage[key]
449
450  def __setitem__(self, key, value):
451    self._check_external_modification()
452
453    if isinstance(key, slice):
454      # Note: this is quite inefficient, but the list API supports a broad range
455      # of slice setters (e.g. truncate, extend, replace) and immitating this
456      # for a range of Python versions is non-trivial.
457      storage_copy = list(self._storage)
458      self._storage[key] = value
459
460      len_before = len(storage_copy)
461      len_now = len(self._storage)
462      for i in range(max(len_before, len_now)):
463        value_now = self._storage[i] if i < len_now else None
464        value_before = storage_copy[i] if i < len_before else None
465
466        if isinstance(value_before, base.Trackable):
467          self._non_append_mutation = True
468
469        if value_now is not None and value_now != value_before:
470          self._storage[i] = self._track_value(self._storage[i],
471                                               self._name_element(i))
472
473    else:
474      if isinstance(self._storage[key], base.Trackable):
475        self._non_append_mutation = True
476      self._storage[key] = self._track_value(value, self._name_element(key))
477
478    self._update_snapshot()
479
480  def append(self, value):
481    """Add a new trackable value."""
482    self._check_external_modification()
483    super(_ListWrapper, self).append(value)
484    self._update_snapshot()
485
486  def extend(self, values):
487    """Add a sequence of trackable values."""
488    self._check_external_modification()
489    super(_ListWrapper, self).extend(values)
490    self._update_snapshot()
491
492  def __eq__(self, other):
493    return self._storage == getattr(other, "_storage", other)
494
495  def __ne__(self, other):
496    return self._storage != getattr(other, "_storage", other)
497
498  def __lt__(self, other):
499    return self._storage < getattr(other, "_storage", other)
500
501  def __le__(self, other):
502    return self._storage <= getattr(other, "_storage", other)
503
504  def __gt__(self, other):
505    return self._storage > getattr(other, "_storage", other)
506
507  def __ge__(self, other):
508    return self._storage >= getattr(other, "_storage", other)
509
510  def __hash__(self):
511    # List wrappers need to compare like regular lists, and so like regular
512    # lists they don't belong in hash tables.
513    raise TypeError("unhashable type: 'ListWrapper'")
514
515  def insert(self, index, obj):
516    self._non_append_mutation = True
517    self._storage.insert(index, obj)
518
519  def sort(self):
520    self._non_append_mutation = True
521    self._storage.sort()
522
523  def __setslice__(self, i, j, y):
524    self.__setitem__(slice(i, j), y)
525
526  def __delslice__(self, i, j):
527    self._non_append_mutation = True
528    del self._storage[slice(i, j)]
529
530  def _track_value(self, value, name):
531    """Allows storage of non-trackable objects."""
532    try:
533      value = super(_ListWrapper, self)._track_value(value=value, name=name)
534    except ValueError:
535      # Even if this value isn't trackable, we need to make sure
536      # NoDependency objects get unwrapped.
537      value = sticky_attribute_assignment(
538          trackable=self, value=value, name=name)
539    return value
540
541  def __repr__(self):
542    return "ListWrapper(%s)" % (repr(self._storage),)
543
544  def _list_functions_for_serialization(self):
545    return {
546        str(key): value for key, value in enumerate(self)
547        if _is_function(value)
548    }
549
550
551class Mapping(TrackableDataStructure, collections.Mapping):
552  """An append-only trackable mapping data structure with string keys.
553
554  Maintains checkpoint dependencies on its contents (which must also be
555  trackable), named based on its keys.
556
557  Note that once a key has been added, it may not be deleted or replaced. If
558  names may not be unique, see `tf.contrib.checkpoint.UniqueNameTracker`.
559  """
560
561  def __init__(self, *args, **kwargs):
562    """Construct a new sequence. Arguments are passed to `dict()`."""
563    super(Mapping, self).__init__()
564    self._storage = self._make_storage(*args, **kwargs)
565    self._storage.update(
566        {key: self._track_value(
567            value, name=self._name_element(key))
568         for key, value in self._storage.items()})
569
570  def __copy__(self):
571    return type(self)(copy.copy(self._storage))
572
573  def __deepcopy__(self, memo):
574    return type(self)(copy.deepcopy(self._storage, memo))
575
576  def _make_storage(self, *args, **kwargs):
577    return dict(*args, **kwargs)
578
579  @property
580  def _values(self):
581    # Sort items deterministically by key
582    ordered = list(zip(*sorted(self.items(), key=lambda it: it[0])))
583    if ordered:
584      return ordered[1]
585    return []
586
587  def _name_element(self, key):
588    if not isinstance(key, six.string_types):
589      raise TypeError(
590          "Mapping accepts only string keys, but got a key %s."
591          % repr(key))
592    return str(key)
593
594  def __setitem__(self, key, value):
595    name = self._name_element(key)
596    value = self._track_value(value, name=name)
597    current_value = self._storage.setdefault(key, value)
598    if current_value is not value:
599      raise ValueError(
600          ("Mappings are an append-only data structure. Tried to overwrite the "
601           "key '%s' with value %s, but it already contains %s")
602          % (key, value, current_value))
603
604  def update(self, *args, **kwargs):
605    for key, value in dict(*args, **kwargs).items():
606      self[key] = value
607
608  def __getitem__(self, key):
609    return self._storage[key]
610
611  def __len__(self):
612    return len(self._storage)
613
614  def __repr__(self):
615    return "Mapping(%s)" % (repr(self._storage),)
616
617  def __iter__(self):
618    return iter(self._storage)
619
620
621# Unlike _ListWrapper, having _DictWrapper inherit from dict and pass isinstance
622# checks seems infeasible. CPython will not call Python methods/properties on
623# dictionary subclasses when running e.g. {}.update(dict_subclass), and instead
624# collects elements directly from dict_subclass's C structs. So subclassing dict
625# implies that the storage has to be "self" (i.e. the C structs for the object
626# must be updated correctly), but we also need that storage to be the wrapped
627# dictionary to avoid synchronization bugs (un-tracked external modifications
628# should still show up when the dict is accessed through the wrapper). Monkey
629# patching all of the "wrapped" dict's methods instead of creating a wrapper
630# object is an option, but not a very attractive one (replacing methods without
631# creating reference cycles is difficult, and then dicts would need to be
632# special cased everywhere as being trackable).
633class _DictWrapper(Mapping, collections.MutableMapping):
634  """Wraps built-in dicts to support restore-on-create for variables.
635
636  _DictWrapper is to Mapping as _ListWrapper is to List. Unlike Mapping,
637  _DictWrapper allows non-string keys and values and arbitrary mutations (delete
638  keys, reassign values). Like _ListWrapper, these mutations mean that
639  _DictWrapper will raise an exception on save.
640  """
641
642  def __new__(cls, *args):
643    if len(args) == 1 and isinstance(args[0], dict):
644      return super(_DictWrapper, cls).__new__(cls)
645    else:
646      # Allow construction from a sequence, e.g. for nest.pack_sequence_as. In
647      # this case there's nothing to wrap, so we make a normal dictionary. Also
648      # allows constructing empty instances of the _DictWrapper type, as Session
649      # is wont to do (and again there's nothing to wrap, so a normal dictionary
650      # makes more sense).
651      return dict(*args)
652
653  def __init__(self, wrapped_dict):
654    self._non_string_key = False
655    self._non_append_mutation = False
656    self._external_modification = False
657    super(_DictWrapper, self).__init__(wrapped_dict)
658    self._update_snapshot()
659
660  # pylint: disable=protected-access
661  def __copy__(self):
662    copied = super(_DictWrapper, self).__copy__()
663    copied._non_append_mutation = self._non_append_mutation
664    copied._external_modification = self._external_modification
665    copied._non_string_key = self._non_string_key
666    return copied
667
668  def __deepcopy__(self, memo):
669    copied = super(_DictWrapper, self).__deepcopy__(memo)
670    copied._non_append_mutation = self._non_append_mutation
671    copied._external_modification = self._external_modification
672    copied._non_string_key = self._non_string_key
673    return copied
674  # pylint: enable=protected-access
675
676  def _make_storage(self, wrapped_dict):
677    """Re-use the wrapped dict for storage (to force them to be in sync)."""
678    return wrapped_dict
679
680  @property
681  def _checkpoint_dependencies(self):
682    """Check that the object is saveable before listing its dependencies."""
683    self._check_external_modification()
684    if self._non_string_key:
685      raise ValueError(
686          "Unable to save the object %s (a dictionary wrapper constructed "
687          "automatically on attribute assignment). The wrapped dictionary "
688          "contains a non-string key which maps to a trackable object or "
689          "mutable data structure.\n\nIf you don't need this dictionary "
690          "checkpointed, wrap it in a tf.contrib.checkpoint.NoDependency "
691          "object; it will be automatically un-wrapped and subsequently "
692          "ignored." % (self,))
693    if self._non_append_mutation:
694      raise ValueError(
695          "Unable to save the object %s (a dictionary wrapper constructed "
696          "automatically on attribute assignment). A key mapping to a "
697          "trackable object was overwritten or deleted, which would "
698          "cause problems for restoration.\n\nIf you don't need this "
699          "dictionary checkpointed, wrap it in a "
700          "tf.contrib.checkpoint.NoDependency object; it will be automatically "
701          "un-wrapped and subsequently ignored." % (self,))
702    if self._external_modification:
703      raise ValueError(
704          "Unable to save the object %s (a dictionary wrapper constructed "
705          "automatically on attribute assignment). The wrapped dictionary was "
706          "modified outside the wrapper (its final value was %s, its value "
707          "when a checkpoint dependency was added was %s), which breaks "
708          "restoration on object creation.\n\nIf you don't need this "
709          "dictionary checkpointed, wrap it in a "
710          "tf.contrib.checkpoint.NoDependency object; it will be automatically "
711          "un-wrapped and subsequently ignored." % (
712              self, self, self._last_wrapped_dict_snapshot))
713    assert not self._dirty  # Any reason for dirtiness should have an exception.
714    return super(_DictWrapper, self)._checkpoint_dependencies
715
716  @property
717  def _dirty(self):
718    """Check if there has already been a mutation which prevents saving."""
719    return (self._external_modification
720            or self._non_append_mutation
721            or self._non_string_key)
722
723  def _check_external_modification(self):
724    """Checks for any changes to the wrapped dict not through the wrapper."""
725    if self._dirty:
726      return
727    if self != self._last_wrapped_dict_snapshot:
728      self._external_modification = True
729      self._last_wrapped_dict_snapshot = None
730
731  def _update_snapshot(self):
732    """Acknowledges tracked changes to the wrapped dict."""
733    if self._dirty:
734      return
735    self._last_wrapped_dict_snapshot = dict(self)
736
737  def _track_value(self, value, name):
738    """Allows storage of non-trackable objects."""
739    if isinstance(name, six.string_types):
740      string_key = True
741    else:
742      name = "-non_string_key"
743      string_key = False
744    try:
745      no_dependency = isinstance(value, NoDependency)
746      value = super(_DictWrapper, self)._track_value(value=value, name=name)
747      if not (string_key or no_dependency):
748        # A non-string key maps to a trackable value. This data structure
749        # is not saveable.
750        self._non_string_key = True
751      return value
752    except ValueError:
753      # Even if this value isn't trackable, we need to make sure
754      # NoDependency objects get unwrapped.
755      return sticky_attribute_assignment(
756          trackable=self, value=value, name=name)
757
758  def _name_element(self, key):
759    """Don't throw errors for non-string keys."""
760    if isinstance(key, six.string_types):
761      return super(_DictWrapper, self)._name_element(key)
762    else:
763      return key
764
765  def __setitem__(self, key, value):
766    """Allow any modifications, but possibly mark the wrapper as unsaveable."""
767    self._check_external_modification()
768    no_dep = isinstance(value, NoDependency)
769    if isinstance(key, six.string_types):
770      existing_dependency = self._lookup_dependency(key)
771      value = self._track_value(value, name=key)
772    else:
773      value = _wrap_or_unwrap(value)
774      existing_dependency = None
775      if not no_dep and isinstance(value, base.Trackable):
776        # Non-string keys are OK as long as we have no reason to add a
777        # dependency on the value (either because the value is not
778        # trackable, or because it was wrapped in a NoDependency object).
779        self._non_string_key = True
780    if key in self._storage:
781      previous_value = self._storage[key]
782      if previous_value is not value:
783        if ((not no_dep and isinstance(value, base.Trackable))
784            # We don't want to just check that the existing object is
785            # trackable, since it may have been wrapped in a NoDependency
786            # object.
787            or existing_dependency is not None):
788          # A trackable object was replaced under the same key; this means
789          # that restoring would be error-prone, so we'll throw an exception on
790          # save.
791          self._non_append_mutation = True
792    self._storage[key] = value
793
794    self._update_snapshot()
795
796  def __delitem__(self, key):
797    self._check_external_modification()
798    existing_value = self[key]
799    if isinstance(existing_value, base.Trackable):
800      # Deleting tracked trackable values means restoring is problematic,
801      # so we'll throw an exception on save.
802      self._non_append_mutation = True
803    del self._storage[key]
804    self._update_snapshot()
805
806  def __repr__(self):
807    return "DictWrapper(%s)" % (repr(self._storage),)
808
809  def __hash__(self):
810    raise TypeError("unhashable type: 'DictWrapper'")
811
812  def __eq__(self, other):
813    return self._storage == getattr(other, "_storage", other)
814
815  def update(self, *args, **kwargs):
816    for key, value in dict(*args, **kwargs).items():
817      self[key] = value
818
819  def _list_functions_for_serialization(self):
820    return {
821        key: value for key, value in self.items()
822        if _is_function(value)
823    }
824
825
826def _is_function(x):
827  return isinstance(x, (def_function.Function, defun.ConcreteFunction))
828
829revived_types.register_revived_type(
830    "trackable_dict_wrapper",
831    lambda obj: isinstance(obj, _DictWrapper),
832    versions=[revived_types.VersionedTypeRegistration(
833        # Standard dependencies are enough to reconstruct the trackable
834        # items in dictionaries, so we don't need to save any extra information.
835        object_factory=lambda proto: _DictWrapper({}),
836        version=1,
837        min_producer_version=1,
838        min_consumer_version=1,
839        setter=operator.setitem)])
840
841
842def _set_list_item(list_object, index_string, value):
843  item_index = int(index_string)
844  if len(list_object) <= item_index:
845    list_object.extend([None] * (1 + item_index - len(list_object)))
846  list_object[item_index] = value
847
848
849revived_types.register_revived_type(
850    "trackable_list_wrapper",
851    lambda obj: isinstance(obj, _ListWrapper),
852    versions=[revived_types.VersionedTypeRegistration(
853        object_factory=lambda proto: _ListWrapper([]),
854        version=1,
855        min_producer_version=1,
856        min_consumer_version=1,
857        setter=_set_list_item)])
858