1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Type-based dispatch for TensorFlow ops. 16 17"Operation dispatchers" can be used to override the behavior for TensorFlow ops 18when they are called with otherwise unsupported argument types. In particular, 19when an operation is called with arguments that would cause it to raise a 20TypeError, it falls back on its registered operation dispatchers. If any 21registered dispatchers can handle the arguments, then its result is returned. 22Otherwise, the original TypeError is raised. 23 24By default, dispatch support is added to the generated op wrappers for any 25visible ops by default. Ops that are implemented in Python can opt in to 26dispatch support using the `add_dispatch_support` decorator. 27""" 28 29from __future__ import absolute_import 30from __future__ import division 31from __future__ import print_function 32 33import itertools 34 35from tensorflow.python.util import tf_decorator 36from tensorflow.python.util import tf_inspect 37from tensorflow.python.util import traceback_utils 38from tensorflow.python.util.tf_export import tf_export 39 40 41# Private function attribute used to store a list of dispatchers. 42DISPATCH_ATTR = "_tf_dispatchers" 43 44 45# OpDispatchers which should be used for all operations. 46_GLOBAL_DISPATCHERS = [] 47 48 49@tf_export("__internal__.dispatch.OpDispatcher", v1=[]) 50class OpDispatcher(object): 51 """Abstract base class for TensorFlow operator dispatchers. 52 53 Each operation dispatcher acts as an override handler for a single 54 TensorFlow operation, and its results are used when the handler indicates 55 that it can handle the operation's arguments (by returning any value other 56 than `OpDispatcher.NOT_SUPPORTED`). 57 """ 58 59 # Sentinel value that can be returned to indicate that an operation 60 # dispatcher does not support a given set of arguments. 61 NOT_SUPPORTED = object() 62 63 def handle(self, args, kwargs): # pylint: disable=unused-argument 64 """Handle this dispatcher's operation with the specified arguments. 65 66 If this operation dispatcher can handle the given arguments, then 67 return an appropriate value (or raise an appropriate exception). 68 69 Args: 70 args: The arguments to the operation. 71 kwargs: They keyword arguments to the operation. 72 73 Returns: 74 The result of the operation, or `OpDispatcher.NOT_SUPPORTED` if this 75 dispatcher can not handle the given arguments. 76 """ 77 return self.NOT_SUPPORTED 78 79 def register(self, op): 80 """Register this dispatcher as a handler for `op`. 81 82 Args: 83 op: Python function: the TensorFlow operation that should be handled. Must 84 have a dispatch list (which is added automatically for generated ops, 85 and can be added to Python ops using the `add_dispatch_support` 86 decorator). 87 """ 88 if not hasattr(op, DISPATCH_ATTR): 89 raise AssertionError("Dispatching not enabled for %s" % op) 90 getattr(op, DISPATCH_ATTR).append(self) 91 92 93@tf_export("__internal__.dispatch.GlobalOpDispatcher", v1=[]) 94class GlobalOpDispatcher(object): 95 """Abstract base class for TensorFlow global operator dispatchers.""" 96 97 NOT_SUPPORTED = OpDispatcher.NOT_SUPPORTED 98 99 def handle(self, op, args, kwargs): 100 """Handle the specified operation with the specified arguments.""" 101 102 def register(self): 103 """Register this dispatcher as a handler for all ops.""" 104 _GLOBAL_DISPATCHERS.append(self) 105 106 107def dispatch(op, args, kwargs): 108 """Returns the result from the first successful dispatcher for a given op. 109 110 Calls the `handle` method of each `OpDispatcher` that has been registered 111 to handle `op`, and returns the value from the first successful handler. 112 113 Args: 114 op: Python function: the operation to dispatch for. 115 args: The arguments to the operation. 116 kwargs: They keyword arguments to the operation. 117 118 Returns: 119 The result of the operation, or `NOT_SUPPORTED` if no registered 120 dispatcher can handle the given arguments. 121 """ 122 for dispatcher in getattr(op, DISPATCH_ATTR): 123 result = dispatcher.handle(args, kwargs) 124 if result is not OpDispatcher.NOT_SUPPORTED: 125 return result 126 for dispatcher in _GLOBAL_DISPATCHERS: 127 result = dispatcher.handle(op, args, kwargs) 128 if result is not OpDispatcher.NOT_SUPPORTED: 129 return result 130 return OpDispatcher.NOT_SUPPORTED 131 132 133class _TypeBasedDispatcher(OpDispatcher): 134 """Dispatcher that handles op if any arguments have a specified type. 135 136 Checks the types of the arguments and keyword arguments (including elements 137 of lists or tuples), and if any argument values have the indicated type(s), 138 then delegates to an override function. 139 """ 140 141 def __init__(self, override_func, types): 142 self._types = types 143 self._override_func = override_func 144 145 def _handles(self, args, kwargs): 146 for arg in itertools.chain(args, kwargs.values()): 147 if (isinstance(arg, self._types) or 148 (isinstance(arg, (list, tuple)) and 149 any(isinstance(elt, self._types) for elt in arg))): 150 return True 151 return False 152 153 def handle(self, args, kwargs): 154 if self._handles(args, kwargs): 155 return self._override_func(*args, **kwargs) 156 else: 157 return self.NOT_SUPPORTED 158 159 160# pylint: disable=g-doc-return-or-yield 161def dispatch_for_types(op, *types): 162 """Decorator to declare that a Python function overrides an op for a type. 163 164 The decorated function is used to override `op` if any of the arguments or 165 keyword arguments (including elements of lists or tuples) have one of the 166 specified types. 167 168 Example: 169 170 ```python 171 @dispatch_for_types(math_ops.add, RaggedTensor, RaggedTensorValue) 172 def ragged_add(x, y, name=None): ... 173 ``` 174 175 Args: 176 op: Python function: the operation that should be overridden. 177 *types: The argument types for which this function should be used. 178 """ 179 180 def decorator(func): 181 if tf_inspect.getargspec(func) != tf_inspect.getargspec(op): 182 raise AssertionError("The decorated function's signature must exactly " 183 "match the signature of the overridden op.") 184 _TypeBasedDispatcher(func, types).register(op) 185 return func 186 187 return decorator 188 189 190# pylint: enable=g-doc-return-or-yield 191 192 193def add_dispatch_list(target): 194 """Decorator that adds a dispatch_list attribute to an op.""" 195 if hasattr(target, DISPATCH_ATTR): 196 raise AssertionError("%s already has a dispatch list" % target) 197 setattr(target, DISPATCH_ATTR, []) 198 return target 199 200 201@tf_export("__internal__.dispatch.add_dispatch_support", v1=[]) 202def add_dispatch_support(target): 203 """Decorator that adds a dispatch handling wrapper to an op.""" 204 205 @traceback_utils.filter_traceback 206 def op_dispatch_handler(*args, **kwargs): 207 """Call target, and fall back on dispatchers if there is a TypeError.""" 208 try: 209 return target(*args, **kwargs) 210 except (TypeError, ValueError): 211 # Note: convert_to_eager_tensor currently raises a ValueError, not a 212 # TypeError, when given unexpected types. So we need to catch both. 213 result = dispatch(op_dispatch_handler, args, kwargs) 214 if result is not OpDispatcher.NOT_SUPPORTED: 215 return result 216 else: 217 raise 218 219 add_dispatch_list(op_dispatch_handler) 220 return tf_decorator.make_decorator(target, op_dispatch_handler) 221