1"""Trackable data structures.""" 2# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import collections 21import copy 22import operator 23import sys 24 25import six 26try: 27 import wrapt 28except ImportError: 29 # Fall back to the build-time dependency if the system package is not available. 30 from .....third_party import wrapt 31 32from tensorflow.python.eager import def_function 33from tensorflow.python.eager import function as defun 34from tensorflow.python.ops import variables 35from tensorflow.python.saved_model import revived_types 36from tensorflow.python.training.tracking import base 37from tensorflow.python.training.tracking import layer_utils 38from tensorflow.python.util.compat import collections_abc 39 40 41class NoDependency(object): 42 """Allows attribute assignment to `Trackable` objects with no dependency. 43 44 Example usage: 45 ```python 46 obj = Trackable() 47 obj.has_dependency = tf.Variable(0., name="dep") 48 obj.no_dependency = NoDependency(tf.Variable(1., name="nodep")) 49 assert obj.no_dependency.name == "nodep:0" 50 ``` 51 52 `obj` in this example has a dependency on the variable "dep", and both 53 attributes contain un-wrapped `Variable` objects. 54 55 `NoDependency` also works with `tf.keras.Model`, but only for checkpoint 56 dependencies: wrapping a `Layer` in `NoDependency` will assign the (unwrapped) 57 `Layer` to the attribute without a checkpoint dependency, but the `Model` will 58 still track the `Layer` (so it will appear in `Model.layers`, and its 59 variables will appear in `Model.variables`). 60 """ 61 62 def __init__(self, value): 63 self.value = value 64 65 66def _should_wrap_tuple(t): 67 """Determine if a tuple has any trackable components.""" 68 for element in t: 69 if isinstance(element, NoDependency): 70 return True # We should remove the NoDependency object from the tuple. 71 if isinstance(element, base.Trackable): 72 return True 73 if _wrap_or_unwrap(element) is not element: 74 return True 75 # There are no trackable elements or data structures. Tuples are immutable, so 76 # mutation isn't a concern. Don't wrap. 77 return False 78 79 80def _wrap_or_unwrap(value): 81 """Wraps basic data structures, unwraps NoDependency objects.""" 82 # pylint: disable=unidiomatic-typecheck 83 # Exact type checking to avoid mucking up custom logic in list/dict 84 # subclasses, e.g. collections.Counter. 85 if isinstance(value, NoDependency): 86 return value.value 87 if isinstance(value, base.Trackable): 88 return value # Skip conversion for already trackable objects. 89 elif type(value) == dict: 90 return _DictWrapper(value) 91 elif type(value) == collections.OrderedDict: 92 return _DictWrapper(value) 93 elif type(value) == list: 94 return ListWrapper(value) 95 elif isinstance(value, tuple) and _should_wrap_tuple(value): 96 # There are trackable elements or data structures. Wrap the tuple. 97 return _TupleWrapper(value) 98 else: 99 return value 100 # pylint: enable=unidiomatic-typecheck 101 102 103def sticky_attribute_assignment(trackable, name, value): 104 """Adds dependencies, generally called from __setattr__. 105 106 This behavior is shared between Trackable and Model. 107 108 Respects NoDependency indicators, but otherwise makes trackable objects 109 out of common data structures and tracks objects by their attribute names. 110 111 Args: 112 trackable: The object to add dependencies to (generally the one having 113 an attribute assigned). 114 name: The attribute name being assigned. 115 value: The value being assigned. Not necessarily a trackable object. 116 117 Returns: 118 The value which should be stored in the attribute (unwrapped from a 119 NoDependency object if necessary). 120 """ 121 if isinstance(value, NoDependency): 122 add_dependency = False 123 else: 124 add_dependency = True 125 value = _wrap_or_unwrap(value) 126 if not add_dependency: 127 return value 128 if isinstance(value, base.Trackable): 129 trackable._track_trackable( # pylint: disable=protected-access 130 value, name=name, 131 # Allow the user to switch the Trackable which is tracked by this 132 # name, since assigning a new variable to an attribute has 133 # historically been fine (e.g. Adam did this). 134 overwrite=True) 135 return value 136 137 138class _UntrackableError(ValueError): 139 140 def __init__(self, value): # pylint: disable=super-init-not-called 141 self._value = value 142 143 def __str__(self): 144 return (("Only trackable objects (such as Layers or Optimizers) may be " 145 "stored in a List object. Got %s, which does not inherit from " 146 "Trackable.") % (self._value,)) 147 148 149class TrackableDataStructure(base.Trackable): 150 """Base class for data structures which contain trackable objects.""" 151 152 def __init__(self): 153 # Attributes prefixed with "_self_" for compatibility with 154 # wrapt.ObjectProxy. All additional attrs MUST conform to this pattern, as 155 # extending `__slots__` on a subclass of ObjectProxy breaks in a variety of 156 # ways. 157 self._self_trainable = True 158 self._self_extra_variables = [] 159 self._self_attribute_sentinel = layer_utils.AttributeSentinel(True) 160 161 @property 162 def _attribute_sentinel(self): 163 return self._self_attribute_sentinel 164 165 @property 166 def trainable(self): 167 return self._self_trainable 168 169 @trainable.setter 170 def trainable(self, value): 171 self._self_trainable = value 172 173 def _track_value(self, value, name): 174 """Add a dependency on `value`.""" 175 value = sticky_attribute_assignment( 176 trackable=self, value=value, name=name) 177 if isinstance(value, variables.Variable): 178 self._self_extra_variables.append(value) 179 if not isinstance(value, base.Trackable): 180 raise _UntrackableError(value) 181 if hasattr(value, "_use_resource_variables"): 182 # In subclassed models, legacy layers (tf.layers) must always use 183 # resource variables. 184 value._use_resource_variables = True # pylint: disable=protected-access 185 value_attribute_sentinel = getattr(value, "_attribute_sentinel", None) 186 if value_attribute_sentinel: 187 value_attribute_sentinel.add_parent(self._attribute_sentinel) 188 return value 189 190 @property 191 def _values(self): 192 """An iterable/sequence which may contain trackable objects.""" 193 raise NotImplementedError("Abstract method") 194 195 @property 196 def _layers(self): 197 """All Layers and Layer containers, including empty containers.""" 198 # Filter objects on demand so that wrapper objects use values from the thing 199 # they're wrapping if out of sync. 200 collected = [] 201 for obj in self._values: 202 if (isinstance(obj, TrackableDataStructure) 203 or layer_utils.is_layer(obj) 204 or layer_utils.has_weights(obj)): 205 collected.append(obj) 206 return collected 207 208 @property 209 def layers(self): 210 return list(layer_utils.filter_empty_layer_containers(self._layers)) 211 212 @property 213 def trainable_weights(self): 214 return layer_utils.gather_trainable_weights( 215 trainable=self.trainable, 216 sub_layers=self._layers, 217 extra_variables=self._self_extra_variables) 218 219 @property 220 def non_trainable_weights(self): 221 return layer_utils.gather_non_trainable_weights( 222 trainable=self.trainable, 223 sub_layers=self._layers, 224 extra_variables=self._self_extra_variables) 225 226 @property 227 def weights(self): 228 return self.trainable_weights + self.non_trainable_weights 229 230 @property 231 def trainable_variables(self): 232 return self.trainable_weights 233 234 @property 235 def non_trainable_variables(self): 236 return self.non_trainable_weights 237 238 @property 239 def variables(self): 240 return self.weights 241 242 @property 243 def updates(self): 244 """Aggregate updates from any `Layer` instances.""" 245 # Updates and conditional losses are forwarded as-is rather than being 246 # filtered based on inputs, since this is just a container and won't ever 247 # have any inputs. 248 aggregated = [] 249 for layer in self.layers: 250 if hasattr(layer, "updates"): 251 aggregated += layer.updates 252 return aggregated 253 254 @property 255 def losses(self): 256 """Aggregate losses from any `Layer` instances.""" 257 aggregated = [] 258 for layer in self.layers: 259 if hasattr(layer, "losses"): 260 aggregated += layer.losses 261 return aggregated 262 263 def __hash__(self): 264 # Support object-identity hashing, so these structures can be used as keys 265 # in sets/dicts. 266 return id(self) 267 268 def __eq__(self, other): 269 # Similar to Tensors, trackable data structures use object-identity 270 # equality to support set/dict membership. 271 return self is other 272 273 274class List(TrackableDataStructure, collections_abc.Sequence): 275 """An append-only sequence type which is trackable. 276 277 Maintains checkpoint dependencies on its contents (which must also be 278 trackable), and forwards any `Layer` metadata such as updates and losses. 279 280 Note that `List` is purely a container. It lets a `tf.keras.Model` or 281 other trackable object know about its contents, but does not call any 282 `Layer` instances which are added to it. To indicate a sequence of `Layer` 283 instances which should be called sequentially, use `tf.keras.Sequential`. 284 285 Example usage: 286 ```python 287 class HasList(tf.keras.Model): 288 289 def __init__(self): 290 super(HasList, self).__init__() 291 self.layer_list = tf.contrib.checkpoint.List([layers.Dense(3)]) 292 self.layer_list.append(layers.Dense(4)) 293 294 def call(self, x): 295 aggregation = 0. 296 for l in self.layer_list: 297 x = l(x) 298 aggregation += tf.reduce_sum(x) 299 return aggregation 300 ``` 301 302 This kind of wrapping is necessary because `Trackable` objects do not 303 (yet) deeply inspect regular Python data structures, so for example assigning 304 a regular list (`self.layer_list = [layers.Dense(3)]`) does not create a 305 checkpoint dependency and does not add the `Layer` instance's weights to its 306 parent `Model`. 307 """ 308 309 def __init__(self, *args, **kwargs): 310 """Construct a new sequence. Arguments are passed to `list()`.""" 311 super(List, self).__init__() 312 self._storage = self._make_storage(*args, **kwargs) 313 for index, element in enumerate(self._storage): 314 self._storage[index] = self._track_value( 315 element, name=self._name_element(index)) 316 317 def copy(self): 318 return type(self)(copy.copy(self._storage)) 319 320 def __copy__(self): 321 return self.copy() 322 323 def __deepcopy__(self, memo): 324 return type(self)(copy.deepcopy(self._storage, memo)) 325 326 def _make_storage(self, *args, **kwargs): 327 """Determines the backing storage (overridden in subclasses).""" 328 return list(*args, **kwargs) 329 330 def _name_element(self, index): 331 return "%d" % (index,) 332 333 @property 334 def _values(self): 335 """Collect values for TrackableDataStructure.""" 336 return self 337 338 def append(self, value): 339 """Add a new trackable value.""" 340 value = self._track_value(value, self._name_element(len(self._storage))) 341 self._storage.append(value) 342 343 def extend(self, values): 344 """Add a sequence of trackable values.""" 345 for value in values: 346 self.append(value) 347 348 def __iadd__(self, values): 349 self.extend(values) 350 return self 351 352 def __add__(self, other): 353 return self._storage + getattr(other, "_storage", other) 354 355 def __imul__(self, y): 356 if y <= 0: 357 raise ValueError( 358 "List only supports append, multiplying in place by %d removes " 359 "elements." % y) 360 361 n = len(self._storage) 362 for _ in range(y - 1): 363 for i in range(n): 364 self.append(self._storage[i]) 365 366 return self 367 368 def __mul__(self, n): 369 return self._storage * n 370 371 def __rmul__(self, n): 372 return self * n 373 374 def __radd__(self, other): 375 return other + self._storage 376 377 def __getitem__(self, key): 378 return self._storage[key] 379 380 def __getslice__(self, i, j): 381 return self._storage[slice(i, j)] 382 383 def __len__(self): 384 return len(self._storage) 385 386 def __repr__(self): 387 return "List(%s)" % (repr(self._storage),) 388 389 def __sizeof__(self): 390 return super(List, self).__sizeof__() + sys.getsizeof(self._storage) 391 392 393# TODO(tomhennigan) Update to collections.UserList? 394# TODO(allenl): Try switching this to wrapt.ObjectProxy again when we drop 395# Python 3.4 support (may still be tricky). 396class ListWrapper( 397 List, 398 collections_abc.MutableSequence, 399 # Shadowed, but there for isinstance checks. 400 list): 401 """Wraps the built-in `list` to support restore-on-create for variables. 402 403 Unlike `List`, this sequence type is mutable in the same ways built-in lists 404 are. Instead of throwing an error immediately like `List`, it records 405 problematic mutations (e.g. assigning a new element to a position already 406 occupied, meaning both elements get the same names at different times) and 407 refuses to save. 408 409 On assignment to an attribute of a Model or Trackable object, Python 410 lists are replaced with ListWrapper. Wrapping a list in a 411 `tf.contrib.checkpoint.NoDependency` object prevents this. 412 """ 413 414 def __init__(self, wrapped_list): 415 """Construct a new list wrapper. 416 417 Args: 418 wrapped_list: The initial value of the data structure. A shallow copy may 419 be maintained for error checking. `wrapped_list` itself should not be 420 modified directly after constructing the `ListWrapper`, and if changes 421 are detected the `ListWrapper` will throw an exception on save. 422 """ 423 # Monotonic flags which indicate this object would not be restored properly, 424 # and therefore should throw an error on save to avoid giving the impression 425 # that restoring it will work. 426 self._non_append_mutation_value = False 427 self._external_modification_value = False 428 super(ListWrapper, self).__init__(wrapped_list) 429 self._last_wrapped_list_snapshot = list(self._storage) 430 431 @property 432 def _non_append_mutation(self): 433 return self._non_append_mutation_value 434 435 @_non_append_mutation.setter 436 def _non_append_mutation(self, value): 437 # Trackable only cares that a mutation occurred at some point; when 438 # attempting to save it checks whether a mutation occurred and the object is 439 # in a "dirty" state but otherwise the specifics of how it got to that state 440 # are ignored. By contrast, the attribute cache needs to signal the mutation 441 # immediately since a caller could query the value of an attribute (And 442 # should not hit the cached value since the mutation may have affected the 443 # result.) 444 self._attribute_sentinel.invalidate_all() 445 self._non_append_mutation_value = value 446 447 @property 448 def _external_modification(self): 449 return self._external_modification_value 450 451 @_external_modification.setter 452 def _external_modification(self, value): 453 # Invalidate for the same reason as `_non_append_mutation` 454 self._attribute_sentinel.invalidate_all() 455 self._external_modification_value = value 456 457 # pylint: disable=protected-access 458 def __copy__(self): 459 copied = super(ListWrapper, self).__copy__() 460 copied._non_append_mutation = self._non_append_mutation 461 copied._external_modification = self._external_modification 462 return copied 463 464 def __deepcopy__(self, memo): 465 copied = super(ListWrapper, self).__deepcopy__(memo) 466 copied._non_append_mutation = self._non_append_mutation 467 copied._external_modification = self._external_modification 468 return copied 469 # pylint: enable=protected-access 470 471 def __reduce_ex__(self, protocol): 472 return (self.__class__, 473 (self._storage,)) 474 475 def _make_storage(self, wrapped_list): 476 """Use the user's original list for storage.""" 477 return wrapped_list 478 479 def _check_external_modification(self): 480 """Checks for any changes to the wrapped list not through the wrapper.""" 481 if self._external_modification or self._non_append_mutation: 482 return 483 if self._storage != self._last_wrapped_list_snapshot: 484 self._external_modification = True 485 self._last_wrapped_list_snapshot = None 486 487 def _update_snapshot(self): 488 """Acknowledges tracked changes to the wrapped list.""" 489 490 # Mutation tracking for attributes reuses the same infrastructure as 491 # Trackable mutation tracking. 492 self._attribute_sentinel.invalidate_all() 493 if self._external_modification or self._non_append_mutation: 494 return 495 self._last_wrapped_list_snapshot = list(self._storage) 496 497 @property 498 def _checkpoint_dependencies(self): 499 self._check_external_modification() 500 if self._non_append_mutation: 501 raise ValueError( 502 ("Unable to save the object %s (a list wrapper constructed to track " 503 "trackable TensorFlow objects). A list element was replaced " 504 "(__setitem__, __setslice__), deleted (__delitem__, __delslice__), " 505 "or moved (sort). In order to support restoration on object " 506 "creation, tracking is exclusively for append-only data structures." 507 "\n\nIf you don't need this list checkpointed, wrap it in a " 508 "tf.contrib.checkpoint.NoDependency object; it will be " 509 "automatically un-wrapped and subsequently ignored." % (self,))) 510 if self._external_modification: 511 raise ValueError( 512 ("Unable to save the object %s (a list wrapper constructed to track " 513 "trackable TensorFlow objects). The wrapped list was modified " 514 "outside the wrapper (its final value was %s, its value when a " 515 "checkpoint dependency was added was %s), which breaks restoration " 516 "on object creation.\n\nIf you don't need this list checkpointed, " 517 "wrap it in a tf.contrib.checkpoint.NoDependency object; it will be " 518 "automatically un-wrapped and subsequently ignored." % ( 519 self, self._storage, self._last_wrapped_list_snapshot))) 520 return super(ListWrapper, self)._checkpoint_dependencies 521 522 def __delitem__(self, key): 523 self._non_append_mutation = True 524 del self._storage[key] 525 526 def __setitem__(self, key, value): 527 self._check_external_modification() 528 529 if isinstance(key, slice): 530 # Note: this is quite inefficient, but the list API supports a broad range 531 # of slice setters (e.g. truncate, extend, replace) and imitating this 532 # for a range of Python versions is non-trivial. 533 storage_copy = list(self._storage) 534 self._storage[key] = value 535 536 len_before = len(storage_copy) 537 len_now = len(self._storage) 538 for i in range(max(len_before, len_now)): 539 value_now = self._storage[i] if i < len_now else None 540 value_before = storage_copy[i] if i < len_before else None 541 542 if isinstance(value_before, base.Trackable): 543 self._non_append_mutation = True 544 545 if value_now is not None and value_now != value_before: 546 self._storage[i] = self._track_value(self._storage[i], 547 self._name_element(i)) 548 549 else: 550 if isinstance(self._storage[key], base.Trackable): 551 self._non_append_mutation = True 552 self._storage[key] = self._track_value(value, self._name_element(key)) 553 554 self._update_snapshot() 555 556 def append(self, value): 557 """Add a new trackable value.""" 558 self._check_external_modification() 559 super(ListWrapper, self).append(value) 560 self._update_snapshot() 561 562 def extend(self, values): 563 """Add a sequence of trackable values.""" 564 self._check_external_modification() 565 super(ListWrapper, self).extend(values) 566 self._update_snapshot() 567 568 def __imul__(self, y): 569 if y <= 0: 570 self._self_non_append_mutation = True 571 self._storage *= y 572 return self 573 574 # Relies on super() calling append, which updates the snapshot. 575 return super(ListWrapper, self).__imul__(y) 576 577 def __eq__(self, other): 578 return self._storage == getattr(other, "_storage", other) 579 580 def __ne__(self, other): 581 return self._storage != getattr(other, "_storage", other) 582 583 def __lt__(self, other): 584 return self._storage < getattr(other, "_storage", other) 585 586 def __le__(self, other): 587 return self._storage <= getattr(other, "_storage", other) 588 589 def __gt__(self, other): 590 return self._storage > getattr(other, "_storage", other) 591 592 def __ge__(self, other): 593 return self._storage >= getattr(other, "_storage", other) 594 595 def __hash__(self): 596 # List wrappers need to compare like regular lists, and so like regular 597 # lists they don't belong in hash tables. 598 raise TypeError("unhashable type: 'ListWrapper'") 599 600 def insert(self, index, obj): 601 self._non_append_mutation = True 602 self._storage.insert(index, obj) 603 604 def sort(self): 605 self._non_append_mutation = True 606 self._storage.sort() 607 608 def __setslice__(self, i, j, y): 609 self.__setitem__(slice(i, j), y) 610 611 def __delslice__(self, i, j): 612 self._non_append_mutation = True 613 del self._storage[slice(i, j)] 614 615 def _track_value(self, value, name): 616 """Allows storage of non-trackable objects.""" 617 try: 618 value = super(ListWrapper, self)._track_value(value=value, name=name) 619 except ValueError: 620 # Even if this value isn't trackable, we need to make sure 621 # NoDependency objects get unwrapped. 622 value = sticky_attribute_assignment( 623 trackable=self, value=value, name=name) 624 return value 625 626 def __repr__(self): 627 return "ListWrapper(%s)" % (repr(self._storage),) 628 629 def _list_functions_for_serialization(self, unused_functions): 630 return { 631 str(key): value for key, value in enumerate(self) 632 if _is_function(value) 633 } 634 635 636class Mapping(TrackableDataStructure, collections_abc.Mapping): 637 """An append-only trackable mapping data structure with string keys. 638 639 Maintains checkpoint dependencies on its contents (which must also be 640 trackable), named based on its keys. 641 642 Note that once a key has been added, it may not be deleted or replaced. If 643 names may not be unique, see `tf.contrib.checkpoint.UniqueNameTracker`. 644 """ 645 646 def __init__(self, *args, **kwargs): 647 """Construct a new sequence. Arguments are passed to `dict()`.""" 648 super(Mapping, self).__init__() 649 self._storage = self._make_storage(*args, **kwargs) 650 self._storage.update( 651 {key: self._track_value( 652 value, name=self._name_element(key)) 653 for key, value in self._storage.items()}) 654 655 def __copy__(self): 656 return type(self)(copy.copy(self._storage)) 657 658 def __deepcopy__(self, memo): 659 return type(self)(copy.deepcopy(self._storage, memo)) 660 661 def _make_storage(self, *args, **kwargs): 662 return dict(*args, **kwargs) 663 664 @property 665 def _values(self): 666 """Collect values for TrackableDataStructure.""" 667 # Sort items deterministically by key 668 ordered = list(zip(*sorted(self.items(), key=lambda it: it[0]))) 669 if ordered: 670 return ordered[1] 671 return [] 672 673 def _name_element(self, key): 674 if not isinstance(key, six.string_types): 675 raise TypeError( 676 "Mapping accepts only string keys, but got a key %s." 677 % repr(key)) 678 return str(key) 679 680 def __setitem__(self, key, value): 681 name = self._name_element(key) 682 value = self._track_value(value, name=name) 683 current_value = self._storage.setdefault(key, value) 684 if current_value is not value: 685 raise ValueError( 686 ("Mappings are an append-only data structure. Tried to overwrite the " 687 "key '%s' with value %s, but it already contains %s") 688 % (key, value, current_value)) 689 690 def update(self, *args, **kwargs): 691 for key, value in dict(*args, **kwargs).items(): 692 self[key] = value 693 694 def __getitem__(self, key): 695 return self._storage[key] 696 697 def __len__(self): 698 return len(self._storage) 699 700 def __repr__(self): 701 return "Mapping(%s)" % (repr(self._storage),) 702 703 def __iter__(self): 704 return iter(self._storage) 705 706 707class _DictWrapper(TrackableDataStructure, wrapt.ObjectProxy): 708 """Wraps built-in dicts to support restore-on-create for variables. 709 710 _DictWrapper is to Mapping as ListWrapper is to List. Unlike Mapping, 711 _DictWrapper allows non-string keys and values and arbitrary mutations (delete 712 keys, reassign values). Like ListWrapper, these mutations mean that 713 _DictWrapper will raise an exception on save. 714 """ 715 716 def __init__(self, wrapped_dict=None): 717 if wrapped_dict is None: 718 # Allow zero-argument construction, e.g. from session.run's re-wrapping. 719 wrapped_dict = {} 720 if not isinstance(wrapped_dict, collections.Mapping): 721 # Allow construction from a sequence, e.g. from nest.pack_sequence_as. 722 wrapped_dict = dict(wrapped_dict) 723 wrapt.ObjectProxy.__init__(self, wrapped_dict) 724 TrackableDataStructure.__init__(self) 725 self._self_non_string_key = False 726 self._self_external_modification = False 727 self.__wrapped__.update( 728 {key: self._track_value( 729 value, name=self._name_element(key)) 730 for key, value in self.__wrapped__.items()}) 731 self._update_snapshot() 732 733 def __reduce_ex__(self, protocol): 734 return (self.__class__, 735 (self.__wrapped__,)) 736 737 def __getattribute__(self, name): 738 if (hasattr(type(self), name) 739 and isinstance(getattr(type(self), name), property)): 740 # Bypass ObjectProxy for properties. Whether this workaround is necessary 741 # appears to depend on the Python version but not the wrapt version: 3.4 742 # in particular seems to look up properties on the wrapped object instead 743 # of the wrapper without this logic. 744 return object.__getattribute__(self, name) 745 else: 746 return super(_DictWrapper, self).__getattribute__(name) 747 748 def copy(self): 749 return copy.copy(self) 750 751 # pylint: disable=protected-access 752 def __copy__(self): 753 copied = _DictWrapper(copy.copy(self.__wrapped__)) 754 copied._self_external_modification = self._self_external_modification 755 copied._self_non_string_key = self._self_non_string_key 756 return copied 757 758 def __deepcopy__(self, memo): 759 copied = _DictWrapper(copy.deepcopy(self.__wrapped__, memo)) 760 copied._self_external_modification = self._self_external_modification 761 copied._self_non_string_key = self._self_non_string_key 762 return copied 763 # pylint: enable=protected-access 764 765 @property 766 def _values(self): 767 """Collect values for TrackableDataStructure.""" 768 # Sort items deterministically by key 769 ordered = list(zip(*sorted(self.items(), key=lambda it: it[0]))) 770 if ordered: 771 return ordered[1] 772 return [] 773 774 @property 775 def _checkpoint_dependencies(self): 776 """Check that the object is saveable before listing its dependencies.""" 777 self._check_self_external_modification() 778 if self._self_non_string_key: 779 raise ValueError( 780 "Unable to save the object %s (a dictionary wrapper constructed " 781 "automatically on attribute assignment). The wrapped dictionary " 782 "contains a non-string key which maps to a trackable object or " 783 "mutable data structure.\n\nIf you don't need this dictionary " 784 "checkpointed, wrap it in a tf.contrib.checkpoint.NoDependency " 785 "object; it will be automatically un-wrapped and subsequently " 786 "ignored." % (self,)) 787 if self._self_external_modification: 788 raise ValueError( 789 "Unable to save the object %s (a dictionary wrapper constructed " 790 "automatically on attribute assignment). The wrapped dictionary was " 791 "modified outside the wrapper (its final value was %s, its value " 792 "when a checkpoint dependency was added was %s), which breaks " 793 "restoration on object creation.\n\nIf you don't need this " 794 "dictionary checkpointed, wrap it in a " 795 "tf.contrib.checkpoint.NoDependency object; it will be automatically " 796 "un-wrapped and subsequently ignored." % ( 797 self, self, self._self_last_wrapped_dict_snapshot)) 798 assert not self._dirty # Any reason for dirtiness should have an exception. 799 return super(_DictWrapper, self)._checkpoint_dependencies 800 801 @property 802 def _dirty(self): 803 """Check if there has already been a mutation which prevents saving.""" 804 return (self._self_external_modification 805 or self._self_non_string_key) 806 807 def _check_self_external_modification(self): 808 """Checks for any changes to the wrapped dict not through the wrapper.""" 809 if self._dirty: 810 return 811 if self != self._self_last_wrapped_dict_snapshot: 812 self._self_external_modification = True 813 self._self_last_wrapped_dict_snapshot = None 814 815 def _update_snapshot(self): 816 """Acknowledges tracked changes to the wrapped dict.""" 817 self._attribute_sentinel.invalidate_all() 818 if self._dirty: 819 return 820 self._self_last_wrapped_dict_snapshot = dict(self) 821 822 def _track_value(self, value, name): 823 """Allows storage of non-trackable objects.""" 824 if isinstance(name, six.string_types): 825 string_key = True 826 else: 827 name = "-non_string_key" 828 string_key = False 829 try: 830 no_dependency = isinstance(value, NoDependency) 831 value = super(_DictWrapper, self)._track_value(value=value, name=name) 832 if not (string_key or no_dependency): 833 # A non-string key maps to a trackable value. This data structure 834 # is not saveable. 835 self._self_non_string_key = True 836 return value 837 except ValueError: 838 # Even if this value isn't trackable, we need to make sure 839 # NoDependency objects get unwrapped. 840 return sticky_attribute_assignment( 841 trackable=self, value=value, name=name) 842 843 def _name_element(self, key): 844 """Tells TrackableDataStructure to use keys as names as-is.""" 845 return key 846 847 def __setitem__(self, key, value): 848 """Allow any modifications, but possibly mark the wrapper as unsaveable.""" 849 self._check_self_external_modification() 850 self._maybe_initialize_trackable() 851 no_dep = isinstance(value, NoDependency) 852 if isinstance(key, six.string_types): 853 value = self._track_value(value, name=key) 854 else: 855 value = _wrap_or_unwrap(value) 856 if not no_dep and isinstance(value, base.Trackable): 857 # Non-string keys are OK as long as we have no reason to add a 858 # dependency on the value (either because the value is not 859 # trackable, or because it was wrapped in a NoDependency object). 860 self._self_non_string_key = True 861 self.__wrapped__[key] = value 862 863 self._update_snapshot() 864 865 def __delitem__(self, key): 866 self._check_self_external_modification() 867 del self.__wrapped__[key] 868 self._update_snapshot() 869 870 def __repr__(self): 871 return "DictWrapper(%s)" % (repr(self.__wrapped__),) 872 873 def __hash__(self): 874 raise TypeError("unhashable type: 'DictWrapper'") 875 876 def __eq__(self, other): 877 # Override the TrackableDataStructure "== -> is" forwarding and go back to 878 # the wrapt implementation. 879 return self.__wrapped__ == other 880 881 def update(self, *args, **kwargs): 882 for key, value in six.iteritems(dict(*args, **kwargs)): 883 self[key] = value 884 885 def _list_functions_for_serialization(self, unused_serialization_cache): 886 return { 887 key: value for key, value in self.items() 888 if _is_function(value) 889 } 890 891 892class _TupleWrapper(TrackableDataStructure, wrapt.ObjectProxy): 893 """Trackable wrapper for tuples and namedtuples.""" 894 895 def __init__(self, original_wrapped_tuple=()): 896 add_dependency = [] 897 substituted_wrapped_tuple = [] 898 for element in original_wrapped_tuple: 899 if isinstance(element, NoDependency): 900 add_dependency.append(False) 901 else: 902 add_dependency.append(True) 903 substituted_wrapped_tuple.append(_wrap_or_unwrap(element)) 904 try: 905 fields = original_wrapped_tuple._fields 906 except AttributeError: 907 # Not a namedtuple 908 is_namedtuple = False 909 else: 910 is_namedtuple = True 911 original_type = type(original_wrapped_tuple) 912 # Flag to poison saving if we can't re-construct a namedtupled because its 913 # __new__ takes different keyword arguments than its _fields. 914 self._self_tuple_is_constructable = True 915 if is_namedtuple: 916 try: 917 # NamedTuples take N arguments, unlike tuple which takes a sequence. 918 substituted_wrapped_tuple = original_type( 919 **dict(zip(fields, substituted_wrapped_tuple))) 920 except TypeError: 921 wrapt.ObjectProxy.__init__(self, original_wrapped_tuple) 922 TrackableDataStructure.__init__(self) 923 self._self_tuple_is_constructable = False 924 return 925 else: 926 substituted_wrapped_tuple = original_type(substituted_wrapped_tuple) 927 wrapt.ObjectProxy.__init__(self, substituted_wrapped_tuple) 928 TrackableDataStructure.__init__(self) 929 930 if is_namedtuple: 931 # For namedtuples, also track by names for compatibility with 932 # dictionaries. 933 for name, should_depend, element in zip( 934 fields, add_dependency, substituted_wrapped_tuple): 935 if should_depend: 936 self._track_value(element, name=name) 937 938 # Track by index as well, for compatibility with lists. 939 for index, (should_depend, element) in enumerate( 940 zip(add_dependency, substituted_wrapped_tuple)): 941 if should_depend: 942 self._track_value(element, name="%d" % (index,)) 943 944 @property 945 def _values(self): 946 """Collect values for TrackableDataStructure.""" 947 return self 948 949 def _track_value(self, value, name): 950 """Allows storage of non-trackable objects.""" 951 try: 952 value = super(_TupleWrapper, self)._track_value(value=value, name=name) 953 except ValueError: 954 # Even if this value isn't trackable, we need to make sure 955 # NoDependency objects get unwrapped. 956 value = sticky_attribute_assignment( 957 trackable=self, value=value, name=name) 958 return value 959 960 def __repr__(self): 961 return "_TupleWrapper(%s)" % (repr(self.__wrapped__),) 962 963 def __hash__(self): 964 # Override the TrackableDataStructure hash forwarding and go back to 965 # the wrapt implementation. 966 return hash(self.__wrapped__) 967 968 def __eq__(self, other): 969 # Override the TrackableDataStructure "== -> is" forwarding and go back to 970 # the wrapt implementation. 971 return self.__wrapped__ == other 972 973 def __copy__(self): 974 return _TupleWrapper(copy.copy(self.__wrapped__)) 975 976 def __deepcopy__(self, memo): 977 return _TupleWrapper(copy.deepcopy(self.__wrapped__, memo)) 978 979 def __reduce_ex__(self, protocol): 980 return (self.__class__, 981 (self.__wrapped__,)) 982 983 # imul and iadd are the only tuple-relevant in-place operators. They need to 984 # be special-cased to avoid mutating the original proxy object. 985 def __imul__(self, y): 986 """Avoid running self.__wrapped__ *= y, which mutates `self`.""" 987 return self.__wrapped__ * y 988 989 def __iadd__(self, y): 990 """Avoid running self.__wrapped__ += y, which mutates `self`.""" 991 return self.__wrapped__ + y 992 993 @property 994 def _checkpoint_dependencies(self): 995 if not self._self_tuple_is_constructable: 996 raise ValueError( 997 ("Unable to save because the namedtuple {} is not constructable from " 998 "its _fields (i.e. __new__ is overridden). Expected keyword " 999 "arguments {}. If you do not need to save this object, consider " 1000 "wrapping it in a custom object that does not inherit from tuple.") 1001 .format(self.__wrapped__, self.__wrapped__._fields)) 1002 return super(_TupleWrapper, self)._checkpoint_dependencies 1003 1004 def __getattribute__(self, name): 1005 if (hasattr(type(self), name) 1006 and isinstance(getattr(type(self), name), property)): 1007 # Bypass ObjectProxy for properties. Whether this workaround is necessary 1008 # appears to depend on the Python version but not the wrapt version: 3.4 1009 # in particular seems to look up properties on the wrapped object instead 1010 # of the wrapper without this logic. 1011 return object.__getattribute__(self, name) 1012 else: 1013 return super(_TupleWrapper, self).__getattribute__(name) 1014 1015 1016def _is_function(x): 1017 return isinstance(x, (def_function.Function, defun.ConcreteFunction)) 1018 1019 1020revived_types.register_revived_type( 1021 "trackable_dict_wrapper", 1022 lambda obj: isinstance(obj, _DictWrapper), 1023 versions=[revived_types.VersionedTypeRegistration( 1024 # Standard dependencies are enough to reconstruct the trackable 1025 # items in dictionaries, so we don't need to save any extra information. 1026 object_factory=lambda proto: _DictWrapper({}), 1027 version=1, 1028 min_producer_version=1, 1029 min_consumer_version=1, 1030 setter=operator.setitem)]) 1031 1032 1033def _set_list_item(list_object, index_string, value): 1034 item_index = int(index_string) 1035 if len(list_object) <= item_index: 1036 list_object.extend([None] * (1 + item_index - len(list_object))) 1037 list_object[item_index] = value 1038 1039 1040revived_types.register_revived_type( 1041 "trackable_list_wrapper", 1042 lambda obj: isinstance(obj, ListWrapper), 1043 versions=[revived_types.VersionedTypeRegistration( 1044 object_factory=lambda proto: ListWrapper([]), 1045 version=1, 1046 min_producer_version=1, 1047 min_consumer_version=1, 1048 setter=_set_list_item)]) 1049 1050 1051def _set_tuple_item(list_object, index_string, value): 1052 try: 1053 item_index = int(index_string) 1054 except ValueError: 1055 # Ignore namedtuple fields. 1056 return 1057 if len(list_object) <= item_index: 1058 list_object.extend([None] * (1 + item_index - len(list_object))) 1059 list_object[item_index] = value 1060 1061 1062# Revive tuples as lists so we can append any dependencies during loading. 1063revived_types.register_revived_type( 1064 "trackable_tuple_wrapper", 1065 lambda obj: isinstance(obj, _TupleWrapper), 1066 versions=[revived_types.VersionedTypeRegistration( 1067 object_factory=lambda proto: ListWrapper([]), 1068 version=1, 1069 min_producer_version=1, 1070 min_consumer_version=1, 1071 setter=_set_tuple_item)]) 1072