• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Type-based dispatch for TensorFlow ops.
16
17"Operation dispatchers" can be used to override the behavior for TensorFlow ops
18when they are called with otherwise unsupported argument types.  In particular,
19when an operation is called with arguments that would cause it to raise a
20TypeError, it falls back on its registered operation dispatchers.  If any
21registered dispatchers can handle the arguments, then its result is returned.
22Otherwise, the original TypeError is raised.
23
24By default, dispatch support is added to the generated op wrappers for any
25visible ops by default.  Ops that are implemented in Python can opt in to
26dispatch support using the `add_dispatch_support` decorator.
27"""
28
29from __future__ import absolute_import
30from __future__ import division
31from __future__ import print_function
32
33import itertools
34
35from tensorflow.python.util import tf_decorator
36from tensorflow.python.util import tf_inspect
37from tensorflow.python.util import traceback_utils
38from tensorflow.python.util.tf_export import tf_export
39
40
41# Private function attribute used to store a list of dispatchers.
42DISPATCH_ATTR = "_tf_dispatchers"
43
44
45# OpDispatchers which should be used for all operations.
46_GLOBAL_DISPATCHERS = []
47
48
49@tf_export("__internal__.dispatch.OpDispatcher", v1=[])
50class OpDispatcher(object):
51  """Abstract base class for TensorFlow operator dispatchers.
52
53  Each operation dispatcher acts as an override handler for a single
54  TensorFlow operation, and its results are used when the handler indicates
55  that it can handle the operation's arguments (by returning any value other
56  than `OpDispatcher.NOT_SUPPORTED`).
57  """
58
59  # Sentinel value that can be returned to indicate that an operation
60  # dispatcher does not support a given set of arguments.
61  NOT_SUPPORTED = object()
62
63  def handle(self, args, kwargs):  # pylint: disable=unused-argument
64    """Handle this dispatcher's operation with the specified arguments.
65
66    If this operation dispatcher can handle the given arguments, then
67    return an appropriate value (or raise an appropriate exception).
68
69    Args:
70      args: The arguments to the operation.
71      kwargs: They keyword arguments to the operation.
72
73    Returns:
74      The result of the operation, or `OpDispatcher.NOT_SUPPORTED` if this
75      dispatcher can not handle the given arguments.
76    """
77    return self.NOT_SUPPORTED
78
79  def register(self, op):
80    """Register this dispatcher as a handler for `op`.
81
82    Args:
83      op: Python function: the TensorFlow operation that should be handled. Must
84        have a dispatch list (which is added automatically for generated ops,
85        and can be added to Python ops using the `add_dispatch_support`
86        decorator).
87    """
88    if not hasattr(op, DISPATCH_ATTR):
89      raise AssertionError("Dispatching not enabled for %s" % op)
90    getattr(op, DISPATCH_ATTR).append(self)
91
92
93@tf_export("__internal__.dispatch.GlobalOpDispatcher", v1=[])
94class GlobalOpDispatcher(object):
95  """Abstract base class for TensorFlow global operator dispatchers."""
96
97  NOT_SUPPORTED = OpDispatcher.NOT_SUPPORTED
98
99  def handle(self, op, args, kwargs):
100    """Handle the specified operation with the specified arguments."""
101
102  def register(self):
103    """Register this dispatcher as a handler for all ops."""
104    _GLOBAL_DISPATCHERS.append(self)
105
106
107def dispatch(op, args, kwargs):
108  """Returns the result from the first successful dispatcher for a given op.
109
110  Calls the `handle` method of each `OpDispatcher` that has been registered
111  to handle `op`, and returns the value from the first successful handler.
112
113  Args:
114    op: Python function: the operation to dispatch for.
115    args: The arguments to the operation.
116    kwargs: They keyword arguments to the operation.
117
118  Returns:
119    The result of the operation, or `NOT_SUPPORTED` if no registered
120    dispatcher can handle the given arguments.
121  """
122  for dispatcher in getattr(op, DISPATCH_ATTR):
123    result = dispatcher.handle(args, kwargs)
124    if result is not OpDispatcher.NOT_SUPPORTED:
125      return result
126  for dispatcher in _GLOBAL_DISPATCHERS:
127    result = dispatcher.handle(op, args, kwargs)
128    if result is not OpDispatcher.NOT_SUPPORTED:
129      return result
130  return OpDispatcher.NOT_SUPPORTED
131
132
133class _TypeBasedDispatcher(OpDispatcher):
134  """Dispatcher that handles op if any arguments have a specified type.
135
136  Checks the types of the arguments and keyword arguments (including elements
137  of lists or tuples), and if any argument values have the indicated type(s),
138  then delegates to an override function.
139  """
140
141  def __init__(self, override_func, types):
142    self._types = types
143    self._override_func = override_func
144
145  def _handles(self, args, kwargs):
146    for arg in itertools.chain(args, kwargs.values()):
147      if (isinstance(arg, self._types) or
148          (isinstance(arg, (list, tuple)) and
149           any(isinstance(elt, self._types) for elt in arg))):
150        return True
151    return False
152
153  def handle(self, args, kwargs):
154    if self._handles(args, kwargs):
155      return self._override_func(*args, **kwargs)
156    else:
157      return self.NOT_SUPPORTED
158
159
160# pylint: disable=g-doc-return-or-yield
161def dispatch_for_types(op, *types):
162  """Decorator to declare that a Python function overrides an op for a type.
163
164  The decorated function is used to override `op` if any of the arguments or
165  keyword arguments (including elements of lists or tuples) have one of the
166  specified types.
167
168  Example:
169
170  ```python
171  @dispatch_for_types(math_ops.add, RaggedTensor, RaggedTensorValue)
172  def ragged_add(x, y, name=None): ...
173  ```
174
175  Args:
176    op: Python function: the operation that should be overridden.
177    *types: The argument types for which this function should be used.
178  """
179
180  def decorator(func):
181    if tf_inspect.getargspec(func) != tf_inspect.getargspec(op):
182      raise AssertionError("The decorated function's signature must exactly "
183                           "match the signature of the overridden op.")
184    _TypeBasedDispatcher(func, types).register(op)
185    return func
186
187  return decorator
188
189
190# pylint: enable=g-doc-return-or-yield
191
192
193def add_dispatch_list(target):
194  """Decorator that adds a dispatch_list attribute to an op."""
195  if hasattr(target, DISPATCH_ATTR):
196    raise AssertionError("%s already has a dispatch list" % target)
197  setattr(target, DISPATCH_ATTR, [])
198  return target
199
200
201@tf_export("__internal__.dispatch.add_dispatch_support", v1=[])
202def add_dispatch_support(target):
203  """Decorator that adds a dispatch handling wrapper to an op."""
204
205  @traceback_utils.filter_traceback
206  def op_dispatch_handler(*args, **kwargs):
207    """Call target, and fall back on dispatchers if there is a TypeError."""
208    try:
209      return target(*args, **kwargs)
210    except (TypeError, ValueError):
211      # Note: convert_to_eager_tensor currently raises a ValueError, not a
212      # TypeError, when given unexpected types.  So we need to catch both.
213      result = dispatch(op_dispatch_handler, args, kwargs)
214      if result is not OpDispatcher.NOT_SUPPORTED:
215        return result
216      else:
217        raise
218
219  add_dispatch_list(op_dispatch_handler)
220  return tf_decorator.make_decorator(target, op_dispatch_handler)
221