1"""Trackable data structures.""" 2# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import collections 21import copy 22import operator 23import sys 24 25import six 26try: 27 import wrapt 28except ImportError: 29 # Fall back to the build-time dependency if the system package is not available. 30 from .....third_party import wrapt 31 32from tensorflow.python.eager import def_function 33from tensorflow.python.eager import function as defun 34from tensorflow.python.ops import variables 35from tensorflow.python.saved_model import revived_types 36from tensorflow.python.training.tracking import base 37from tensorflow.python.training.tracking import layer_utils 38from tensorflow.python.util import lazy_loader 39from tensorflow.python.util.compat import collections_abc 40 41module = lazy_loader.LazyLoader( 42 "module", globals(), "tensorflow.python.module.module") 43 44 45class NoDependency(object): 46 """Allows attribute assignment to `Trackable` objects with no dependency. 47 48 Example usage: 49 ```python 50 obj = Trackable() 51 obj.has_dependency = tf.Variable(0., name="dep") 52 obj.no_dependency = NoDependency(tf.Variable(1., name="nodep")) 53 assert obj.no_dependency.name == "nodep:0" 54 ``` 55 56 `obj` in this example has a dependency on the variable "dep", and both 57 attributes contain un-wrapped `Variable` objects. 58 59 `NoDependency` also works with `tf.keras.Model`, but only for checkpoint 60 dependencies: wrapping a `Layer` in `NoDependency` will assign the (unwrapped) 61 `Layer` to the attribute without a checkpoint dependency, but the `Model` will 62 still track the `Layer` (so it will appear in `Model.layers`, and its 63 variables will appear in `Model.variables`). 64 """ 65 66 __slots__ = ["value"] 67 68 def __init__(self, value): 69 self.value = value 70 71 72def _should_wrap_tuple(t): 73 """Determine if a tuple has any trackable components.""" 74 for element in t: 75 if isinstance(element, NoDependency): 76 return True # We should remove the NoDependency object from the tuple. 77 if isinstance(element, base.Trackable): 78 return True 79 if wrap_or_unwrap(element) is not element: 80 return True 81 # There are no trackable elements or data structures. Tuples are immutable, so 82 # mutation isn't a concern. Don't wrap. 83 return False 84 85 86def wrap_or_unwrap(value): 87 """Wraps basic data structures, unwraps NoDependency objects.""" 88 # pylint: disable=unidiomatic-typecheck 89 # Exact type checking to avoid mucking up custom logic in list/dict 90 # subclasses, e.g. collections.Counter. 91 if isinstance(value, NoDependency): 92 return value.value 93 if isinstance(value, base.Trackable): 94 return value # Skip conversion for already trackable objects. 95 elif type(value) == dict: 96 return _DictWrapper(value) 97 elif type(value) == collections.OrderedDict: 98 return _DictWrapper(value) 99 elif type(value) == list: 100 return ListWrapper(value) 101 elif isinstance(value, tuple) and _should_wrap_tuple(value): 102 # There are trackable elements or data structures. Wrap the tuple. 103 return _TupleWrapper(value) 104 else: 105 return value 106 # pylint: enable=unidiomatic-typecheck 107 108 109def sticky_attribute_assignment(trackable, name, value): 110 """Adds dependencies, generally called from __setattr__. 111 112 This behavior is shared between Trackable and Model. 113 114 Respects NoDependency indicators, but otherwise makes trackable objects 115 out of common data structures and tracks objects by their attribute names. 116 117 Args: 118 trackable: The object to add dependencies to (generally the one having 119 an attribute assigned). 120 name: The attribute name being assigned. 121 value: The value being assigned. Not necessarily a trackable object. 122 123 Returns: 124 The value which should be stored in the attribute (unwrapped from a 125 NoDependency object if necessary). 126 """ 127 if isinstance(value, NoDependency): 128 add_dependency = False 129 else: 130 add_dependency = True 131 value = wrap_or_unwrap(value) 132 if not add_dependency: 133 return value 134 if isinstance(value, base.Trackable): 135 trackable._track_trackable( # pylint: disable=protected-access 136 value, name=name, 137 # Allow the user to switch the Trackable which is tracked by this 138 # name, since assigning a new variable to an attribute has 139 # historically been fine (e.g. Adam did this). 140 overwrite=True) 141 return value 142 143 144class _UntrackableError(ValueError): 145 146 def __init__(self, value): # pylint: disable=super-init-not-called 147 self._value = value 148 149 def __str__(self): 150 return (("Only trackable objects (such as Layers or Optimizers) may be " 151 "stored in a List object. Got %s, which does not inherit from " 152 "Trackable.") % (self._value,)) 153 154 155class TrackableDataStructure(base.Trackable): 156 """Base class for data structures which contain trackable objects.""" 157 158 def __init__(self): 159 # Attributes prefixed with "_self_" for compatibility with 160 # wrapt.ObjectProxy. All additional attrs MUST conform to this pattern, as 161 # extending `__slots__` on a subclass of ObjectProxy breaks in a variety of 162 # ways. 163 self._self_trainable = True 164 self._self_extra_variables = [] 165 self._self_attribute_sentinel = layer_utils.AttributeSentinel(True) 166 167 @property 168 def _attribute_sentinel(self): 169 return self._self_attribute_sentinel 170 171 @property 172 def trainable(self): 173 return self._self_trainable 174 175 @trainable.setter 176 def trainable(self, value): 177 self._self_trainable = value 178 179 def _track_value(self, value, name): 180 """Add a dependency on `value`.""" 181 value = sticky_attribute_assignment( 182 trackable=self, value=value, name=name) 183 if isinstance(value, variables.Variable): 184 self._self_extra_variables.append(value) 185 if not isinstance(value, base.Trackable): 186 raise _UntrackableError(value) 187 if hasattr(value, "_use_resource_variables"): 188 # In subclassed models, legacy layers (tf.layers) must always use 189 # resource variables. 190 value._use_resource_variables = True # pylint: disable=protected-access 191 value_attribute_sentinel = getattr(value, "_attribute_sentinel", None) 192 if value_attribute_sentinel: 193 value_attribute_sentinel.add_parent(self._attribute_sentinel) 194 return value 195 196 @property 197 def _values(self): 198 """An iterable/sequence which may contain trackable objects.""" 199 raise NotImplementedError("Abstract method") 200 201 @property 202 def _layers(self): 203 """All Layers and Layer containers, including empty containers.""" 204 # Filter objects on demand so that wrapper objects use values from the thing 205 # they're wrapping if out of sync. 206 collected = [] 207 for obj in self._values: 208 if (isinstance(obj, TrackableDataStructure) 209 or layer_utils.is_layer(obj) 210 or layer_utils.has_weights(obj)): 211 collected.append(obj) 212 return collected 213 214 @property 215 def layers(self): 216 return list(layer_utils.filter_empty_layer_containers(self._layers)) 217 218 @property 219 def trainable_weights(self): 220 if not self._self_trainable: 221 return [] 222 trainable_variables = [] 223 for obj in self._values: 224 if isinstance(obj, (TrackableDataStructure, module.Module)): 225 trainable_variables += obj.trainable_variables 226 trainable_extra_variables = [ 227 v for v in self._self_extra_variables if v.trainable 228 ] 229 return trainable_variables + trainable_extra_variables 230 231 @property 232 def non_trainable_weights(self): 233 trainable_extra_variables = [ 234 v for v in self._self_extra_variables if v.trainable 235 ] 236 non_trainable_extra_variables = [ 237 v for v in self._self_extra_variables if not v.trainable 238 ] 239 non_trainable_variables = [] 240 for obj in self._values: 241 if isinstance(obj, (TrackableDataStructure, module.Module)): 242 non_trainable_variables += obj.non_trainable_variables 243 244 if not self._self_trainable: 245 # Return order is all trainable vars, then all non-trainable vars. 246 trainable_variables = [] 247 for obj in self._values: 248 if isinstance(obj, (TrackableDataStructure, module.Module)): 249 trainable_variables += obj.trainable_variables 250 251 non_trainable_variables = ( 252 trainable_variables + trainable_extra_variables + 253 non_trainable_variables + non_trainable_extra_variables) 254 else: 255 non_trainable_variables = ( 256 non_trainable_variables + non_trainable_extra_variables) 257 258 return non_trainable_variables 259 260 @property 261 def weights(self): 262 return self.trainable_weights + self.non_trainable_weights 263 264 @property 265 def trainable_variables(self): 266 return self.trainable_weights 267 268 @property 269 def non_trainable_variables(self): 270 return self.non_trainable_weights 271 272 @property 273 def variables(self): 274 return self.weights 275 276 @property 277 def updates(self): 278 """Aggregate updates from any `Layer` instances.""" 279 # Updates and conditional losses are forwarded as-is rather than being 280 # filtered based on inputs, since this is just a container and won't ever 281 # have any inputs. 282 aggregated = [] 283 for layer in self.layers: 284 if hasattr(layer, "updates"): 285 aggregated += layer.updates 286 return aggregated 287 288 @property 289 def losses(self): 290 """Aggregate losses from any `Layer` instances.""" 291 aggregated = [] 292 for layer in self.layers: 293 if hasattr(layer, "losses"): 294 aggregated += layer.losses 295 return aggregated 296 297 def __hash__(self): 298 # Support object-identity hashing, so these structures can be used as keys 299 # in sets/dicts. 300 return id(self) 301 302 def __eq__(self, other): 303 # Similar to Tensors, trackable data structures use object-identity 304 # equality to support set/dict membership. 305 return self is other 306 307 308class List(TrackableDataStructure, collections_abc.Sequence): 309 """An append-only sequence type which is trackable. 310 311 Maintains checkpoint dependencies on its contents (which must also be 312 trackable), and forwards any `Layer` metadata such as updates and losses. 313 314 Note that `List` is purely a container. It lets a `tf.keras.Model` or 315 other trackable object know about its contents, but does not call any 316 `Layer` instances which are added to it. To indicate a sequence of `Layer` 317 instances which should be called sequentially, use `tf.keras.Sequential`. 318 319 Example usage: 320 ```python 321 class HasList(tf.keras.Model): 322 323 def __init__(self): 324 super(HasList, self).__init__() 325 self.layer_list = List([layers.Dense(3)]) 326 self.layer_list.append(layers.Dense(4)) 327 328 def call(self, x): 329 aggregation = 0. 330 for l in self.layer_list: 331 x = l(x) 332 aggregation += tf.reduce_sum(x) 333 return aggregation 334 ``` 335 336 This kind of wrapping is necessary because `Trackable` objects do not 337 (yet) deeply inspect regular Python data structures, so for example assigning 338 a regular list (`self.layer_list = [layers.Dense(3)]`) does not create a 339 checkpoint dependency and does not add the `Layer` instance's weights to its 340 parent `Model`. 341 """ 342 343 def __init__(self, *args, **kwargs): 344 """Construct a new sequence. Arguments are passed to `list()`.""" 345 super(List, self).__init__() 346 self._storage = self._make_storage(*args, **kwargs) 347 for index, element in enumerate(self._storage): 348 self._storage[index] = self._track_value( 349 element, name=self._name_element(index)) 350 351 def copy(self): 352 return type(self)(copy.copy(self._storage)) 353 354 def __copy__(self): 355 return self.copy() 356 357 def __deepcopy__(self, memo): 358 return type(self)(copy.deepcopy(self._storage, memo)) 359 360 def _make_storage(self, *args, **kwargs): 361 """Determines the backing storage (overridden in subclasses).""" 362 return list(*args, **kwargs) 363 364 def _name_element(self, index): 365 return "%d" % (index,) 366 367 @property 368 def _values(self): 369 """Collect values for TrackableDataStructure.""" 370 return self 371 372 def append(self, value): 373 """Add a new trackable value.""" 374 value = self._track_value(value, self._name_element(len(self._storage))) 375 self._storage.append(value) 376 377 def extend(self, values): 378 """Add a sequence of trackable values.""" 379 for value in values: 380 self.append(value) 381 382 def __iadd__(self, values): 383 self.extend(values) 384 return self 385 386 def __add__(self, other): 387 return self._storage + getattr(other, "_storage", other) 388 389 def __imul__(self, y): 390 if y <= 0: 391 raise ValueError( 392 "List only supports append, multiplying in place by %d removes " 393 "elements." % y) 394 395 n = len(self._storage) 396 for _ in range(y - 1): 397 for i in range(n): 398 self.append(self._storage[i]) 399 400 return self 401 402 def __mul__(self, n): 403 return self._storage * n 404 405 def __rmul__(self, n): 406 return self * n 407 408 def __radd__(self, other): 409 return other + self._storage 410 411 def __getitem__(self, key): 412 return self._storage[key] 413 414 def __getslice__(self, i, j): 415 return self._storage[slice(i, j)] 416 417 def __len__(self): 418 return len(self._storage) 419 420 def __repr__(self): 421 return "List(%s)" % (repr(self._storage),) 422 423 def __sizeof__(self): 424 return super(List, self).__sizeof__() + sys.getsizeof(self._storage) 425 426 427# TODO(tomhennigan) Update to collections.UserList? 428# TODO(allenl): Try switching this to wrapt.ObjectProxy again when we drop 429# Python 3.4 support (may still be tricky). 430class ListWrapper( 431 List, 432 collections_abc.MutableSequence, 433 # Shadowed, but there for isinstance checks. 434 list): 435 """Wraps the built-in `list` to support restore-on-create for variables. 436 437 Unlike `List`, this sequence type is mutable in the same ways built-in lists 438 are. Instead of throwing an error immediately like `List`, it records 439 problematic mutations (e.g. assigning a new element to a position already 440 occupied, meaning both elements get the same names at different times) and 441 refuses to save. 442 443 On assignment to an attribute of a Model or Trackable object, Python 444 lists are replaced with ListWrapper. Wrapping a list in a 445 `NoDependency` object prevents this. 446 """ 447 448 def __init__(self, wrapped_list): 449 """Construct a new list wrapper. 450 451 Args: 452 wrapped_list: The initial value of the data structure. A shallow copy may 453 be maintained for error checking. `wrapped_list` itself should not be 454 modified directly after constructing the `ListWrapper`, and if changes 455 are detected the `ListWrapper` will throw an exception on save. 456 """ 457 # Monotonic flags which indicate this object would not be restored properly, 458 # and therefore should throw an error on save to avoid giving the impression 459 # that restoring it will work. 460 self._non_append_mutation_value = False 461 self._external_modification_value = False 462 super(ListWrapper, self).__init__(wrapped_list) 463 self._last_wrapped_list_snapshot = list(self._storage) 464 465 @property 466 def _non_append_mutation(self): 467 return self._non_append_mutation_value 468 469 @_non_append_mutation.setter 470 def _non_append_mutation(self, value): 471 # Trackable only cares that a mutation occurred at some point; when 472 # attempting to save it checks whether a mutation occurred and the object is 473 # in a "dirty" state but otherwise the specifics of how it got to that state 474 # are ignored. By contrast, the attribute cache needs to signal the mutation 475 # immediately since a caller could query the value of an attribute (And 476 # should not hit the cached value since the mutation may have affected the 477 # result.) 478 self._attribute_sentinel.invalidate_all() 479 self._non_append_mutation_value = value 480 481 @property 482 def _external_modification(self): 483 return self._external_modification_value 484 485 @_external_modification.setter 486 def _external_modification(self, value): 487 # Invalidate for the same reason as `_non_append_mutation` 488 self._attribute_sentinel.invalidate_all() 489 self._external_modification_value = value 490 491 # pylint: disable=protected-access 492 def __copy__(self): 493 copied = super(ListWrapper, self).__copy__() 494 copied._non_append_mutation = self._non_append_mutation 495 copied._external_modification = self._external_modification 496 return copied 497 498 def __deepcopy__(self, memo): 499 copied = super(ListWrapper, self).__deepcopy__(memo) 500 copied._non_append_mutation = self._non_append_mutation 501 copied._external_modification = self._external_modification 502 return copied 503 # pylint: enable=protected-access 504 505 def __reduce_ex__(self, protocol): 506 return (self.__class__, 507 (self._storage,)) 508 509 def _make_storage(self, wrapped_list): 510 """Use the user's original list for storage.""" 511 return wrapped_list 512 513 def _check_external_modification(self): 514 """Checks for any changes to the wrapped list not through the wrapper.""" 515 if self._external_modification or self._non_append_mutation: 516 return 517 if self._storage != self._last_wrapped_list_snapshot: 518 self._external_modification = True 519 self._last_wrapped_list_snapshot = None 520 521 def _update_snapshot(self): 522 """Acknowledges tracked changes to the wrapped list.""" 523 524 # Mutation tracking for attributes reuses the same infrastructure as 525 # Trackable mutation tracking. 526 self._attribute_sentinel.invalidate_all() 527 if self._external_modification or self._non_append_mutation: 528 return 529 self._last_wrapped_list_snapshot = list(self._storage) 530 531 @property 532 def _checkpoint_dependencies(self): 533 self._check_external_modification() 534 if self._non_append_mutation: 535 raise ValueError( 536 ("Unable to save the object %s (a list wrapper constructed to track " 537 "trackable TensorFlow objects). A list element was replaced " 538 "(__setitem__, __setslice__), deleted (__delitem__, __delslice__), " 539 "or moved (sort). In order to support restoration on object " 540 "creation, tracking is exclusively for append-only data structures." 541 "\n\nIf you don't need this list checkpointed, wrap it in a " 542 "non-trackable object; it will be subsequently ignored." % (self,))) 543 if self._external_modification: 544 raise ValueError( 545 ("Unable to save the object %s (a list wrapper constructed to track " 546 "trackable TensorFlow objects). The wrapped list was modified " 547 "outside the wrapper (its final value was %s, its value when a " 548 "checkpoint dependency was added was %s), which breaks restoration " 549 "on object creation.\n\nIf you don't need this list checkpointed, " 550 "wrap it in a NoDependency object; it will be " 551 "subsequently ignored." % ( 552 self, self._storage, self._last_wrapped_list_snapshot))) 553 return super(ListWrapper, self)._checkpoint_dependencies 554 555 def _has_mutation_or_trackable(self): 556 """Short-circuits a check for trackables if there's already a mutation.""" 557 if self._non_append_mutation: 558 return True 559 return any(isinstance(element, base.Trackable) for element in self._storage) 560 561 def __delitem__(self, key): 562 self._check_external_modification() 563 if self._has_mutation_or_trackable(): 564 self._non_append_mutation = True 565 del self._storage[key] 566 self._update_snapshot() 567 568 def __setitem__(self, key, value): 569 self._check_external_modification() 570 571 if isinstance(key, slice): 572 # Note: this is quite inefficient, but the list API supports a broad range 573 # of slice setters (e.g. truncate, extend, replace) and imitating this 574 # for a range of Python versions is non-trivial. 575 storage_copy = list(self._storage) 576 self._storage[key] = value 577 578 len_before = len(storage_copy) 579 len_now = len(self._storage) 580 for i in range(max(len_before, len_now)): 581 value_now = self._storage[i] if i < len_now else None 582 value_before = storage_copy[i] if i < len_before else None 583 584 if isinstance(value_before, base.Trackable): 585 self._non_append_mutation = True 586 587 if value_now is not None and value_now != value_before: 588 self._storage[i] = self._track_value(self._storage[i], 589 self._name_element(i)) 590 591 else: 592 if isinstance(self._storage[key], base.Trackable): 593 self._non_append_mutation = True 594 self._storage[key] = self._track_value(value, self._name_element(key)) 595 596 self._update_snapshot() 597 598 def append(self, value): 599 """Add a new trackable value.""" 600 self._check_external_modification() 601 super(ListWrapper, self).append(value) 602 self._update_snapshot() 603 604 def extend(self, values): 605 """Add a sequence of trackable values.""" 606 self._check_external_modification() 607 super(ListWrapper, self).extend(values) 608 self._update_snapshot() 609 610 def __imul__(self, y): 611 if y <= 0: 612 self._check_external_modification() 613 if self._has_mutation_or_trackable(): 614 self._non_append_mutation = True 615 self._storage *= y 616 self._update_snapshot() 617 return self 618 619 # Relies on super() calling append, which updates the snapshot. 620 return super(ListWrapper, self).__imul__(y) 621 622 def __eq__(self, other): 623 return self._storage == getattr(other, "_storage", other) 624 625 def __ne__(self, other): 626 return self._storage != getattr(other, "_storage", other) 627 628 def __lt__(self, other): 629 return self._storage < getattr(other, "_storage", other) 630 631 def __le__(self, other): 632 return self._storage <= getattr(other, "_storage", other) 633 634 def __gt__(self, other): 635 return self._storage > getattr(other, "_storage", other) 636 637 def __ge__(self, other): 638 return self._storage >= getattr(other, "_storage", other) 639 640 def __hash__(self): 641 # List wrappers need to compare like regular lists, and so like regular 642 # lists they don't belong in hash tables. 643 raise TypeError("unhashable type: 'ListWrapper'") 644 645 def insert(self, index, obj): 646 self._check_external_modification() 647 if (self._has_mutation_or_trackable() or isinstance(obj, base.Trackable)): 648 self._non_append_mutation = True 649 self._storage.insert(index, obj) 650 self._update_snapshot() 651 652 def sort(self): 653 self._check_external_modification() 654 if self._has_mutation_or_trackable(): 655 self._non_append_mutation = True 656 self._storage.sort() 657 self._update_snapshot() 658 659 def __setslice__(self, i, j, y): 660 self.__setitem__(slice(i, j), y) 661 662 def __delslice__(self, i, j): 663 self._check_external_modification() 664 if self._has_mutation_or_trackable(): 665 self._non_append_mutation = True 666 del self._storage[slice(i, j)] 667 self._update_snapshot() 668 669 def _track_value(self, value, name): 670 """Allows storage of non-trackable objects.""" 671 try: 672 value = super(ListWrapper, self)._track_value(value=value, name=name) 673 except ValueError: 674 # Even if this value isn't trackable, we need to make sure 675 # NoDependency objects get unwrapped. 676 value = sticky_attribute_assignment( 677 trackable=self, value=value, name=name) 678 return value 679 680 def __repr__(self): 681 return "ListWrapper(%s)" % (repr(self._storage),) 682 683 def _list_functions_for_serialization(self, unused_functions): 684 return { 685 str(key): value for key, value in enumerate(self) 686 if _is_function(value) 687 } 688 689 690class Mapping(TrackableDataStructure, collections_abc.Mapping): 691 """An append-only trackable mapping data structure with string keys. 692 693 Maintains checkpoint dependencies on its contents (which must also be 694 trackable), named based on its keys. 695 696 Note that once a key has been added, it may not be deleted or replaced. 697 """ 698 699 def __init__(self, *args, **kwargs): 700 """Construct a new sequence. Arguments are passed to `dict()`.""" 701 super(Mapping, self).__init__() 702 self._storage = self._make_storage(*args, **kwargs) 703 self._storage.update( 704 {key: self._track_value( 705 value, name=self._name_element(key)) 706 for key, value in self._storage.items()}) 707 708 def __copy__(self): 709 return type(self)(copy.copy(self._storage)) 710 711 def __deepcopy__(self, memo): 712 return type(self)(copy.deepcopy(self._storage, memo)) 713 714 def _make_storage(self, *args, **kwargs): 715 return dict(*args, **kwargs) 716 717 @property 718 def _values(self): 719 """Collect values for TrackableDataStructure.""" 720 # Sort items deterministically by key 721 ordered = list(zip(*sorted(self.items(), key=lambda it: it[0]))) 722 if ordered: 723 return ordered[1] 724 return [] 725 726 def _name_element(self, key): 727 if not isinstance(key, six.string_types): 728 raise TypeError( 729 "Mapping accepts only string keys, but got a key %s." 730 % repr(key)) 731 return str(key) 732 733 def __setitem__(self, key, value): 734 name = self._name_element(key) 735 value = self._track_value(value, name=name) 736 current_value = self._storage.setdefault(key, value) 737 if current_value is not value: 738 raise ValueError( 739 ("Mappings are an append-only data structure. Tried to overwrite the " 740 "key '%s' with value %s, but it already contains %s") 741 % (key, value, current_value)) 742 743 def update(self, *args, **kwargs): 744 for key, value in dict(*args, **kwargs).items(): 745 self[key] = value 746 747 def __getitem__(self, key): 748 return self._storage[key] 749 750 def __len__(self): 751 return len(self._storage) 752 753 def __repr__(self): 754 return "Mapping(%s)" % (repr(self._storage),) 755 756 def __iter__(self): 757 return iter(self._storage) 758 759 760class _DictWrapper(TrackableDataStructure, wrapt.ObjectProxy): 761 """Wraps built-in dicts to support restore-on-create for variables. 762 763 _DictWrapper is to Mapping as ListWrapper is to List. Unlike Mapping, 764 _DictWrapper allows non-string keys and values and arbitrary mutations (delete 765 keys, reassign values). Like ListWrapper, these mutations mean that 766 _DictWrapper will raise an exception on save. 767 """ 768 769 def __init__(self, wrapped_dict=None): 770 if wrapped_dict is None: 771 # Allow zero-argument construction, e.g. from session.run's re-wrapping. 772 wrapped_dict = {} 773 if not isinstance(wrapped_dict, collections_abc.Mapping): 774 # Allow construction from a sequence, e.g. from nest.pack_sequence_as. 775 wrapped_dict = dict(wrapped_dict) 776 wrapt.ObjectProxy.__init__(self, wrapped_dict) 777 TrackableDataStructure.__init__(self) 778 self._self_non_string_key = False 779 self._self_external_modification = False 780 self.__wrapped__.update( 781 {key: self._track_value( 782 value, name=self._name_element(key)) 783 for key, value in self.__wrapped__.items()}) 784 self._update_snapshot() 785 786 def __reduce_ex__(self, protocol): 787 return (self.__class__, 788 (self.__wrapped__,)) 789 790 def __getattribute__(self, name): 791 if (hasattr(type(self), name) 792 and isinstance(getattr(type(self), name), property)): 793 # Bypass ObjectProxy for properties. Whether this workaround is necessary 794 # appears to depend on the Python version but not the wrapt version: 3.4 795 # in particular seems to look up properties on the wrapped object instead 796 # of the wrapper without this logic. 797 return object.__getattribute__(self, name) 798 else: 799 return super(_DictWrapper, self).__getattribute__(name) 800 801 def copy(self): 802 return copy.copy(self) 803 804 # pylint: disable=protected-access 805 def __copy__(self): 806 copied = _DictWrapper(copy.copy(self.__wrapped__)) 807 copied._self_external_modification = self._self_external_modification 808 copied._self_non_string_key = self._self_non_string_key 809 return copied 810 811 def __deepcopy__(self, memo): 812 copied = _DictWrapper(copy.deepcopy(self.__wrapped__, memo)) 813 copied._self_external_modification = self._self_external_modification 814 copied._self_non_string_key = self._self_non_string_key 815 return copied 816 # pylint: enable=protected-access 817 818 @property 819 def _values(self): 820 """Collect values for TrackableDataStructure.""" 821 # Sort items deterministically by key 822 ordered = list(zip(*sorted(self.items(), key=lambda it: it[0]))) 823 if ordered: 824 return ordered[1] 825 return [] 826 827 @property 828 def _checkpoint_dependencies(self): 829 """Check that the object is saveable before listing its dependencies.""" 830 self._check_self_external_modification() 831 if self._self_non_string_key: 832 raise ValueError( 833 "Unable to save the object %s (a dictionary wrapper constructed " 834 "automatically on attribute assignment). The wrapped dictionary " 835 "contains a non-string key which maps to a trackable object or " 836 "mutable data structure.\n\nIf you don't need this dictionary " 837 "checkpointed, wrap it in a non-trackable " 838 "object; it will be subsequently ignored." % (self,)) 839 if self._self_external_modification: 840 raise ValueError( 841 "Unable to save the object %s (a dictionary wrapper constructed " 842 "automatically on attribute assignment). The wrapped dictionary was " 843 "modified outside the wrapper (its final value was %s, its value " 844 "when a checkpoint dependency was added was %s), which breaks " 845 "restoration on object creation.\n\nIf you don't need this " 846 "dictionary checkpointed, wrap it in a " 847 "non-trackable object; it will be subsequently ignored." % ( 848 self, self, self._self_last_wrapped_dict_snapshot)) 849 assert not self._dirty # Any reason for dirtiness should have an exception. 850 return super(_DictWrapper, self)._checkpoint_dependencies 851 852 @property 853 def _dirty(self): 854 """Check if there has already been a mutation which prevents saving.""" 855 return (self._self_external_modification 856 or self._self_non_string_key) 857 858 def _check_self_external_modification(self): 859 """Checks for any changes to the wrapped dict not through the wrapper.""" 860 if self._dirty: 861 return 862 if self != self._self_last_wrapped_dict_snapshot: 863 self._self_external_modification = True 864 self._self_last_wrapped_dict_snapshot = None 865 866 def _update_snapshot(self): 867 """Acknowledges tracked changes to the wrapped dict.""" 868 self._attribute_sentinel.invalidate_all() 869 if self._dirty: 870 return 871 self._self_last_wrapped_dict_snapshot = dict(self) 872 873 def _track_value(self, value, name): 874 """Allows storage of non-trackable objects.""" 875 if isinstance(name, six.string_types): 876 string_key = True 877 else: 878 name = "-non_string_key" 879 string_key = False 880 try: 881 no_dependency = isinstance(value, NoDependency) 882 value = super(_DictWrapper, self)._track_value(value=value, name=name) 883 if not (string_key or no_dependency): 884 # A non-string key maps to a trackable value. This data structure 885 # is not saveable. 886 self._self_non_string_key = True 887 return value 888 except ValueError: 889 # Even if this value isn't trackable, we need to make sure 890 # NoDependency objects get unwrapped. 891 return sticky_attribute_assignment( 892 trackable=self, value=value, name=name) 893 894 def _name_element(self, key): 895 """Tells TrackableDataStructure to use keys as names as-is.""" 896 return key 897 898 def __setitem__(self, key, value): 899 """Allow any modifications, but possibly mark the wrapper as unsaveable.""" 900 self._check_self_external_modification() 901 self._maybe_initialize_trackable() 902 no_dep = isinstance(value, NoDependency) 903 if isinstance(key, six.string_types): 904 value = self._track_value(value, name=key) 905 else: 906 value = wrap_or_unwrap(value) 907 if not no_dep and isinstance(value, base.Trackable): 908 # Non-string keys are OK as long as we have no reason to add a 909 # dependency on the value (either because the value is not 910 # trackable, or because it was wrapped in a NoDependency object). 911 self._self_non_string_key = True 912 self.__wrapped__[key] = value 913 914 self._update_snapshot() 915 916 def __delitem__(self, key): 917 self._check_self_external_modification() 918 del self.__wrapped__[key] 919 self._update_snapshot() 920 921 def __repr__(self): 922 return "DictWrapper(%s)" % (repr(self.__wrapped__),) 923 924 def __hash__(self): 925 raise TypeError("unhashable type: 'DictWrapper'") 926 927 def __eq__(self, other): 928 # Override the TrackableDataStructure "== -> is" forwarding and go back to 929 # the wrapt implementation. 930 return self.__wrapped__ == other 931 932 def update(self, *args, **kwargs): 933 for key, value in six.iteritems(dict(*args, **kwargs)): 934 self[key] = value 935 936 def _list_functions_for_serialization(self, unused_serialization_cache): 937 return { 938 key: value for key, value in self.items() 939 if _is_function(value) 940 } 941 942 943class _TupleWrapper(TrackableDataStructure, wrapt.ObjectProxy): 944 """Trackable wrapper for tuples and namedtuples.""" 945 946 def __init__(self, original_wrapped_tuple=()): 947 add_dependency = [] 948 substituted_wrapped_tuple = [] 949 for element in original_wrapped_tuple: 950 if isinstance(element, NoDependency): 951 add_dependency.append(False) 952 else: 953 add_dependency.append(True) 954 substituted_wrapped_tuple.append(wrap_or_unwrap(element)) 955 try: 956 fields = original_wrapped_tuple._fields 957 except AttributeError: 958 # Not a namedtuple 959 is_namedtuple = False 960 else: 961 is_namedtuple = True 962 original_type = type(original_wrapped_tuple) 963 # Flag to poison saving if we can't re-construct a namedtupled because its 964 # __new__ takes different keyword arguments than its _fields. 965 self._self_tuple_is_constructable = True 966 if is_namedtuple: 967 try: 968 # NamedTuples take N arguments, unlike tuple which takes a sequence. 969 substituted_wrapped_tuple = original_type( 970 **dict(zip(fields, substituted_wrapped_tuple))) 971 except TypeError: 972 wrapt.ObjectProxy.__init__(self, original_wrapped_tuple) 973 TrackableDataStructure.__init__(self) 974 self._self_tuple_is_constructable = False 975 return 976 else: 977 substituted_wrapped_tuple = original_type(substituted_wrapped_tuple) 978 wrapt.ObjectProxy.__init__(self, substituted_wrapped_tuple) 979 TrackableDataStructure.__init__(self) 980 981 if is_namedtuple: 982 # For namedtuples, also track by names for compatibility with 983 # dictionaries. 984 for name, should_depend, element in zip( 985 fields, add_dependency, substituted_wrapped_tuple): 986 if should_depend: 987 self._track_value(element, name=name) 988 989 # Track by index as well, for compatibility with lists. 990 for index, (should_depend, element) in enumerate( 991 zip(add_dependency, substituted_wrapped_tuple)): 992 if should_depend: 993 self._track_value(element, name="%d" % (index,)) 994 995 @property 996 def _values(self): 997 """Collect values for TrackableDataStructure.""" 998 return self 999 1000 def _track_value(self, value, name): 1001 """Allows storage of non-trackable objects.""" 1002 try: 1003 value = super(_TupleWrapper, self)._track_value(value=value, name=name) 1004 except ValueError: 1005 # Even if this value isn't trackable, we need to make sure 1006 # NoDependency objects get unwrapped. 1007 value = sticky_attribute_assignment( 1008 trackable=self, value=value, name=name) 1009 return value 1010 1011 def __repr__(self): 1012 return "_TupleWrapper(%s)" % (repr(self.__wrapped__),) 1013 1014 def __hash__(self): 1015 # Override the TrackableDataStructure hash forwarding and go back to 1016 # the wrapt implementation. 1017 return hash(self.__wrapped__) 1018 1019 def __eq__(self, other): 1020 # Override the TrackableDataStructure "== -> is" forwarding and go back to 1021 # the wrapt implementation. 1022 return self.__wrapped__ == other 1023 1024 def __copy__(self): 1025 return _TupleWrapper(copy.copy(self.__wrapped__)) 1026 1027 def __deepcopy__(self, memo): 1028 return _TupleWrapper(copy.deepcopy(self.__wrapped__, memo)) 1029 1030 def __reduce_ex__(self, protocol): 1031 return (self.__class__, 1032 (self.__wrapped__,)) 1033 1034 # imul and iadd are the only tuple-relevant in-place operators. They need to 1035 # be special-cased to avoid mutating the original proxy object. 1036 def __imul__(self, y): 1037 """Avoid running self.__wrapped__ *= y, which mutates `self`.""" 1038 return self.__wrapped__ * y 1039 1040 def __iadd__(self, y): 1041 """Avoid running self.__wrapped__ += y, which mutates `self`.""" 1042 return self.__wrapped__ + y 1043 1044 @property 1045 def _checkpoint_dependencies(self): 1046 if not self._self_tuple_is_constructable: 1047 raise ValueError( 1048 ("Unable to save because the namedtuple {} is not constructable from " 1049 "its _fields (i.e. __new__ is overridden). Expected keyword " 1050 "arguments {}. If you do not need to save this object, consider " 1051 "wrapping it in a custom object that does not inherit from tuple.") 1052 .format(self.__wrapped__, self.__wrapped__._fields)) 1053 return super(_TupleWrapper, self)._checkpoint_dependencies 1054 1055 def __getattribute__(self, name): 1056 if (hasattr(type(self), name) 1057 and isinstance(getattr(type(self), name), property)): 1058 # Bypass ObjectProxy for properties. Whether this workaround is necessary 1059 # appears to depend on the Python version but not the wrapt version: 3.4 1060 # in particular seems to look up properties on the wrapped object instead 1061 # of the wrapper without this logic. 1062 return object.__getattribute__(self, name) 1063 else: 1064 return super(_TupleWrapper, self).__getattribute__(name) 1065 1066 1067def _is_function(x): 1068 return isinstance(x, (def_function.Function, defun.ConcreteFunction)) 1069 1070 1071revived_types.register_revived_type( 1072 "trackable_dict_wrapper", 1073 lambda obj: isinstance(obj, _DictWrapper), 1074 versions=[revived_types.VersionedTypeRegistration( 1075 # Standard dependencies are enough to reconstruct the trackable 1076 # items in dictionaries, so we don't need to save any extra information. 1077 object_factory=lambda proto: _DictWrapper({}), 1078 version=1, 1079 min_producer_version=1, 1080 min_consumer_version=1, 1081 setter=operator.setitem)]) 1082 1083 1084def _set_list_item(list_object, index_string, value): 1085 item_index = int(index_string) 1086 if len(list_object) <= item_index: 1087 list_object.extend([None] * (1 + item_index - len(list_object))) 1088 list_object[item_index] = value 1089 1090 1091revived_types.register_revived_type( 1092 "trackable_list_wrapper", 1093 lambda obj: isinstance(obj, ListWrapper), 1094 versions=[revived_types.VersionedTypeRegistration( 1095 object_factory=lambda proto: ListWrapper([]), 1096 version=1, 1097 min_producer_version=1, 1098 min_consumer_version=1, 1099 setter=_set_list_item)]) 1100 1101 1102def _set_tuple_item(list_object, index_string, value): 1103 try: 1104 item_index = int(index_string) 1105 except ValueError: 1106 # Ignore namedtuple fields. 1107 return 1108 if len(list_object) <= item_index: 1109 list_object.extend([None] * (1 + item_index - len(list_object))) 1110 list_object[item_index] = value 1111 1112 1113# Revive tuples as lists so we can append any dependencies during loading. 1114revived_types.register_revived_type( 1115 "trackable_tuple_wrapper", 1116 lambda obj: isinstance(obj, _TupleWrapper), 1117 versions=[revived_types.VersionedTypeRegistration( 1118 object_factory=lambda proto: ListWrapper([]), 1119 version=1, 1120 min_producer_version=1, 1121 min_consumer_version=1, 1122 setter=_set_tuple_item)]) 1123