• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Cache to manage concrete functions and their signatures."""
16
17import collections
18from typing import Optional, Hashable, Sequence, Any, NamedTuple, Dict
19
20from tensorflow.core.function import trace_type
21from tensorflow.core.function.polymorphism import type_dispatch
22from tensorflow.python.types import trace
23
24# TODO(b/182990542): Enable and remove flag when stable.
25DELETE_WITH_WEAKREF = False
26
27
28class FunctionContext(NamedTuple):
29  """Contains information regarding tf.function execution context."""
30  context: Any
31
32
33class CaptureSnapshot(trace.TraceType):
34  """Store tf.function captures to accommodate its specific tracing logic.
35
36  Captures are stored in mapping format, but its tracing logic is different from
37  Python dict. When comparing types of two normal Python dicts in function
38  argumenst, their keys are required to be the same. When comparing types for
39  captures, keys can be different. This is because tf.function maintains a full
40  list of captures and only a subset is active for each ConcreteFunction.
41  But before dispatch, which captures are active is unknown, so all caputres are
42  evaluated for comparison. Please also check `is_subtype_of` method.
43
44  Attributes:
45    mapping: A mapping from keys to corresponding TraceTypes of the dict values.
46  """
47
48  def __init__(self, mapping: Dict[Hashable, trace.TraceType]):
49    self.mapping = mapping
50
51  def _contain_all_keys_of(self, other):
52    for key in other.mapping:
53      if key not in self.mapping:
54        return False
55    return True
56
57  def is_subtype_of(self, query: "CaptureSnapshot") -> bool:
58    """This method is used to check if `self` is a subtype of query.
59
60    Typically, self represents an existing snapshot for a ConcreteFunction, and
61    the query is a snapshot from all captures with runtime values. Keys in the
62    query should be a superset of self.
63    This method differs from default_types.Dict as this CaptureSnapshot doesn't
64    require a full match of keys.
65
66      For example:
67
68      a = CaptureSnapshot({'x'=1, 'y'=2})
69      b = CaptureSnapshot({'x'=1, 'y'=2, 'z'=3})
70      assert not a.is_subtype_of(b)
71      assert b.is_subtype_of(a)
72
73    Args:
74      query: A CaptureSnapshot instance that represents the current runtime
75        values of all captures.
76
77    Returns:
78      A bool value represents the result.
79    """
80    if not isinstance(query, CaptureSnapshot):
81      return False
82
83    if not self._contain_all_keys_of(query):
84      return False
85    return all(self.mapping[key].is_subtype_of(item)
86               for key, item in query.mapping.items())
87
88  def most_specific_common_supertype(
89      self,
90      types: Sequence[trace.TraceType]) -> Optional["CaptureSnapshot"]:
91    """See base class."""
92    common_keys = set(self.mapping.keys())
93    for other in types:
94      common_keys = common_keys.intersection(other.mapping.keys())
95    new_mapping = {}
96    for key in common_keys:
97      common = self.mapping[key].most_specific_common_supertype(
98          [other.mapping[key] for other in types])
99      if common is None:
100        return None
101      else:
102        new_mapping[key] = common
103    return CaptureSnapshot(new_mapping)
104
105  def _placeholder_value(self) -> Any:
106    return {
107        key: value._placeholder_value()  # pylint: disable=protected-access
108        for key, value in self.mapping.items()
109    }
110
111  def __eq__(self, other: "CaptureSnapshot") -> bool:
112    if not isinstance(other, CaptureSnapshot):
113      return False
114
115    return self.mapping == other.mapping
116
117  def __hash__(self) -> int:
118    return hash(frozenset(self.mapping.keys()))
119
120
121# TODO(panzf): Rename `FunctionCacheKey` to `FunctionType`
122class FunctionCacheKey(trace.TraceType):
123  """The unique key associated with a concrete function.
124
125  Attributes:
126    args_signature: A TraceType corresponding to the function arguments.
127    captures_signature: A CaptureSnapshot corresponding to the function
128      captures.
129    call_context: The FunctionContext for when the args_signature was
130      generated.
131  """
132
133  def __init__(self, args_signature: trace.TraceType,
134               captures_signature: CaptureSnapshot,
135               call_context: FunctionContext):
136    self.args_signature = args_signature
137    self.captures_signature = captures_signature
138    self.call_context = call_context
139
140  def is_subtype_of(self, other: trace.TraceType) -> bool:
141    if not isinstance(other, FunctionCacheKey):
142      return False
143
144    if self.call_context != other.call_context:
145      return False
146
147    return (self.args_signature.is_subtype_of(other.args_signature)
148            and self.captures_signature.is_subtype_of(other.captures_signature))
149
150  def most_specific_common_supertype(
151      self, others: Sequence[trace.TraceType]) -> Optional["FunctionCacheKey"]:
152    if not all(
153        isinstance(other, FunctionCacheKey) and
154        self.call_context == other.call_context for other in others):
155      return None
156
157    # `args` and `captures` are independent when finding common supertypes.
158    args_common = self.args_signature.most_specific_common_supertype(
159        [other.args_signature for other in others])
160
161    if args_common is None:
162      return None
163
164    captures_common = self.captures_signature.most_specific_common_supertype(
165        [other.captures_signature for other in others])
166
167    return FunctionCacheKey(args_common, captures_common, self.call_context)
168
169  def _placeholder_value(self) -> Any:
170    """Value used for tracing a function signature with this TraceType."""
171    return {"args": self.args_signature._placeholder_value(),  # pylint: disable=protected-access
172            "captures": self.captures_signature._placeholder_value()}  # pylint: disable=protected-access
173
174  def __hash__(self) -> int:
175    return hash((self.call_context,
176                 self.args_signature,
177                 self.captures_signature))
178
179  def __eq__(self, other) -> bool:
180    if not isinstance(other, trace.TraceType):
181      return NotImplemented
182
183    if not isinstance(other, FunctionCacheKey):
184      return False
185
186    return (self.call_context == other.call_context and
187            self.args_signature == other.args_signature and
188            self.captures_signature == other.captures_signature)
189
190  def __repr__(self) -> str:
191    return (
192        f"{type(self).__name__}(args_signature={repr(self.args_signature)},"
193        f"(captures_signature={repr(self.captures_signature)},"
194        f" call_context={repr(self.call_context)})")
195
196
197# TODO(fmuham): Rename to FunctionLibrary.
198class FunctionCache:
199  """A container for managing concrete functions."""
200
201  __slots__ = [
202      "_primary", "_dispatch_table", "_garbage_collectors"
203  ]
204
205  def __init__(self):
206    # The primary cache, mapping FunctionCacheKey to a concrete function.
207    self._primary = collections.OrderedDict()
208
209    # Maps a FunctionCacheKey K to a FunctionCacheKey V such that it is safe
210    # to dispatch K to the concrete function of V that exists in _primary.
211    # Used to lookup posible concrete functions when K is not in _primary.
212    self._dispatch_table = type_dispatch.TypeDispatchTable()
213
214  # Note: Instead of returning any viable function, we can return the most
215  # specfic one by maintaining trees of traces where children are more specific
216  # traces of their parents.
217  def lookup(self, key: FunctionCacheKey, use_function_subtyping: bool):
218    """Looks up a concrete function based on the key."""
219    if not use_function_subtyping:
220      return self._primary.get(key, None)
221
222    dispatch_key = self._dispatch_table.dispatch(key)
223    if dispatch_key is not None:
224      return self._primary[dispatch_key]
225
226    return None
227
228  def delete(self, key: FunctionCacheKey):
229    """Deletes a concrete function given the key it was added with."""
230    if key not in self._primary:
231      return False
232
233    del self._primary[key]
234    self._dispatch_table.delete(key)
235
236    return True
237
238  def add(self, key: FunctionCacheKey,
239          deletion_observer: trace_type.WeakrefDeletionObserver,
240          concrete):
241    """Adds a new concrete function alongside its key.
242
243    Args:
244      key: A FunctionCacheKey object corresponding to the provided `concrete`.
245      deletion_observer: A WeakrefDeletionObserver object for the `key`.
246      concrete: The concrete function to be added to the cache.
247    """
248    self._primary[key] = concrete
249    self._dispatch_table.add_target(key)
250    deletion_observer.add_listener(
251        lambda: self.delete(key) if DELETE_WITH_WEAKREF else None)
252
253  def generalize(self, key: FunctionCacheKey) -> FunctionCacheKey:
254    return self._dispatch_table.try_generalizing_trace_type(key)  # pylint: disable=protected-access
255
256  # TODO(b/205971333): Remove this function.
257  def clear(self):
258    """Removes all concrete functions from the cache."""
259    self._primary.clear()
260    self._dispatch_table.clear()
261
262  def values(self):
263    """Returns a list of all `ConcreteFunction` instances held by this cache."""
264    return list(self._primary.values())
265