1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Type-based dispatch for TensorFlow ops. 16 17"Operation dispatchers" can be used to override the behavior for TensorFlow ops 18when they are called with otherwise unsupported argument types. In particular, 19when an operation is called with arguments that would cause it to raise a 20TypeError, it falls back on its registered operation dispatchers. If any 21registered dispatchers can handle the arguments, then its result is returned. 22Otherwise, the original TypeError is raised. 23 24By default, dispatch support is added to the generated op wrappers for any 25visible ops by default. Ops that are implemented in Python can opt in to 26dispatch support using the `add_dispatch_support` decorator. 27""" 28 29from __future__ import absolute_import 30from __future__ import division 31from __future__ import print_function 32 33import itertools 34 35from tensorflow.python.util import tf_decorator 36from tensorflow.python.util import tf_inspect 37 38# Private function attribute used to store a list of dispatchers. 39DISPATCH_ATTR = "_tf_dispatchers" 40 41 42class OpDispatcher(object): 43 """Abstract base class for TensorFlow operator dispatchers. 44 45 Each operation dispatcher acts as an override handler for a single 46 TensorFlow operation, and its results are used when the handler indicates 47 that it can handle the operation's arguments (by returning any value other 48 than `OpDispatcher.NOT_SUPPORTED`). 49 """ 50 51 # Sentinel value that can be returned to indicate that an operation 52 # dispatcher does not support a given set of arguments. 53 NOT_SUPPORTED = object() 54 55 def handle(self, args, kwargs): # pylint: disable=unused-argument 56 """Handle this dispatcher's operation with the specified arguments. 57 58 If this operation dispatcher can handle the given arguments, then 59 return an appropriate value (or raise an appropriate exception). 60 61 Args: 62 args: The arguments to the operation. 63 kwargs: They keyword arguments to the operation. 64 65 Returns: 66 The result of the operation, or `OpDispatcher.NOT_SUPPORTED` if this 67 dispatcher can not handle the given arguments. 68 """ 69 return self.NOT_SUPPORTED 70 71 def register(self, op): 72 """Register this dispatcher as a handler for `op`. 73 74 Args: 75 op: Python function: the TensorFlow operation that should be handled. Must 76 have a dispatch list (which is added automatically for generated ops, 77 and can be added to Python ops using the `add_dispatch_support` 78 decorator). 79 """ 80 if not hasattr(op, DISPATCH_ATTR): 81 raise AssertionError("Dispatching not enabled for %s" % op) 82 getattr(op, DISPATCH_ATTR).append(self) 83 84 85def dispatch(op, *args, **kwargs): 86 """Returns the result from the first successful dispatcher for a given op. 87 88 Calls the `handle` method of each `OpDispatcher` that has been registered 89 to handle `op`, and returns the value from the first successful handler. 90 91 Args: 92 op: Python function: the operation to dispatch for. 93 *args: The arguments to the operation. 94 **kwargs: They keyword arguments to the operation. 95 96 Returns: 97 The result of the operation, or `NOT_SUPPORTED` if no registered 98 dispatcher can handle the given arguments. 99 """ 100 for dispatcher in getattr(op, DISPATCH_ATTR): 101 result = dispatcher.handle(args, kwargs) 102 if result is not OpDispatcher.NOT_SUPPORTED: 103 return result 104 return OpDispatcher.NOT_SUPPORTED 105 106 107class _TypeBasedDispatcher(OpDispatcher): 108 """Dispatcher that handles op if any arguments have a specified type. 109 110 Checks the types of the arguments and keyword arguments (including elements 111 of lists or tuples), and if any argument values have the indicated type(s), 112 then delegates to an override function. 113 """ 114 115 def __init__(self, override_func, types): 116 self._types = types 117 self._override_func = override_func 118 119 def _handles(self, args, kwargs): 120 for arg in itertools.chain(args, kwargs.values()): 121 if (isinstance(arg, self._types) or 122 (isinstance(arg, (list, tuple)) and 123 any(isinstance(elt, self._types) for elt in arg))): 124 return True 125 return False 126 127 def handle(self, args, kwargs): 128 if self._handles(args, kwargs): 129 return self._override_func(*args, **kwargs) 130 else: 131 return self.NOT_SUPPORTED 132 133 134# pylint: disable=g-doc-return-or-yield 135def dispatch_for_types(op, *types): 136 """Decorator to declare that a Python function overrides an op for a type. 137 138 The decorated function is used to override `op` if any of the arguments or 139 keyword arguments (including elements of lists or tuples) have one of the 140 specified types. 141 142 Example: 143 144 ```python 145 @dispatch_for_types(math_ops.add, RaggedTensor, RaggedTensorValue) 146 def ragged_add(x, y, name=None): ... 147 ``` 148 149 Args: 150 op: Python function: the operation that should be overridden. 151 *types: The argument types for which this function should be used. 152 """ 153 154 def decorator(func): 155 if tf_inspect.getargspec(func) != tf_inspect.getargspec(op): 156 raise AssertionError("The decorated function's signature must exactly " 157 "match the signature of the overridden op.") 158 _TypeBasedDispatcher(func, types).register(op) 159 return func 160 161 return decorator 162 163 164# pylint: enable=g-doc-return-or-yield 165 166 167def add_dispatch_list(target): 168 """Decorator that adds a dispatch_list attribute to an op.""" 169 if hasattr(target, DISPATCH_ATTR): 170 raise AssertionError("%s already has a dispatch list" % target) 171 setattr(target, DISPATCH_ATTR, []) 172 return target 173 174 175def add_dispatch_support(target): 176 """Decorator that adds a dispatch handling wrapper to an op.""" 177 def wrapper(*args, **kwargs): 178 """Call target, and fall back on dispatchers if there is a TypeError.""" 179 try: 180 return target(*args, **kwargs) 181 except (TypeError, ValueError): 182 # Note: convert_to_eager_tensor currently raises a ValueError, not a 183 # TypeError, when given unexpected types. So we need to catch both. 184 result = dispatch(wrapper, *args, **kwargs) 185 if result is not OpDispatcher.NOT_SUPPORTED: 186 return result 187 else: 188 raise 189 190 add_dispatch_list(wrapper) 191 return tf_decorator.make_decorator(target, wrapper) 192