1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Polymorphic Type Dispatch.""" 16 17import collections 18from typing import Optional, Iterable 19 20from tensorflow.python.types import trace 21 22# The maximum number of dispatch lookups to cache. 23_MAX_DISPATCH_CACHE = 1024 24 25 26class TypeDispatchTable: 27 """Type dispatch table implementation. 28 29 A type dispatch table is a list, L, of target types. Given a request type, R, 30 the table selects a target type, T, according to the following dispatch rules: 31 1. R == T or R is subtype of T 32 2. There does not exist O in L such that R is subtype of O and O is a 33 subtype of T (in other words, T is the closest to R, within list L). 34 3. If the above two rules are satisfied by multiple targets, the earliest 35 inserted one is chosen. 36 """ 37 38 def __init__(self): 39 """Creates a TypeDispatchTable object.""" 40 # Holds all inserted types as keys mapping to None. 41 # (Using OrderedDict as a set for determinism) 42 self._dispatch_table = collections.OrderedDict() 43 44 # LRU cache for dispatch results. 45 # Maps request types to target types (see class description). 46 # Does not contain exact matches, i.e, if cache[a] is b then a is not b. 47 self._dispatch_cache = collections.OrderedDict() 48 49 def add_target(self, target: trace.TraceType) -> None: 50 """Adds a new target type.""" 51 self._dispatch_table[target] = None 52 for request in self._dispatch_cache: 53 if target.is_subtype_of(self._dispatch_cache[request]): 54 self._dispatch_cache[request] = target 55 56 @property 57 def targets(self) -> Iterable[trace.TraceType]: 58 """Returns an iterable to all targets in the table.""" 59 return self._dispatch_table.keys() 60 61 def delete(self, target: trace.TraceType) -> None: 62 """Deletes a target in the table if it exists.""" 63 if target in self._dispatch_table: 64 del self._dispatch_table[target] 65 for request in list(self._dispatch_cache.keys()): 66 if self._dispatch_cache[request] == target: 67 del self._dispatch_cache[request] 68 69 # TODO(b/205971333): remove once FunctionCache 'clear' is removed. 70 def clear(self) -> None: 71 """Deletes all targets in the table.""" 72 self._dispatch_table.clear() 73 self._dispatch_cache.clear() 74 75 def dispatch(self, request: trace.TraceType) -> Optional[trace.TraceType]: 76 """Returns the deepest subtype target if it exists in the table.""" 77 # For known exact matches. 78 if request in self._dispatch_table: 79 return request 80 81 # For known non-exact matches. 82 # (self._dispatch cache does not contain exact matches) 83 if request in self._dispatch_cache: 84 # Move to the front of LRU cache. 85 result = self._dispatch_cache.pop(request) 86 self._dispatch_cache[request] = result 87 return result 88 89 most_specific_subtype = None 90 for other in self._dispatch_table: 91 if request.is_subtype_of(other): 92 if most_specific_subtype is None or other.is_subtype_of( 93 most_specific_subtype): 94 most_specific_subtype = other 95 96 self._cache_dispatch(request, most_specific_subtype) 97 return most_specific_subtype 98 99 def _cache_dispatch(self, request, target): 100 """Caches the dispatch lookup result for a target.""" 101 if target is not None: 102 # LRU Cache removes oldest item 103 if len(self._dispatch_cache) > _MAX_DISPATCH_CACHE: 104 self._dispatch_cache.popitem(last=False) 105 self._dispatch_cache[request] = target 106 107 def try_generalizing_trace_type(self, 108 target: trace.TraceType) -> trace.TraceType: 109 """Returns a generalized subtype of the one given. 110 111 This heuristic aims to reduce the number of future traces by computing a 112 type that represents more general function inputs. 113 114 The original "experimental_relax_shapes" heuristic identified a known type 115 which shared a common supertype with the current unknown type and then 116 traced with that common supertype. However, the notion of "common supertype" 117 was only limited to shapes. This heuristic extends that to TraceType. 118 119 Returns `target` if a common supertype can not be found. 120 121 Args: 122 target: The TraceType to generalize 123 """ 124 relaxed = target 125 for other in self._dispatch_table: 126 supertype = relaxed.most_specific_common_supertype([other]) 127 if supertype is not None: 128 relaxed = supertype 129 return relaxed 130