1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Type-based dispatch for TensorFlow ops. 16 17"Operation dispatchers" can be used to override the behavior for TensorFlow ops 18when they are called with otherwise unsupported argument types. In particular, 19when an operation is called with arguments that would cause it to raise a 20TypeError, it falls back on its registered operation dispatchers. If any 21registered dispatchers can handle the arguments, then its result is returned. 22Otherwise, the original TypeError is raised. 23 24By default, dispatch support is added to the generated op wrappers for any 25visible ops by default. Ops that are implemented in Python can opt in to 26dispatch support using the `add_dispatch_support` decorator. 27""" 28 29from __future__ import absolute_import 30from __future__ import division 31from __future__ import print_function 32 33import itertools 34 35from tensorflow.python.util import tf_decorator 36from tensorflow.python.util import tf_inspect 37from tensorflow.python.util.tf_export import tf_export 38 39 40# Private function attribute used to store a list of dispatchers. 41DISPATCH_ATTR = "_tf_dispatchers" 42 43 44# OpDispatchers which should be used for all operations. 45_GLOBAL_DISPATCHERS = [] 46 47 48@tf_export("__internal__.dispatch.OpDispatcher", v1=[]) 49class OpDispatcher(object): 50 """Abstract base class for TensorFlow operator dispatchers. 51 52 Each operation dispatcher acts as an override handler for a single 53 TensorFlow operation, and its results are used when the handler indicates 54 that it can handle the operation's arguments (by returning any value other 55 than `OpDispatcher.NOT_SUPPORTED`). 56 """ 57 58 # Sentinel value that can be returned to indicate that an operation 59 # dispatcher does not support a given set of arguments. 60 NOT_SUPPORTED = object() 61 62 def handle(self, args, kwargs): # pylint: disable=unused-argument 63 """Handle this dispatcher's operation with the specified arguments. 64 65 If this operation dispatcher can handle the given arguments, then 66 return an appropriate value (or raise an appropriate exception). 67 68 Args: 69 args: The arguments to the operation. 70 kwargs: They keyword arguments to the operation. 71 72 Returns: 73 The result of the operation, or `OpDispatcher.NOT_SUPPORTED` if this 74 dispatcher can not handle the given arguments. 75 """ 76 return self.NOT_SUPPORTED 77 78 def register(self, op): 79 """Register this dispatcher as a handler for `op`. 80 81 Args: 82 op: Python function: the TensorFlow operation that should be handled. Must 83 have a dispatch list (which is added automatically for generated ops, 84 and can be added to Python ops using the `add_dispatch_support` 85 decorator). 86 """ 87 if not hasattr(op, DISPATCH_ATTR): 88 raise AssertionError("Dispatching not enabled for %s" % op) 89 getattr(op, DISPATCH_ATTR).append(self) 90 91 92@tf_export("__internal__.dispatch.GlobalOpDispatcher", v1=[]) 93class GlobalOpDispatcher(object): 94 """Abstract base class for TensorFlow global operator dispatchers.""" 95 96 NOT_SUPPORTED = OpDispatcher.NOT_SUPPORTED 97 98 def handle(self, op, args, kwargs): 99 """Handle the specified operation with the specified arguments.""" 100 101 def register(self): 102 """Register this dispatcher as a handler for all ops.""" 103 _GLOBAL_DISPATCHERS.append(self) 104 105 106def dispatch(op, args, kwargs): 107 """Returns the result from the first successful dispatcher for a given op. 108 109 Calls the `handle` method of each `OpDispatcher` that has been registered 110 to handle `op`, and returns the value from the first successful handler. 111 112 Args: 113 op: Python function: the operation to dispatch for. 114 args: The arguments to the operation. 115 kwargs: They keyword arguments to the operation. 116 117 Returns: 118 The result of the operation, or `NOT_SUPPORTED` if no registered 119 dispatcher can handle the given arguments. 120 """ 121 for dispatcher in getattr(op, DISPATCH_ATTR): 122 result = dispatcher.handle(args, kwargs) 123 if result is not OpDispatcher.NOT_SUPPORTED: 124 return result 125 for dispatcher in _GLOBAL_DISPATCHERS: 126 result = dispatcher.handle(op, args, kwargs) 127 if result is not OpDispatcher.NOT_SUPPORTED: 128 return result 129 return OpDispatcher.NOT_SUPPORTED 130 131 132class _TypeBasedDispatcher(OpDispatcher): 133 """Dispatcher that handles op if any arguments have a specified type. 134 135 Checks the types of the arguments and keyword arguments (including elements 136 of lists or tuples), and if any argument values have the indicated type(s), 137 then delegates to an override function. 138 """ 139 140 def __init__(self, override_func, types): 141 self._types = types 142 self._override_func = override_func 143 144 def _handles(self, args, kwargs): 145 for arg in itertools.chain(args, kwargs.values()): 146 if (isinstance(arg, self._types) or 147 (isinstance(arg, (list, tuple)) and 148 any(isinstance(elt, self._types) for elt in arg))): 149 return True 150 return False 151 152 def handle(self, args, kwargs): 153 if self._handles(args, kwargs): 154 return self._override_func(*args, **kwargs) 155 else: 156 return self.NOT_SUPPORTED 157 158 159# pylint: disable=g-doc-return-or-yield 160def dispatch_for_types(op, *types): 161 """Decorator to declare that a Python function overrides an op for a type. 162 163 The decorated function is used to override `op` if any of the arguments or 164 keyword arguments (including elements of lists or tuples) have one of the 165 specified types. 166 167 Example: 168 169 ```python 170 @dispatch_for_types(math_ops.add, RaggedTensor, RaggedTensorValue) 171 def ragged_add(x, y, name=None): ... 172 ``` 173 174 Args: 175 op: Python function: the operation that should be overridden. 176 *types: The argument types for which this function should be used. 177 """ 178 179 def decorator(func): 180 if tf_inspect.getargspec(func) != tf_inspect.getargspec(op): 181 raise AssertionError("The decorated function's signature must exactly " 182 "match the signature of the overridden op.") 183 _TypeBasedDispatcher(func, types).register(op) 184 return func 185 186 return decorator 187 188 189# pylint: enable=g-doc-return-or-yield 190 191 192def add_dispatch_list(target): 193 """Decorator that adds a dispatch_list attribute to an op.""" 194 if hasattr(target, DISPATCH_ATTR): 195 raise AssertionError("%s already has a dispatch list" % target) 196 setattr(target, DISPATCH_ATTR, []) 197 return target 198 199 200@tf_export("__internal__.dispatch.add_dispatch_support", v1=[]) 201def add_dispatch_support(target): 202 """Decorator that adds a dispatch handling wrapper to an op.""" 203 def wrapper(*args, **kwargs): 204 """Call target, and fall back on dispatchers if there is a TypeError.""" 205 try: 206 return target(*args, **kwargs) 207 except (TypeError, ValueError): 208 # Note: convert_to_eager_tensor currently raises a ValueError, not a 209 # TypeError, when given unexpected types. So we need to catch both. 210 result = dispatch(wrapper, args, kwargs) 211 if result is not OpDispatcher.NOT_SUPPORTED: 212 return result 213 else: 214 raise 215 216 add_dispatch_list(wrapper) 217 return tf_decorator.make_decorator(target, wrapper) 218