1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Type-based dispatch for TensorFlow's Python APIs. 16 17"Python APIs" refers to Python functions that have been exported with 18`tf_export`, such as `tf.add` and `tf.linalg.matmul`; they are sometimes also 19referred to as "ops". 20 21There are currently two dispatch systems for TensorFlow: 22 23 * The "fallback dispatch" system calls an API's standard implementation first, 24 and only tries to perform dispatch if that standard implementation raises a 25 TypeError (or ValueError) exception. 26 27 * The "type-based dispatch" system checks the types of the parameters passed 28 to an API, and performs dispatch if those types match any signatures that 29 have been registered for dispatch. 30 31The fallback dispatch system was the original dispatch system, but it was 32somewhat brittle and had limitations, such as an inability to support dispatch 33for some operations (like convert_to_tensor). We plan to remove the fallback 34dispatch system in favor of the type-based dispatch system, once all users have 35been switched over to use it. 36 37### Fallback Dispatch 38 39The fallback dispatch system is based on "operation dispatchers", which can be 40used to override the behavior for TensorFlow ops when they are called with 41otherwise unsupported argument types. In particular, when an operation is 42called with arguments that would cause it to raise a TypeError, it falls back on 43its registered operation dispatchers. If any registered dispatchers can handle 44the arguments, then its result is returned. Otherwise, the original TypeError is 45raised. 46 47### Type-based Dispatch 48 49The main interface for the type-based dispatch system is the `dispatch_for_api` 50decorator, which overrides the default implementation for a TensorFlow API. 51The decorated function (known as the "dispatch target") will override the 52default implementation for the API when the API is called with parameters that 53match a specified type signature. 54 55### Dispatch Support 56 57By default, dispatch support is added to the generated op wrappers for any 58visible ops by default. APIs/ops that are implemented in Python can opt in to 59dispatch support using the `add_dispatch_support` decorator. 60""" 61 62import collections 63import itertools 64import typing # pylint: disable=unused-import (used in doctests) 65 66from tensorflow.python.framework import _pywrap_python_api_dispatcher as _api_dispatcher 67from tensorflow.python.framework import ops 68from tensorflow.python.util import tf_decorator 69from tensorflow.python.util import tf_export as tf_export_lib 70from tensorflow.python.util import tf_inspect 71from tensorflow.python.util import traceback_utils 72from tensorflow.python.util import type_annotations 73from tensorflow.python.util.tf_export import tf_export 74 75 76# Private function attributes used to store dispatchers on TensorFlow APIs. 77FALLBACK_DISPATCH_ATTR = "_tf_fallback_dispatchers" 78TYPE_BASED_DISPATCH_ATTR = "_tf_type_based_dispatcher" 79 80# OpDispatchers which should be used for all operations. 81_GLOBAL_DISPATCHERS = [] 82 83 84################################################################################ 85# Fallback Dispatch 86################################################################################ 87 88 89@tf_export("__internal__.dispatch.OpDispatcher", v1=[]) 90class OpDispatcher(object): 91 """Abstract base class for TensorFlow operator dispatchers. 92 93 Each operation dispatcher acts as an override handler for a single 94 TensorFlow operation, and its results are used when the handler indicates 95 that it can handle the operation's arguments (by returning any value other 96 than `OpDispatcher.NOT_SUPPORTED`). 97 """ 98 99 # Sentinel value that can be returned to indicate that an operation 100 # dispatcher does not support a given set of arguments. 101 NOT_SUPPORTED = object() 102 103 def handle(self, args, kwargs): # pylint: disable=unused-argument 104 """Handle this dispatcher's operation with the specified arguments. 105 106 If this operation dispatcher can handle the given arguments, then 107 return an appropriate value (or raise an appropriate exception). 108 109 Args: 110 args: The arguments to the operation. 111 kwargs: They keyword arguments to the operation. 112 113 Returns: 114 The result of the operation, or `OpDispatcher.NOT_SUPPORTED` if this 115 dispatcher can not handle the given arguments. 116 """ 117 return self.NOT_SUPPORTED 118 119 def register(self, op): 120 """Register this dispatcher as a handler for `op`. 121 122 Args: 123 op: Python function: the TensorFlow operation that should be handled. Must 124 have a dispatch list (which is added automatically for generated ops, 125 and can be added to Python ops using the `add_dispatch_support` 126 decorator). 127 """ 128 if not hasattr(op, FALLBACK_DISPATCH_ATTR): 129 raise AssertionError("Dispatching not enabled for %s" % op) 130 getattr(op, FALLBACK_DISPATCH_ATTR).append(self) 131 132 133@tf_export("__internal__.dispatch.GlobalOpDispatcher", v1=[]) 134class GlobalOpDispatcher(object): 135 """Abstract base class for TensorFlow global operator dispatchers.""" 136 137 NOT_SUPPORTED = OpDispatcher.NOT_SUPPORTED 138 139 def handle(self, op, args, kwargs): 140 """Handle the specified operation with the specified arguments.""" 141 142 def register(self): 143 """Register this dispatcher as a handler for all ops.""" 144 _GLOBAL_DISPATCHERS.append(self) 145 146 147def dispatch(op, args, kwargs): 148 """Returns the result from the first successful dispatcher for a given op. 149 150 Calls the `handle` method of each `OpDispatcher` that has been registered 151 to handle `op`, and returns the value from the first successful handler. 152 153 Args: 154 op: Python function: the operation to dispatch for. 155 args: The arguments to the operation. 156 kwargs: They keyword arguments to the operation. 157 158 Returns: 159 The result of the operation, or `NOT_SUPPORTED` if no registered 160 dispatcher can handle the given arguments. 161 """ 162 for dispatcher in getattr(op, FALLBACK_DISPATCH_ATTR): 163 result = dispatcher.handle(args, kwargs) 164 if result is not OpDispatcher.NOT_SUPPORTED: 165 return result 166 for dispatcher in _GLOBAL_DISPATCHERS: 167 result = dispatcher.handle(op, args, kwargs) 168 if result is not OpDispatcher.NOT_SUPPORTED: 169 return result 170 return OpDispatcher.NOT_SUPPORTED 171 172 173class _TypeBasedDispatcher(OpDispatcher): 174 """Dispatcher that handles op if any arguments have a specified type. 175 176 Checks the types of the arguments and keyword arguments (including elements 177 of lists or tuples), and if any argument values have the indicated type(s), 178 then delegates to an override function. 179 """ 180 181 def __init__(self, override_func, types): 182 self._types = types 183 self._override_func = override_func 184 185 def _handles(self, args, kwargs): 186 for arg in itertools.chain(args, kwargs.values()): 187 if (isinstance(arg, self._types) or 188 (isinstance(arg, (list, tuple)) and 189 any(isinstance(elt, self._types) for elt in arg))): 190 return True 191 return False 192 193 def handle(self, args, kwargs): 194 if self._handles(args, kwargs): 195 return self._override_func(*args, **kwargs) 196 else: 197 return self.NOT_SUPPORTED 198 199 200# pylint: disable=g-doc-return-or-yield 201def dispatch_for_types(op, *types): 202 """Decorator to declare that a Python function overrides an op for a type. 203 204 The decorated function is used to override `op` if any of the arguments or 205 keyword arguments (including elements of lists or tuples) have one of the 206 specified types. 207 208 Example: 209 210 ```python 211 @dispatch_for_types(math_ops.add, RaggedTensor, RaggedTensorValue) 212 def ragged_add(x, y, name=None): ... 213 ``` 214 215 Args: 216 op: Python function: the operation that should be overridden. 217 *types: The argument types for which this function should be used. 218 """ 219 220 def decorator(func): 221 if tf_inspect.getargspec(func) != tf_inspect.getargspec(op): 222 raise AssertionError("The decorated function's signature must exactly " 223 "match the signature of the overridden op.") 224 _TypeBasedDispatcher(func, types).register(op) 225 return func 226 227 return decorator 228 229 230# pylint: enable=g-doc-return-or-yield 231 232 233def add_fallback_dispatch_list(target): 234 """Decorator that adds a dispatch_list attribute to an op.""" 235 if hasattr(target, FALLBACK_DISPATCH_ATTR): 236 raise AssertionError("%s already has a dispatch list" % target) 237 setattr(target, FALLBACK_DISPATCH_ATTR, []) 238 return target 239 240 241# Alias for backwards-compatibility. 242add_dispatch_list = add_fallback_dispatch_list 243 244 245################################################################################ 246# Type-based Dispatch 247################################################################################ 248 249 250@tf_export("experimental.dispatch_for_api") 251def dispatch_for_api(api, *signatures): 252 """Decorator that overrides the default implementation for a TensorFlow API. 253 254 The decorated function (known as the "dispatch target") will override the 255 default implementation for the API when the API is called with parameters that 256 match a specified type signature. Signatures are specified using dictionaries 257 that map parameter names to type annotations. E.g., in the following example, 258 `masked_add` will be called for `tf.add` if both `x` and `y` are 259 `MaskedTensor`s: 260 261 >>> class MaskedTensor(tf.experimental.ExtensionType): 262 ... values: tf.Tensor 263 ... mask: tf.Tensor 264 265 >>> @dispatch_for_api(tf.math.add, {'x': MaskedTensor, 'y': MaskedTensor}) 266 ... def masked_add(x, y, name=None): 267 ... return MaskedTensor(x.values + y.values, x.mask & y.mask) 268 269 >>> mt = tf.add(MaskedTensor([1, 2], [True, False]), MaskedTensor(10, True)) 270 >>> print(f"values={mt.values.numpy()}, mask={mt.mask.numpy()}") 271 values=[11 12], mask=[ True False] 272 273 If multiple type signatures are specified, then the dispatch target will be 274 called if any of the signatures match. For example, the following code 275 registers `masked_add` to be called if `x` is a `MaskedTensor` *or* `y` is 276 a `MaskedTensor`. 277 278 >>> @dispatch_for_api(tf.math.add, {'x': MaskedTensor}, {'y':MaskedTensor}) 279 ... def masked_add(x, y): 280 ... x_values = x.values if isinstance(x, MaskedTensor) else x 281 ... x_mask = x.mask if isinstance(x, MaskedTensor) else True 282 ... y_values = y.values if isinstance(y, MaskedTensor) else y 283 ... y_mask = y.mask if isinstance(y, MaskedTensor) else True 284 ... return MaskedTensor(x_values + y_values, x_mask & y_mask) 285 286 The type annotations in type signatures may be type objects (e.g., 287 `MaskedTensor`), `typing.List` values, or `typing.Union` values. For 288 example, the following will register `masked_concat` to be called if `values` 289 is a list of `MaskedTensor` values: 290 291 >>> @dispatch_for_api(tf.concat, {'values': typing.List[MaskedTensor]}) 292 ... def masked_concat(values, axis): 293 ... return MaskedTensor(tf.concat([v.values for v in values], axis), 294 ... tf.concat([v.mask for v in values], axis)) 295 296 Each type signature must contain at least one subclass of `tf.CompositeTensor` 297 (which includes subclasses of `tf.ExtensionType`), and dispatch will only be 298 triggered if at least one type-annotated parameter contains a 299 `CompositeTensor` value. This rule avoids invoking dispatch in degenerate 300 cases, such as the following examples: 301 302 * `@dispatch_for_api(tf.concat, {'values': List[MaskedTensor]})`: Will not 303 dispatch to the decorated dispatch target when the user calls 304 `tf.concat([])`. 305 306 * `@dispatch_for_api(tf.add, {'x': Union[MaskedTensor, Tensor], 'y': 307 Union[MaskedTensor, Tensor]})`: Will not dispatch to the decorated dispatch 308 target when the user calls `tf.add(tf.constant(1), tf.constant(2))`. 309 310 The dispatch target's signature must match the signature of the API that is 311 being overridden. In particular, parameters must have the same names, and 312 must occur in the same order. The dispatch target may optionally elide the 313 "name" parameter, in which case it will be wrapped with a call to 314 `tf.name_scope` when appropraite. 315 316 Args: 317 api: The TensorFlow API to override. 318 *signatures: Dictionaries mapping parameter names or indices to type 319 annotations, specifying when the dispatch target should be called. In 320 particular, the dispatch target will be called if any signature matches; 321 and a signature matches if all of the specified parameters have types that 322 match with the indicated type annotations. If no signatures are 323 specified, then a signature will be read from the dispatch target 324 function's type annotations. 325 326 Returns: 327 A decorator that overrides the default implementation for `api`. 328 329 #### Registered APIs 330 331 The TensorFlow APIs that may be overridden by `@dispatch_for_api` are: 332 333 <<API_LIST>> 334 """ 335 dispatcher = getattr(api, TYPE_BASED_DISPATCH_ATTR, None) 336 if dispatcher is None: 337 raise ValueError(f"{api} does not support dispatch.") 338 339 api_signature = tf_inspect.signature(api) 340 signature_checkers = [ 341 _make_signature_checker(api_signature, signature) 342 for signature in signatures 343 ] 344 345 def decorator(dispatch_target): 346 """Decorator that registers the given dispatch target.""" 347 if not callable(dispatch_target): 348 raise TypeError("Expected dispatch_target to be callable; " 349 f"got {dispatch_target!r}") 350 dispatch_target = _add_name_scope_wrapper(dispatch_target, api_signature) 351 _check_signature(api_signature, dispatch_target) 352 353 for signature_checker in signature_checkers: 354 dispatcher.Register(signature_checker, dispatch_target) 355 _TYPE_BASED_DISPATCH_SIGNATURES[api][dispatch_target].extend(signatures) 356 357 if not signature_checkers: 358 signature = _signature_from_annotations(dispatch_target) 359 checker = _make_signature_checker(api_signature, signature) 360 dispatcher.Register(checker, dispatch_target) 361 _TYPE_BASED_DISPATCH_SIGNATURES[api][dispatch_target].append(signature) 362 363 return dispatch_target 364 365 return decorator 366 367 368# Nested dict mapping `api_func` -> `dispatch_target` -> `List[signature]`, 369# which can be used for documentation generation and for improved error messages 370# when APIs are called with unsupported types. 371_TYPE_BASED_DISPATCH_SIGNATURES = {} 372 373 374def apis_with_type_based_dispatch(): 375 """Returns a list of TensorFlow APIs that support type-based dispatch.""" 376 return sorted( 377 _TYPE_BASED_DISPATCH_SIGNATURES, 378 key=lambda api: f"{api.__module__}.{api.__name__}") 379 380 381def type_based_dispatch_signatures_for(cls): 382 """Returns dispatch signatures that have been registered for a given class. 383 384 This function is intended for documentation-generation purposes. 385 386 Args: 387 cls: The class to search for. Type signatures are searched recursively, so 388 e.g., if `cls=RaggedTensor`, then information will be returned for all 389 dispatch targets that have `RaggedTensor` anywhere in their type 390 annotations (including nested in `typing.Union` or `typing.List`.) 391 392 Returns: 393 A `dict` mapping `api` -> `signatures`, where `api` is a TensorFlow API 394 function; and `signatures` is a list of dispatch signatures for `api` 395 that include `cls`. (Each signature is a dict mapping argument names to 396 type annotations; see `dispatch_for_api` for more info.) 397 """ 398 399 def contains_cls(x): 400 """Returns true if `x` contains `cls`.""" 401 if isinstance(x, dict): 402 return any(contains_cls(v) for v in x.values()) 403 elif x is cls: 404 return True 405 elif (type_annotations.is_generic_list(x) or 406 type_annotations.is_generic_union(x)): 407 type_args = type_annotations.get_generic_type_args(x) 408 return any(contains_cls(arg) for arg in type_args) 409 else: 410 return False 411 412 result = {} 413 for api, api_signatures in _TYPE_BASED_DISPATCH_SIGNATURES.items(): 414 for _, signatures in api_signatures.items(): 415 filtered = list(filter(contains_cls, signatures)) 416 if filtered: 417 result.setdefault(api, []).extend(filtered) 418 return result 419 420 421# TODO(edloper): Consider using a mechanism like this to automatically add 422# the `name` argument to all TensorFlow APIs that are implemented in Python 423# (so each Python function doesn't need to do it manually). 424def _add_name_scope_wrapper(func, api_signature): 425 """Wraps `func` to expect a "name" arg, and use it to call `ops.name_scope`. 426 427 If `func` already expects a "name" arg, or if `api_signature` does not 428 expect a "name" arg, then returns `func` as-is. 429 430 Args: 431 func: The function to wrap. Signature must match `api_signature` (except 432 the "name" parameter may be missing. 433 api_signature: The signature of the original API (used to find the index for 434 the "name" parameter). 435 436 Returns: 437 The wrapped function (or the original function if no wrapping is needed). 438 """ 439 if "name" not in api_signature.parameters: 440 return func # no wrapping needed (API has no name parameter). 441 442 func_signature = tf_inspect.signature(func) 443 func_argspec = tf_inspect.getargspec(func) 444 if "name" in func_signature.parameters or func_argspec.keywords is not None: 445 return func # No wrapping needed (already has name parameter). 446 447 name_index = list(api_signature.parameters).index("name") 448 449 def wrapped_func(*args, **kwargs): 450 if name_index < len(args): 451 name = args[name_index] 452 args = args[:name_index] + args[name_index + 1:] 453 else: 454 name = kwargs.pop("name", None) 455 if name is None: 456 return func(*args, **kwargs) 457 else: 458 with ops.name_scope(name): 459 return func(*args, **kwargs) 460 461 wrapped_func = tf_decorator.make_decorator(func, wrapped_func) 462 wrapped_func.__signature__ = func_signature.replace( 463 parameters=(list(func_signature.parameters.values()) + 464 [api_signature.parameters["name"]])) 465 del wrapped_func._tf_decorator 466 return wrapped_func 467 468 469@tf_export("experimental.unregister_dispatch_for") 470def unregister_dispatch_for(dispatch_target): 471 """Unregisters a function that was registered with `@dispatch_for_*`. 472 473 This is primarily intended for testing purposes. 474 475 Example: 476 477 >>> # Define a type and register a dispatcher to override `tf.abs`: 478 >>> class MyTensor(tf.experimental.ExtensionType): 479 ... value: tf.Tensor 480 >>> @tf.experimental.dispatch_for_api(tf.abs) 481 ... def my_abs(x: MyTensor): 482 ... return MyTensor(tf.abs(x.value)) 483 >>> tf.abs(MyTensor(5)) 484 MyTensor(value=<tf.Tensor: shape=(), dtype=int32, numpy=5>) 485 486 >>> # Unregister the dispatcher, so `tf.abs` no longer calls `my_abs`. 487 >>> unregister_dispatch_for(my_abs) 488 >>> tf.abs(MyTensor(5)) 489 Traceback (most recent call last): 490 ... 491 ValueError: Attempt to convert a value ... to a Tensor. 492 493 Args: 494 dispatch_target: The function to unregister. 495 496 Raises: 497 ValueError: If `dispatch_target` was not registered using `@dispatch_for`, 498 `@dispatch_for_unary_elementwise_apis`, or 499 `@dispatch_for_binary_elementwise_apis`. 500 """ 501 found = False 502 503 # Check if dispatch_target registered by `@dispatch_for_api` 504 for api, signatures in _TYPE_BASED_DISPATCH_SIGNATURES.items(): 505 if dispatch_target in signatures: 506 dispatcher = getattr(api, TYPE_BASED_DISPATCH_ATTR) 507 dispatcher.Unregister(dispatch_target) 508 del signatures[dispatch_target] 509 found = True 510 511 # Check if dispatch_target registered by `@dispatch_for_*_elementwise_apis` 512 elementwise_keys_to_delete = [ 513 key for (key, handler) in _ELEMENTWISE_API_HANDLERS.items() 514 if handler is dispatch_target 515 ] 516 for key in set(elementwise_keys_to_delete): 517 for _, target in _ELEMENTWISE_API_TARGETS[key]: 518 unregister_dispatch_for(target) 519 del _ELEMENTWISE_API_HANDLERS[key] 520 del _ELEMENTWISE_API_TARGETS[key] 521 found = True 522 523 if not found: 524 raise ValueError(f"Function {dispatch_target} was not registered using " 525 "a `@dispatch_for_*` decorator.") 526 527 528def register_dispatchable_type(cls): 529 """Class decorator that registers a type for use with type-based dispatch. 530 531 Should *not* be used with subclasses of `CompositeTensor` or `ExtensionType` 532 (which are automatically registered). 533 534 Note: this function is intended to support internal legacy use cases (such 535 as RaggedTensorValue), and will probably not be exposed as a public API. 536 537 Args: 538 cls: The class to register. 539 540 Returns: 541 `cls`. 542 """ 543 _api_dispatcher.register_dispatchable_type(cls) 544 return cls 545 546 547def add_type_based_api_dispatcher(target): 548 """Adds a PythonAPIDispatcher to the given TensorFlow API function.""" 549 if hasattr(target, TYPE_BASED_DISPATCH_ATTR): 550 raise ValueError(f"{target} already has a type-based API dispatcher.") 551 552 _, unwrapped = tf_decorator.unwrap(target) 553 target_argspec = tf_inspect.getargspec(unwrapped) 554 if target_argspec.varargs or target_argspec.keywords: 555 # @TODO(b/194903203) Add v2 dispatch support for APIs that take varargs 556 # and keywords. Examples of APIs that take varargs and kwargs: meshgrid, 557 # einsum, map_values, map_flat_values. 558 return target 559 560 setattr( 561 target, TYPE_BASED_DISPATCH_ATTR, 562 _api_dispatcher.PythonAPIDispatcher(unwrapped.__name__, 563 target_argspec.args, 564 target_argspec.defaults)) 565 _TYPE_BASED_DISPATCH_SIGNATURES[target] = collections.defaultdict(list) 566 return target 567 568 569def _check_signature(api_signature, func): 570 """Checks that a dispatch target's signature is compatible with an API. 571 572 Args: 573 api_signature: The signature of the TensorFlow API. 574 func: The dispatch target. 575 576 Raises: 577 ValueError: if the signatures are incompatible. Two signatures are 578 considered compatible if they have the same number of parameters, and all 579 corresponding parameters have the same `name` and `kind`. (Parameters 580 are not required to have the same default value or the same annotation.) 581 """ 582 # Special case: if func_signature is (*args, **kwargs), then assume it's ok. 583 func_argspec = tf_inspect.getargspec(func) 584 if (func_argspec.varargs is not None and func_argspec.keywords is not None 585 and not func_argspec.args): 586 return 587 588 func_signature = tf_inspect.signature(func) 589 ok = len(api_signature.parameters) == len(func_signature.parameters) 590 if ok: 591 for param_1, param_2 in zip(api_signature.parameters.values(), 592 func_signature.parameters.values()): 593 if (param_1.name != param_2.name) or (param_1.kind != param_2.kind): 594 ok = False 595 if not ok: 596 raise ValueError(f"Dispatch function's signature {func_signature} does " 597 f"not match API's signature {api_signature}.") 598 599 600def _make_signature_checker(api_signature, signature): 601 """Builds a PySignatureChecker for the given type signature. 602 603 Args: 604 api_signature: The `inspect.Signature` of the API whose signature is 605 being checked. 606 signature: Dictionary mapping parameter names to type annotations. 607 608 Returns: 609 A `PySignatureChecker`. 610 """ 611 if not (isinstance(signature, dict) and 612 all(isinstance(k, (str, int)) for k in signature)): 613 raise TypeError("signatures must be dictionaries mapping parameter names " 614 "to type annotations.") 615 checkers = [] 616 617 param_names = list(api_signature.parameters) 618 for param_name, param_type in signature.items(): 619 # Convert positional parameters to named parameters. 620 if (isinstance(param_name, int) and 621 param_name < len(api_signature.parameters)): 622 param_name = list(api_signature.parameters.values())[param_name].name 623 624 # Check that the parameter exists, and has an appropriate kind. 625 param = api_signature.parameters.get(param_name, None) 626 if param is None: 627 raise ValueError("signature includes annotation for unknown " 628 f"parameter {param_name!r}.") 629 if param.kind not in (tf_inspect.Parameter.POSITIONAL_ONLY, 630 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD): 631 raise ValueError("Dispatch currently only supports type annotations " 632 "for positional parameters; can't handle annotation " 633 f"for {param.kind!r} parameter {param_name}.") 634 635 checker = make_type_checker(param_type) 636 index = param_names.index(param_name) 637 checkers.append((index, checker)) 638 639 return _api_dispatcher.PySignatureChecker(checkers) 640 641 642# Cache for InstanceTypeChecker objects (we only want to create one 643# InstanceTypeChecker for each type, since each one uses an internal cache 644# to avoid repeated calls back into Python's isinstance). 645_is_instance_checker_cache = {} 646 647 648def make_type_checker(annotation): 649 """Builds a PyTypeChecker for the given type annotation.""" 650 if type_annotations.is_generic_union(annotation): 651 type_args = type_annotations.get_generic_type_args(annotation) 652 653 # If the union contains two or more simple types, then use a single 654 # InstanceChecker to check them. 655 simple_types = [t for t in type_args if isinstance(t, type)] 656 simple_types = tuple(sorted(simple_types, key=id)) 657 if len(simple_types) > 1: 658 if simple_types not in _is_instance_checker_cache: 659 checker = _api_dispatcher.MakeInstanceChecker(*simple_types) 660 _is_instance_checker_cache[simple_types] = checker 661 options = ([_is_instance_checker_cache[simple_types]] + 662 [make_type_checker(t) for t in type_args 663 if not isinstance(t, type)]) 664 return _api_dispatcher.MakeUnionChecker(options) 665 666 options = [make_type_checker(t) for t in type_args] 667 return _api_dispatcher.MakeUnionChecker(options) 668 669 elif type_annotations.is_generic_list(annotation): 670 type_args = type_annotations.get_generic_type_args(annotation) 671 if len(type_args) != 1: 672 raise AssertionError("Expected List[...] to have a single type parameter") 673 elt_type = make_type_checker(type_args[0]) 674 return _api_dispatcher.MakeListChecker(elt_type) 675 676 elif isinstance(annotation, type): 677 if annotation not in _is_instance_checker_cache: 678 checker = _api_dispatcher.MakeInstanceChecker(annotation) 679 _is_instance_checker_cache[annotation] = checker 680 return _is_instance_checker_cache[annotation] 681 682 elif annotation is None: 683 return make_type_checker(type(None)) 684 685 else: 686 raise ValueError(f"Type annotation {annotation} is not currently supported" 687 " by dispatch. Supported annotations: type objects, " 688 " List[...], and Union[...]") 689 690 691def _signature_from_annotations(func): 692 """Builds a dict mapping from parameter names to type annotations.""" 693 func_signature = tf_inspect.signature(func) 694 695 signature = dict([(name, param.annotation) 696 for (name, param) in func_signature.parameters.items() 697 if param.annotation != tf_inspect.Parameter.empty]) 698 if not signature: 699 raise ValueError("The dispatch_for_api decorator must be called with at " 700 "least one signature, or applied to a function that " 701 "has type annotations on its parameters.") 702 return signature 703 704 705# Registries for elementwise APIs and API handlers. 706# 707# _*_ELEMENTWISE_APIS: A list of TensorFlow APIs that have been registered 708# as elementwise operations using the `register_*_elementwise_api` 709# decorators. 710# 711# _ELEMENTWISE_API_HANDLERS: Dicts mapping from argument type(s) to API 712# handlers that have been registered with the `dispatch_for_*_elementwise_apis` 713# decorators. 714# 715# _ELEMENTWISE_API_TARGETS: Dict mapping from argument type(s) to lists of 716# `(api, dispatch_target)` pairs. Used to impelement 717# `unregister_elementwise_api_handler`. 718_UNARY_ELEMENTWISE_APIS = [] 719_BINARY_ELEMENTWISE_APIS = [] 720_BINARY_ELEMENTWISE_ASSERT_APIS = [] 721_ELEMENTWISE_API_HANDLERS = {} 722_ELEMENTWISE_API_TARGETS = {} 723 724_ASSERT_API_TAG = "ASSERT_API_TAG" 725 726 727@tf_export("experimental.dispatch_for_unary_elementwise_apis") 728def dispatch_for_unary_elementwise_apis(x_type): 729 """Decorator to override default implementation for unary elementwise APIs. 730 731 The decorated function (known as the "elementwise api handler") overrides 732 the default implementation for any unary elementwise API whenever the value 733 for the first argument (typically named `x`) matches the type annotation 734 `x_type`. The elementwise api handler is called with two arguments: 735 736 `elementwise_api_handler(api_func, x)` 737 738 Where `api_func` is a function that takes a single parameter and performs the 739 elementwise operation (e.g., `tf.abs`), and `x` is the first argument to the 740 elementwise api. 741 742 The following example shows how this decorator can be used to update all 743 unary elementwise operations to handle a `MaskedTensor` type: 744 745 >>> class MaskedTensor(tf.experimental.ExtensionType): 746 ... values: tf.Tensor 747 ... mask: tf.Tensor 748 >>> @dispatch_for_unary_elementwise_apis(MaskedTensor) 749 ... def unary_elementwise_api_handler(api_func, x): 750 ... return MaskedTensor(api_func(x.values), x.mask) 751 >>> mt = MaskedTensor([1, -2, -3], [True, False, True]) 752 >>> abs_mt = tf.abs(mt) 753 >>> print(f"values={abs_mt.values.numpy()}, mask={abs_mt.mask.numpy()}") 754 values=[1 2 3], mask=[ True False True] 755 756 For unary elementwise operations that take extra arguments beyond `x`, those 757 arguments are *not* passed to the elementwise api handler, but are 758 automatically added when `api_func` is called. E.g., in the following 759 example, the `dtype` parameter is not passed to 760 `unary_elementwise_api_handler`, but is added by `api_func`. 761 762 >>> ones_mt = tf.ones_like(mt, dtype=tf.float32) 763 >>> print(f"values={ones_mt.values.numpy()}, mask={ones_mt.mask.numpy()}") 764 values=[1.0 1.0 1.0], mask=[ True False True] 765 766 Args: 767 x_type: A type annotation indicating when the api handler should be called. 768 See `dispatch_for_api` for a list of supported annotation types. 769 770 Returns: 771 A decorator. 772 773 #### Registered APIs 774 775 The unary elementwise APIs are: 776 777 <<API_LIST>> 778 """ 779 780 def decorator(handler): 781 if (x_type,) in _ELEMENTWISE_API_HANDLERS: 782 raise ValueError("A unary elementwise dispatch handler " 783 f"({_ELEMENTWISE_API_HANDLERS[(x_type,)]}) " 784 f"has already been registered for {x_type}.") 785 _ELEMENTWISE_API_HANDLERS[(x_type,)] = handler 786 for api in _UNARY_ELEMENTWISE_APIS: 787 _add_dispatch_for_unary_elementwise_api(api, x_type, handler) 788 789 return handler 790 791 return decorator 792 793 794@tf_export("experimental.dispatch_for_binary_elementwise_apis") 795def dispatch_for_binary_elementwise_apis(x_type, y_type): 796 """Decorator to override default implementation for binary elementwise APIs. 797 798 The decorated function (known as the "elementwise api handler") overrides 799 the default implementation for any binary elementwise API whenever the value 800 for the first two arguments (typically named `x` and `y`) match the specified 801 type annotations. The elementwise api handler is called with two arguments: 802 803 `elementwise_api_handler(api_func, x, y)` 804 805 Where `x` and `y` are the first two arguments to the elementwise api, and 806 `api_func` is a TensorFlow function that takes two parameters and performs the 807 elementwise operation (e.g., `tf.add`). 808 809 The following example shows how this decorator can be used to update all 810 binary elementwise operations to handle a `MaskedTensor` type: 811 812 >>> class MaskedTensor(tf.experimental.ExtensionType): 813 ... values: tf.Tensor 814 ... mask: tf.Tensor 815 >>> @dispatch_for_binary_elementwise_apis(MaskedTensor, MaskedTensor) 816 ... def binary_elementwise_api_handler(api_func, x, y): 817 ... return MaskedTensor(api_func(x.values, y.values), x.mask & y.mask) 818 >>> a = MaskedTensor([1, 2, 3, 4, 5], [True, True, True, True, False]) 819 >>> b = MaskedTensor([2, 4, 6, 8, 0], [True, True, True, False, True]) 820 >>> c = tf.add(a, b) 821 >>> print(f"values={c.values.numpy()}, mask={c.mask.numpy()}") 822 values=[ 3 6 9 12 5], mask=[ True True True False False] 823 824 Args: 825 x_type: A type annotation indicating when the api handler should be called. 826 y_type: A type annotation indicating when the api handler should be called. 827 828 Returns: 829 A decorator. 830 831 #### Registered APIs 832 833 The binary elementwise APIs are: 834 835 <<API_LIST>> 836 """ 837 838 def decorator(handler): 839 if (x_type, y_type) in _ELEMENTWISE_API_HANDLERS: 840 raise ValueError("A binary elementwise dispatch handler " 841 f"({_ELEMENTWISE_API_HANDLERS[x_type, y_type]}) " 842 f"has already been registered for ({x_type}, {y_type}).") 843 _ELEMENTWISE_API_HANDLERS[x_type, y_type] = handler 844 for api in _BINARY_ELEMENTWISE_APIS: 845 _add_dispatch_for_binary_elementwise_api(api, x_type, y_type, handler) 846 847 return handler 848 849 return decorator 850 851 852@tf_export("experimental.dispatch_for_binary_elementwise_assert_apis") 853def dispatch_for_binary_elementwise_assert_apis(x_type, y_type): 854 """Decorator to override default implementation for binary elementwise assert APIs. 855 856 The decorated function (known as the "elementwise assert handler") 857 overrides the default implementation for any binary elementwise assert API 858 whenever the value for the first two arguments (typically named `x` and `y`) 859 match the specified type annotations. The handler is called with two 860 arguments: 861 862 `elementwise_assert_handler(assert_func, x, y)` 863 864 Where `x` and `y` are the first two arguments to the binary elementwise assert 865 operation, and `assert_func` is a TensorFlow function that takes two 866 parameters and performs the elementwise assert operation (e.g., 867 `tf.debugging.assert_equal`). 868 869 The following example shows how this decorator can be used to update all 870 binary elementwise assert operations to handle a `MaskedTensor` type: 871 872 >>> class MaskedTensor(tf.experimental.ExtensionType): 873 ... values: tf.Tensor 874 ... mask: tf.Tensor 875 >>> @dispatch_for_binary_elementwise_assert_apis(MaskedTensor, MaskedTensor) 876 ... def binary_elementwise_assert_api_handler(assert_func, x, y): 877 ... merged_mask = tf.logical_and(x.mask, y.mask) 878 ... selected_x_values = tf.boolean_mask(x.values, merged_mask) 879 ... selected_y_values = tf.boolean_mask(y.values, merged_mask) 880 ... assert_func(selected_x_values, selected_y_values) 881 >>> a = MaskedTensor([1, 1, 0, 1, 1], [False, False, True, True, True]) 882 >>> b = MaskedTensor([2, 2, 0, 2, 2], [True, True, True, False, False]) 883 >>> tf.debugging.assert_equal(a, b) # assert passed; no exception was thrown 884 885 >>> a = MaskedTensor([1, 1, 1, 1, 1], [True, True, True, True, True]) 886 >>> b = MaskedTensor([0, 0, 0, 0, 2], [True, True, True, True, True]) 887 >>> tf.debugging.assert_greater(a, b) 888 Traceback (most recent call last): 889 ... 890 InvalidArgumentError: Condition x > y did not hold. 891 892 Args: 893 x_type: A type annotation indicating when the api handler should be called. 894 y_type: A type annotation indicating when the api handler should be called. 895 896 Returns: 897 A decorator. 898 899 #### Registered APIs 900 901 The binary elementwise assert APIs are: 902 903 <<API_LIST>> 904 """ 905 906 def decorator(handler): 907 api_handler_key = (x_type, y_type, _ASSERT_API_TAG) 908 if api_handler_key in _ELEMENTWISE_API_HANDLERS: 909 raise ValueError("A binary elementwise assert dispatch handler " 910 f"({_ELEMENTWISE_API_HANDLERS[api_handler_key]}) " 911 f"has already been registered for ({x_type}, {y_type}).") 912 _ELEMENTWISE_API_HANDLERS[api_handler_key] = handler 913 for api in _BINARY_ELEMENTWISE_ASSERT_APIS: 914 _add_dispatch_for_binary_elementwise_api(api, x_type, y_type, handler) 915 916 return handler 917 918 return decorator 919 920 921def register_unary_elementwise_api(func): 922 """Decorator that registers a TensorFlow op as a unary elementwise API.""" 923 _UNARY_ELEMENTWISE_APIS.append(func) 924 for args, handler in _ELEMENTWISE_API_HANDLERS.items(): 925 if len(args) == 1: 926 _add_dispatch_for_unary_elementwise_api(func, args[0], handler) 927 return func 928 929 930def register_binary_elementwise_api(func): 931 """Decorator that registers a TensorFlow op as a binary elementwise API.""" 932 _BINARY_ELEMENTWISE_APIS.append(func) 933 for args, handler in _ELEMENTWISE_API_HANDLERS.items(): 934 if len(args) == 2: 935 _add_dispatch_for_binary_elementwise_api(func, args[0], args[1], handler) 936 return func 937 938 939def register_binary_elementwise_assert_api(func): 940 """Decorator that registers a TensorFlow op as a binary elementwise assert API. 941 942 Different from `dispatch_for_binary_elementwise_apis`, this decorator is used 943 for assert apis, such as assert_equal, assert_none_equal, etc, which return 944 None in eager mode and an op in graph mode. 945 946 Args: 947 func: The function that implements the binary elementwise assert API. 948 949 Returns: 950 `func` 951 """ 952 _BINARY_ELEMENTWISE_ASSERT_APIS.append(func) 953 for args, handler in _ELEMENTWISE_API_HANDLERS.items(): 954 if len(args) == 3 and args[2] is _ASSERT_API_TAG: 955 _add_dispatch_for_binary_elementwise_api(func, args[0], args[1], handler) 956 return func 957 958 959def unary_elementwise_apis(): 960 """Returns a list of APIs that have been registered as unary elementwise.""" 961 return tuple(_UNARY_ELEMENTWISE_APIS) 962 963 964def binary_elementwise_apis(): 965 """Returns a list of APIs that have been registered as binary elementwise.""" 966 return tuple(_BINARY_ELEMENTWISE_APIS) 967 968 969def _add_dispatch_for_unary_elementwise_api(api, x_type, 970 elementwise_api_handler): 971 """Registers a unary elementwise handler as a dispatcher for a given API.""" 972 api_signature = tf_inspect.signature(api) 973 x_name = list(api_signature.parameters)[0] 974 name_index = _find_name_index(api_signature) 975 976 need_to_bind_api_args = ( 977 len(api_signature.parameters) > 2 or 978 "name" not in api_signature.parameters) 979 980 @dispatch_for_api(api, {x_name: x_type}) 981 def dispatch_target(*args, **kwargs): 982 args, kwargs, name = _extract_name_arg(args, kwargs, name_index) 983 if args: 984 x, args = args[0], args[1:] 985 else: 986 x = kwargs.pop(x_name) 987 988 if need_to_bind_api_args: 989 tensor_api = lambda v: api(v, *args, **kwargs) 990 else: 991 tensor_api = api 992 993 if name is None: 994 return elementwise_api_handler(tensor_api, x) 995 else: 996 with ops.name_scope(name, None, [x]): 997 return elementwise_api_handler(tensor_api, x) 998 999 dispatch_target.__name__ = "elementwise_dispatch_target_for_" + api.__name__ 1000 dispatch_target.__qualname__ = dispatch_target.__name__ 1001 # Keep track of what targets we've registered (so we can unregister them). 1002 target_list = _ELEMENTWISE_API_TARGETS.setdefault((x_type,), []) 1003 target_list.append((api, dispatch_target)) 1004 1005 1006def _add_dispatch_for_binary_elementwise_api(api, x_type, y_type, 1007 elementwise_api_handler): 1008 """Registers a binary elementwise handler as a dispatcher for a given API.""" 1009 api_signature = tf_inspect.signature(api) 1010 x_name, y_name = list(api_signature.parameters)[:2] 1011 name_index = _find_name_index(api_signature) 1012 1013 need_to_bind_api_args = (len(api_signature.parameters) > 3 or 1014 "name" not in api_signature.parameters) 1015 1016 @dispatch_for_api(api, {x_name: x_type, y_name: y_type}) 1017 def dispatch_target(*args, **kwargs): 1018 args, kwargs, name = _extract_name_arg(args, kwargs, name_index) 1019 if len(args) > 1: 1020 x, y, args = args[0], args[1], args[2:] 1021 elif args: 1022 x, args = args[0], args[1:] 1023 y = kwargs.pop(y_name, None) 1024 else: 1025 x = kwargs.pop(x_name, None) 1026 y = kwargs.pop(y_name, None) 1027 1028 if need_to_bind_api_args: 1029 tensor_api = lambda v1, v2: api(v1, v2, *args, **kwargs) 1030 else: 1031 tensor_api = api 1032 1033 if name is None: 1034 return elementwise_api_handler(tensor_api, x, y) 1035 else: 1036 with ops.name_scope(name, None, [x, y]): 1037 return elementwise_api_handler(tensor_api, x, y) 1038 1039 dispatch_target.__name__ = "elementwise_dispatch_target_for_" + api.__name__ 1040 dispatch_target.__qualname__ = dispatch_target.__name__ 1041 # Keep track of what targets we've registered (so we can unregister them). 1042 target_list = _ELEMENTWISE_API_TARGETS.setdefault((x_type, y_type), []) 1043 target_list.append((api, dispatch_target)) 1044 1045 1046def _find_name_index(signature): 1047 """Returns the index of the `name` parameter, or -1 if it's not present.""" 1048 try: 1049 return list(signature.parameters).index("name") 1050 except ValueError: 1051 return -1 1052 1053 1054def _extract_name_arg(args, kwargs, name_index): 1055 """Extracts the parameter `name` and returns `(args, kwargs, name_value)`.""" 1056 if name_index < 0: 1057 name_value = None 1058 elif name_index < len(args): 1059 name_value = args[name_index] 1060 args = args[:name_index] + args[name_index + 1:] 1061 else: 1062 name_value = kwargs.pop("name", None) 1063 return args, kwargs, name_value 1064 1065 1066def update_docstrings_with_api_lists(): 1067 """Updates the docstrings of dispatch decorators with API lists. 1068 1069 Updates docstrings for `dispatch_for_api`, 1070 `dispatch_for_unary_elementwise_apis`, and 1071 `dispatch_for_binary_elementwise_apis`, by replacing the string '<<API_LIST>>' 1072 with a list of APIs that have been registered for that decorator. 1073 """ 1074 _update_docstring_with_api_list(dispatch_for_unary_elementwise_apis, 1075 _UNARY_ELEMENTWISE_APIS) 1076 _update_docstring_with_api_list(dispatch_for_binary_elementwise_apis, 1077 _BINARY_ELEMENTWISE_APIS) 1078 _update_docstring_with_api_list(dispatch_for_binary_elementwise_assert_apis, 1079 _BINARY_ELEMENTWISE_ASSERT_APIS) 1080 _update_docstring_with_api_list(dispatch_for_api, 1081 _TYPE_BASED_DISPATCH_SIGNATURES) 1082 1083 1084def _update_docstring_with_api_list(target, api_list): 1085 """Replaces `<<API_LIST>>` in target.__doc__ with the given list of APIs.""" 1086 lines = [] 1087 for func in api_list: 1088 name = tf_export_lib.get_canonical_name_for_symbol( 1089 func, add_prefix_to_v1_names=True) 1090 if name is not None: 1091 params = tf_inspect.signature(func).parameters.keys() 1092 lines.append(f" * `tf.{name}({', '.join(params)})`") 1093 lines.sort() 1094 target.__doc__ = target.__doc__.replace(" <<API_LIST>>", "\n".join(lines)) 1095 1096 1097################################################################################ 1098# Dispatch Support 1099################################################################################ 1100@tf_export("__internal__.dispatch.add_dispatch_support", v1=[]) 1101def add_dispatch_support(target=None, iterable_parameters=None): 1102 """Decorator that adds a dispatch handling wrapper to a TensorFlow Python API. 1103 1104 This wrapper adds the decorated function as an API that can be overridden 1105 using the `@dispatch_for_api` decorator. In the following example, we first 1106 define a new API (`double`) that supports dispatch, then define a custom type 1107 (`MaskedTensor`) and finally use `dispatch_for_api` to override the default 1108 implementation of `double` when called with `MaskedTensor` values: 1109 1110 >>> @add_dispatch_support 1111 ... def double(x): 1112 ... return x * 2 1113 >>> class MaskedTensor(tf.experimental.ExtensionType): 1114 ... values: tf.Tensor 1115 ... mask: tf.Tensor 1116 >>> @dispatch_for_api(double, {'x': MaskedTensor}) 1117 ... def masked_double(x): 1118 ... return MaskedTensor(x.values * 2, y.mask) 1119 1120 The optional `iterable_parameter` argument can be used to mark parameters that 1121 can take arbitrary iterable values (such as generator expressions). These 1122 need to be handled specially during dispatch, since just iterating over an 1123 iterable uses up its values. In the following example, we define a new API 1124 whose second argument can be an iterable value; and then override the default 1125 implementatio of that API when the iterable contains MaskedTensors: 1126 1127 >>> @add_dispatch_support(iterable_parameters=['ys']) 1128 ... def add_tensor_to_list_of_tensors(x, ys): 1129 ... return [x + y for y in ys] 1130 >>> @dispatch_for_api(add_tensor_to_list_of_tensors, 1131 ... {'ys': typing.List[MaskedTensor]}) 1132 ... def masked_add_tensor_to_list_of_tensors(x, ys): 1133 ... return [MaskedTensor(x+y.values, y.mask) for y in ys] 1134 1135 (Note: the only TensorFlow API that currently supports iterables is `add_n`.) 1136 1137 Args: 1138 target: The TensorFlow API that should support dispatch. 1139 iterable_parameters: Optional list of parameter names that may be called 1140 with iterables (such as the `inputs` parameter for `tf.add_n`). 1141 1142 Returns: 1143 A decorator. 1144 """ 1145 1146 if not (iterable_parameters is None or 1147 (isinstance(iterable_parameters, (list, tuple)) and 1148 all(isinstance(p, str) for p in iterable_parameters))): 1149 raise TypeError("iterable_parameters should be a list or tuple of string.") 1150 1151 def decorator(dispatch_target): 1152 1153 # Get the name & index for each iterable parameter. 1154 if iterable_parameters is None: 1155 iterable_params = None 1156 else: 1157 arg_names = tf_inspect.getargspec(dispatch_target).args 1158 iterable_params = [ 1159 (name, arg_names.index(name)) for name in iterable_parameters 1160 ] 1161 1162 @traceback_utils.filter_traceback 1163 def op_dispatch_handler(*args, **kwargs): 1164 """Call `dispatch_target`, peforming dispatch when appropriate.""" 1165 1166 # Type-based dispatch system (dispatch v2): 1167 if api_dispatcher is not None: 1168 if iterable_params is not None: 1169 args, kwargs = replace_iterable_params(args, kwargs, iterable_params) 1170 result = api_dispatcher.Dispatch(args, kwargs) 1171 if result is not NotImplemented: 1172 return result 1173 1174 # Fallback dispatch system (dispatch v1): 1175 try: 1176 return dispatch_target(*args, **kwargs) 1177 except (TypeError, ValueError): 1178 # Note: convert_to_eager_tensor currently raises a ValueError, not a 1179 # TypeError, when given unexpected types. So we need to catch both. 1180 result = dispatch(op_dispatch_handler, args, kwargs) 1181 if result is not OpDispatcher.NOT_SUPPORTED: 1182 return result 1183 else: 1184 raise 1185 1186 add_fallback_dispatch_list(op_dispatch_handler) 1187 op_dispatch_handler = tf_decorator.make_decorator(dispatch_target, 1188 op_dispatch_handler) 1189 add_type_based_api_dispatcher(op_dispatch_handler) 1190 api_dispatcher = getattr(op_dispatch_handler, TYPE_BASED_DISPATCH_ATTR, 1191 None) 1192 return op_dispatch_handler 1193 1194 if target is None: 1195 return decorator 1196 else: 1197 return decorator(target) 1198 1199 1200def replace_iterable_params(args, kwargs, iterable_params): 1201 """Returns (args, kwargs) with any iterable parameters converted to lists. 1202 1203 Args: 1204 args: Positional rguments to a function 1205 kwargs: Keyword arguments to a function. 1206 iterable_params: A list of (name, index) tuples for iterable parameters. 1207 1208 Returns: 1209 A tuple (args, kwargs), where any positional or keyword parameters in 1210 `iterable_params` have their value converted to a `list`. 1211 """ 1212 args = list(args) 1213 for name, index in iterable_params: 1214 if index < len(args): 1215 args[index] = list(args[index]) 1216 elif name in kwargs: 1217 kwargs[name] = list(kwargs[name]) 1218 return tuple(args), kwargs 1219