1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Operator dispatch for RaggedTensors.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import numpy as np 23 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import sparse_tensor 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import clip_ops 29from tensorflow.python.ops import gen_bitwise_ops 30from tensorflow.python.ops import math_ops 31from tensorflow.python.ops import parsing_ops 32from tensorflow.python.ops import string_ops 33from tensorflow.python.ops import variables 34from tensorflow.python.ops.ragged import ragged_array_ops 35from tensorflow.python.ops.ragged import ragged_batch_gather_ops 36from tensorflow.python.ops.ragged import ragged_concat_ops 37from tensorflow.python.ops.ragged import ragged_gather_ops 38from tensorflow.python.ops.ragged import ragged_math_ops 39from tensorflow.python.ops.ragged import ragged_tensor 40from tensorflow.python.ops.ragged import ragged_tensor_shape 41from tensorflow.python.ops.ragged import ragged_util 42from tensorflow.python.ops.ragged import ragged_where_op 43from tensorflow.python.util import dispatch 44from tensorflow.python.util import tf_decorator 45from tensorflow.python.util import tf_export 46from tensorflow.python.util import tf_inspect 47 48# @TODO(edloper): Set this to True in the CL that exports RaggedTensors. 49_UPDATE_DOCSTRINGS = False 50 51# Information about an argument to an operation: The name of the argument, its 52# position in the argument list, and a boolean flag indicating whether it 53# expects a list of tensors. 54_ArgInfo = collections.namedtuple('ArgInfo', ['name', 'position', 'is_list']) 55 56 57def _get_arg_infos(func, arg_names): 58 """Returns an `_ArgInfo` for each argument of `func` specified by `arg_names`. 59 60 Args: 61 func: The function whose arguments should be described. 62 arg_names: The names of the arguments to get info for. 63 64 Returns: 65 A tuple of `_ArgInfo`s. 66 """ 67 arg_infos = [] 68 69 # Inspect the func's argspec to find the position of each arg. 70 arg_spec = tf_inspect.getargspec(func) 71 for argname in arg_names: 72 assert isinstance(argname, str) 73 is_list = argname.startswith('[') and argname.endswith(']') 74 if is_list: 75 argname = argname[1:-1] 76 if argname not in arg_spec.args: 77 raise ValueError('Argument %r not found function in %s. Args=%s' % 78 (argname, func, arg_spec.args)) 79 arg_infos.append(_ArgInfo(argname, arg_spec.args.index(argname), is_list)) 80 return arg_infos 81 82 83def _is_convertible_to_tensor(value): 84 """Returns true if `value` is convertible to a `Tensor`.""" 85 if value is None: 86 return True 87 if isinstance(value, 88 (ops.Tensor, variables.Variable, np.ndarray, int, float, str)): 89 return True 90 elif isinstance(value, (sparse_tensor.SparseTensor,)): 91 return False 92 else: 93 try: 94 ops.convert_to_tensor(value) 95 return True 96 except (TypeError, ValueError): 97 return False 98 99 100class UnaryRaggedElementwiseDispatcher(dispatch.OpDispatcher): 101 """OpDispatcher for unary ops that map a base op across ragged values.""" 102 103 def __init__(self, original_op, arg_is_list=False): 104 self._original_op = original_op 105 self._arg_is_list = arg_is_list 106 arg_names = tf_inspect.getfullargspec(original_op)[0] 107 self._x = arg_names[0] 108 if _UPDATE_DOCSTRINGS: 109 original_op.__doc__ = ( 110 original_op.__doc__.rstrip() + '\n\n' + 111 ' `{x}` may be a `tf.RaggedTensor`.\n'.format(x=self._x)) 112 113 def handle(self, args, kwargs): 114 if args: 115 x, args = args[0], args[1:] 116 else: 117 kwargs = kwargs.copy() 118 x = kwargs.pop(self._x, None) 119 if x is None: 120 return self.NOT_SUPPORTED 121 if self._arg_is_list: 122 found_ragged = False 123 for elt in x: 124 if ragged_tensor.is_ragged(elt): 125 found_ragged = True 126 elif not _is_convertible_to_tensor(elt): 127 return self.NOT_SUPPORTED 128 if found_ragged: 129 nested_splits_lists = [ 130 elt.nested_row_splits for elt in x if ragged_tensor.is_ragged(elt) 131 ] 132 flat_values = [ 133 elt.flat_values if ragged_tensor.is_ragged(elt) else elt 134 for elt in x 135 ] 136 with ops.control_dependencies( 137 ragged_util.assert_splits_match(nested_splits_lists)): 138 return ragged_tensor.RaggedTensor.from_nested_row_splits( 139 self._original_op(flat_values, *args, **kwargs), 140 nested_splits_lists[0]) 141 else: 142 return self.NOT_SUPPORTED 143 else: 144 found_ragged = ragged_tensor.is_ragged(x) 145 if found_ragged: 146 mapped_values = self._original_op(x.flat_values, *args, **kwargs) 147 return x.with_flat_values(mapped_values) 148 else: 149 return self.NOT_SUPPORTED 150 151 152class BinaryRaggedElementwiseDispatcher(dispatch.OpDispatcher): 153 """OpDispatcher for binary ops that map a base op across ragged values. 154 155 Supports broadcasting. 156 """ 157 158 def __init__(self, original_op): 159 self._original_op = original_op 160 arg_names = tf_inspect.getfullargspec(original_op)[0] 161 self._x = arg_names[0] 162 self._y = arg_names[1] 163 if _UPDATE_DOCSTRINGS: 164 original_op.__doc__ = ( 165 original_op.__doc__.rstrip() + '\n\n' + 166 ' `{x}` and `{y}` may be a `tf.RaggedTensor`.\n'.format( 167 x=self._x, y=self._y)) 168 169 def handle(self, args, kwargs): 170 # Extract the binary args. 171 if len(args) > 1: 172 x = args[0] 173 y = args[1] 174 args = args[2:] 175 elif args: 176 kwargs = kwargs.copy() 177 x = args[0] 178 y = kwargs.pop(self._y, None) 179 args = args[1:] 180 else: 181 kwargs = kwargs.copy() 182 x = kwargs.pop(self._x, None) 183 y = kwargs.pop(self._y, None) 184 185 # Bail if we don't have at least one ragged argument. 186 x_is_ragged = ragged_tensor.is_ragged(x) 187 y_is_ragged = ragged_tensor.is_ragged(y) 188 if not (x_is_ragged or y_is_ragged): 189 return self.NOT_SUPPORTED 190 191 # Convert args to tensors. Bail if conversion fails. 192 try: 193 if not x_is_ragged: 194 x = ops.convert_to_tensor(x, name=self._x, preferred_dtype=y.dtype) 195 if not y_is_ragged: 196 y = ops.convert_to_tensor(y, name=self._y, preferred_dtype=x.dtype) 197 except (TypeError, ValueError): 198 return self.NOT_SUPPORTED 199 200 if ((x_is_ragged and y_is_ragged) or 201 (x_is_ragged and x.flat_values.shape.ndims <= y.shape.ndims) or 202 (y_is_ragged and y.flat_values.shape.ndims <= x.shape.ndims)): 203 bcast_shape = ragged_tensor_shape.broadcast_dynamic_shape( 204 ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(x), 205 ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(y)) 206 x = ragged_tensor_shape.broadcast_to( 207 x, bcast_shape, broadcast_inner_dimensions=False) 208 y = ragged_tensor_shape.broadcast_to( 209 y, bcast_shape, broadcast_inner_dimensions=False) 210 211 x_values = x.flat_values if ragged_tensor.is_ragged(x) else x 212 y_values = y.flat_values if ragged_tensor.is_ragged(y) else y 213 mapped_values = self._original_op(x_values, y_values, *args, **kwargs) 214 if ragged_tensor.is_ragged(x): 215 return x.with_flat_values(mapped_values) 216 else: 217 return y.with_flat_values(mapped_values) 218 219 220class RaggedDispatcher(dispatch.OpDispatcher): 221 """OpDispatcher for ragged ops. 222 223 Dispatches to a wrapped op-handler if at least one of the `tensor_args` 224 arguments is a RaggedTensor or a RaggedTensorValue; and all of the 225 `tensor_args` arguments are convertible to Tensor or RaggedTensor. 226 """ 227 228 def __init__(self, original_op, ragged_op, ragged_args): 229 op_arg_names = tf_inspect.getfullargspec(original_op)[0] 230 ragged_arg_names = tf_inspect.getfullargspec(ragged_op)[0] 231 if op_arg_names != ragged_arg_names: 232 raise AssertionError( 233 'Signature must exactly match when overriding %s with %s: %s vs %s' % 234 (original_op, ragged_op, op_arg_names, ragged_arg_names)) 235 self._ragged_op = ragged_op 236 self._ragged_args = _get_arg_infos(ragged_op, ragged_args) 237 if _UPDATE_DOCSTRINGS: 238 arg_list = ' and '.join('`%s`' % arg for arg in ragged_args) 239 original_op.__doc__ = ( 240 original_op.__doc__.rstrip() + '\n\n' + 241 ' {0} may be a `tf.RaggedTensor`.\n'.format(arg_list)) 242 243 def handle(self, args, kwargs): 244 if self.is_supported(args, kwargs): 245 return self._ragged_op(*args, **kwargs) 246 else: 247 return self.NOT_SUPPORTED 248 249 def is_supported(self, args, kwargs): 250 found_ragged = False 251 for arg_info in self._ragged_args: 252 if arg_info.position < len(args): 253 arg = args[arg_info.position] 254 else: 255 arg = kwargs.get(arg_info.name, None) 256 257 if arg_info.is_list: 258 if not isinstance(arg, (list, tuple)): 259 return False 260 for elt in arg: 261 if ragged_tensor.is_ragged(elt): 262 found_ragged = True 263 elif not _is_convertible_to_tensor(elt): 264 return False 265 else: 266 if ragged_tensor.is_ragged(arg): 267 found_ragged = True 268 elif not _is_convertible_to_tensor(arg): 269 return False 270 return found_ragged 271 272 273def ragged_dispatch(original_op, tensor_args): 274 275 def decorator(ragged_op): 276 dispatch.RaggedDispatcher(original_op, ragged_op, 277 tensor_args).register(original_op) 278 return ragged_op 279 280 return decorator 281 282 283_UNARY_ELEMENTWISE_OPS = [ 284 array_ops.check_numerics, 285 array_ops.identity, 286 array_ops.ones_like, 287 array_ops.ones_like_v2, 288 array_ops.zeros_like, 289 array_ops.zeros_like_v2, 290 clip_ops.clip_by_value, 291 gen_bitwise_ops.invert, 292 math_ops.abs, 293 math_ops.acos, 294 math_ops.acosh, 295 math_ops.angle, 296 math_ops.asin, 297 math_ops.asinh, 298 math_ops.atan, 299 math_ops.atanh, 300 math_ops.cast, 301 math_ops.ceil, 302 math_ops.conj, 303 math_ops.cos, 304 math_ops.cosh, 305 math_ops.digamma, 306 math_ops.erf, 307 math_ops.erfc, 308 math_ops.exp, 309 math_ops.expm1, 310 math_ops.floor, 311 math_ops.imag, 312 math_ops.is_finite, 313 math_ops.is_inf, 314 math_ops.is_nan, 315 math_ops.lgamma, 316 math_ops.log, 317 math_ops.log1p, 318 math_ops.log_sigmoid, 319 math_ops.logical_not, 320 math_ops.negative, 321 math_ops.real, 322 math_ops.reciprocal, 323 math_ops.rint, 324 math_ops.round, 325 math_ops.rsqrt, 326 math_ops.saturate_cast, 327 math_ops.sign, 328 math_ops.sin, 329 math_ops.sinh, 330 math_ops.sqrt, 331 math_ops.square, 332 math_ops.tan, 333 parsing_ops.decode_compressed, 334 string_ops.string_to_number, 335 string_ops.string_to_hash_bucket, 336 string_ops.as_string, 337 string_ops.decode_base64, 338 string_ops.encode_base64, 339 string_ops.regex_full_match, 340 string_ops.regex_replace, 341 string_ops.string_strip, 342 string_ops.string_to_hash_bucket, 343 string_ops.string_to_hash_bucket_fast, 344 string_ops.string_to_hash_bucket_strong, 345 string_ops.substr, 346 string_ops.substr_v2, 347 string_ops.string_length, 348 string_ops.string_length_v2, 349 string_ops.unicode_script, 350] 351 352_UNARY_LIST_ELEMENTWISE_OPS = [ 353 math_ops.add_n, 354 string_ops.string_join, 355] 356 357_BINARY_ELEMENTWISE_OPS = [ 358 gen_bitwise_ops.bitwise_and, 359 gen_bitwise_ops.bitwise_or, 360 gen_bitwise_ops.bitwise_xor, 361 gen_bitwise_ops.left_shift, 362 gen_bitwise_ops.right_shift, 363 math_ops.add, 364 math_ops.atan2, 365 math_ops.complex, 366 math_ops.div_no_nan, 367 math_ops.divide, 368 math_ops.equal, 369 math_ops.floordiv, 370 math_ops.floormod, 371 math_ops.greater, 372 math_ops.greater_equal, 373 math_ops.less, 374 math_ops.less_equal, 375 math_ops.logical_and, 376 math_ops.logical_or, 377 math_ops.logical_xor, 378 math_ops.maximum, 379 math_ops.minimum, 380 math_ops.multiply, 381 math_ops.not_equal, 382 math_ops.pow, 383 math_ops.realdiv, 384 math_ops.squared_difference, 385 math_ops.subtract, 386 math_ops.truediv, 387 math_ops.truncatediv, 388 math_ops.truncatemod, 389] 390 391 392# We don't need to register a separate delegation handler for these v1 ops, 393# since they delegate to the v2 ops (which already have a handler). But we 394# still want to include them in the ragged_op_list() output. 395_V1_OPS_THAT_DELEGATE_TO_V2_OPS = [ 396 math_ops.reduce_sum, 397 math_ops.reduce_prod, 398 math_ops.reduce_min, 399 math_ops.reduce_max, 400 math_ops.reduce_mean, 401 math_ops.reduce_any, 402 math_ops.reduce_all, 403] 404 405 406def _ragged_gather_v1(params, indices, validate_indices=None, name=None, 407 axis=0, batch_dims=0): 408 return ragged_gather_ops.gather( 409 params=params, 410 indices=indices, 411 validate_indices=validate_indices, 412 axis=axis, 413 batch_dims=batch_dims, 414 name=name) 415 416 417def _ragged_expand_dims_v1(input, axis=None, name=None, dim=None): # pylint: disable=redefined-builtin 418 if dim is not None: 419 axis = dim 420 return ragged_array_ops.expand_dims(input=input, axis=axis, name=name) 421 422 423def _ragged_size_v1(input, name=None, out_type=dtypes.int32): # pylint: disable=redefined-builtin 424 return ragged_array_ops.size(input=input, out_type=out_type, name=name) 425 426 427# (original_op, ragged_op, ragged_args) 428_RAGGED_DISPATCH_OPS = [ 429 (array_ops.batch_gather, ragged_batch_gather_ops.batch_gather, 430 ['params', 'indices']), 431 (array_ops.concat, ragged_concat_ops.concat, ['[values]']), 432 (array_ops.expand_dims, _ragged_expand_dims_v1, ['input']), 433 (array_ops.expand_dims_v2, ragged_array_ops.expand_dims, ['input']), 434 (array_ops.gather, _ragged_gather_v1, ['params', 'indices']), 435 (array_ops.gather_v2, ragged_gather_ops.gather, ['params', 'indices']), 436 (array_ops.gather_nd, ragged_gather_ops.gather_nd, ['params', 'indices']), 437 (array_ops.rank, ragged_array_ops.rank, ['input']), 438 (array_ops.size, _ragged_size_v1, ['input']), 439 (array_ops.size_v2, ragged_array_ops.size, ['input']), 440 (array_ops.stack, ragged_concat_ops.stack, ['[values]']), 441 (array_ops.tile, ragged_array_ops.tile, ['input']), 442 (array_ops.where, ragged_where_op.where, ['condition', 'x', 'y']), 443 (math_ops.unsorted_segment_sum, ragged_math_ops.segment_sum, 444 ['data', 'segment_ids']), 445 (math_ops.unsorted_segment_prod, ragged_math_ops.segment_prod, 446 ['data', 'segment_ids']), 447 (math_ops.unsorted_segment_min, ragged_math_ops.segment_min, 448 ['data', 'segment_ids']), 449 (math_ops.unsorted_segment_max, ragged_math_ops.segment_max, 450 ['data', 'segment_ids']), 451 (math_ops.unsorted_segment_mean, ragged_math_ops.segment_mean, 452 ['data', 'segment_ids']), 453 (math_ops.unsorted_segment_sqrt_n, ragged_math_ops.segment_sqrt_n, 454 ['data', 'segment_ids']), 455 (math_ops.reduce_sum, ragged_math_ops.reduce_sum, ['input_tensor']), 456 (math_ops.reduce_prod, ragged_math_ops.reduce_prod, ['input_tensor']), 457 (math_ops.reduce_min, ragged_math_ops.reduce_min, ['input_tensor']), 458 (math_ops.reduce_max, ragged_math_ops.reduce_max, ['input_tensor']), 459 (math_ops.reduce_mean, ragged_math_ops.reduce_mean, ['input_tensor']), 460 (math_ops.reduce_any, ragged_math_ops.reduce_any, ['input_tensor']), 461 (math_ops.reduce_all, ragged_math_ops.reduce_all, ['input_tensor']), 462] 463 464 465def register_dispatchers(): 466 """Constructs & registers OpDispatchers for ragged ops.""" 467 468 op_list = ( 469 _UNARY_ELEMENTWISE_OPS + _UNARY_LIST_ELEMENTWISE_OPS + 470 _BINARY_ELEMENTWISE_OPS + [x[0] for x in _RAGGED_DISPATCH_OPS]) 471 for op in op_list: 472 _, undecorated_op = tf_decorator.unwrap(op) 473 if not hasattr(undecorated_op, 474 tf_export.API_ATTRS[tf_export.TENSORFLOW_API_NAME].names): 475 raise AssertionError('Expected %s to be an exported symbol ' 476 '(while adding a RaggedTensor dispatcher)') 477 478 for op in _UNARY_ELEMENTWISE_OPS: 479 UnaryRaggedElementwiseDispatcher(op).register(op) 480 481 for op in _UNARY_LIST_ELEMENTWISE_OPS: 482 UnaryRaggedElementwiseDispatcher(op, True).register(op) 483 484 for op in _BINARY_ELEMENTWISE_OPS: 485 BinaryRaggedElementwiseDispatcher(op).register(op) 486 487 for (original_op, ragged_op, args) in _RAGGED_DISPATCH_OPS: 488 RaggedDispatcher(original_op, ragged_op, args).register(original_op) 489 490 491def _ragged_op_signature(op, ragged_args): 492 """Returns a signature for the given op, marking ragged args in bold.""" 493 op_name = tf_export.get_canonical_name_for_symbol(op) 494 argspec = tf_inspect.getfullargspec(op) 495 arg_names = argspec.args 496 497 # Mark ragged arguments in bold. 498 for pos in ragged_args: 499 arg_names[pos] = '**' + arg_names[pos] + '**' 500 501 # Add argument defaults. 502 for pos in range(-1, -len(argspec.defaults) - 1, -1): 503 arg_names[pos] += '=`{!r}`'.format(argspec.defaults[pos]) 504 505 # Add varargs and keyword args 506 if argspec.varargs: 507 arg_names.append('*' + argspec.varargs) 508 if argspec.varkw: 509 arg_names.append('**' + argspec.varkw) 510 511 return '* `tf.{}`({})'.format(op_name, ', '.join(arg_names)) 512 513 514def _op_is_in_tf_version(op, version): 515 if version == 1: 516 return (tf_export.get_v1_names(tf_decorator.unwrap(op)[1]) or 517 op in _V1_OPS_THAT_DELEGATE_TO_V2_OPS) 518 elif version == 2: 519 return tf_export.get_v2_names(tf_decorator.unwrap(op)[1]) 520 else: 521 raise ValueError('Expected version 1 or 2.') 522 523 524def ragged_op_list(tf_version=1): 525 """Returns a string listing operators that have dispathers registered.""" 526 lines = [] 527 for op in _UNARY_ELEMENTWISE_OPS + _UNARY_LIST_ELEMENTWISE_OPS: 528 if _op_is_in_tf_version(op, tf_version): 529 lines.append(_ragged_op_signature(op, [0])) 530 for op in _BINARY_ELEMENTWISE_OPS: 531 if _op_is_in_tf_version(op, tf_version): 532 lines.append(_ragged_op_signature(op, [0, 1])) 533 for op, _, ragged_args in _RAGGED_DISPATCH_OPS: 534 if _op_is_in_tf_version(op, tf_version): 535 arginfos = _get_arg_infos(op, ragged_args) 536 ragged_args = [arginfo.position for arginfo in arginfos] 537 lines.append(_ragged_op_signature(op, ragged_args)) 538 return ('\n\n### Additional ops that support `RaggedTensor`\n\n' 539 'Arguments that accept `RaggedTensor`s are marked in **bold**.\n\n' + 540 '\n'.join(sorted(lines)) + 'n') 541 542 543register_dispatchers() 544