1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities related to layer/model functionality.""" 16 17# TODO(b/110718070): Move these functions back to tensorflow/python/keras/utils 18# once __init__ files no longer require all of tf.keras to be imported together. 19 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24import collections 25import functools 26import weakref 27 28from tensorflow.python.util import object_identity 29 30try: 31 # typing module is only used for comment type annotations. 32 import typing # pylint: disable=g-import-not-at-top, unused-import 33except ImportError: 34 pass 35 36 37def is_layer(obj): 38 """Implicit check for Layer-like objects.""" 39 # TODO(b/110718070): Replace with isinstance(obj, base_layer.Layer). 40 return hasattr(obj, "_is_layer") and not isinstance(obj, type) 41 42 43def has_weights(obj): 44 """Implicit check for Layer-like objects.""" 45 # TODO(b/110718070): Replace with isinstance(obj, base_layer.Layer). 46 has_weight = (hasattr(type(obj), "trainable_weights") 47 and hasattr(type(obj), "non_trainable_weights")) 48 49 return has_weight and not isinstance(obj, type) 50 51 52def invalidate_recursive_cache(key): 53 """Convenience decorator to invalidate the cache when setting attributes.""" 54 def outer(f): 55 @functools.wraps(f) 56 def wrapped(self, value): 57 sentinel = getattr(self, "_attribute_sentinel") # type: AttributeSentinel 58 sentinel.invalidate(key) 59 return f(self, value) 60 return wrapped 61 return outer 62 63 64class MutationSentinel(object): 65 """Container for tracking whether a property is in a cached state.""" 66 _in_cached_state = False 67 68 def mark_as(self, value): # type: (MutationSentinel, bool) -> bool 69 may_affect_upstream = (value != self._in_cached_state) 70 self._in_cached_state = value 71 return may_affect_upstream 72 73 @property 74 def in_cached_state(self): 75 return self._in_cached_state 76 77 78class AttributeSentinel(object): 79 """Container for managing attribute cache state within a Layer. 80 81 The cache can be invalidated either on an individual basis (for instance when 82 an attribute is mutated) or a layer-wide basis (such as when a new dependency 83 is added). 84 """ 85 86 def __init__(self, always_propagate=False): 87 self._parents = weakref.WeakSet() 88 self.attributes = collections.defaultdict(MutationSentinel) 89 90 # The trackable data structure containers are simple pass throughs. They 91 # don't know or care about particular attributes. As a result, they will 92 # consider themselves to be in a cached state, so it's up to the Layer 93 # which contains them to terminate propagation. 94 self.always_propagate = always_propagate 95 96 def __repr__(self): 97 return "{}\n {}".format( 98 super(AttributeSentinel, self).__repr__(), 99 {k: v.in_cached_state for k, v in self.attributes.items()}) 100 101 def add_parent(self, node): 102 # type: (AttributeSentinel, AttributeSentinel) -> None 103 104 # Properly tracking removal is quite challenging; however since this is only 105 # used to invalidate a cache it's alright to be overly conservative. We need 106 # to invalidate the cache of `node` (since it has implicitly gained a child) 107 # but we don't need to invalidate self since attributes should not depend on 108 # parent Layers. 109 self._parents.add(node) 110 node.invalidate_all() 111 112 def get(self, key): 113 # type: (AttributeSentinel, str) -> bool 114 return self.attributes[key].in_cached_state 115 116 def _set(self, key, value): 117 # type: (AttributeSentinel, str, bool) -> None 118 may_affect_upstream = self.attributes[key].mark_as(value) 119 if may_affect_upstream or self.always_propagate: 120 for node in self._parents: # type: AttributeSentinel 121 node.invalidate(key) 122 123 def mark_cached(self, key): 124 # type: (AttributeSentinel, str) -> None 125 self._set(key, True) 126 127 def invalidate(self, key): 128 # type: (AttributeSentinel, str) -> None 129 self._set(key, False) 130 131 def invalidate_all(self): 132 # Parents may have different keys than their children, so we locally 133 # invalidate but use the `invalidate_all` method of parents. 134 for key in self.attributes.keys(): 135 self.attributes[key].mark_as(False) 136 137 for node in self._parents: 138 node.invalidate_all() 139 140 141def filter_empty_layer_containers(layer_list): 142 """Filter out empty Layer-like containers and uniquify.""" 143 # TODO(b/130381733): Make this an attribute in base_layer.Layer. 144 existing = object_identity.ObjectIdentitySet() 145 to_visit = layer_list[::-1] 146 while to_visit: 147 obj = to_visit.pop() 148 if obj in existing: 149 continue 150 existing.add(obj) 151 if is_layer(obj): 152 yield obj 153 else: 154 sub_layers = getattr(obj, "layers", None) or [] 155 156 # Trackable data structures will not show up in ".layers" lists, but 157 # the layers they contain will. 158 to_visit.extend(sub_layers[::-1]) 159