• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Polymorphic Type Dispatch."""
16
17import collections
18from typing import Optional, Iterable
19
20from tensorflow.python.types import trace
21
22# The maximum number of dispatch lookups to cache.
23_MAX_DISPATCH_CACHE = 1024
24
25
26class TypeDispatchTable:
27  """Type dispatch table implementation.
28
29  A type dispatch table is a list, L, of target types. Given a request type, R,
30  the table selects a target type, T, according to the following dispatch rules:
31    1. R == T or R is subtype of T
32    2. There does not exist O in L such that R is subtype of O and O is a
33       subtype of T (in other words, T is the closest to R, within list L).
34    3. If the above two rules are satisfied by multiple targets, the earliest
35       inserted one is chosen.
36  """
37
38  def __init__(self):
39    """Creates a TypeDispatchTable object."""
40    # Holds all inserted types as keys mapping to None.
41    # (Using OrderedDict as a set for determinism)
42    self._dispatch_table = collections.OrderedDict()
43
44    # LRU cache for dispatch results.
45    # Maps request types to target types (see class description).
46    # Does not contain exact matches, i.e, if cache[a] is b then a is not b.
47    self._dispatch_cache = collections.OrderedDict()
48
49  def add_target(self, target: trace.TraceType) -> None:
50    """Adds a new target type."""
51    self._dispatch_table[target] = None
52    for request in self._dispatch_cache:
53      if target.is_subtype_of(self._dispatch_cache[request]):
54        self._dispatch_cache[request] = target
55
56  @property
57  def targets(self) -> Iterable[trace.TraceType]:
58    """Returns an iterable to all targets in the table."""
59    return self._dispatch_table.keys()
60
61  def delete(self, target: trace.TraceType) -> None:
62    """Deletes a target in the table if it exists."""
63    if target in self._dispatch_table:
64      del self._dispatch_table[target]
65      for request in list(self._dispatch_cache.keys()):
66        if self._dispatch_cache[request] == target:
67          del self._dispatch_cache[request]
68
69  # TODO(b/205971333): remove once FunctionCache 'clear' is removed.
70  def clear(self) -> None:
71    """Deletes all targets in the table."""
72    self._dispatch_table.clear()
73    self._dispatch_cache.clear()
74
75  def dispatch(self, request: trace.TraceType) -> Optional[trace.TraceType]:
76    """Returns the deepest subtype target if it exists in the table."""
77    # For known exact matches.
78    if request in self._dispatch_table:
79      return request
80
81    # For known non-exact matches.
82    # (self._dispatch cache does not contain exact matches)
83    if request in self._dispatch_cache:
84      # Move to the front of LRU cache.
85      result = self._dispatch_cache.pop(request)
86      self._dispatch_cache[request] = result
87      return result
88
89    most_specific_subtype = None
90    for other in self._dispatch_table:
91      if request.is_subtype_of(other):
92        if most_specific_subtype is None or other.is_subtype_of(
93            most_specific_subtype):
94          most_specific_subtype = other
95
96    self._cache_dispatch(request, most_specific_subtype)
97    return most_specific_subtype
98
99  def _cache_dispatch(self, request, target):
100    """Caches the dispatch lookup result for a target."""
101    if target is not None:
102      # LRU Cache removes oldest item
103      if len(self._dispatch_cache) > _MAX_DISPATCH_CACHE:
104        self._dispatch_cache.popitem(last=False)
105      self._dispatch_cache[request] = target
106
107  def try_generalizing_trace_type(self,
108                                  target: trace.TraceType) -> trace.TraceType:
109    """Returns a generalized subtype of the one given.
110
111    This heuristic aims to reduce the number of future traces by computing a
112    type that represents more general function inputs.
113
114    The original "experimental_relax_shapes" heuristic identified a known type
115    which shared a common supertype with the current unknown type and then
116    traced with that common supertype. However, the notion of "common supertype"
117    was only limited to shapes. This heuristic extends that to TraceType.
118
119    Returns `target` if a common supertype can not be found.
120
121    Args:
122      target: The TraceType to generalize
123    """
124    relaxed = target
125    for other in self._dispatch_table:
126      supertype = relaxed.most_specific_common_supertype([other])
127      if supertype is not None:
128        relaxed = supertype
129    return relaxed
130