• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Type-based dispatch for TensorFlow's Python APIs.
16
17"Python APIs" refers to Python functions that have been exported with
18`tf_export`, such as `tf.add` and `tf.linalg.matmul`; they are sometimes also
19referred to as "ops".
20
21There are currently two dispatch systems for TensorFlow:
22
23  * The "fallback dispatch" system calls an API's standard implementation first,
24    and only tries to perform dispatch if that standard implementation raises a
25    TypeError (or ValueError) exception.
26
27  * The "type-based dispatch" system checks the types of the parameters passed
28    to an API, and performs dispatch if those types match any signatures that
29    have been registered for dispatch.
30
31The fallback dispatch system was the original dispatch system, but it was
32somewhat brittle and had limitations, such as an inability to support dispatch
33for some operations (like convert_to_tensor).  We plan to remove the fallback
34dispatch system in favor of the type-based dispatch system, once all users have
35been switched over to use it.
36
37### Fallback Dispatch
38
39The fallback dispatch system is based on "operation dispatchers", which can be
40used to override the behavior for TensorFlow ops when they are called with
41otherwise unsupported argument types.  In particular, when an operation is
42called with arguments that would cause it to raise a TypeError, it falls back on
43its registered operation dispatchers.  If any registered dispatchers can handle
44the arguments, then its result is returned. Otherwise, the original TypeError is
45raised.
46
47### Type-based Dispatch
48
49The main interface for the type-based dispatch system is the `dispatch_for_api`
50decorator, which overrides the default implementation for a TensorFlow API.
51The decorated function (known as the "dispatch target") will override the
52default implementation for the API when the API is called with parameters that
53match a specified type signature.
54
55### Dispatch Support
56
57By default, dispatch support is added to the generated op wrappers for any
58visible ops by default.  APIs/ops that are implemented in Python can opt in to
59dispatch support using the `add_dispatch_support` decorator.
60"""
61
62import collections
63import itertools
64import typing  # pylint: disable=unused-import (used in doctests)
65
66from tensorflow.python.framework import _pywrap_python_api_dispatcher as _api_dispatcher
67from tensorflow.python.framework import ops
68from tensorflow.python.util import tf_decorator
69from tensorflow.python.util import tf_export as tf_export_lib
70from tensorflow.python.util import tf_inspect
71from tensorflow.python.util import traceback_utils
72from tensorflow.python.util import type_annotations
73from tensorflow.python.util.tf_export import tf_export
74
75
76# Private function attributes used to store dispatchers on TensorFlow APIs.
77FALLBACK_DISPATCH_ATTR = "_tf_fallback_dispatchers"
78TYPE_BASED_DISPATCH_ATTR = "_tf_type_based_dispatcher"
79
80# OpDispatchers which should be used for all operations.
81_GLOBAL_DISPATCHERS = []
82
83
84################################################################################
85# Fallback Dispatch
86################################################################################
87
88
89@tf_export("__internal__.dispatch.OpDispatcher", v1=[])
90class OpDispatcher(object):
91  """Abstract base class for TensorFlow operator dispatchers.
92
93  Each operation dispatcher acts as an override handler for a single
94  TensorFlow operation, and its results are used when the handler indicates
95  that it can handle the operation's arguments (by returning any value other
96  than `OpDispatcher.NOT_SUPPORTED`).
97  """
98
99  # Sentinel value that can be returned to indicate that an operation
100  # dispatcher does not support a given set of arguments.
101  NOT_SUPPORTED = object()
102
103  def handle(self, args, kwargs):  # pylint: disable=unused-argument
104    """Handle this dispatcher's operation with the specified arguments.
105
106    If this operation dispatcher can handle the given arguments, then
107    return an appropriate value (or raise an appropriate exception).
108
109    Args:
110      args: The arguments to the operation.
111      kwargs: They keyword arguments to the operation.
112
113    Returns:
114      The result of the operation, or `OpDispatcher.NOT_SUPPORTED` if this
115      dispatcher can not handle the given arguments.
116    """
117    return self.NOT_SUPPORTED
118
119  def register(self, op):
120    """Register this dispatcher as a handler for `op`.
121
122    Args:
123      op: Python function: the TensorFlow operation that should be handled. Must
124        have a dispatch list (which is added automatically for generated ops,
125        and can be added to Python ops using the `add_dispatch_support`
126        decorator).
127    """
128    if not hasattr(op, FALLBACK_DISPATCH_ATTR):
129      raise AssertionError("Dispatching not enabled for %s" % op)
130    getattr(op, FALLBACK_DISPATCH_ATTR).append(self)
131
132
133@tf_export("__internal__.dispatch.GlobalOpDispatcher", v1=[])
134class GlobalOpDispatcher(object):
135  """Abstract base class for TensorFlow global operator dispatchers."""
136
137  NOT_SUPPORTED = OpDispatcher.NOT_SUPPORTED
138
139  def handle(self, op, args, kwargs):
140    """Handle the specified operation with the specified arguments."""
141
142  def register(self):
143    """Register this dispatcher as a handler for all ops."""
144    _GLOBAL_DISPATCHERS.append(self)
145
146
147def dispatch(op, args, kwargs):
148  """Returns the result from the first successful dispatcher for a given op.
149
150  Calls the `handle` method of each `OpDispatcher` that has been registered
151  to handle `op`, and returns the value from the first successful handler.
152
153  Args:
154    op: Python function: the operation to dispatch for.
155    args: The arguments to the operation.
156    kwargs: They keyword arguments to the operation.
157
158  Returns:
159    The result of the operation, or `NOT_SUPPORTED` if no registered
160    dispatcher can handle the given arguments.
161  """
162  for dispatcher in getattr(op, FALLBACK_DISPATCH_ATTR):
163    result = dispatcher.handle(args, kwargs)
164    if result is not OpDispatcher.NOT_SUPPORTED:
165      return result
166  for dispatcher in _GLOBAL_DISPATCHERS:
167    result = dispatcher.handle(op, args, kwargs)
168    if result is not OpDispatcher.NOT_SUPPORTED:
169      return result
170  return OpDispatcher.NOT_SUPPORTED
171
172
173class _TypeBasedDispatcher(OpDispatcher):
174  """Dispatcher that handles op if any arguments have a specified type.
175
176  Checks the types of the arguments and keyword arguments (including elements
177  of lists or tuples), and if any argument values have the indicated type(s),
178  then delegates to an override function.
179  """
180
181  def __init__(self, override_func, types):
182    self._types = types
183    self._override_func = override_func
184
185  def _handles(self, args, kwargs):
186    for arg in itertools.chain(args, kwargs.values()):
187      if (isinstance(arg, self._types) or
188          (isinstance(arg, (list, tuple)) and
189           any(isinstance(elt, self._types) for elt in arg))):
190        return True
191    return False
192
193  def handle(self, args, kwargs):
194    if self._handles(args, kwargs):
195      return self._override_func(*args, **kwargs)
196    else:
197      return self.NOT_SUPPORTED
198
199
200# pylint: disable=g-doc-return-or-yield
201def dispatch_for_types(op, *types):
202  """Decorator to declare that a Python function overrides an op for a type.
203
204  The decorated function is used to override `op` if any of the arguments or
205  keyword arguments (including elements of lists or tuples) have one of the
206  specified types.
207
208  Example:
209
210  ```python
211  @dispatch_for_types(math_ops.add, RaggedTensor, RaggedTensorValue)
212  def ragged_add(x, y, name=None): ...
213  ```
214
215  Args:
216    op: Python function: the operation that should be overridden.
217    *types: The argument types for which this function should be used.
218  """
219
220  def decorator(func):
221    if tf_inspect.getargspec(func) != tf_inspect.getargspec(op):
222      raise AssertionError("The decorated function's signature must exactly "
223                           "match the signature of the overridden op.")
224    _TypeBasedDispatcher(func, types).register(op)
225    return func
226
227  return decorator
228
229
230# pylint: enable=g-doc-return-or-yield
231
232
233def add_fallback_dispatch_list(target):
234  """Decorator that adds a dispatch_list attribute to an op."""
235  if hasattr(target, FALLBACK_DISPATCH_ATTR):
236    raise AssertionError("%s already has a dispatch list" % target)
237  setattr(target, FALLBACK_DISPATCH_ATTR, [])
238  return target
239
240
241# Alias for backwards-compatibility.
242add_dispatch_list = add_fallback_dispatch_list
243
244
245################################################################################
246# Type-based Dispatch
247################################################################################
248
249
250@tf_export("experimental.dispatch_for_api")
251def dispatch_for_api(api, *signatures):
252  """Decorator that overrides the default implementation for a TensorFlow API.
253
254  The decorated function (known as the "dispatch target") will override the
255  default implementation for the API when the API is called with parameters that
256  match a specified type signature.  Signatures are specified using dictionaries
257  that map parameter names to type annotations.  E.g., in the following example,
258  `masked_add` will be called for `tf.add` if both `x` and `y` are
259  `MaskedTensor`s:
260
261  >>> class MaskedTensor(tf.experimental.ExtensionType):
262  ...   values: tf.Tensor
263  ...   mask: tf.Tensor
264
265  >>> @dispatch_for_api(tf.math.add, {'x': MaskedTensor, 'y': MaskedTensor})
266  ... def masked_add(x, y, name=None):
267  ...   return MaskedTensor(x.values + y.values, x.mask & y.mask)
268
269  >>> mt = tf.add(MaskedTensor([1, 2], [True, False]), MaskedTensor(10, True))
270  >>> print(f"values={mt.values.numpy()}, mask={mt.mask.numpy()}")
271  values=[11 12], mask=[ True False]
272
273  If multiple type signatures are specified, then the dispatch target will be
274  called if any of the signatures match.  For example, the following code
275  registers `masked_add` to be called if `x` is a `MaskedTensor` *or* `y` is
276  a `MaskedTensor`.
277
278  >>> @dispatch_for_api(tf.math.add, {'x': MaskedTensor}, {'y':MaskedTensor})
279  ... def masked_add(x, y):
280  ...   x_values = x.values if isinstance(x, MaskedTensor) else x
281  ...   x_mask = x.mask if isinstance(x, MaskedTensor) else True
282  ...   y_values = y.values if isinstance(y, MaskedTensor) else y
283  ...   y_mask = y.mask if isinstance(y, MaskedTensor) else True
284  ...   return MaskedTensor(x_values + y_values, x_mask & y_mask)
285
286  The type annotations in type signatures may be type objects (e.g.,
287  `MaskedTensor`), `typing.List` values, or `typing.Union` values.   For
288  example, the following will register `masked_concat` to be called if `values`
289  is a list of `MaskedTensor` values:
290
291  >>> @dispatch_for_api(tf.concat, {'values': typing.List[MaskedTensor]})
292  ... def masked_concat(values, axis):
293  ...   return MaskedTensor(tf.concat([v.values for v in values], axis),
294  ...                       tf.concat([v.mask for v in values], axis))
295
296  Each type signature must contain at least one subclass of `tf.CompositeTensor`
297  (which includes subclasses of `tf.ExtensionType`), and dispatch will only be
298  triggered if at least one type-annotated parameter contains a
299  `CompositeTensor` value.  This rule avoids invoking dispatch in degenerate
300  cases, such as the following examples:
301
302  * `@dispatch_for_api(tf.concat, {'values': List[MaskedTensor]})`: Will not
303    dispatch to the decorated dispatch target when the user calls
304    `tf.concat([])`.
305
306  * `@dispatch_for_api(tf.add, {'x': Union[MaskedTensor, Tensor], 'y':
307    Union[MaskedTensor, Tensor]})`: Will not dispatch to the decorated dispatch
308    target when the user calls `tf.add(tf.constant(1), tf.constant(2))`.
309
310  The dispatch target's signature must match the signature of the API that is
311  being overridden.  In particular, parameters must have the same names, and
312  must occur in the same order.  The dispatch target may optionally elide the
313  "name" parameter, in which case it will be wrapped with a call to
314  `tf.name_scope` when appropraite.
315
316  Args:
317    api: The TensorFlow API to override.
318    *signatures: Dictionaries mapping parameter names or indices to type
319      annotations, specifying when the dispatch target should be called.  In
320      particular, the dispatch target will be called if any signature matches;
321      and a signature matches if all of the specified parameters have types that
322      match with the indicated type annotations.  If no signatures are
323      specified, then a signature will be read from the dispatch target
324      function's type annotations.
325
326  Returns:
327    A decorator that overrides the default implementation for `api`.
328
329  #### Registered APIs
330
331  The TensorFlow APIs that may be overridden by `@dispatch_for_api` are:
332
333  <<API_LIST>>
334  """
335  dispatcher = getattr(api, TYPE_BASED_DISPATCH_ATTR, None)
336  if dispatcher is None:
337    raise ValueError(f"{api} does not support dispatch.")
338
339  api_signature = tf_inspect.signature(api)
340  signature_checkers = [
341      _make_signature_checker(api_signature, signature)
342      for signature in signatures
343  ]
344
345  def decorator(dispatch_target):
346    """Decorator that registers the given dispatch target."""
347    if not callable(dispatch_target):
348      raise TypeError("Expected dispatch_target to be callable; "
349                      f"got {dispatch_target!r}")
350    dispatch_target = _add_name_scope_wrapper(dispatch_target, api_signature)
351    _check_signature(api_signature, dispatch_target)
352
353    for signature_checker in signature_checkers:
354      dispatcher.Register(signature_checker, dispatch_target)
355    _TYPE_BASED_DISPATCH_SIGNATURES[api][dispatch_target].extend(signatures)
356
357    if not signature_checkers:
358      signature = _signature_from_annotations(dispatch_target)
359      checker = _make_signature_checker(api_signature, signature)
360      dispatcher.Register(checker, dispatch_target)
361      _TYPE_BASED_DISPATCH_SIGNATURES[api][dispatch_target].append(signature)
362
363    return dispatch_target
364
365  return decorator
366
367
368# Nested dict mapping `api_func` -> `dispatch_target` -> `List[signature]`,
369# which can be used for documentation generation and for improved error messages
370# when APIs are called with unsupported types.
371_TYPE_BASED_DISPATCH_SIGNATURES = {}
372
373
374def apis_with_type_based_dispatch():
375  """Returns a list of TensorFlow APIs that support type-based dispatch."""
376  return sorted(
377      _TYPE_BASED_DISPATCH_SIGNATURES,
378      key=lambda api: f"{api.__module__}.{api.__name__}")
379
380
381def type_based_dispatch_signatures_for(cls):
382  """Returns dispatch signatures that have been registered for a given class.
383
384  This function is intended for documentation-generation purposes.
385
386  Args:
387    cls: The class to search for.  Type signatures are searched recursively, so
388      e.g., if `cls=RaggedTensor`, then information will be returned for all
389      dispatch targets that have `RaggedTensor` anywhere in their type
390      annotations (including nested in `typing.Union` or `typing.List`.)
391
392  Returns:
393    A `dict` mapping `api` -> `signatures`, where `api` is a TensorFlow API
394    function; and `signatures` is a list of dispatch signatures for `api`
395    that include `cls`.  (Each signature is a dict mapping argument names to
396    type annotations; see `dispatch_for_api` for more info.)
397  """
398
399  def contains_cls(x):
400    """Returns true if `x` contains `cls`."""
401    if isinstance(x, dict):
402      return any(contains_cls(v) for v in x.values())
403    elif x is cls:
404      return True
405    elif (type_annotations.is_generic_list(x) or
406          type_annotations.is_generic_union(x)):
407      type_args = type_annotations.get_generic_type_args(x)
408      return any(contains_cls(arg) for arg in type_args)
409    else:
410      return False
411
412  result = {}
413  for api, api_signatures in _TYPE_BASED_DISPATCH_SIGNATURES.items():
414    for _, signatures in api_signatures.items():
415      filtered = list(filter(contains_cls, signatures))
416      if filtered:
417        result.setdefault(api, []).extend(filtered)
418  return result
419
420
421# TODO(edloper): Consider using a mechanism like this to automatically add
422# the `name` argument to all TensorFlow APIs that are implemented in Python
423# (so each Python function doesn't need to do it manually).
424def _add_name_scope_wrapper(func, api_signature):
425  """Wraps `func` to expect a "name" arg, and use it to call `ops.name_scope`.
426
427  If `func` already expects a "name" arg, or if `api_signature` does not
428  expect a "name" arg, then returns `func` as-is.
429
430  Args:
431    func: The function to wrap.  Signature must match `api_signature` (except
432      the "name" parameter may be missing.
433    api_signature: The signature of the original API (used to find the index for
434      the "name" parameter).
435
436  Returns:
437    The wrapped function (or the original function if no wrapping is needed).
438  """
439  if "name" not in api_signature.parameters:
440    return func  # no wrapping needed (API has no name parameter).
441
442  func_signature = tf_inspect.signature(func)
443  func_argspec = tf_inspect.getargspec(func)
444  if "name" in func_signature.parameters or func_argspec.keywords is not None:
445    return func  # No wrapping needed (already has name parameter).
446
447  name_index = list(api_signature.parameters).index("name")
448
449  def wrapped_func(*args, **kwargs):
450    if name_index < len(args):
451      name = args[name_index]
452      args = args[:name_index] + args[name_index + 1:]
453    else:
454      name = kwargs.pop("name", None)
455    if name is None:
456      return func(*args, **kwargs)
457    else:
458      with ops.name_scope(name):
459        return func(*args, **kwargs)
460
461  wrapped_func = tf_decorator.make_decorator(func, wrapped_func)
462  wrapped_func.__signature__ = func_signature.replace(
463      parameters=(list(func_signature.parameters.values()) +
464                  [api_signature.parameters["name"]]))
465  del wrapped_func._tf_decorator
466  return wrapped_func
467
468
469@tf_export("experimental.unregister_dispatch_for")
470def unregister_dispatch_for(dispatch_target):
471  """Unregisters a function that was registered with `@dispatch_for_*`.
472
473  This is primarily intended for testing purposes.
474
475  Example:
476
477  >>> # Define a type and register a dispatcher to override `tf.abs`:
478  >>> class MyTensor(tf.experimental.ExtensionType):
479  ...   value: tf.Tensor
480  >>> @tf.experimental.dispatch_for_api(tf.abs)
481  ... def my_abs(x: MyTensor):
482  ...   return MyTensor(tf.abs(x.value))
483  >>> tf.abs(MyTensor(5))
484  MyTensor(value=<tf.Tensor: shape=(), dtype=int32, numpy=5>)
485
486  >>> # Unregister the dispatcher, so `tf.abs` no longer calls `my_abs`.
487  >>> unregister_dispatch_for(my_abs)
488  >>> tf.abs(MyTensor(5))
489  Traceback (most recent call last):
490  ...
491  ValueError: Attempt to convert a value ... to a Tensor.
492
493  Args:
494    dispatch_target: The function to unregister.
495
496  Raises:
497    ValueError: If `dispatch_target` was not registered using `@dispatch_for`,
498      `@dispatch_for_unary_elementwise_apis`, or
499      `@dispatch_for_binary_elementwise_apis`.
500  """
501  found = False
502
503  # Check if dispatch_target registered by `@dispatch_for_api`
504  for api, signatures in _TYPE_BASED_DISPATCH_SIGNATURES.items():
505    if dispatch_target in signatures:
506      dispatcher = getattr(api, TYPE_BASED_DISPATCH_ATTR)
507      dispatcher.Unregister(dispatch_target)
508      del signatures[dispatch_target]
509      found = True
510
511  # Check if dispatch_target registered by `@dispatch_for_*_elementwise_apis`
512  elementwise_keys_to_delete = [
513      key for (key, handler) in _ELEMENTWISE_API_HANDLERS.items()
514      if handler is dispatch_target
515  ]
516  for key in set(elementwise_keys_to_delete):
517    for _, target in _ELEMENTWISE_API_TARGETS[key]:
518      unregister_dispatch_for(target)
519    del _ELEMENTWISE_API_HANDLERS[key]
520    del _ELEMENTWISE_API_TARGETS[key]
521    found = True
522
523  if not found:
524    raise ValueError(f"Function {dispatch_target} was not registered using "
525                     "a `@dispatch_for_*` decorator.")
526
527
528def register_dispatchable_type(cls):
529  """Class decorator that registers a type for use with type-based dispatch.
530
531  Should *not* be used with subclasses of `CompositeTensor` or `ExtensionType`
532  (which are automatically registered).
533
534  Note: this function is intended to support internal legacy use cases (such
535  as RaggedTensorValue), and will probably not be exposed as a public API.
536
537  Args:
538    cls: The class to register.
539
540  Returns:
541    `cls`.
542  """
543  _api_dispatcher.register_dispatchable_type(cls)
544  return cls
545
546
547def add_type_based_api_dispatcher(target):
548  """Adds a PythonAPIDispatcher to the given TensorFlow API function."""
549  if hasattr(target, TYPE_BASED_DISPATCH_ATTR):
550    raise ValueError(f"{target} already has a type-based API dispatcher.")
551
552  _, unwrapped = tf_decorator.unwrap(target)
553  target_argspec = tf_inspect.getargspec(unwrapped)
554  if target_argspec.varargs or target_argspec.keywords:
555    # @TODO(b/194903203) Add v2 dispatch support for APIs that take varargs
556    # and keywords.  Examples of APIs that take varargs and kwargs: meshgrid,
557    # einsum, map_values, map_flat_values.
558    return target
559
560  setattr(
561      target, TYPE_BASED_DISPATCH_ATTR,
562      _api_dispatcher.PythonAPIDispatcher(unwrapped.__name__,
563                                          target_argspec.args,
564                                          target_argspec.defaults))
565  _TYPE_BASED_DISPATCH_SIGNATURES[target] = collections.defaultdict(list)
566  return target
567
568
569def _check_signature(api_signature, func):
570  """Checks that a dispatch target's signature is compatible with an API.
571
572  Args:
573    api_signature: The signature of the TensorFlow API.
574    func: The dispatch target.
575
576  Raises:
577    ValueError: if the signatures are incompatible.  Two signatures are
578      considered compatible if they have the same number of parameters, and all
579      corresponding parameters have the same `name` and `kind`.  (Parameters
580      are not required to have the same default value or the same annotation.)
581  """
582  # Special case: if func_signature is (*args, **kwargs), then assume it's ok.
583  func_argspec = tf_inspect.getargspec(func)
584  if (func_argspec.varargs is not None and func_argspec.keywords is not None
585      and not func_argspec.args):
586    return
587
588  func_signature = tf_inspect.signature(func)
589  ok = len(api_signature.parameters) == len(func_signature.parameters)
590  if ok:
591    for param_1, param_2 in zip(api_signature.parameters.values(),
592                                func_signature.parameters.values()):
593      if (param_1.name != param_2.name) or (param_1.kind != param_2.kind):
594        ok = False
595  if not ok:
596    raise ValueError(f"Dispatch function's signature {func_signature} does "
597                     f"not match API's signature {api_signature}.")
598
599
600def _make_signature_checker(api_signature, signature):
601  """Builds a PySignatureChecker for the given type signature.
602
603  Args:
604    api_signature: The `inspect.Signature` of the API whose signature is
605      being checked.
606    signature: Dictionary mapping parameter names to type annotations.
607
608  Returns:
609    A `PySignatureChecker`.
610  """
611  if not (isinstance(signature, dict) and
612          all(isinstance(k, (str, int)) for k in signature)):
613    raise TypeError("signatures must be dictionaries mapping parameter names "
614                    "to type annotations.")
615  checkers = []
616
617  param_names = list(api_signature.parameters)
618  for param_name, param_type in signature.items():
619    # Convert positional parameters to named parameters.
620    if (isinstance(param_name, int) and
621        param_name < len(api_signature.parameters)):
622      param_name = list(api_signature.parameters.values())[param_name].name
623
624    # Check that the parameter exists, and has an appropriate kind.
625    param = api_signature.parameters.get(param_name, None)
626    if param is None:
627      raise ValueError("signature includes annotation for unknown "
628                       f"parameter {param_name!r}.")
629    if param.kind not in (tf_inspect.Parameter.POSITIONAL_ONLY,
630                          tf_inspect.Parameter.POSITIONAL_OR_KEYWORD):
631      raise ValueError("Dispatch currently only supports type annotations "
632                       "for positional parameters; can't handle annotation "
633                       f"for {param.kind!r} parameter {param_name}.")
634
635    checker = make_type_checker(param_type)
636    index = param_names.index(param_name)
637    checkers.append((index, checker))
638
639  return _api_dispatcher.PySignatureChecker(checkers)
640
641
642# Cache for InstanceTypeChecker objects (we only want to create one
643# InstanceTypeChecker for each type, since each one uses an internal cache
644# to avoid repeated calls back into Python's isinstance).
645_is_instance_checker_cache = {}
646
647
648def make_type_checker(annotation):
649  """Builds a PyTypeChecker for the given type annotation."""
650  if type_annotations.is_generic_union(annotation):
651    type_args = type_annotations.get_generic_type_args(annotation)
652
653    # If the union contains two or more simple types, then use a single
654    # InstanceChecker to check them.
655    simple_types = [t for t in type_args if isinstance(t, type)]
656    simple_types = tuple(sorted(simple_types, key=id))
657    if len(simple_types) > 1:
658      if simple_types not in _is_instance_checker_cache:
659        checker = _api_dispatcher.MakeInstanceChecker(*simple_types)
660        _is_instance_checker_cache[simple_types] = checker
661      options = ([_is_instance_checker_cache[simple_types]] +
662                 [make_type_checker(t) for t in type_args
663                  if not isinstance(t, type)])
664      return _api_dispatcher.MakeUnionChecker(options)
665
666    options = [make_type_checker(t) for t in type_args]
667    return _api_dispatcher.MakeUnionChecker(options)
668
669  elif type_annotations.is_generic_list(annotation):
670    type_args = type_annotations.get_generic_type_args(annotation)
671    if len(type_args) != 1:
672      raise AssertionError("Expected List[...] to have a single type parameter")
673    elt_type = make_type_checker(type_args[0])
674    return _api_dispatcher.MakeListChecker(elt_type)
675
676  elif isinstance(annotation, type):
677    if annotation not in _is_instance_checker_cache:
678      checker = _api_dispatcher.MakeInstanceChecker(annotation)
679      _is_instance_checker_cache[annotation] = checker
680    return _is_instance_checker_cache[annotation]
681
682  elif annotation is None:
683    return make_type_checker(type(None))
684
685  else:
686    raise ValueError(f"Type annotation {annotation} is not currently supported"
687                     " by dispatch.  Supported annotations: type objects, "
688                     " List[...], and Union[...]")
689
690
691def _signature_from_annotations(func):
692  """Builds a dict mapping from parameter names to type annotations."""
693  func_signature = tf_inspect.signature(func)
694
695  signature = dict([(name, param.annotation)
696                    for (name, param) in func_signature.parameters.items()
697                    if param.annotation != tf_inspect.Parameter.empty])
698  if not signature:
699    raise ValueError("The dispatch_for_api decorator must be called with at "
700                     "least one signature, or applied to a function that "
701                     "has type annotations on its parameters.")
702  return signature
703
704
705# Registries for elementwise APIs and API handlers.
706#
707# _*_ELEMENTWISE_APIS: A list of TensorFlow APIs that have been registered
708# as elementwise operations using the `register_*_elementwise_api`
709# decorators.
710#
711# _ELEMENTWISE_API_HANDLERS: Dicts mapping from argument type(s) to API
712# handlers that have been registered with the `dispatch_for_*_elementwise_apis`
713# decorators.
714#
715# _ELEMENTWISE_API_TARGETS: Dict mapping from argument type(s) to lists of
716# `(api, dispatch_target)` pairs.  Used to impelement
717# `unregister_elementwise_api_handler`.
718_UNARY_ELEMENTWISE_APIS = []
719_BINARY_ELEMENTWISE_APIS = []
720_BINARY_ELEMENTWISE_ASSERT_APIS = []
721_ELEMENTWISE_API_HANDLERS = {}
722_ELEMENTWISE_API_TARGETS = {}
723
724_ASSERT_API_TAG = "ASSERT_API_TAG"
725
726
727@tf_export("experimental.dispatch_for_unary_elementwise_apis")
728def dispatch_for_unary_elementwise_apis(x_type):
729  """Decorator to override default implementation for unary elementwise APIs.
730
731  The decorated function (known as the "elementwise api handler") overrides
732  the default implementation for any unary elementwise API whenever the value
733  for the first argument (typically named `x`) matches the type annotation
734  `x_type`. The elementwise api handler is called with two arguments:
735
736    `elementwise_api_handler(api_func, x)`
737
738  Where `api_func` is a function that takes a single parameter and performs the
739  elementwise operation (e.g., `tf.abs`), and `x` is the first argument to the
740  elementwise api.
741
742  The following example shows how this decorator can be used to update all
743  unary elementwise operations to handle a `MaskedTensor` type:
744
745  >>> class MaskedTensor(tf.experimental.ExtensionType):
746  ...   values: tf.Tensor
747  ...   mask: tf.Tensor
748  >>> @dispatch_for_unary_elementwise_apis(MaskedTensor)
749  ... def unary_elementwise_api_handler(api_func, x):
750  ...   return MaskedTensor(api_func(x.values), x.mask)
751  >>> mt = MaskedTensor([1, -2, -3], [True, False, True])
752  >>> abs_mt = tf.abs(mt)
753  >>> print(f"values={abs_mt.values.numpy()}, mask={abs_mt.mask.numpy()}")
754  values=[1 2 3], mask=[ True False True]
755
756  For unary elementwise operations that take extra arguments beyond `x`, those
757  arguments are *not* passed to the elementwise api handler, but are
758  automatically added when `api_func` is called.  E.g., in the following
759  example, the `dtype` parameter is not passed to
760  `unary_elementwise_api_handler`, but is added by `api_func`.
761
762  >>> ones_mt = tf.ones_like(mt, dtype=tf.float32)
763  >>> print(f"values={ones_mt.values.numpy()}, mask={ones_mt.mask.numpy()}")
764  values=[1.0 1.0 1.0], mask=[ True False True]
765
766  Args:
767    x_type: A type annotation indicating when the api handler should be called.
768      See `dispatch_for_api` for a list of supported annotation types.
769
770  Returns:
771    A decorator.
772
773  #### Registered APIs
774
775  The unary elementwise APIs are:
776
777  <<API_LIST>>
778  """
779
780  def decorator(handler):
781    if (x_type,) in _ELEMENTWISE_API_HANDLERS:
782      raise ValueError("A unary elementwise dispatch handler "
783                       f"({_ELEMENTWISE_API_HANDLERS[(x_type,)]}) "
784                       f"has already been registered for {x_type}.")
785    _ELEMENTWISE_API_HANDLERS[(x_type,)] = handler
786    for api in _UNARY_ELEMENTWISE_APIS:
787      _add_dispatch_for_unary_elementwise_api(api, x_type, handler)
788
789    return handler
790
791  return decorator
792
793
794@tf_export("experimental.dispatch_for_binary_elementwise_apis")
795def dispatch_for_binary_elementwise_apis(x_type, y_type):
796  """Decorator to override default implementation for binary elementwise APIs.
797
798  The decorated function (known as the "elementwise api handler") overrides
799  the default implementation for any binary elementwise API whenever the value
800  for the first two arguments (typically named `x` and `y`) match the specified
801  type annotations.  The elementwise api handler is called with two arguments:
802
803    `elementwise_api_handler(api_func, x, y)`
804
805  Where `x` and `y` are the first two arguments to the elementwise api, and
806  `api_func` is a TensorFlow function that takes two parameters and performs the
807  elementwise operation (e.g., `tf.add`).
808
809  The following example shows how this decorator can be used to update all
810  binary elementwise operations to handle a `MaskedTensor` type:
811
812  >>> class MaskedTensor(tf.experimental.ExtensionType):
813  ...   values: tf.Tensor
814  ...   mask: tf.Tensor
815  >>> @dispatch_for_binary_elementwise_apis(MaskedTensor, MaskedTensor)
816  ... def binary_elementwise_api_handler(api_func, x, y):
817  ...   return MaskedTensor(api_func(x.values, y.values), x.mask & y.mask)
818  >>> a = MaskedTensor([1, 2, 3, 4, 5], [True, True, True, True, False])
819  >>> b = MaskedTensor([2, 4, 6, 8, 0], [True, True, True, False, True])
820  >>> c = tf.add(a, b)
821  >>> print(f"values={c.values.numpy()}, mask={c.mask.numpy()}")
822  values=[ 3 6 9 12 5], mask=[ True True True False False]
823
824  Args:
825    x_type: A type annotation indicating when the api handler should be called.
826    y_type: A type annotation indicating when the api handler should be called.
827
828  Returns:
829    A decorator.
830
831  #### Registered APIs
832
833  The binary elementwise APIs are:
834
835  <<API_LIST>>
836  """
837
838  def decorator(handler):
839    if (x_type, y_type) in _ELEMENTWISE_API_HANDLERS:
840      raise ValueError("A binary elementwise dispatch handler "
841                       f"({_ELEMENTWISE_API_HANDLERS[x_type, y_type]}) "
842                       f"has already been registered for ({x_type}, {y_type}).")
843    _ELEMENTWISE_API_HANDLERS[x_type, y_type] = handler
844    for api in _BINARY_ELEMENTWISE_APIS:
845      _add_dispatch_for_binary_elementwise_api(api, x_type, y_type, handler)
846
847    return handler
848
849  return decorator
850
851
852@tf_export("experimental.dispatch_for_binary_elementwise_assert_apis")
853def dispatch_for_binary_elementwise_assert_apis(x_type, y_type):
854  """Decorator to override default implementation for binary elementwise assert APIs.
855
856  The decorated function (known as the "elementwise assert handler")
857  overrides the default implementation for any binary elementwise assert API
858  whenever the value for the first two arguments (typically named `x` and `y`)
859  match the specified type annotations.  The handler is called with two
860  arguments:
861
862    `elementwise_assert_handler(assert_func, x, y)`
863
864  Where `x` and `y` are the first two arguments to the binary elementwise assert
865  operation, and `assert_func` is a TensorFlow function that takes two
866  parameters and performs the elementwise assert operation (e.g.,
867  `tf.debugging.assert_equal`).
868
869  The following example shows how this decorator can be used to update all
870  binary elementwise assert operations to handle a `MaskedTensor` type:
871
872  >>> class MaskedTensor(tf.experimental.ExtensionType):
873  ...   values: tf.Tensor
874  ...   mask: tf.Tensor
875  >>> @dispatch_for_binary_elementwise_assert_apis(MaskedTensor, MaskedTensor)
876  ... def binary_elementwise_assert_api_handler(assert_func, x, y):
877  ...   merged_mask = tf.logical_and(x.mask, y.mask)
878  ...   selected_x_values = tf.boolean_mask(x.values, merged_mask)
879  ...   selected_y_values = tf.boolean_mask(y.values, merged_mask)
880  ...   assert_func(selected_x_values, selected_y_values)
881  >>> a = MaskedTensor([1, 1, 0, 1, 1], [False, False, True, True, True])
882  >>> b = MaskedTensor([2, 2, 0, 2, 2], [True, True, True, False, False])
883  >>> tf.debugging.assert_equal(a, b) # assert passed; no exception was thrown
884
885  >>> a = MaskedTensor([1, 1, 1, 1, 1], [True, True, True, True, True])
886  >>> b = MaskedTensor([0, 0, 0, 0, 2], [True, True, True, True, True])
887  >>> tf.debugging.assert_greater(a, b)
888  Traceback (most recent call last):
889  ...
890  InvalidArgumentError: Condition x > y did not hold.
891
892  Args:
893    x_type: A type annotation indicating when the api handler should be called.
894    y_type: A type annotation indicating when the api handler should be called.
895
896  Returns:
897    A decorator.
898
899  #### Registered APIs
900
901  The binary elementwise assert APIs are:
902
903  <<API_LIST>>
904  """
905
906  def decorator(handler):
907    api_handler_key = (x_type, y_type, _ASSERT_API_TAG)
908    if api_handler_key in _ELEMENTWISE_API_HANDLERS:
909      raise ValueError("A binary elementwise assert dispatch handler "
910                       f"({_ELEMENTWISE_API_HANDLERS[api_handler_key]}) "
911                       f"has already been registered for ({x_type}, {y_type}).")
912    _ELEMENTWISE_API_HANDLERS[api_handler_key] = handler
913    for api in _BINARY_ELEMENTWISE_ASSERT_APIS:
914      _add_dispatch_for_binary_elementwise_api(api, x_type, y_type, handler)
915
916    return handler
917
918  return decorator
919
920
921def register_unary_elementwise_api(func):
922  """Decorator that registers a TensorFlow op as a unary elementwise API."""
923  _UNARY_ELEMENTWISE_APIS.append(func)
924  for args, handler in _ELEMENTWISE_API_HANDLERS.items():
925    if len(args) == 1:
926      _add_dispatch_for_unary_elementwise_api(func, args[0], handler)
927  return func
928
929
930def register_binary_elementwise_api(func):
931  """Decorator that registers a TensorFlow op as a binary elementwise API."""
932  _BINARY_ELEMENTWISE_APIS.append(func)
933  for args, handler in _ELEMENTWISE_API_HANDLERS.items():
934    if len(args) == 2:
935      _add_dispatch_for_binary_elementwise_api(func, args[0], args[1], handler)
936  return func
937
938
939def register_binary_elementwise_assert_api(func):
940  """Decorator that registers a TensorFlow op as a binary elementwise assert API.
941
942  Different from `dispatch_for_binary_elementwise_apis`, this decorator is used
943  for assert apis, such as assert_equal, assert_none_equal, etc, which return
944  None in eager mode and an op in graph mode.
945
946  Args:
947    func: The function that implements the binary elementwise assert API.
948
949  Returns:
950    `func`
951  """
952  _BINARY_ELEMENTWISE_ASSERT_APIS.append(func)
953  for args, handler in _ELEMENTWISE_API_HANDLERS.items():
954    if len(args) == 3 and args[2] is _ASSERT_API_TAG:
955      _add_dispatch_for_binary_elementwise_api(func, args[0], args[1], handler)
956  return func
957
958
959def unary_elementwise_apis():
960  """Returns a list of APIs that have been registered as unary elementwise."""
961  return tuple(_UNARY_ELEMENTWISE_APIS)
962
963
964def binary_elementwise_apis():
965  """Returns a list of APIs that have been registered as binary elementwise."""
966  return tuple(_BINARY_ELEMENTWISE_APIS)
967
968
969def _add_dispatch_for_unary_elementwise_api(api, x_type,
970                                            elementwise_api_handler):
971  """Registers a unary elementwise handler as a dispatcher for a given API."""
972  api_signature = tf_inspect.signature(api)
973  x_name = list(api_signature.parameters)[0]
974  name_index = _find_name_index(api_signature)
975
976  need_to_bind_api_args = (
977      len(api_signature.parameters) > 2 or
978      "name" not in api_signature.parameters)
979
980  @dispatch_for_api(api, {x_name: x_type})
981  def dispatch_target(*args, **kwargs):
982    args, kwargs, name = _extract_name_arg(args, kwargs, name_index)
983    if args:
984      x, args = args[0], args[1:]
985    else:
986      x = kwargs.pop(x_name)
987
988    if need_to_bind_api_args:
989      tensor_api = lambda v: api(v, *args, **kwargs)
990    else:
991      tensor_api = api
992
993    if name is None:
994      return elementwise_api_handler(tensor_api, x)
995    else:
996      with ops.name_scope(name, None, [x]):
997        return elementwise_api_handler(tensor_api, x)
998
999  dispatch_target.__name__ = "elementwise_dispatch_target_for_" + api.__name__
1000  dispatch_target.__qualname__ = dispatch_target.__name__
1001  # Keep track of what targets we've registered (so we can unregister them).
1002  target_list = _ELEMENTWISE_API_TARGETS.setdefault((x_type,), [])
1003  target_list.append((api, dispatch_target))
1004
1005
1006def _add_dispatch_for_binary_elementwise_api(api, x_type, y_type,
1007                                             elementwise_api_handler):
1008  """Registers a binary elementwise handler as a dispatcher for a given API."""
1009  api_signature = tf_inspect.signature(api)
1010  x_name, y_name = list(api_signature.parameters)[:2]
1011  name_index = _find_name_index(api_signature)
1012
1013  need_to_bind_api_args = (len(api_signature.parameters) > 3 or
1014                           "name" not in api_signature.parameters)
1015
1016  @dispatch_for_api(api, {x_name: x_type, y_name: y_type})
1017  def dispatch_target(*args, **kwargs):
1018    args, kwargs, name = _extract_name_arg(args, kwargs, name_index)
1019    if len(args) > 1:
1020      x, y, args = args[0], args[1], args[2:]
1021    elif args:
1022      x, args = args[0], args[1:]
1023      y = kwargs.pop(y_name, None)
1024    else:
1025      x = kwargs.pop(x_name, None)
1026      y = kwargs.pop(y_name, None)
1027
1028    if need_to_bind_api_args:
1029      tensor_api = lambda v1, v2: api(v1, v2, *args, **kwargs)
1030    else:
1031      tensor_api = api
1032
1033    if name is None:
1034      return elementwise_api_handler(tensor_api, x, y)
1035    else:
1036      with ops.name_scope(name, None, [x, y]):
1037        return elementwise_api_handler(tensor_api, x, y)
1038
1039  dispatch_target.__name__ = "elementwise_dispatch_target_for_" + api.__name__
1040  dispatch_target.__qualname__ = dispatch_target.__name__
1041  # Keep track of what targets we've registered (so we can unregister them).
1042  target_list = _ELEMENTWISE_API_TARGETS.setdefault((x_type, y_type), [])
1043  target_list.append((api, dispatch_target))
1044
1045
1046def _find_name_index(signature):
1047  """Returns the index of the `name` parameter, or -1 if it's not present."""
1048  try:
1049    return list(signature.parameters).index("name")
1050  except ValueError:
1051    return -1
1052
1053
1054def _extract_name_arg(args, kwargs, name_index):
1055  """Extracts the parameter `name` and returns `(args, kwargs, name_value)`."""
1056  if name_index < 0:
1057    name_value = None
1058  elif name_index < len(args):
1059    name_value = args[name_index]
1060    args = args[:name_index] + args[name_index + 1:]
1061  else:
1062    name_value = kwargs.pop("name", None)
1063  return args, kwargs, name_value
1064
1065
1066def update_docstrings_with_api_lists():
1067  """Updates the docstrings of dispatch decorators with API lists.
1068
1069  Updates docstrings for `dispatch_for_api`,
1070  `dispatch_for_unary_elementwise_apis`, and
1071  `dispatch_for_binary_elementwise_apis`, by replacing the string '<<API_LIST>>'
1072  with a list of APIs that have been registered for that decorator.
1073  """
1074  _update_docstring_with_api_list(dispatch_for_unary_elementwise_apis,
1075                                  _UNARY_ELEMENTWISE_APIS)
1076  _update_docstring_with_api_list(dispatch_for_binary_elementwise_apis,
1077                                  _BINARY_ELEMENTWISE_APIS)
1078  _update_docstring_with_api_list(dispatch_for_binary_elementwise_assert_apis,
1079                                  _BINARY_ELEMENTWISE_ASSERT_APIS)
1080  _update_docstring_with_api_list(dispatch_for_api,
1081                                  _TYPE_BASED_DISPATCH_SIGNATURES)
1082
1083
1084def _update_docstring_with_api_list(target, api_list):
1085  """Replaces `<<API_LIST>>` in target.__doc__ with the given list of APIs."""
1086  lines = []
1087  for func in api_list:
1088    name = tf_export_lib.get_canonical_name_for_symbol(
1089        func, add_prefix_to_v1_names=True)
1090    if name is not None:
1091      params = tf_inspect.signature(func).parameters.keys()
1092      lines.append(f"  * `tf.{name}({', '.join(params)})`")
1093  lines.sort()
1094  target.__doc__ = target.__doc__.replace("  <<API_LIST>>", "\n".join(lines))
1095
1096
1097################################################################################
1098# Dispatch Support
1099################################################################################
1100@tf_export("__internal__.dispatch.add_dispatch_support", v1=[])
1101def add_dispatch_support(target=None, iterable_parameters=None):
1102  """Decorator that adds a dispatch handling wrapper to a TensorFlow Python API.
1103
1104  This wrapper adds the decorated function as an API that can be overridden
1105  using the `@dispatch_for_api` decorator.  In the following example, we first
1106  define a new API (`double`) that supports dispatch, then define a custom type
1107  (`MaskedTensor`) and finally use `dispatch_for_api` to override the default
1108  implementation of `double` when called with `MaskedTensor` values:
1109
1110  >>> @add_dispatch_support
1111  ... def double(x):
1112  ...   return x * 2
1113  >>> class MaskedTensor(tf.experimental.ExtensionType):
1114  ...   values: tf.Tensor
1115  ...   mask: tf.Tensor
1116  >>> @dispatch_for_api(double, {'x': MaskedTensor})
1117  ... def masked_double(x):
1118  ...   return MaskedTensor(x.values * 2, y.mask)
1119
1120  The optional `iterable_parameter` argument can be used to mark parameters that
1121  can take arbitrary iterable values (such as generator expressions).  These
1122  need to be handled specially during dispatch, since just iterating over an
1123  iterable uses up its values.  In the following example, we define a new API
1124  whose second argument can be an iterable value; and then override the default
1125  implementatio of that API when the iterable contains MaskedTensors:
1126
1127  >>> @add_dispatch_support(iterable_parameters=['ys'])
1128  ... def add_tensor_to_list_of_tensors(x, ys):
1129  ...   return [x + y for y in ys]
1130  >>> @dispatch_for_api(add_tensor_to_list_of_tensors,
1131  ...               {'ys': typing.List[MaskedTensor]})
1132  ... def masked_add_tensor_to_list_of_tensors(x, ys):
1133  ...   return [MaskedTensor(x+y.values, y.mask) for y in ys]
1134
1135  (Note: the only TensorFlow API that currently supports iterables is `add_n`.)
1136
1137  Args:
1138    target: The TensorFlow API that should support dispatch.
1139    iterable_parameters: Optional list of parameter names that may be called
1140      with iterables (such as the `inputs` parameter for `tf.add_n`).
1141
1142  Returns:
1143    A decorator.
1144  """
1145
1146  if not (iterable_parameters is None or
1147          (isinstance(iterable_parameters, (list, tuple)) and
1148           all(isinstance(p, str) for p in iterable_parameters))):
1149    raise TypeError("iterable_parameters should be a list or tuple of string.")
1150
1151  def decorator(dispatch_target):
1152
1153    # Get the name & index for each iterable parameter.
1154    if iterable_parameters is None:
1155      iterable_params = None
1156    else:
1157      arg_names = tf_inspect.getargspec(dispatch_target).args
1158      iterable_params = [
1159          (name, arg_names.index(name)) for name in iterable_parameters
1160      ]
1161
1162    @traceback_utils.filter_traceback
1163    def op_dispatch_handler(*args, **kwargs):
1164      """Call `dispatch_target`, peforming dispatch when appropriate."""
1165
1166      # Type-based dispatch system (dispatch v2):
1167      if api_dispatcher is not None:
1168        if iterable_params is not None:
1169          args, kwargs = replace_iterable_params(args, kwargs, iterable_params)
1170        result = api_dispatcher.Dispatch(args, kwargs)
1171        if result is not NotImplemented:
1172          return result
1173
1174      # Fallback dispatch system (dispatch v1):
1175      try:
1176        return dispatch_target(*args, **kwargs)
1177      except (TypeError, ValueError):
1178        # Note: convert_to_eager_tensor currently raises a ValueError, not a
1179        # TypeError, when given unexpected types.  So we need to catch both.
1180        result = dispatch(op_dispatch_handler, args, kwargs)
1181        if result is not OpDispatcher.NOT_SUPPORTED:
1182          return result
1183        else:
1184          raise
1185
1186    add_fallback_dispatch_list(op_dispatch_handler)
1187    op_dispatch_handler = tf_decorator.make_decorator(dispatch_target,
1188                                                      op_dispatch_handler)
1189    add_type_based_api_dispatcher(op_dispatch_handler)
1190    api_dispatcher = getattr(op_dispatch_handler, TYPE_BASED_DISPATCH_ATTR,
1191                             None)
1192    return op_dispatch_handler
1193
1194  if target is None:
1195    return decorator
1196  else:
1197    return decorator(target)
1198
1199
1200def replace_iterable_params(args, kwargs, iterable_params):
1201  """Returns (args, kwargs) with any iterable parameters converted to lists.
1202
1203  Args:
1204    args: Positional rguments to a function
1205    kwargs: Keyword arguments to a function.
1206    iterable_params: A list of (name, index) tuples for iterable parameters.
1207
1208  Returns:
1209    A tuple (args, kwargs), where any positional or keyword parameters in
1210    `iterable_params` have their value converted to a `list`.
1211  """
1212  args = list(args)
1213  for name, index in iterable_params:
1214    if index < len(args):
1215      args[index] = list(args[index])
1216    elif name in kwargs:
1217      kwargs[name] = list(kwargs[name])
1218  return tuple(args), kwargs
1219