1"""Trackable data structures.""" 2# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import collections 21import copy 22import operator 23import sys 24 25import six 26try: 27 import wrapt 28except ImportError: 29 # Fall back to the build-time dependency if the system package is not available. 30 from .....third_party import wrapt # pylint: disable=relative-beyond-top-level 31 32from tensorflow.python.eager import def_function 33from tensorflow.python.eager import function as defun 34from tensorflow.python.ops import variables 35from tensorflow.python.saved_model import revived_types 36from tensorflow.python.training.tracking import base 37from tensorflow.python.training.tracking import layer_utils 38from tensorflow.python.util import lazy_loader 39from tensorflow.python.util.compat import collections_abc 40from tensorflow.python.util.tf_export import tf_export 41 42 43module = lazy_loader.LazyLoader( 44 "module", globals(), "tensorflow.python.module.module") 45 46 47class NoDependency(object): 48 """Allows attribute assignment to `Trackable` objects with no dependency. 49 50 Example usage: 51 ```python 52 obj = Trackable() 53 obj.has_dependency = tf.Variable(0., name="dep") 54 obj.no_dependency = NoDependency(tf.Variable(1., name="nodep")) 55 assert obj.no_dependency.name == "nodep:0" 56 ``` 57 58 `obj` in this example has a dependency on the variable "dep", and both 59 attributes contain un-wrapped `Variable` objects. 60 61 `NoDependency` also works with `tf.keras.Model`, but only for checkpoint 62 dependencies: wrapping a `Layer` in `NoDependency` will assign the (unwrapped) 63 `Layer` to the attribute without a checkpoint dependency, but the `Model` will 64 still track the `Layer` (so it will appear in `Model.layers`, and its 65 variables will appear in `Model.variables`). 66 """ 67 68 __slots__ = ["value"] 69 70 def __init__(self, value): 71 self.value = value 72 73 74def _should_wrap_tuple(t): 75 """Determine if a tuple has any trackable components.""" 76 # pylint: disable=unidiomatic-typecheck 77 # Exact type checking to avoid mucking up custom logic in list/dict 78 # subclasses, e.g. collections.Counter. 79 for element in t: 80 if isinstance(element, NoDependency): 81 return True # We should remove the NoDependency object from the tuple. 82 if isinstance(element, base.Trackable): 83 return True 84 if type(element) == dict: 85 return True 86 if type(element) == collections.OrderedDict: 87 return True 88 if type(element) == list: 89 return True 90 if isinstance(element, tuple) and _should_wrap_tuple(element): 91 return True 92 # There are no trackable elements or data structures. Tuples are immutable, so 93 # mutation isn't a concern. Don't wrap. 94 return False 95 # pylint: enable=unidiomatic-typecheck 96 97 98@tf_export("__internal__.tracking.wrap", v1=[]) 99def wrap_or_unwrap(value): 100 """Wraps input value into trackable data structures. 101 102 This is mostly useful for containers like list, dict, etc, which could contain 103 trackable objects in it. Wrapped data structure will be tracked when 104 associated with a `tf.Module`, so that save model/checkpoint can properly 105 track the dependency. 106 107 It will also unwrap NoDependency objects. 108 109 Args: 110 value: the input object to be wrapped. 111 112 Returns: 113 Wrapped trackable data structure. 114 """ 115 # pylint: disable=unidiomatic-typecheck 116 # Exact type checking to avoid mucking up custom logic in list/dict 117 # subclasses, e.g. collections.Counter. 118 if isinstance(value, NoDependency): 119 return value.value 120 if isinstance(value, base.Trackable): 121 return value # Skip conversion for already trackable objects. 122 elif type(value) == dict: 123 return _DictWrapper(value) 124 elif type(value) == collections.OrderedDict: 125 return _DictWrapper(value) 126 elif type(value) == list: 127 return ListWrapper(value) 128 elif isinstance(value, tuple) and _should_wrap_tuple(value): 129 # There are trackable elements or data structures. Wrap the tuple. 130 return _TupleWrapper(value) 131 else: 132 return value 133 # pylint: enable=unidiomatic-typecheck 134 135 136@tf_export("__internal__.tracking.sticky_attribute_assignment", v1=[]) 137def sticky_attribute_assignment(trackable, name, value): 138 """Adds dependencies, generally called from __setattr__. 139 140 This behavior is shared between Trackable and Model. 141 142 Respects NoDependency indicators, but otherwise makes trackable objects 143 out of common data structures and tracks objects by their attribute names. 144 145 Args: 146 trackable: The object to add dependencies to (generally the one having 147 an attribute assigned). 148 name: The attribute name being assigned. 149 value: The value being assigned. Not necessarily a trackable object. 150 151 Returns: 152 The value which should be stored in the attribute (unwrapped from a 153 NoDependency object if necessary). 154 """ 155 if isinstance(value, NoDependency): 156 add_dependency = False 157 else: 158 add_dependency = True 159 value = wrap_or_unwrap(value) 160 if not add_dependency: 161 return value 162 if isinstance(value, base.Trackable): 163 trackable._track_trackable( # pylint: disable=protected-access 164 value, name=name, 165 # Allow the user to switch the Trackable which is tracked by this 166 # name, since assigning a new variable to an attribute has 167 # historically been fine (e.g. Adam did this). 168 overwrite=True) 169 return value 170 171 172class _UntrackableError(ValueError): 173 174 def __init__(self, value): # pylint: disable=super-init-not-called 175 self._value = value 176 177 def __str__(self): 178 return (("Only trackable objects (such as Layers or Optimizers) may be " 179 "stored in a List object. Got %s, which does not inherit from " 180 "Trackable.") % (self._value,)) 181 182 183@tf_export("__internal__.tracking.TrackableDataStructure", v1=[]) 184class TrackableDataStructure(base.Trackable): 185 """Base class for data structures which contain trackable objects.""" 186 187 def __init__(self): 188 # Attributes prefixed with "_self_" for compatibility with 189 # wrapt.ObjectProxy. All additional attrs MUST conform to this pattern, as 190 # extending `__slots__` on a subclass of ObjectProxy breaks in a variety of 191 # ways. 192 self._self_trainable = True 193 self._self_extra_variables = [] 194 self._self_attribute_sentinel = layer_utils.AttributeSentinel(True) 195 196 @property 197 def _attribute_sentinel(self): 198 return self._self_attribute_sentinel 199 200 @property 201 def trainable(self): 202 return self._self_trainable 203 204 @trainable.setter 205 def trainable(self, value): 206 self._self_trainable = value 207 208 def _track_value(self, value, name): 209 """Add a dependency on `value`.""" 210 value = sticky_attribute_assignment( 211 trackable=self, value=value, name=name) 212 if isinstance(value, variables.Variable): 213 self._self_extra_variables.append(value) 214 if not isinstance(value, base.Trackable): 215 raise _UntrackableError(value) 216 if hasattr(value, "_use_resource_variables"): 217 # In subclassed models, legacy layers (tf.layers) must always use 218 # resource variables. 219 value._use_resource_variables = True # pylint: disable=protected-access 220 value_attribute_sentinel = getattr(value, "_attribute_sentinel", None) 221 if value_attribute_sentinel: 222 value_attribute_sentinel.add_parent(self._attribute_sentinel) 223 return value 224 225 @property 226 def _values(self): 227 """An iterable/sequence which may contain trackable objects.""" 228 raise NotImplementedError("Abstract method") 229 230 @property 231 def _layers(self): 232 """All Layers and Layer containers, including empty containers.""" 233 # Filter objects on demand so that wrapper objects use values from the thing 234 # they're wrapping if out of sync. 235 collected = [] 236 for obj in self._values: 237 if (isinstance(obj, TrackableDataStructure) 238 or layer_utils.is_layer(obj) 239 or layer_utils.has_weights(obj)): 240 collected.append(obj) 241 return collected 242 243 @property 244 def layers(self): 245 return list(layer_utils.filter_empty_layer_containers(self._layers)) 246 247 @property 248 def trainable_weights(self): 249 if not self._self_trainable: 250 return [] 251 trainable_variables = [] 252 for obj in self._values: 253 if isinstance(obj, (TrackableDataStructure, module.Module)): 254 trainable_variables += obj.trainable_variables 255 trainable_extra_variables = [ 256 v for v in self._self_extra_variables if v.trainable 257 ] 258 return trainable_variables + trainable_extra_variables 259 260 @property 261 def non_trainable_weights(self): 262 trainable_extra_variables = [ 263 v for v in self._self_extra_variables if v.trainable 264 ] 265 non_trainable_extra_variables = [ 266 v for v in self._self_extra_variables if not v.trainable 267 ] 268 non_trainable_variables = [] 269 for obj in self._values: 270 if isinstance(obj, (TrackableDataStructure, module.Module)): 271 non_trainable_variables += obj.non_trainable_variables 272 273 if not self._self_trainable: 274 # Return order is all trainable vars, then all non-trainable vars. 275 trainable_variables = [] 276 for obj in self._values: 277 if isinstance(obj, (TrackableDataStructure, module.Module)): 278 trainable_variables += obj.trainable_variables 279 280 non_trainable_variables = ( 281 trainable_variables + trainable_extra_variables + 282 non_trainable_variables + non_trainable_extra_variables) 283 else: 284 non_trainable_variables = ( 285 non_trainable_variables + non_trainable_extra_variables) 286 287 return non_trainable_variables 288 289 @property 290 def weights(self): 291 return self.trainable_weights + self.non_trainable_weights 292 293 @property 294 def trainable_variables(self): 295 return self.trainable_weights 296 297 @property 298 def non_trainable_variables(self): 299 return self.non_trainable_weights 300 301 @property 302 def variables(self): 303 return self.weights 304 305 @property 306 def updates(self): 307 """Aggregate updates from any `Layer` instances.""" 308 # Updates and conditional losses are forwarded as-is rather than being 309 # filtered based on inputs, since this is just a container and won't ever 310 # have any inputs. 311 aggregated = [] 312 for layer in self.layers: 313 if hasattr(layer, "updates"): 314 aggregated += layer.updates 315 return aggregated 316 317 @property 318 def losses(self): 319 """Aggregate losses from any `Layer` instances.""" 320 aggregated = [] 321 for layer in self.layers: 322 if hasattr(layer, "losses"): 323 aggregated += layer.losses 324 return aggregated 325 326 def __hash__(self): 327 # Support object-identity hashing, so these structures can be used as keys 328 # in sets/dicts. 329 return id(self) 330 331 def __eq__(self, other): 332 # Similar to Tensors, trackable data structures use object-identity 333 # equality to support set/dict membership. 334 return self is other 335 336 337class List(TrackableDataStructure, collections_abc.Sequence): 338 """An append-only sequence type which is trackable. 339 340 Maintains checkpoint dependencies on its contents (which must also be 341 trackable), and forwards any `Layer` metadata such as updates and losses. 342 343 Note that `List` is purely a container. It lets a `tf.keras.Model` or 344 other trackable object know about its contents, but does not call any 345 `Layer` instances which are added to it. To indicate a sequence of `Layer` 346 instances which should be called sequentially, use `tf.keras.Sequential`. 347 348 Example usage: 349 ```python 350 class HasList(tf.keras.Model): 351 352 def __init__(self): 353 super(HasList, self).__init__() 354 self.layer_list = List([layers.Dense(3)]) 355 self.layer_list.append(layers.Dense(4)) 356 357 def call(self, x): 358 aggregation = 0. 359 for l in self.layer_list: 360 x = l(x) 361 aggregation += tf.reduce_sum(x) 362 return aggregation 363 ``` 364 365 This kind of wrapping is necessary because `Trackable` objects do not 366 (yet) deeply inspect regular Python data structures, so for example assigning 367 a regular list (`self.layer_list = [layers.Dense(3)]`) does not create a 368 checkpoint dependency and does not add the `Layer` instance's weights to its 369 parent `Model`. 370 """ 371 372 def __init__(self, *args, **kwargs): 373 """Construct a new sequence. Arguments are passed to `list()`.""" 374 super(List, self).__init__() 375 self._storage = self._make_storage(*args, **kwargs) 376 for index, element in enumerate(self._storage): 377 self._storage[index] = self._track_value( 378 element, name=self._name_element(index)) 379 380 def copy(self): 381 return type(self)(copy.copy(self._storage)) 382 383 def __copy__(self): 384 return self.copy() 385 386 def __deepcopy__(self, memo): 387 return type(self)(copy.deepcopy(self._storage, memo)) 388 389 def _make_storage(self, *args, **kwargs): 390 """Determines the backing storage (overridden in subclasses).""" 391 return list(*args, **kwargs) 392 393 def _name_element(self, index): 394 return "%d" % (index,) 395 396 @property 397 def _values(self): 398 """Collect values for TrackableDataStructure.""" 399 return self 400 401 def append(self, value): 402 """Add a new trackable value.""" 403 value = self._track_value(value, self._name_element(len(self._storage))) 404 self._storage.append(value) 405 406 def extend(self, values): 407 """Add a sequence of trackable values.""" 408 for value in values: 409 self.append(value) 410 411 def __iadd__(self, values): 412 self.extend(values) 413 return self 414 415 def __add__(self, other): 416 return self._storage + getattr(other, "_storage", other) 417 418 def __imul__(self, y): 419 if y <= 0: 420 raise ValueError( 421 "List only supports append, multiplying in place by %d removes " 422 "elements." % y) 423 424 n = len(self._storage) 425 for _ in range(y - 1): 426 for i in range(n): 427 self.append(self._storage[i]) 428 429 return self 430 431 def __mul__(self, n): 432 return self._storage * n 433 434 def __rmul__(self, n): 435 return self * n 436 437 def __radd__(self, other): 438 return other + self._storage 439 440 def __getitem__(self, key): 441 return self._storage[key] 442 443 def __getslice__(self, i, j): 444 return self._storage[slice(i, j)] 445 446 def __len__(self): 447 return len(self._storage) 448 449 def __repr__(self): 450 return "List(%s)" % (repr(self._storage),) 451 452 def __sizeof__(self): 453 return super(List, self).__sizeof__() + sys.getsizeof(self._storage) 454 455 456# TODO(tomhennigan) Update to collections.UserList? 457# TODO(allenl): Try switching this to wrapt.ObjectProxy again when we drop 458# Python 3.4 support (may still be tricky). 459class ListWrapper( 460 List, 461 collections_abc.MutableSequence, 462 # Shadowed, but there for isinstance checks. 463 list): 464 """Wraps the built-in `list` to support restore-on-create for variables. 465 466 Unlike `List`, this sequence type is mutable in the same ways built-in lists 467 are. Instead of throwing an error immediately like `List`, it records 468 problematic mutations (e.g. assigning a new element to a position already 469 occupied, meaning both elements get the same names at different times) and 470 refuses to save. 471 472 On assignment to an attribute of a Model or Trackable object, Python 473 lists are replaced with ListWrapper. Wrapping a list in a 474 `NoDependency` object prevents this. 475 """ 476 477 def __init__(self, wrapped_list): 478 """Construct a new list wrapper. 479 480 Args: 481 wrapped_list: The initial value of the data structure. A shallow copy may 482 be maintained for error checking. `wrapped_list` itself should not be 483 modified directly after constructing the `ListWrapper`, and if changes 484 are detected the `ListWrapper` will throw an exception on save. 485 """ 486 # Monotonic flags which indicate this object would not be restored properly, 487 # and therefore should throw an error on save to avoid giving the impression 488 # that restoring it will work. 489 self._non_append_mutation_value = False 490 self._external_modification_value = False 491 super(ListWrapper, self).__init__(wrapped_list) 492 self._last_wrapped_list_snapshot = list(self._storage) 493 494 @property 495 def _non_append_mutation(self): 496 return self._non_append_mutation_value 497 498 @_non_append_mutation.setter 499 def _non_append_mutation(self, value): 500 # Trackable only cares that a mutation occurred at some point; when 501 # attempting to save it checks whether a mutation occurred and the object is 502 # in a "dirty" state but otherwise the specifics of how it got to that state 503 # are ignored. By contrast, the attribute cache needs to signal the mutation 504 # immediately since a caller could query the value of an attribute (And 505 # should not hit the cached value since the mutation may have affected the 506 # result.) 507 self._attribute_sentinel.invalidate_all() 508 self._non_append_mutation_value = value 509 510 @property 511 def _external_modification(self): 512 return self._external_modification_value 513 514 @_external_modification.setter 515 def _external_modification(self, value): 516 # Invalidate for the same reason as `_non_append_mutation` 517 self._attribute_sentinel.invalidate_all() 518 self._external_modification_value = value 519 520 # pylint: disable=protected-access 521 def __copy__(self): 522 copied = super(ListWrapper, self).__copy__() 523 copied._non_append_mutation = self._non_append_mutation 524 copied._external_modification = self._external_modification 525 return copied 526 527 def __deepcopy__(self, memo): 528 copied = super(ListWrapper, self).__deepcopy__(memo) 529 copied._non_append_mutation = self._non_append_mutation 530 copied._external_modification = self._external_modification 531 return copied 532 # pylint: enable=protected-access 533 534 def __reduce_ex__(self, protocol): 535 return (self.__class__, 536 (self._storage,)) 537 538 def _make_storage(self, wrapped_list): 539 """Use the user's original list for storage.""" 540 return wrapped_list 541 542 def _check_external_modification(self): 543 """Checks for any changes to the wrapped list not through the wrapper.""" 544 if self._external_modification or self._non_append_mutation: 545 return 546 if self._storage != self._last_wrapped_list_snapshot: 547 self._external_modification = True 548 self._last_wrapped_list_snapshot = None 549 550 def _update_snapshot(self): 551 """Acknowledges tracked changes to the wrapped list.""" 552 553 # Mutation tracking for attributes reuses the same infrastructure as 554 # Trackable mutation tracking. 555 self._attribute_sentinel.invalidate_all() 556 if self._external_modification or self._non_append_mutation: 557 return 558 self._last_wrapped_list_snapshot = list(self._storage) 559 560 @property 561 def _checkpoint_dependencies(self): 562 self._check_external_modification() 563 if self._non_append_mutation: 564 raise ValueError( 565 ("Unable to save the object %s (a list wrapper constructed to track " 566 "trackable TensorFlow objects). A list element was replaced " 567 "(__setitem__, __setslice__), deleted (__delitem__, __delslice__), " 568 "or moved (sort). In order to support restoration on object " 569 "creation, tracking is exclusively for append-only data structures." 570 "\n\nIf you don't need this list checkpointed, wrap it in a " 571 "non-trackable object; it will be subsequently ignored." % (self,))) 572 if self._external_modification: 573 raise ValueError( 574 ("Unable to save the object %s (a list wrapper constructed to track " 575 "trackable TensorFlow objects). The wrapped list was modified " 576 "outside the wrapper (its final value was %s, its value when a " 577 "checkpoint dependency was added was %s), which breaks restoration " 578 "on object creation.\n\nIf you don't need this list checkpointed, " 579 "wrap it in a NoDependency object; it will be " 580 "subsequently ignored." % ( 581 self, self._storage, self._last_wrapped_list_snapshot))) 582 return super(ListWrapper, self)._checkpoint_dependencies 583 584 def _has_mutation_or_trackable(self): 585 """Short-circuits a check for trackables if there's already a mutation.""" 586 if self._non_append_mutation: 587 return True 588 return any(isinstance(element, base.Trackable) for element in self._storage) 589 590 def __delitem__(self, key): 591 self._check_external_modification() 592 if self._has_mutation_or_trackable(): 593 self._non_append_mutation = True 594 del self._storage[key] 595 self._update_snapshot() 596 597 def __setitem__(self, key, value): 598 self._check_external_modification() 599 600 if isinstance(key, slice): 601 # Note: this is quite inefficient, but the list API supports a broad range 602 # of slice setters (e.g. truncate, extend, replace) and imitating this 603 # for a range of Python versions is non-trivial. 604 storage_copy = list(self._storage) 605 self._storage[key] = value 606 607 len_before = len(storage_copy) 608 len_now = len(self._storage) 609 for i in range(max(len_before, len_now)): 610 value_now = self._storage[i] if i < len_now else None 611 value_before = storage_copy[i] if i < len_before else None 612 613 if isinstance(value_before, base.Trackable): 614 self._non_append_mutation = True 615 616 if value_now is not None and value_now != value_before: 617 self._storage[i] = self._track_value(self._storage[i], 618 self._name_element(i)) 619 620 else: 621 if isinstance(self._storage[key], base.Trackable): 622 self._non_append_mutation = True 623 self._storage[key] = self._track_value(value, self._name_element(key)) 624 625 self._update_snapshot() 626 627 def append(self, value): 628 """Add a new trackable value.""" 629 self._check_external_modification() 630 super(ListWrapper, self).append(value) 631 self._update_snapshot() 632 633 def extend(self, values): 634 """Add a sequence of trackable values.""" 635 self._check_external_modification() 636 super(ListWrapper, self).extend(values) 637 self._update_snapshot() 638 639 def __imul__(self, y): 640 if y <= 0: 641 self._check_external_modification() 642 if self._has_mutation_or_trackable(): 643 self._non_append_mutation = True 644 self._storage *= y 645 self._update_snapshot() 646 return self 647 648 # Relies on super() calling append, which updates the snapshot. 649 return super(ListWrapper, self).__imul__(y) 650 651 def __eq__(self, other): 652 return self._storage == getattr(other, "_storage", other) 653 654 def __ne__(self, other): 655 return self._storage != getattr(other, "_storage", other) 656 657 def __lt__(self, other): 658 return self._storage < getattr(other, "_storage", other) 659 660 def __le__(self, other): 661 return self._storage <= getattr(other, "_storage", other) 662 663 def __gt__(self, other): 664 return self._storage > getattr(other, "_storage", other) 665 666 def __ge__(self, other): 667 return self._storage >= getattr(other, "_storage", other) 668 669 def __hash__(self): 670 # List wrappers need to compare like regular lists, and so like regular 671 # lists they don't belong in hash tables. 672 raise TypeError("unhashable type: 'ListWrapper'") 673 674 def insert(self, index, obj): 675 self._check_external_modification() 676 if (self._has_mutation_or_trackable() or isinstance(obj, base.Trackable)): 677 self._non_append_mutation = True 678 self._storage.insert(index, obj) 679 self._update_snapshot() 680 681 def sort(self): 682 self._check_external_modification() 683 if self._has_mutation_or_trackable(): 684 self._non_append_mutation = True 685 self._storage.sort() 686 self._update_snapshot() 687 688 def __setslice__(self, i, j, y): 689 self.__setitem__(slice(i, j), y) 690 691 def __delslice__(self, i, j): 692 self._check_external_modification() 693 if self._has_mutation_or_trackable(): 694 self._non_append_mutation = True 695 del self._storage[slice(i, j)] 696 self._update_snapshot() 697 698 def _track_value(self, value, name): 699 """Allows storage of non-trackable objects.""" 700 try: 701 value = super(ListWrapper, self)._track_value(value=value, name=name) 702 except ValueError: 703 # Even if this value isn't trackable, we need to make sure 704 # NoDependency objects get unwrapped. 705 value = sticky_attribute_assignment( 706 trackable=self, value=value, name=name) 707 return value 708 709 def __repr__(self): 710 return "ListWrapper(%s)" % (repr(self._storage),) 711 712 def _list_functions_for_serialization(self, unused_functions): 713 return { 714 str(key): value for key, value in enumerate(self) 715 if _is_function(value) 716 } 717 718 719class Mapping(TrackableDataStructure, collections_abc.Mapping): 720 """An append-only trackable mapping data structure with string keys. 721 722 Maintains checkpoint dependencies on its contents (which must also be 723 trackable), named based on its keys. 724 725 Note that once a key has been added, it may not be deleted or replaced. 726 """ 727 728 def __init__(self, *args, **kwargs): 729 """Construct a new sequence. Arguments are passed to `dict()`.""" 730 super(Mapping, self).__init__() 731 self._storage = self._make_storage(*args, **kwargs) 732 self._storage.update( 733 {key: self._track_value( 734 value, name=self._name_element(key)) 735 for key, value in self._storage.items()}) 736 737 def __copy__(self): 738 return type(self)(copy.copy(self._storage)) 739 740 def __deepcopy__(self, memo): 741 return type(self)(copy.deepcopy(self._storage, memo)) 742 743 def _make_storage(self, *args, **kwargs): 744 return dict(*args, **kwargs) 745 746 @property 747 def _values(self): 748 """Collect values for TrackableDataStructure.""" 749 # Sort items deterministically by key 750 ordered = list(zip(*sorted(self.items(), key=lambda it: it[0]))) 751 if ordered: 752 return ordered[1] 753 return [] 754 755 def _name_element(self, key): 756 if not isinstance(key, six.string_types): 757 raise TypeError( 758 "Mapping accepts only string keys, but got a key %s." 759 % repr(key)) 760 return str(key) 761 762 def __setitem__(self, key, value): 763 name = self._name_element(key) 764 value = self._track_value(value, name=name) 765 current_value = self._storage.setdefault(key, value) 766 if current_value is not value: 767 raise ValueError( 768 ("Mappings are an append-only data structure. Tried to overwrite the " 769 "key '%s' with value %s, but it already contains %s") 770 % (key, value, current_value)) 771 772 def update(self, *args, **kwargs): 773 for key, value in dict(*args, **kwargs).items(): 774 self[key] = value 775 776 def __getitem__(self, key): 777 return self._storage[key] 778 779 def __len__(self): 780 return len(self._storage) 781 782 def __repr__(self): 783 return "Mapping(%s)" % (repr(self._storage),) 784 785 def __iter__(self): 786 return iter(self._storage) 787 788 789class _DictWrapper(TrackableDataStructure, wrapt.ObjectProxy): 790 """Wraps built-in dicts to support restore-on-create for variables. 791 792 _DictWrapper is to Mapping as ListWrapper is to List. Unlike Mapping, 793 _DictWrapper allows non-string keys and values and arbitrary mutations (delete 794 keys, reassign values). Like ListWrapper, these mutations mean that 795 _DictWrapper will raise an exception on save. 796 """ 797 798 def __init__(self, wrapped_dict=None): 799 if wrapped_dict is None: 800 # Allow zero-argument construction, e.g. from session.run's re-wrapping. 801 wrapped_dict = {} 802 if not isinstance(wrapped_dict, collections_abc.Mapping): 803 # Allow construction from a sequence, e.g. from nest.pack_sequence_as. 804 wrapped_dict = dict(wrapped_dict) 805 wrapt.ObjectProxy.__init__(self, wrapped_dict) 806 TrackableDataStructure.__init__(self) 807 self._self_non_string_key = False 808 self._self_external_modification = False 809 self.__wrapped__.update( 810 {key: self._track_value( 811 value, name=self._name_element(key)) 812 for key, value in self.__wrapped__.items()}) 813 self._update_snapshot() 814 815 def __reduce_ex__(self, protocol): 816 return (self.__class__, 817 (self.__wrapped__,)) 818 819 def __getattribute__(self, name): 820 if (hasattr(type(self), name) 821 and isinstance(getattr(type(self), name), property)): 822 # Bypass ObjectProxy for properties. Whether this workaround is necessary 823 # appears to depend on the Python version but not the wrapt version: 3.4 824 # in particular seems to look up properties on the wrapped object instead 825 # of the wrapper without this logic. 826 return object.__getattribute__(self, name) 827 else: 828 return super(_DictWrapper, self).__getattribute__(name) 829 830 def copy(self): 831 return copy.copy(self) 832 833 # pylint: disable=protected-access 834 def __copy__(self): 835 copied = _DictWrapper(copy.copy(self.__wrapped__)) 836 copied._self_external_modification = self._self_external_modification 837 copied._self_non_string_key = self._self_non_string_key 838 return copied 839 840 def __deepcopy__(self, memo): 841 copied = _DictWrapper(copy.deepcopy(self.__wrapped__, memo)) 842 copied._self_external_modification = self._self_external_modification 843 copied._self_non_string_key = self._self_non_string_key 844 return copied 845 # pylint: enable=protected-access 846 847 @property 848 def _values(self): 849 """Collect values for TrackableDataStructure.""" 850 # Sort items deterministically by key 851 ordered = list(zip(*sorted(self.items(), key=lambda it: it[0]))) 852 if ordered: 853 return ordered[1] 854 return [] 855 856 @property 857 def _checkpoint_dependencies(self): 858 """Check that the object is saveable before listing its dependencies.""" 859 self._check_self_external_modification() 860 if self._self_non_string_key: 861 raise ValueError( 862 "Unable to save the object %s (a dictionary wrapper constructed " 863 "automatically on attribute assignment). The wrapped dictionary " 864 "contains a non-string key which maps to a trackable object or " 865 "mutable data structure.\n\nIf you don't need this dictionary " 866 "checkpointed, wrap it in a non-trackable " 867 "object; it will be subsequently ignored." % (self,)) 868 if self._self_external_modification: 869 raise ValueError( 870 "Unable to save the object %s (a dictionary wrapper constructed " 871 "automatically on attribute assignment). The wrapped dictionary was " 872 "modified outside the wrapper (its final value was %s, its value " 873 "when a checkpoint dependency was added was %s), which breaks " 874 "restoration on object creation.\n\nIf you don't need this " 875 "dictionary checkpointed, wrap it in a " 876 "non-trackable object; it will be subsequently ignored." % ( 877 self, self, self._self_last_wrapped_dict_snapshot)) 878 assert not self._dirty # Any reason for dirtiness should have an exception. 879 return super(_DictWrapper, self)._checkpoint_dependencies 880 881 @property 882 def _dirty(self): 883 """Check if there has already been a mutation which prevents saving.""" 884 return (self._self_external_modification 885 or self._self_non_string_key) 886 887 def _check_self_external_modification(self): 888 """Checks for any changes to the wrapped dict not through the wrapper.""" 889 if self._dirty: 890 return 891 if self != self._self_last_wrapped_dict_snapshot: 892 self._self_external_modification = True 893 self._self_last_wrapped_dict_snapshot = None 894 895 def _update_snapshot(self): 896 """Acknowledges tracked changes to the wrapped dict.""" 897 self._attribute_sentinel.invalidate_all() 898 if self._dirty: 899 return 900 self._self_last_wrapped_dict_snapshot = dict(self) 901 902 def _track_value(self, value, name): 903 """Allows storage of non-trackable objects.""" 904 if isinstance(name, six.string_types): 905 string_key = True 906 else: 907 name = "-non_string_key" 908 string_key = False 909 try: 910 no_dependency = isinstance(value, NoDependency) 911 value = super(_DictWrapper, self)._track_value(value=value, name=name) 912 if not (string_key or no_dependency): 913 # A non-string key maps to a trackable value. This data structure 914 # is not saveable. 915 self._self_non_string_key = True 916 return value 917 except ValueError: 918 # Even if this value isn't trackable, we need to make sure 919 # NoDependency objects get unwrapped. 920 return sticky_attribute_assignment( 921 trackable=self, value=value, name=name) 922 923 def _name_element(self, key): 924 """Tells TrackableDataStructure to use keys as names as-is.""" 925 return key 926 927 def __setitem__(self, key, value): 928 """Allow any modifications, but possibly mark the wrapper as unsaveable.""" 929 self._check_self_external_modification() 930 self._maybe_initialize_trackable() 931 no_dep = isinstance(value, NoDependency) 932 if isinstance(key, six.string_types): 933 value = self._track_value(value, name=key) 934 else: 935 value = wrap_or_unwrap(value) 936 if not no_dep and isinstance(value, base.Trackable): 937 # Non-string keys are OK as long as we have no reason to add a 938 # dependency on the value (either because the value is not 939 # trackable, or because it was wrapped in a NoDependency object). 940 self._self_non_string_key = True 941 self.__wrapped__[key] = value 942 943 self._update_snapshot() 944 945 def __delitem__(self, key): 946 self._check_self_external_modification() 947 del self.__wrapped__[key] 948 self._update_snapshot() 949 950 def __repr__(self): 951 return "DictWrapper(%s)" % (repr(self.__wrapped__),) 952 953 def __hash__(self): 954 raise TypeError("unhashable type: 'DictWrapper'") 955 956 def __eq__(self, other): 957 # Override the TrackableDataStructure "== -> is" forwarding and go back to 958 # the wrapt implementation. 959 return self.__wrapped__ == other 960 961 def update(self, *args, **kwargs): 962 for key, value in six.iteritems(dict(*args, **kwargs)): 963 self[key] = value 964 965 def _list_functions_for_serialization(self, unused_serialization_cache): 966 return { 967 key: value for key, value in self.items() 968 if _is_function(value) 969 } 970 971 972class _TupleWrapper(TrackableDataStructure, wrapt.ObjectProxy): 973 """Trackable wrapper for tuples and namedtuples.""" 974 975 def __init__(self, original_wrapped_tuple=()): 976 add_dependency = [] 977 substituted_wrapped_tuple = [] 978 for element in original_wrapped_tuple: 979 if isinstance(element, NoDependency): 980 add_dependency.append(False) 981 else: 982 add_dependency.append(True) 983 substituted_wrapped_tuple.append(wrap_or_unwrap(element)) 984 try: 985 fields = original_wrapped_tuple._fields 986 except AttributeError: 987 # Not a namedtuple 988 is_namedtuple = False 989 else: 990 is_namedtuple = True 991 original_type = type(original_wrapped_tuple) 992 # Flag to poison saving if we can't re-construct a namedtupled because its 993 # __new__ takes different keyword arguments than its _fields. 994 self._self_tuple_is_constructable = True 995 if is_namedtuple: 996 try: 997 # NamedTuples take N arguments, unlike tuple which takes a sequence. 998 substituted_wrapped_tuple = original_type( 999 **dict(zip(fields, substituted_wrapped_tuple))) 1000 except TypeError: 1001 wrapt.ObjectProxy.__init__(self, original_wrapped_tuple) 1002 TrackableDataStructure.__init__(self) 1003 self._self_tuple_is_constructable = False 1004 return 1005 else: 1006 substituted_wrapped_tuple = original_type(substituted_wrapped_tuple) 1007 wrapt.ObjectProxy.__init__(self, substituted_wrapped_tuple) 1008 TrackableDataStructure.__init__(self) 1009 1010 if is_namedtuple: 1011 # For namedtuples, also track by names for compatibility with 1012 # dictionaries. 1013 for name, should_depend, element in zip( 1014 fields, add_dependency, substituted_wrapped_tuple): 1015 if should_depend: 1016 self._track_value(element, name=name) 1017 1018 # Track by index as well, for compatibility with lists. 1019 for index, (should_depend, element) in enumerate( 1020 zip(add_dependency, substituted_wrapped_tuple)): 1021 if should_depend: 1022 self._track_value(element, name="%d" % (index,)) 1023 1024 @property 1025 def _values(self): 1026 """Collect values for TrackableDataStructure.""" 1027 return self 1028 1029 def _track_value(self, value, name): 1030 """Allows storage of non-trackable objects.""" 1031 try: 1032 value = super(_TupleWrapper, self)._track_value(value=value, name=name) 1033 except ValueError: 1034 # Even if this value isn't trackable, we need to make sure 1035 # NoDependency objects get unwrapped. 1036 value = sticky_attribute_assignment( 1037 trackable=self, value=value, name=name) 1038 return value 1039 1040 def __repr__(self): 1041 return "_TupleWrapper(%s)" % (repr(self.__wrapped__),) 1042 1043 def __hash__(self): 1044 # Override the TrackableDataStructure hash forwarding and go back to 1045 # the wrapt implementation. 1046 return hash(self.__wrapped__) 1047 1048 def __eq__(self, other): 1049 # Override the TrackableDataStructure "== -> is" forwarding and go back to 1050 # the wrapt implementation. 1051 return self.__wrapped__ == other 1052 1053 def __copy__(self): 1054 return _TupleWrapper(copy.copy(self.__wrapped__)) 1055 1056 def __deepcopy__(self, memo): 1057 return _TupleWrapper(copy.deepcopy(self.__wrapped__, memo)) 1058 1059 def __reduce_ex__(self, protocol): 1060 return (self.__class__, 1061 (self.__wrapped__,)) 1062 1063 # imul and iadd are the only tuple-relevant in-place operators. They need to 1064 # be special-cased to avoid mutating the original proxy object. 1065 def __imul__(self, y): 1066 """Avoid running self.__wrapped__ *= y, which mutates `self`.""" 1067 return self.__wrapped__ * y 1068 1069 def __iadd__(self, y): 1070 """Avoid running self.__wrapped__ += y, which mutates `self`.""" 1071 return self.__wrapped__ + y 1072 1073 @property 1074 def _checkpoint_dependencies(self): 1075 if not self._self_tuple_is_constructable: 1076 raise ValueError( 1077 ("Unable to save because the namedtuple {} is not constructable from " 1078 "its _fields (i.e. __new__ is overridden). Expected keyword " 1079 "arguments {}. If you do not need to save this object, consider " 1080 "wrapping it in a custom object that does not inherit from tuple.") 1081 .format(self.__wrapped__, self.__wrapped__._fields)) 1082 return super(_TupleWrapper, self)._checkpoint_dependencies 1083 1084 def __getattribute__(self, name): 1085 if name != "__wrapped__" and hasattr(self.__wrapped__, name): 1086 # Prefer attributes on the wrapped object when they conflict with 1087 # attributes on the wrapper object. 1088 return getattr(self.__wrapped__, name) 1089 1090 if (hasattr(type(self), name) 1091 and isinstance(getattr(type(self), name), property)): 1092 # Bypass ObjectProxy for properties. Whether this workaround is necessary 1093 # appears to depend on the Python version but not the wrapt version: 3.4 1094 # in particular seems to look up properties on the wrapped object instead 1095 # of the wrapper without this logic. 1096 return object.__getattribute__(self, name) 1097 else: 1098 return super(_TupleWrapper, self).__getattribute__(name) 1099 1100 1101def _is_function(x): 1102 return isinstance(x, (def_function.Function, defun.ConcreteFunction)) 1103 1104 1105revived_types.register_revived_type( 1106 "trackable_dict_wrapper", 1107 lambda obj: isinstance(obj, _DictWrapper), 1108 versions=[revived_types.VersionedTypeRegistration( 1109 # Standard dependencies are enough to reconstruct the trackable 1110 # items in dictionaries, so we don't need to save any extra information. 1111 object_factory=lambda proto: _DictWrapper({}), 1112 version=1, 1113 min_producer_version=1, 1114 min_consumer_version=1, 1115 setter=operator.setitem)]) 1116 1117 1118def _set_list_item(list_object, index_string, value): 1119 item_index = int(index_string) 1120 if len(list_object) <= item_index: 1121 list_object.extend([None] * (1 + item_index - len(list_object))) 1122 list_object[item_index] = value 1123 1124 1125revived_types.register_revived_type( 1126 "trackable_list_wrapper", 1127 lambda obj: isinstance(obj, ListWrapper), 1128 versions=[revived_types.VersionedTypeRegistration( 1129 object_factory=lambda proto: ListWrapper([]), 1130 version=1, 1131 min_producer_version=1, 1132 min_consumer_version=1, 1133 setter=_set_list_item)]) 1134 1135 1136def _set_tuple_item(list_object, index_string, value): 1137 try: 1138 item_index = int(index_string) 1139 except ValueError: 1140 # Ignore namedtuple fields. 1141 return 1142 if len(list_object) <= item_index: 1143 list_object.extend([None] * (1 + item_index - len(list_object))) 1144 list_object[item_index] = value 1145 1146 1147# Revive tuples as lists so we can append any dependencies during loading. 1148revived_types.register_revived_type( 1149 "trackable_tuple_wrapper", 1150 lambda obj: isinstance(obj, _TupleWrapper), 1151 versions=[revived_types.VersionedTypeRegistration( 1152 object_factory=lambda proto: ListWrapper([]), 1153 version=1, 1154 min_producer_version=1, 1155 min_consumer_version=1, 1156 setter=_set_tuple_item)]) 1157