1"""Trackable data structures.""" 2# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import collections 21import copy 22import operator 23import sys 24 25import six 26 27from tensorflow.python.eager import def_function 28from tensorflow.python.eager import function as defun 29from tensorflow.python.ops import variables 30from tensorflow.python.saved_model import revived_types 31from tensorflow.python.training.tracking import base 32from tensorflow.python.training.tracking import layer_utils 33 34 35class NoDependency(object): 36 """Allows attribute assignment to `Trackable` objects with no dependency. 37 38 Example usage: 39 ```python 40 obj = Trackable() 41 obj.has_dependency = tf.Variable(0., name="dep") 42 obj.no_dependency = NoDependency(tf.Variable(1., name="nodep")) 43 assert obj.no_dependency.name == "nodep:0" 44 ``` 45 46 `obj` in this example has a dependency on the variable "dep", and both 47 attributes contain un-wrapped `Variable` objects. 48 49 `NoDependency` also works with `tf.keras.Model`, but only for checkpoint 50 dependencies: wrapping a `Layer` in `NoDependency` will assign the (unwrapped) 51 `Layer` to the attribute without a checkpoint dependency, but the `Model` will 52 still track the `Layer` (so it will appear in `Model.layers`, and its 53 variables will appear in `Model.variables`). 54 """ 55 56 def __init__(self, value): 57 self.value = value 58 59 60def _wrap_or_unwrap(value): 61 """Wraps basic data structures, unwraps NoDependency objects.""" 62 # pylint: disable=unidiomatic-typecheck 63 # Exact type checking to avoid mucking up custom logic in list/dict 64 # subclasses, e.g. collections.Counter. 65 if isinstance(value, NoDependency): 66 return value.value 67 if isinstance(value, base.Trackable): 68 return value # Skip conversion for already trackable objects. 69 elif type(value) == dict: 70 return _DictWrapper(value) 71 elif type(value) == collections.OrderedDict: 72 return _DictWrapper(value) 73 elif type(value) == list: 74 return _ListWrapper(value) 75 else: 76 return value 77 # pylint: enable=unidiomatic-typecheck 78 # TODO(allenl): Handle other common data structures. Tuples will require 79 # special casing (tuple subclasses are not weak referenceable, so replacement 80 # with a wrapper that subclasses tuple on attribute assignment works poorly, 81 # and replacement with a wrapper that isn't a tuple is also problematic), 82 # probably a tree traversal where the leaves are non-tuples(/namedtuples) to 83 # come up with names. Dictionaries should look like lists. 84 85 86def sticky_attribute_assignment(trackable, name, value): 87 """Adds dependencies, generally called from __setattr__. 88 89 This behavior is shared between Trackable and Model. 90 91 Respects NoDependency indicators, but otherwise makes trackable objects 92 out of common data structures and tracks objects by their attribute names. 93 94 Args: 95 trackable: The object to add dependencies to (generally the one having 96 an attribute assigned). 97 name: The attribute name being assigned. 98 value: The value being assigned. Not necessarily a trackable object. 99 100 Returns: 101 The value which should be stored in the attribute (unwrapped from a 102 NoDependency object if necessary). 103 """ 104 if isinstance(value, NoDependency): 105 add_dependency = False 106 else: 107 add_dependency = True 108 value = _wrap_or_unwrap(value) 109 if not add_dependency: 110 return value 111 if isinstance(value, base.Trackable): 112 trackable._track_trackable( # pylint: disable=protected-access 113 value, name=name, 114 # Allow the user to switch the Trackable which is tracked by this 115 # name, since assigning a new variable to an attribute has 116 # historically been fine (e.g. Adam did this). 117 overwrite=True) 118 return value 119 120 121class _UntrackableError(ValueError): 122 123 def __init__(self, value): # pylint: disable=super-init-not-called 124 self._value = value 125 126 def __str__(self): 127 return (("Only trackable objects (such as Layers or Optimizers) may be " 128 "stored in a List object. Got %s, which does not inherit from " 129 "Trackable.") % (self._value,)) 130 131 132class TrackableDataStructure(base.Trackable): 133 """Base class for data structures which contain trackable objects.""" 134 135 def __init__(self): 136 self.trainable = True 137 self._extra_variables = [] 138 139 def _track_value(self, value, name): 140 """Add a dependency on `value`.""" 141 value = sticky_attribute_assignment( 142 trackable=self, value=value, name=name) 143 if isinstance(value, variables.Variable): 144 self._extra_variables.append(value) 145 if not isinstance(value, base.Trackable): 146 raise _UntrackableError(value) 147 if hasattr(value, "_use_resource_variables"): 148 # In subclassed models, legacy layers (tf.layers) must always use 149 # resource variables. 150 value._use_resource_variables = True # pylint: disable=protected-access 151 return value 152 153 @property 154 def _values(self): 155 """An iterable/sequence which may contain trackable objects.""" 156 raise NotImplementedError("Abstract method") 157 158 @property 159 def _layers(self): 160 """All Layers and Layer containers, including empty containers.""" 161 # Filter objects on demand so that wrapper objects use values from the thing 162 # they're wrapping if out of sync. 163 collected = [] 164 for obj in self._values: 165 if (isinstance(obj, TrackableDataStructure) 166 or layer_utils.is_layer(obj) 167 or layer_utils.has_weights(obj)): 168 collected.append(obj) 169 return collected 170 171 @property 172 def layers(self): 173 return layer_utils.filter_empty_layer_containers(self._layers) 174 175 @property 176 def trainable_weights(self): 177 return layer_utils.gather_trainable_weights( 178 trainable=self.trainable, 179 sub_layers=self._layers, 180 extra_variables=self._extra_variables) 181 182 @property 183 def non_trainable_weights(self): 184 return layer_utils.gather_non_trainable_weights( 185 trainable=self.trainable, 186 sub_layers=self._layers, 187 extra_variables=self._extra_variables) 188 189 @property 190 def weights(self): 191 return self.trainable_weights + self.non_trainable_weights 192 193 @property 194 def trainable_variables(self): 195 return self.trainable_weights 196 197 @property 198 def non_trainable_variables(self): 199 return self.non_trainable_weights 200 201 @property 202 def variables(self): 203 return self.weights 204 205 @property 206 def updates(self): 207 """Aggregate updates from any `Layer` instances.""" 208 # Updates and conditional losses are forwarded as-is rather than being 209 # filtered based on inputs, since this is just a container and won't ever 210 # have any inputs. 211 aggregated = [] 212 for layer in self.layers: 213 if hasattr(layer, "updates"): 214 aggregated += layer.updates 215 return aggregated 216 217 @property 218 def losses(self): 219 """Aggregate losses from any `Layer` instances.""" 220 aggregated = [] 221 for layer in self.layers: 222 if hasattr(layer, "losses"): 223 aggregated += layer.losses 224 return aggregated 225 226 def __hash__(self): 227 # Support object-identity hashing, so these structures can be used as keys 228 # in sets/dicts. 229 return id(self) 230 231 def __eq__(self, other): 232 # Similar to Tensors, trackable data structures use object-identity 233 # equality to support set/dict membership. 234 return self is other 235 236 237class List(TrackableDataStructure, collections.Sequence): 238 """An append-only sequence type which is trackable. 239 240 Maintains checkpoint dependencies on its contents (which must also be 241 trackable), and forwards any `Layer` metadata such as updates and losses. 242 243 Note that `List` is purely a container. It lets a `tf.keras.Model` or 244 other trackable object know about its contents, but does not call any 245 `Layer` instances which are added to it. To indicate a sequence of `Layer` 246 instances which should be called sequentially, use `tf.keras.Sequential`. 247 248 Example usage: 249 ```python 250 class HasList(tf.keras.Model): 251 252 def __init__(self): 253 super(HasList, self).__init__() 254 self.layer_list = tf.contrib.checkpoint.List([layers.Dense(3)]) 255 self.layer_list.append(layers.Dense(4)) 256 257 def call(self, x): 258 aggregation = 0. 259 for l in self.layer_list: 260 x = l(x) 261 aggregation += tf.reduce_sum(x) 262 return aggregation 263 ``` 264 265 This kind of wrapping is necessary because `Trackable` objects do not 266 (yet) deeply inspect regular Python data structures, so for example assigning 267 a regular list (`self.layer_list = [layers.Dense(3)]`) does not create a 268 checkpoint dependency and does not add the `Layer` instance's weights to its 269 parent `Model`. 270 """ 271 272 def __init__(self, *args, **kwargs): 273 """Construct a new sequence. Arguments are passed to `list()`.""" 274 super(List, self).__init__() 275 self._storage = self._make_storage(*args, **kwargs) 276 for index, element in enumerate(self._storage): 277 self._storage[index] = self._track_value( 278 element, name=self._name_element(index)) 279 280 def copy(self): 281 return type(self)(copy.copy(self._storage)) 282 283 def __copy__(self): 284 return self.copy() 285 286 def __deepcopy__(self, memo): 287 return type(self)(copy.deepcopy(self._storage, memo)) 288 289 def _make_storage(self, *args, **kwargs): 290 """Determines the backing storage (overridden in subclasses).""" 291 return list(*args, **kwargs) 292 293 def _name_element(self, index): 294 return "%d" % (index,) 295 296 @property 297 def _values(self): 298 return self 299 300 def append(self, value): 301 """Add a new trackable value.""" 302 value = self._track_value(value, self._name_element(len(self._storage))) 303 self._storage.append(value) 304 305 def extend(self, values): 306 """Add a sequence of trackable values.""" 307 for value in values: 308 self.append(value) 309 310 def __iadd__(self, values): 311 self.extend(values) 312 return self 313 314 def __add__(self, other): 315 return self.__class__(self._storage + getattr(other, "_storage", other)) 316 317 def __imul__(self, y): 318 if y <= 0: 319 raise ValueError( 320 "List only supports append, multiplying in place by %d removes " 321 "elements." % y) 322 323 n = len(self._storage) 324 for _ in range(y - 1): 325 for i in range(n): 326 self.append(self._storage[i]) 327 328 return self 329 330 def __mul__(self, n): 331 return self.__class__(self._storage * n) 332 333 def __rmul__(self, n): 334 return self * n 335 336 def __radd__(self, other): 337 return self.__class__(other) + self 338 339 def __getitem__(self, key): 340 return self._storage[key] 341 342 def __getslice__(self, i, j): 343 return self._storage[slice(i, j)] 344 345 def __len__(self): 346 return len(self._storage) 347 348 def __repr__(self): 349 return "List(%s)" % (repr(self._storage),) 350 351 def __sizeof__(self): 352 return super(List, self).__sizeof__() + sys.getsizeof(self._storage) 353 354 355# TODO(tomhennigan) Update to collections.UserList? 356class _ListWrapper(List, collections.MutableSequence, 357 # Shadowed, but there for isinstance checks. 358 list): 359 """Wraps the built-in `list` to support restore-on-create for variables. 360 361 Unlike `List`, this sequence type is mutable in the same ways built-in lists 362 are. Instead of throwing an error immediately like `List`, it records 363 problematic mutations (e.g. assigning a new element to a position already 364 occupied, meaning both elements get the same names at different times) and 365 refuses to save. 366 367 On assignment to an attribute of a Model or Trackable object, Python 368 lists are replaced with _ListWrapper. Wrapping a list in a 369 `tf.contrib.checkpoint.NoDependency` object prevents this. 370 """ 371 372 def __init__(self, wrapped_list): 373 """Construct a new list wrapper. 374 375 Args: 376 wrapped_list: The initial value of the data structure. A shallow copy may 377 be maintained for error checking. `wrapped_list` itself should not be 378 modified directly after constructing the `_ListWrapper`, and if changes 379 are detected the `_ListWrapper` will throw an exception on save. 380 """ 381 # Monotonic flags which indicate this object would not be restored properly, 382 # and therefore should throw an error on save to avoid giving the impression 383 # that restoring it will work. 384 self._non_append_mutation = False 385 self._external_modification = False 386 super(_ListWrapper, self).__init__(wrapped_list) 387 self._last_wrapped_list_snapshot = list(self._storage) 388 389 # pylint: disable=protected-access 390 def __copy__(self): 391 copied = super(_ListWrapper, self).__copy__() 392 copied._non_append_mutation = self._non_append_mutation 393 copied._external_modification = self._external_modification 394 return copied 395 396 def __deepcopy__(self, memo): 397 copied = super(_ListWrapper, self).__deepcopy__(memo) 398 copied._non_append_mutation = self._non_append_mutation 399 copied._external_modification = self._external_modification 400 return copied 401 # pylint: enable=protected-access 402 403 def _make_storage(self, wrapped_list): 404 """Use the user's original list for storage.""" 405 return wrapped_list 406 407 def _check_external_modification(self): 408 """Checks for any changes to the wrapped list not through the wrapper.""" 409 if self._external_modification or self._non_append_mutation: 410 return 411 if self._storage != self._last_wrapped_list_snapshot: 412 self._external_modification = True 413 self._last_wrapped_list_snapshot = None 414 415 def _update_snapshot(self): 416 """Acknowledges tracked changes to the wrapped list.""" 417 if self._external_modification or self._non_append_mutation: 418 return 419 self._last_wrapped_list_snapshot = list(self._storage) 420 421 @property 422 def _checkpoint_dependencies(self): 423 self._check_external_modification() 424 if self._non_append_mutation: 425 raise ValueError( 426 ("Unable to save the object %s (a list wrapper constructed to track " 427 "trackable TensorFlow objects). A list element was replaced " 428 "(__setitem__, __setslice__), deleted (__delitem__, __delslice__), " 429 "or moved (sort). In order to support restoration on object " 430 "creation, tracking is exclusively for append-only data structures." 431 "\n\nIf you don't need this list checkpointed, wrap it in a " 432 "tf.contrib.checkpoint.NoDependency object; it will be " 433 "automatically un-wrapped and subsequently ignored." % (self,))) 434 if self._external_modification: 435 raise ValueError( 436 ("Unable to save the object %s (a list wrapper constructed to track " 437 "trackable TensorFlow objects). The wrapped list was modified " 438 "outside the wrapper (its final value was %s, its value when a " 439 "checkpoint dependency was added was %s), which breaks restoration " 440 "on object creation.\n\nIf you don't need this list checkpointed, " 441 "wrap it in a tf.contrib.checkpoint.NoDependency object; it will be " 442 "automatically un-wrapped and subsequently ignored." % ( 443 self, self._storage, self._last_wrapped_list_snapshot))) 444 return super(_ListWrapper, self)._checkpoint_dependencies 445 446 def __delitem__(self, key): 447 self._non_append_mutation = True 448 del self._storage[key] 449 450 def __setitem__(self, key, value): 451 self._check_external_modification() 452 453 if isinstance(key, slice): 454 # Note: this is quite inefficient, but the list API supports a broad range 455 # of slice setters (e.g. truncate, extend, replace) and immitating this 456 # for a range of Python versions is non-trivial. 457 storage_copy = list(self._storage) 458 self._storage[key] = value 459 460 len_before = len(storage_copy) 461 len_now = len(self._storage) 462 for i in range(max(len_before, len_now)): 463 value_now = self._storage[i] if i < len_now else None 464 value_before = storage_copy[i] if i < len_before else None 465 466 if isinstance(value_before, base.Trackable): 467 self._non_append_mutation = True 468 469 if value_now is not None and value_now != value_before: 470 self._storage[i] = self._track_value(self._storage[i], 471 self._name_element(i)) 472 473 else: 474 if isinstance(self._storage[key], base.Trackable): 475 self._non_append_mutation = True 476 self._storage[key] = self._track_value(value, self._name_element(key)) 477 478 self._update_snapshot() 479 480 def append(self, value): 481 """Add a new trackable value.""" 482 self._check_external_modification() 483 super(_ListWrapper, self).append(value) 484 self._update_snapshot() 485 486 def extend(self, values): 487 """Add a sequence of trackable values.""" 488 self._check_external_modification() 489 super(_ListWrapper, self).extend(values) 490 self._update_snapshot() 491 492 def __eq__(self, other): 493 return self._storage == getattr(other, "_storage", other) 494 495 def __ne__(self, other): 496 return self._storage != getattr(other, "_storage", other) 497 498 def __lt__(self, other): 499 return self._storage < getattr(other, "_storage", other) 500 501 def __le__(self, other): 502 return self._storage <= getattr(other, "_storage", other) 503 504 def __gt__(self, other): 505 return self._storage > getattr(other, "_storage", other) 506 507 def __ge__(self, other): 508 return self._storage >= getattr(other, "_storage", other) 509 510 def __hash__(self): 511 # List wrappers need to compare like regular lists, and so like regular 512 # lists they don't belong in hash tables. 513 raise TypeError("unhashable type: 'ListWrapper'") 514 515 def insert(self, index, obj): 516 self._non_append_mutation = True 517 self._storage.insert(index, obj) 518 519 def sort(self): 520 self._non_append_mutation = True 521 self._storage.sort() 522 523 def __setslice__(self, i, j, y): 524 self.__setitem__(slice(i, j), y) 525 526 def __delslice__(self, i, j): 527 self._non_append_mutation = True 528 del self._storage[slice(i, j)] 529 530 def _track_value(self, value, name): 531 """Allows storage of non-trackable objects.""" 532 try: 533 value = super(_ListWrapper, self)._track_value(value=value, name=name) 534 except ValueError: 535 # Even if this value isn't trackable, we need to make sure 536 # NoDependency objects get unwrapped. 537 value = sticky_attribute_assignment( 538 trackable=self, value=value, name=name) 539 return value 540 541 def __repr__(self): 542 return "ListWrapper(%s)" % (repr(self._storage),) 543 544 def _list_functions_for_serialization(self): 545 return { 546 str(key): value for key, value in enumerate(self) 547 if _is_function(value) 548 } 549 550 551class Mapping(TrackableDataStructure, collections.Mapping): 552 """An append-only trackable mapping data structure with string keys. 553 554 Maintains checkpoint dependencies on its contents (which must also be 555 trackable), named based on its keys. 556 557 Note that once a key has been added, it may not be deleted or replaced. If 558 names may not be unique, see `tf.contrib.checkpoint.UniqueNameTracker`. 559 """ 560 561 def __init__(self, *args, **kwargs): 562 """Construct a new sequence. Arguments are passed to `dict()`.""" 563 super(Mapping, self).__init__() 564 self._storage = self._make_storage(*args, **kwargs) 565 self._storage.update( 566 {key: self._track_value( 567 value, name=self._name_element(key)) 568 for key, value in self._storage.items()}) 569 570 def __copy__(self): 571 return type(self)(copy.copy(self._storage)) 572 573 def __deepcopy__(self, memo): 574 return type(self)(copy.deepcopy(self._storage, memo)) 575 576 def _make_storage(self, *args, **kwargs): 577 return dict(*args, **kwargs) 578 579 @property 580 def _values(self): 581 # Sort items deterministically by key 582 ordered = list(zip(*sorted(self.items(), key=lambda it: it[0]))) 583 if ordered: 584 return ordered[1] 585 return [] 586 587 def _name_element(self, key): 588 if not isinstance(key, six.string_types): 589 raise TypeError( 590 "Mapping accepts only string keys, but got a key %s." 591 % repr(key)) 592 return str(key) 593 594 def __setitem__(self, key, value): 595 name = self._name_element(key) 596 value = self._track_value(value, name=name) 597 current_value = self._storage.setdefault(key, value) 598 if current_value is not value: 599 raise ValueError( 600 ("Mappings are an append-only data structure. Tried to overwrite the " 601 "key '%s' with value %s, but it already contains %s") 602 % (key, value, current_value)) 603 604 def update(self, *args, **kwargs): 605 for key, value in dict(*args, **kwargs).items(): 606 self[key] = value 607 608 def __getitem__(self, key): 609 return self._storage[key] 610 611 def __len__(self): 612 return len(self._storage) 613 614 def __repr__(self): 615 return "Mapping(%s)" % (repr(self._storage),) 616 617 def __iter__(self): 618 return iter(self._storage) 619 620 621# Unlike _ListWrapper, having _DictWrapper inherit from dict and pass isinstance 622# checks seems infeasible. CPython will not call Python methods/properties on 623# dictionary subclasses when running e.g. {}.update(dict_subclass), and instead 624# collects elements directly from dict_subclass's C structs. So subclassing dict 625# implies that the storage has to be "self" (i.e. the C structs for the object 626# must be updated correctly), but we also need that storage to be the wrapped 627# dictionary to avoid synchronization bugs (un-tracked external modifications 628# should still show up when the dict is accessed through the wrapper). Monkey 629# patching all of the "wrapped" dict's methods instead of creating a wrapper 630# object is an option, but not a very attractive one (replacing methods without 631# creating reference cycles is difficult, and then dicts would need to be 632# special cased everywhere as being trackable). 633class _DictWrapper(Mapping, collections.MutableMapping): 634 """Wraps built-in dicts to support restore-on-create for variables. 635 636 _DictWrapper is to Mapping as _ListWrapper is to List. Unlike Mapping, 637 _DictWrapper allows non-string keys and values and arbitrary mutations (delete 638 keys, reassign values). Like _ListWrapper, these mutations mean that 639 _DictWrapper will raise an exception on save. 640 """ 641 642 def __new__(cls, *args): 643 if len(args) == 1 and isinstance(args[0], dict): 644 return super(_DictWrapper, cls).__new__(cls) 645 else: 646 # Allow construction from a sequence, e.g. for nest.pack_sequence_as. In 647 # this case there's nothing to wrap, so we make a normal dictionary. Also 648 # allows constructing empty instances of the _DictWrapper type, as Session 649 # is wont to do (and again there's nothing to wrap, so a normal dictionary 650 # makes more sense). 651 return dict(*args) 652 653 def __init__(self, wrapped_dict): 654 self._non_string_key = False 655 self._non_append_mutation = False 656 self._external_modification = False 657 super(_DictWrapper, self).__init__(wrapped_dict) 658 self._update_snapshot() 659 660 # pylint: disable=protected-access 661 def __copy__(self): 662 copied = super(_DictWrapper, self).__copy__() 663 copied._non_append_mutation = self._non_append_mutation 664 copied._external_modification = self._external_modification 665 copied._non_string_key = self._non_string_key 666 return copied 667 668 def __deepcopy__(self, memo): 669 copied = super(_DictWrapper, self).__deepcopy__(memo) 670 copied._non_append_mutation = self._non_append_mutation 671 copied._external_modification = self._external_modification 672 copied._non_string_key = self._non_string_key 673 return copied 674 # pylint: enable=protected-access 675 676 def _make_storage(self, wrapped_dict): 677 """Re-use the wrapped dict for storage (to force them to be in sync).""" 678 return wrapped_dict 679 680 @property 681 def _checkpoint_dependencies(self): 682 """Check that the object is saveable before listing its dependencies.""" 683 self._check_external_modification() 684 if self._non_string_key: 685 raise ValueError( 686 "Unable to save the object %s (a dictionary wrapper constructed " 687 "automatically on attribute assignment). The wrapped dictionary " 688 "contains a non-string key which maps to a trackable object or " 689 "mutable data structure.\n\nIf you don't need this dictionary " 690 "checkpointed, wrap it in a tf.contrib.checkpoint.NoDependency " 691 "object; it will be automatically un-wrapped and subsequently " 692 "ignored." % (self,)) 693 if self._non_append_mutation: 694 raise ValueError( 695 "Unable to save the object %s (a dictionary wrapper constructed " 696 "automatically on attribute assignment). A key mapping to a " 697 "trackable object was overwritten or deleted, which would " 698 "cause problems for restoration.\n\nIf you don't need this " 699 "dictionary checkpointed, wrap it in a " 700 "tf.contrib.checkpoint.NoDependency object; it will be automatically " 701 "un-wrapped and subsequently ignored." % (self,)) 702 if self._external_modification: 703 raise ValueError( 704 "Unable to save the object %s (a dictionary wrapper constructed " 705 "automatically on attribute assignment). The wrapped dictionary was " 706 "modified outside the wrapper (its final value was %s, its value " 707 "when a checkpoint dependency was added was %s), which breaks " 708 "restoration on object creation.\n\nIf you don't need this " 709 "dictionary checkpointed, wrap it in a " 710 "tf.contrib.checkpoint.NoDependency object; it will be automatically " 711 "un-wrapped and subsequently ignored." % ( 712 self, self, self._last_wrapped_dict_snapshot)) 713 assert not self._dirty # Any reason for dirtiness should have an exception. 714 return super(_DictWrapper, self)._checkpoint_dependencies 715 716 @property 717 def _dirty(self): 718 """Check if there has already been a mutation which prevents saving.""" 719 return (self._external_modification 720 or self._non_append_mutation 721 or self._non_string_key) 722 723 def _check_external_modification(self): 724 """Checks for any changes to the wrapped dict not through the wrapper.""" 725 if self._dirty: 726 return 727 if self != self._last_wrapped_dict_snapshot: 728 self._external_modification = True 729 self._last_wrapped_dict_snapshot = None 730 731 def _update_snapshot(self): 732 """Acknowledges tracked changes to the wrapped dict.""" 733 if self._dirty: 734 return 735 self._last_wrapped_dict_snapshot = dict(self) 736 737 def _track_value(self, value, name): 738 """Allows storage of non-trackable objects.""" 739 if isinstance(name, six.string_types): 740 string_key = True 741 else: 742 name = "-non_string_key" 743 string_key = False 744 try: 745 no_dependency = isinstance(value, NoDependency) 746 value = super(_DictWrapper, self)._track_value(value=value, name=name) 747 if not (string_key or no_dependency): 748 # A non-string key maps to a trackable value. This data structure 749 # is not saveable. 750 self._non_string_key = True 751 return value 752 except ValueError: 753 # Even if this value isn't trackable, we need to make sure 754 # NoDependency objects get unwrapped. 755 return sticky_attribute_assignment( 756 trackable=self, value=value, name=name) 757 758 def _name_element(self, key): 759 """Don't throw errors for non-string keys.""" 760 if isinstance(key, six.string_types): 761 return super(_DictWrapper, self)._name_element(key) 762 else: 763 return key 764 765 def __setitem__(self, key, value): 766 """Allow any modifications, but possibly mark the wrapper as unsaveable.""" 767 self._check_external_modification() 768 no_dep = isinstance(value, NoDependency) 769 if isinstance(key, six.string_types): 770 existing_dependency = self._lookup_dependency(key) 771 value = self._track_value(value, name=key) 772 else: 773 value = _wrap_or_unwrap(value) 774 existing_dependency = None 775 if not no_dep and isinstance(value, base.Trackable): 776 # Non-string keys are OK as long as we have no reason to add a 777 # dependency on the value (either because the value is not 778 # trackable, or because it was wrapped in a NoDependency object). 779 self._non_string_key = True 780 if key in self._storage: 781 previous_value = self._storage[key] 782 if previous_value is not value: 783 if ((not no_dep and isinstance(value, base.Trackable)) 784 # We don't want to just check that the existing object is 785 # trackable, since it may have been wrapped in a NoDependency 786 # object. 787 or existing_dependency is not None): 788 # A trackable object was replaced under the same key; this means 789 # that restoring would be error-prone, so we'll throw an exception on 790 # save. 791 self._non_append_mutation = True 792 self._storage[key] = value 793 794 self._update_snapshot() 795 796 def __delitem__(self, key): 797 self._check_external_modification() 798 existing_value = self[key] 799 if isinstance(existing_value, base.Trackable): 800 # Deleting tracked trackable values means restoring is problematic, 801 # so we'll throw an exception on save. 802 self._non_append_mutation = True 803 del self._storage[key] 804 self._update_snapshot() 805 806 def __repr__(self): 807 return "DictWrapper(%s)" % (repr(self._storage),) 808 809 def __hash__(self): 810 raise TypeError("unhashable type: 'DictWrapper'") 811 812 def __eq__(self, other): 813 return self._storage == getattr(other, "_storage", other) 814 815 def update(self, *args, **kwargs): 816 for key, value in dict(*args, **kwargs).items(): 817 self[key] = value 818 819 def _list_functions_for_serialization(self): 820 return { 821 key: value for key, value in self.items() 822 if _is_function(value) 823 } 824 825 826def _is_function(x): 827 return isinstance(x, (def_function.Function, defun.ConcreteFunction)) 828 829revived_types.register_revived_type( 830 "trackable_dict_wrapper", 831 lambda obj: isinstance(obj, _DictWrapper), 832 versions=[revived_types.VersionedTypeRegistration( 833 # Standard dependencies are enough to reconstruct the trackable 834 # items in dictionaries, so we don't need to save any extra information. 835 object_factory=lambda proto: _DictWrapper({}), 836 version=1, 837 min_producer_version=1, 838 min_consumer_version=1, 839 setter=operator.setitem)]) 840 841 842def _set_list_item(list_object, index_string, value): 843 item_index = int(index_string) 844 if len(list_object) <= item_index: 845 list_object.extend([None] * (1 + item_index - len(list_object))) 846 list_object[item_index] = value 847 848 849revived_types.register_revived_type( 850 "trackable_list_wrapper", 851 lambda obj: isinstance(obj, _ListWrapper), 852 versions=[revived_types.VersionedTypeRegistration( 853 object_factory=lambda proto: _ListWrapper([]), 854 version=1, 855 min_producer_version=1, 856 min_consumer_version=1, 857 setter=_set_list_item)]) 858