1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Cache to manage concrete functions and their signatures.""" 16 17import collections 18from typing import Optional, Hashable, Sequence, Any, NamedTuple, Dict 19 20from tensorflow.core.function import trace_type 21from tensorflow.core.function.polymorphism import type_dispatch 22from tensorflow.python.types import trace 23 24# TODO(b/182990542): Enable and remove flag when stable. 25DELETE_WITH_WEAKREF = False 26 27 28class FunctionContext(NamedTuple): 29 """Contains information regarding tf.function execution context.""" 30 context: Any 31 32 33class CaptureSnapshot(trace.TraceType): 34 """Store tf.function captures to accommodate its specific tracing logic. 35 36 Captures are stored in mapping format, but its tracing logic is different from 37 Python dict. When comparing types of two normal Python dicts in function 38 argumenst, their keys are required to be the same. When comparing types for 39 captures, keys can be different. This is because tf.function maintains a full 40 list of captures and only a subset is active for each ConcreteFunction. 41 But before dispatch, which captures are active is unknown, so all caputres are 42 evaluated for comparison. Please also check `is_subtype_of` method. 43 44 Attributes: 45 mapping: A mapping from keys to corresponding TraceTypes of the dict values. 46 """ 47 48 def __init__(self, mapping: Dict[Hashable, trace.TraceType]): 49 self.mapping = mapping 50 51 def _contain_all_keys_of(self, other): 52 for key in other.mapping: 53 if key not in self.mapping: 54 return False 55 return True 56 57 def is_subtype_of(self, query: "CaptureSnapshot") -> bool: 58 """This method is used to check if `self` is a subtype of query. 59 60 Typically, self represents an existing snapshot for a ConcreteFunction, and 61 the query is a snapshot from all captures with runtime values. Keys in the 62 query should be a superset of self. 63 This method differs from default_types.Dict as this CaptureSnapshot doesn't 64 require a full match of keys. 65 66 For example: 67 68 a = CaptureSnapshot({'x'=1, 'y'=2}) 69 b = CaptureSnapshot({'x'=1, 'y'=2, 'z'=3}) 70 assert not a.is_subtype_of(b) 71 assert b.is_subtype_of(a) 72 73 Args: 74 query: A CaptureSnapshot instance that represents the current runtime 75 values of all captures. 76 77 Returns: 78 A bool value represents the result. 79 """ 80 if not isinstance(query, CaptureSnapshot): 81 return False 82 83 if not self._contain_all_keys_of(query): 84 return False 85 return all(self.mapping[key].is_subtype_of(item) 86 for key, item in query.mapping.items()) 87 88 def most_specific_common_supertype( 89 self, 90 types: Sequence[trace.TraceType]) -> Optional["CaptureSnapshot"]: 91 """See base class.""" 92 common_keys = set(self.mapping.keys()) 93 for other in types: 94 common_keys = common_keys.intersection(other.mapping.keys()) 95 new_mapping = {} 96 for key in common_keys: 97 common = self.mapping[key].most_specific_common_supertype( 98 [other.mapping[key] for other in types]) 99 if common is None: 100 return None 101 else: 102 new_mapping[key] = common 103 return CaptureSnapshot(new_mapping) 104 105 def _placeholder_value(self) -> Any: 106 return { 107 key: value._placeholder_value() # pylint: disable=protected-access 108 for key, value in self.mapping.items() 109 } 110 111 def __eq__(self, other: "CaptureSnapshot") -> bool: 112 if not isinstance(other, CaptureSnapshot): 113 return False 114 115 return self.mapping == other.mapping 116 117 def __hash__(self) -> int: 118 return hash(frozenset(self.mapping.keys())) 119 120 121# TODO(panzf): Rename `FunctionCacheKey` to `FunctionType` 122class FunctionCacheKey(trace.TraceType): 123 """The unique key associated with a concrete function. 124 125 Attributes: 126 args_signature: A TraceType corresponding to the function arguments. 127 captures_signature: A CaptureSnapshot corresponding to the function 128 captures. 129 call_context: The FunctionContext for when the args_signature was 130 generated. 131 """ 132 133 def __init__(self, args_signature: trace.TraceType, 134 captures_signature: CaptureSnapshot, 135 call_context: FunctionContext): 136 self.args_signature = args_signature 137 self.captures_signature = captures_signature 138 self.call_context = call_context 139 140 def is_subtype_of(self, other: trace.TraceType) -> bool: 141 if not isinstance(other, FunctionCacheKey): 142 return False 143 144 if self.call_context != other.call_context: 145 return False 146 147 return (self.args_signature.is_subtype_of(other.args_signature) 148 and self.captures_signature.is_subtype_of(other.captures_signature)) 149 150 def most_specific_common_supertype( 151 self, others: Sequence[trace.TraceType]) -> Optional["FunctionCacheKey"]: 152 if not all( 153 isinstance(other, FunctionCacheKey) and 154 self.call_context == other.call_context for other in others): 155 return None 156 157 # `args` and `captures` are independent when finding common supertypes. 158 args_common = self.args_signature.most_specific_common_supertype( 159 [other.args_signature for other in others]) 160 161 if args_common is None: 162 return None 163 164 captures_common = self.captures_signature.most_specific_common_supertype( 165 [other.captures_signature for other in others]) 166 167 return FunctionCacheKey(args_common, captures_common, self.call_context) 168 169 def _placeholder_value(self) -> Any: 170 """Value used for tracing a function signature with this TraceType.""" 171 return {"args": self.args_signature._placeholder_value(), # pylint: disable=protected-access 172 "captures": self.captures_signature._placeholder_value()} # pylint: disable=protected-access 173 174 def __hash__(self) -> int: 175 return hash((self.call_context, 176 self.args_signature, 177 self.captures_signature)) 178 179 def __eq__(self, other) -> bool: 180 if not isinstance(other, trace.TraceType): 181 return NotImplemented 182 183 if not isinstance(other, FunctionCacheKey): 184 return False 185 186 return (self.call_context == other.call_context and 187 self.args_signature == other.args_signature and 188 self.captures_signature == other.captures_signature) 189 190 def __repr__(self) -> str: 191 return ( 192 f"{type(self).__name__}(args_signature={repr(self.args_signature)}," 193 f"(captures_signature={repr(self.captures_signature)}," 194 f" call_context={repr(self.call_context)})") 195 196 197# TODO(fmuham): Rename to FunctionLibrary. 198class FunctionCache: 199 """A container for managing concrete functions.""" 200 201 __slots__ = [ 202 "_primary", "_dispatch_table", "_garbage_collectors" 203 ] 204 205 def __init__(self): 206 # The primary cache, mapping FunctionCacheKey to a concrete function. 207 self._primary = collections.OrderedDict() 208 209 # Maps a FunctionCacheKey K to a FunctionCacheKey V such that it is safe 210 # to dispatch K to the concrete function of V that exists in _primary. 211 # Used to lookup posible concrete functions when K is not in _primary. 212 self._dispatch_table = type_dispatch.TypeDispatchTable() 213 214 # Note: Instead of returning any viable function, we can return the most 215 # specfic one by maintaining trees of traces where children are more specific 216 # traces of their parents. 217 def lookup(self, key: FunctionCacheKey, use_function_subtyping: bool): 218 """Looks up a concrete function based on the key.""" 219 if not use_function_subtyping: 220 return self._primary.get(key, None) 221 222 dispatch_key = self._dispatch_table.dispatch(key) 223 if dispatch_key is not None: 224 return self._primary[dispatch_key] 225 226 return None 227 228 def delete(self, key: FunctionCacheKey): 229 """Deletes a concrete function given the key it was added with.""" 230 if key not in self._primary: 231 return False 232 233 del self._primary[key] 234 self._dispatch_table.delete(key) 235 236 return True 237 238 def add(self, key: FunctionCacheKey, 239 deletion_observer: trace_type.WeakrefDeletionObserver, 240 concrete): 241 """Adds a new concrete function alongside its key. 242 243 Args: 244 key: A FunctionCacheKey object corresponding to the provided `concrete`. 245 deletion_observer: A WeakrefDeletionObserver object for the `key`. 246 concrete: The concrete function to be added to the cache. 247 """ 248 self._primary[key] = concrete 249 self._dispatch_table.add_target(key) 250 deletion_observer.add_listener( 251 lambda: self.delete(key) if DELETE_WITH_WEAKREF else None) 252 253 def generalize(self, key: FunctionCacheKey) -> FunctionCacheKey: 254 return self._dispatch_table.try_generalizing_trace_type(key) # pylint: disable=protected-access 255 256 # TODO(b/205971333): Remove this function. 257 def clear(self): 258 """Removes all concrete functions from the cache.""" 259 self._primary.clear() 260 self._dispatch_table.clear() 261 262 def values(self): 263 """Returns a list of all `ConcreteFunction` instances held by this cache.""" 264 return list(self._primary.values()) 265