• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Type-based dispatch for TensorFlow ops.
16
17"Operation dispatchers" can be used to override the behavior for TensorFlow ops
18when they are called with otherwise unsupported argument types.  In particular,
19when an operation is called with arguments that would cause it to raise a
20TypeError, it falls back on its registered operation dispatchers.  If any
21registered dispatchers can handle the arguments, then its result is returned.
22Otherwise, the original TypeError is raised.
23
24By default, dispatch support is added to the generated op wrappers for any
25visible ops by default.  Ops that are implemented in Python can opt in to
26dispatch support using the `add_dispatch_support` decorator.
27"""
28
29from __future__ import absolute_import
30from __future__ import division
31from __future__ import print_function
32
33import itertools
34
35from tensorflow.python.util import tf_decorator
36from tensorflow.python.util import tf_inspect
37
38# Private function attribute used to store a list of dispatchers.
39DISPATCH_ATTR = "_tf_dispatchers"
40
41
42class OpDispatcher(object):
43  """Abstract base class for TensorFlow operator dispatchers.
44
45  Each operation dispatcher acts as an override handler for a single
46  TensorFlow operation, and its results are used when the handler indicates
47  that it can handle the operation's arguments (by returning any value other
48  than `OpDispatcher.NOT_SUPPORTED`).
49  """
50
51  # Sentinel value that can be returned to indicate that an operation
52  # dispatcher does not support a given set of arguments.
53  NOT_SUPPORTED = object()
54
55  def handle(self, args, kwargs):  # pylint: disable=unused-argument
56    """Handle this dispatcher's operation with the specified arguments.
57
58    If this operation dispatcher can handle the given arguments, then
59    return an appropriate value (or raise an appropriate exception).
60
61    Args:
62      args: The arguments to the operation.
63      kwargs: They keyword arguments to the operation.
64
65    Returns:
66      The result of the operation, or `OpDispatcher.NOT_SUPPORTED` if this
67      dispatcher can not handle the given arguments.
68    """
69    return self.NOT_SUPPORTED
70
71  def register(self, op):
72    """Register this dispatcher as a handler for `op`.
73
74    Args:
75      op: Python function: the TensorFlow operation that should be handled. Must
76        have a dispatch list (which is added automatically for generated ops,
77        and can be added to Python ops using the `add_dispatch_support`
78        decorator).
79    """
80    if not hasattr(op, DISPATCH_ATTR):
81      raise AssertionError("Dispatching not enabled for %s" % op)
82    getattr(op, DISPATCH_ATTR).append(self)
83
84
85def dispatch(op, *args, **kwargs):
86  """Returns the result from the first successful dispatcher for a given op.
87
88  Calls the `handle` method of each `OpDispatcher` that has been registered
89  to handle `op`, and returns the value from the first successful handler.
90
91  Args:
92    op: Python function: the operation to dispatch for.
93    *args: The arguments to the operation.
94    **kwargs: They keyword arguments to the operation.
95
96  Returns:
97    The result of the operation, or `NOT_SUPPORTED` if no registered
98    dispatcher can handle the given arguments.
99  """
100  for dispatcher in getattr(op, DISPATCH_ATTR):
101    result = dispatcher.handle(args, kwargs)
102    if result is not OpDispatcher.NOT_SUPPORTED:
103      return result
104  return OpDispatcher.NOT_SUPPORTED
105
106
107class _TypeBasedDispatcher(OpDispatcher):
108  """Dispatcher that handles op if any arguments have a specified type.
109
110  Checks the types of the arguments and keyword arguments (including elements
111  of lists or tuples), and if any argument values have the indicated type(s),
112  then delegates to an override function.
113  """
114
115  def __init__(self, override_func, types):
116    self._types = types
117    self._override_func = override_func
118
119  def _handles(self, args, kwargs):
120    for arg in itertools.chain(args, kwargs.values()):
121      if (isinstance(arg, self._types) or
122          (isinstance(arg, (list, tuple)) and
123           any(isinstance(elt, self._types) for elt in arg))):
124        return True
125    return False
126
127  def handle(self, args, kwargs):
128    if self._handles(args, kwargs):
129      return self._override_func(*args, **kwargs)
130    else:
131      return self.NOT_SUPPORTED
132
133
134# pylint: disable=g-doc-return-or-yield
135def dispatch_for_types(op, *types):
136  """Decorator to declare that a Python function overrides an op for a type.
137
138  The decorated function is used to override `op` if any of the arguments or
139  keyword arguments (including elements of lists or tuples) have one of the
140  specified types.
141
142  Example:
143
144  ```python
145  @dispatch_for_types(math_ops.add, RaggedTensor, RaggedTensorValue)
146  def ragged_add(x, y, name=None): ...
147  ```
148
149  Args:
150    op: Python function: the operation that should be overridden.
151    *types: The argument types for which this function should be used.
152  """
153
154  def decorator(func):
155    if tf_inspect.getargspec(func) != tf_inspect.getargspec(op):
156      raise AssertionError("The decorated function's signature must exactly "
157                           "match the signature of the overridden op.")
158    _TypeBasedDispatcher(func, types).register(op)
159    return func
160
161  return decorator
162
163
164# pylint: enable=g-doc-return-or-yield
165
166
167def add_dispatch_list(target):
168  """Decorator that adds a dispatch_list attribute to an op."""
169  if hasattr(target, DISPATCH_ATTR):
170    raise AssertionError("%s already has a dispatch list" % target)
171  setattr(target, DISPATCH_ATTR, [])
172  return target
173
174
175def add_dispatch_support(target):
176  """Decorator that adds a dispatch handling wrapper to an op."""
177  def wrapper(*args, **kwargs):
178    """Call target, and fall back on dispatchers if there is a TypeError."""
179    try:
180      return target(*args, **kwargs)
181    except (TypeError, ValueError):
182      # Note: convert_to_eager_tensor currently raises a ValueError, not a
183      # TypeError, when given unexpected types.  So we need to catch both.
184      result = dispatch(wrapper, *args, **kwargs)
185      if result is not OpDispatcher.NOT_SUPPORTED:
186        return result
187      else:
188        raise
189
190  add_dispatch_list(wrapper)
191  return tf_decorator.make_decorator(target, wrapper)
192