1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Operator dispatch for RaggedTensors.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22 23import numpy as np 24 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import sparse_tensor 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import clip_ops 30from tensorflow.python.ops import data_flow_ops 31from tensorflow.python.ops import gen_bitwise_ops 32from tensorflow.python.ops import math_ops 33from tensorflow.python.ops import parsing_ops 34from tensorflow.python.ops import string_ops 35from tensorflow.python.ops import variables 36from tensorflow.python.ops.ragged import ragged_array_ops 37from tensorflow.python.ops.ragged import ragged_batch_gather_ops 38from tensorflow.python.ops.ragged import ragged_concat_ops 39from tensorflow.python.ops.ragged import ragged_gather_ops 40from tensorflow.python.ops.ragged import ragged_math_ops 41from tensorflow.python.ops.ragged import ragged_squeeze_op 42from tensorflow.python.ops.ragged import ragged_string_ops 43from tensorflow.python.ops.ragged import ragged_tensor 44from tensorflow.python.ops.ragged import ragged_tensor_shape 45from tensorflow.python.ops.ragged import ragged_util 46from tensorflow.python.ops.ragged import ragged_where_op 47from tensorflow.python.util import deprecation 48from tensorflow.python.util import dispatch 49from tensorflow.python.util import tf_decorator 50from tensorflow.python.util import tf_export 51from tensorflow.python.util import tf_inspect 52 53# @TODO(edloper): Set this to True in the CL that exports RaggedTensors. 54_UPDATE_DOCSTRINGS = False 55 56# Information about an argument to an operation: The name of the argument, its 57# position in the argument list, and a boolean flag indicating whether it 58# expects a list of tensors. 59_ArgInfo = collections.namedtuple('ArgInfo', ['name', 'position', 'is_list']) 60 61 62def _get_arg_infos(func, arg_names): 63 """Returns an `_ArgInfo` for each argument of `func` specified by `arg_names`. 64 65 Args: 66 func: The function whose arguments should be described. 67 arg_names: The names of the arguments to get info for. 68 69 Returns: 70 A tuple of `_ArgInfo`s. 71 """ 72 arg_infos = [] 73 74 # Inspect the func's argspec to find the position of each arg. 75 arg_spec = tf_inspect.getargspec(func) 76 for argname in arg_names: 77 assert isinstance(argname, str) 78 is_list = argname.startswith('[') and argname.endswith(']') 79 if is_list: 80 argname = argname[1:-1] 81 if argname not in arg_spec.args: 82 raise ValueError('Argument %r not found function in %s. Args=%s' % 83 (argname, func, arg_spec.args)) 84 arg_infos.append(_ArgInfo(argname, arg_spec.args.index(argname), is_list)) 85 return arg_infos 86 87 88def _is_convertible_to_tensor(value): 89 """Returns true if `value` is convertible to a `Tensor`.""" 90 if value is None: 91 return True 92 if isinstance(value, 93 (ops.Tensor, variables.Variable, np.ndarray, int, float, str)): 94 return True 95 elif isinstance(value, (sparse_tensor.SparseTensor,)): 96 return False 97 else: 98 try: 99 ops.convert_to_tensor(value) 100 return True 101 except (TypeError, ValueError): 102 return False 103 104 105class UnaryRaggedElementwiseDispatcher(dispatch.OpDispatcher): 106 """OpDispatcher for unary ops that map a base op across ragged values.""" 107 108 def __init__(self, original_op, arg_is_list=False): 109 self._original_op = original_op 110 self._arg_is_list = arg_is_list 111 arg_names = tf_inspect.getfullargspec(original_op)[0] 112 self._x = arg_names[0] 113 if _UPDATE_DOCSTRINGS: 114 original_op.__doc__ = ( 115 original_op.__doc__.rstrip() + '\n\n' + 116 ' `{x}` may be a `tf.RaggedTensor`.\n'.format(x=self._x)) 117 118 def handle(self, args, kwargs): 119 if args: 120 x, args = args[0], args[1:] 121 else: 122 kwargs = kwargs.copy() 123 x = kwargs.pop(self._x, None) 124 if x is None: 125 return self.NOT_SUPPORTED 126 if self._arg_is_list: 127 found_ragged = False 128 for elt in x: 129 if ragged_tensor.is_ragged(elt): 130 found_ragged = True 131 elif not _is_convertible_to_tensor(elt): 132 return self.NOT_SUPPORTED 133 if found_ragged: 134 x = ragged_tensor.match_row_splits_dtypes(*x) 135 nested_splits_lists = [ 136 elt.nested_row_splits for elt in x if ragged_tensor.is_ragged(elt) 137 ] 138 flat_values = [ 139 elt.flat_values if ragged_tensor.is_ragged(elt) else elt 140 for elt in x 141 ] 142 with ops.control_dependencies( 143 ragged_util.assert_splits_match(nested_splits_lists)): 144 return ragged_tensor.RaggedTensor.from_nested_row_splits( 145 self._original_op(flat_values, *args, **kwargs), 146 nested_splits_lists[0], validate=False) 147 else: 148 return self.NOT_SUPPORTED 149 else: 150 found_ragged = ragged_tensor.is_ragged(x) 151 if found_ragged: 152 mapped_values = self._original_op(x.flat_values, *args, **kwargs) 153 return x.with_flat_values(mapped_values) 154 else: 155 return self.NOT_SUPPORTED 156 157 158class BinaryRaggedElementwiseDispatcher(dispatch.OpDispatcher): 159 """OpDispatcher for binary ops that map a base op across ragged values. 160 161 Supports broadcasting. 162 """ 163 164 def __init__(self, original_op): 165 self._original_op = original_op 166 arg_names = tf_inspect.getfullargspec(original_op)[0] 167 self._x = arg_names[0] 168 self._y = arg_names[1] 169 if _UPDATE_DOCSTRINGS: 170 original_op.__doc__ = ( 171 original_op.__doc__.rstrip() + '\n\n' + 172 ' `{x}` and `{y}` may be a `tf.RaggedTensor`.\n'.format( 173 x=self._x, y=self._y)) 174 175 def handle(self, args, kwargs): 176 # Extract the binary args. 177 if len(args) > 1: 178 x = args[0] 179 y = args[1] 180 args = args[2:] 181 elif args: 182 kwargs = kwargs.copy() 183 x = args[0] 184 y = kwargs.pop(self._y, None) 185 args = args[1:] 186 else: 187 kwargs = kwargs.copy() 188 x = kwargs.pop(self._x, None) 189 y = kwargs.pop(self._y, None) 190 191 # Bail if we don't have at least one ragged argument. 192 x_is_ragged = ragged_tensor.is_ragged(x) 193 y_is_ragged = ragged_tensor.is_ragged(y) 194 if not (x_is_ragged or y_is_ragged): 195 return self.NOT_SUPPORTED 196 197 # Convert args to tensors. Bail if conversion fails. 198 try: 199 if not x_is_ragged: 200 x = ops.convert_to_tensor(x, name=self._x, preferred_dtype=y.dtype) 201 if not y_is_ragged: 202 y = ops.convert_to_tensor(y, name=self._y, preferred_dtype=x.dtype) 203 except (TypeError, ValueError): 204 return self.NOT_SUPPORTED 205 206 if x_is_ragged and y_is_ragged: 207 x, y = ragged_tensor.match_row_splits_dtypes(x, y) 208 209 if ((x_is_ragged and y_is_ragged) or 210 (x_is_ragged and x.flat_values.shape.ndims <= y.shape.ndims) or 211 (y_is_ragged and y.flat_values.shape.ndims <= x.shape.ndims)): 212 bcast_shape = ragged_tensor_shape.broadcast_dynamic_shape( 213 ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(x), 214 ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(y)) 215 x = ragged_tensor_shape.broadcast_to( 216 x, bcast_shape, broadcast_inner_dimensions=False) 217 y = ragged_tensor_shape.broadcast_to( 218 y, bcast_shape, broadcast_inner_dimensions=False) 219 220 x_values = x.flat_values if ragged_tensor.is_ragged(x) else x 221 y_values = y.flat_values if ragged_tensor.is_ragged(y) else y 222 mapped_values = self._original_op(x_values, y_values, *args, **kwargs) 223 if ragged_tensor.is_ragged(x): 224 return x.with_flat_values(mapped_values) 225 else: 226 return y.with_flat_values(mapped_values) 227 228 229class RaggedDispatcher(dispatch.OpDispatcher): 230 """OpDispatcher for ragged ops. 231 232 Dispatches to a wrapped op-handler if at least one of the `tensor_args` 233 arguments is a RaggedTensor or a RaggedTensorValue; and all of the 234 `tensor_args` arguments are convertible to Tensor or RaggedTensor. 235 """ 236 237 def __init__(self, original_op, ragged_op, ragged_args): 238 op_arg_names = tf_inspect.getfullargspec(original_op)[0] 239 ragged_arg_names = tf_inspect.getfullargspec(ragged_op)[0] 240 if op_arg_names != ragged_arg_names: 241 raise AssertionError( 242 'Signature must exactly match when overriding %s with %s: %s vs %s' % 243 (original_op, ragged_op, op_arg_names, ragged_arg_names)) 244 self._ragged_op = ragged_op 245 self._ragged_args = _get_arg_infos(ragged_op, ragged_args) 246 if _UPDATE_DOCSTRINGS: 247 arg_list = ' and '.join('`%s`' % arg for arg in ragged_args) 248 original_op.__doc__ = ( 249 original_op.__doc__.rstrip() + '\n\n' + 250 ' {0} may be a `tf.RaggedTensor`.\n'.format(arg_list)) 251 252 def handle(self, args, kwargs): 253 if self.is_supported(args, kwargs): 254 return self._ragged_op(*args, **kwargs) 255 else: 256 return self.NOT_SUPPORTED 257 258 def is_supported(self, args, kwargs): 259 found_ragged = False 260 for arg_info in self._ragged_args: 261 if arg_info.position < len(args): 262 arg = args[arg_info.position] 263 else: 264 arg = kwargs.get(arg_info.name, None) 265 266 if arg_info.is_list: 267 if not isinstance(arg, (list, tuple)): 268 return False 269 for elt in arg: 270 if ragged_tensor.is_ragged(elt): 271 found_ragged = True 272 elif not _is_convertible_to_tensor(elt): 273 return False 274 else: 275 if ragged_tensor.is_ragged(arg): 276 found_ragged = True 277 elif not _is_convertible_to_tensor(arg): 278 return False 279 return found_ragged 280 281 282_UNARY_ELEMENTWISE_OPS = [ 283 array_ops.check_numerics, 284 array_ops.identity, 285 array_ops.ones_like, 286 array_ops.ones_like_v2, 287 array_ops.zeros_like, 288 array_ops.zeros_like_v2, 289 clip_ops.clip_by_value, 290 gen_bitwise_ops.invert, 291 math_ops.abs, 292 math_ops.acos, 293 math_ops.acosh, 294 math_ops.angle, 295 math_ops.asin, 296 math_ops.asinh, 297 math_ops.atan, 298 math_ops.atanh, 299 math_ops.cast, 300 math_ops.ceil, 301 math_ops.conj, 302 math_ops.cos, 303 math_ops.cosh, 304 math_ops.digamma, 305 math_ops.erf, 306 math_ops.erfc, 307 math_ops.erfinv, 308 math_ops.exp, 309 math_ops.expm1, 310 math_ops.floor, 311 math_ops.imag, 312 math_ops.is_finite, 313 math_ops.is_inf, 314 math_ops.is_nan, 315 math_ops.lgamma, 316 math_ops.log, 317 math_ops.log1p, 318 math_ops.log_sigmoid, 319 math_ops.logical_not, 320 math_ops.ndtri, 321 math_ops.negative, 322 math_ops.real, 323 math_ops.reciprocal, 324 math_ops.rint, 325 math_ops.round, 326 math_ops.rsqrt, 327 math_ops.saturate_cast, 328 math_ops.sign, 329 math_ops.sin, 330 math_ops.sinh, 331 math_ops.sqrt, 332 math_ops.square, 333 math_ops.tan, 334 parsing_ops.decode_compressed, 335 string_ops.string_to_number, 336 string_ops.string_to_hash_bucket, 337 string_ops.as_string, 338 string_ops.decode_base64, 339 string_ops.encode_base64, 340 string_ops.regex_full_match, 341 string_ops.regex_replace, 342 string_ops.string_strip, 343 string_ops.string_to_hash_bucket, 344 string_ops.string_to_hash_bucket_fast, 345 string_ops.string_to_hash_bucket_strong, 346 string_ops.substr, 347 string_ops.substr_v2, 348 string_ops.string_length, 349 string_ops.string_length_v2, 350 string_ops.unicode_script, 351] 352 353_UNARY_LIST_ELEMENTWISE_OPS = [ 354 math_ops.add_n, 355 string_ops.string_join, 356] 357 358_BINARY_ELEMENTWISE_OPS = [ 359 gen_bitwise_ops.bitwise_and, 360 gen_bitwise_ops.bitwise_or, 361 gen_bitwise_ops.bitwise_xor, 362 gen_bitwise_ops.left_shift, 363 gen_bitwise_ops.right_shift, 364 math_ops.add, 365 math_ops.atan2, 366 math_ops.complex, 367 math_ops.div_no_nan, 368 math_ops.divide, 369 math_ops.equal, 370 math_ops.floordiv, 371 math_ops.floormod, 372 math_ops.greater, 373 math_ops.greater_equal, 374 math_ops.less, 375 math_ops.less_equal, 376 math_ops.logical_and, 377 math_ops.logical_or, 378 math_ops.logical_xor, 379 math_ops.maximum, 380 math_ops.minimum, 381 math_ops.multiply, 382 math_ops.not_equal, 383 math_ops.pow, 384 math_ops.realdiv, 385 math_ops.squared_difference, 386 math_ops.subtract, 387 math_ops.truediv, 388 math_ops.truncatediv, 389 math_ops.truncatemod, 390] 391 392 393# We don't need to register a separate delegation handler for these v1 ops, 394# since they delegate to the v2 ops (which already have a handler). But we 395# still want to include them in the ragged_op_list() output. 396_V2_OPS_THAT_ARE_DELEGATED_TO_FROM_V1_OPS = [ 397 math_ops.reduce_sum, 398 math_ops.reduce_prod, 399 math_ops.reduce_min, 400 math_ops.reduce_max, 401 math_ops.reduce_mean, 402 math_ops.reduce_any, 403 math_ops.reduce_all, 404 string_ops.string_to_number, 405 string_ops.string_to_hash_bucket, 406 string_ops.reduce_join_v2, 407] 408 409 410def _ragged_gather_v1(params, indices, validate_indices=None, name=None, 411 axis=0, batch_dims=0): 412 return ragged_gather_ops.gather( 413 params=params, 414 indices=indices, 415 validate_indices=validate_indices, 416 axis=axis, 417 batch_dims=batch_dims, 418 name=name) 419 420 421def _ragged_gather_nd_v1(params, indices, name=None, batch_dims=0): 422 return ragged_gather_ops.gather_nd( 423 params=params, 424 indices=indices, 425 batch_dims=batch_dims, 426 name=name) 427 428 429def _ragged_expand_dims_v1(input, axis=None, name=None, dim=None): # pylint: disable=redefined-builtin 430 if dim is not None: 431 axis = dim 432 return ragged_array_ops.expand_dims(input=input, axis=axis, name=name) 433 434 435def _ragged_size_v1(input, name=None, out_type=dtypes.int32): # pylint: disable=redefined-builtin 436 return ragged_array_ops.size(input=input, out_type=out_type, name=name) 437 438 439def _ragged_squeeze_v1(input, axis=None, name=None, squeeze_dims=None): # pylint: disable=redefined-builtin 440 axis = deprecation.deprecated_argument_lookup('axis', axis, 'squeeze_dims', 441 squeeze_dims) 442 return ragged_squeeze_op.squeeze(input, axis, name) 443 444 445def _ragged_dynamic_partition(data, partitions, num_partitions, name=None): 446 """RaggedTensor Dispatch override for tf.dynamic_partition.""" 447 if not isinstance(num_partitions, int) or num_partitions < 0: 448 raise TypeError('num_partitions must be a non-negative integer') 449 result = ragged_array_ops.stack_dynamic_partitions(data, partitions, 450 num_partitions, name) 451 return [result[i] for i in range(num_partitions)] 452 453# (original_op, ragged_op, ragged_args) 454_RAGGED_DISPATCH_OPS = [ 455 (array_ops.batch_gather, ragged_batch_gather_ops.batch_gather, 456 ['params', 'indices']), 457 (array_ops.concat, ragged_concat_ops.concat, ['[values]']), 458 (array_ops.expand_dims, _ragged_expand_dims_v1, ['input']), 459 (array_ops.expand_dims_v2, ragged_array_ops.expand_dims, ['input']), 460 (array_ops.gather, _ragged_gather_v1, ['params', 'indices']), 461 (array_ops.gather_v2, ragged_gather_ops.gather, ['params', 'indices']), 462 (array_ops.gather_nd, _ragged_gather_nd_v1, ['params', 'indices']), 463 (array_ops.gather_nd_v2, ragged_gather_ops.gather_nd, ['params', 464 'indices']), 465 (array_ops.one_hot, ragged_array_ops.ragged_one_hot, ['indices']), 466 (array_ops.rank, ragged_array_ops.rank, ['input']), 467 (array_ops.reverse, ragged_array_ops.reverse, ['tensor']), 468 (array_ops.size, _ragged_size_v1, ['input']), 469 (array_ops.size_v2, ragged_array_ops.size, ['input']), 470 (array_ops.squeeze, _ragged_squeeze_v1, ['input']), 471 (array_ops.squeeze_v2, ragged_squeeze_op.squeeze, ['input']), 472 (array_ops.stack, ragged_concat_ops.stack, ['[values]']), 473 (array_ops.tile, ragged_array_ops.tile, ['input']), 474 (array_ops.where, ragged_where_op.where, ['condition', 'x', 'y']), 475 (data_flow_ops.dynamic_partition, _ragged_dynamic_partition, 476 ['data', 'partitions']), 477 (math_ops.unsorted_segment_sum, ragged_math_ops.segment_sum, 478 ['data', 'segment_ids']), 479 (math_ops.unsorted_segment_prod, ragged_math_ops.segment_prod, 480 ['data', 'segment_ids']), 481 (math_ops.unsorted_segment_min, ragged_math_ops.segment_min, 482 ['data', 'segment_ids']), 483 (math_ops.unsorted_segment_max, ragged_math_ops.segment_max, 484 ['data', 'segment_ids']), 485 (math_ops.unsorted_segment_mean, ragged_math_ops.segment_mean, 486 ['data', 'segment_ids']), 487 (math_ops.unsorted_segment_sqrt_n, ragged_math_ops.segment_sqrt_n, 488 ['data', 'segment_ids']), 489 (string_ops.reduce_join_v2, ragged_string_ops.reduce_join, ['inputs']), 490 (math_ops.reduce_sum, ragged_math_ops.reduce_sum, ['input_tensor']), 491 (math_ops.reduce_prod, ragged_math_ops.reduce_prod, ['input_tensor']), 492 (math_ops.reduce_min, ragged_math_ops.reduce_min, ['input_tensor']), 493 (math_ops.reduce_max, ragged_math_ops.reduce_max, ['input_tensor']), 494 (math_ops.reduce_mean, ragged_math_ops.reduce_mean, ['input_tensor']), 495 (math_ops.reduce_any, ragged_math_ops.reduce_any, ['input_tensor']), 496 (math_ops.reduce_all, ragged_math_ops.reduce_all, ['input_tensor']), 497] 498 499 500def register_dispatchers(): 501 """Constructs & registers OpDispatchers for ragged ops.""" 502 503 op_list = ( 504 _UNARY_ELEMENTWISE_OPS + _UNARY_LIST_ELEMENTWISE_OPS + 505 _BINARY_ELEMENTWISE_OPS + [x[0] for x in _RAGGED_DISPATCH_OPS]) 506 for op in op_list: 507 _, undecorated_op = tf_decorator.unwrap(op) 508 if not hasattr(undecorated_op, 509 tf_export.API_ATTRS[tf_export.TENSORFLOW_API_NAME].names): 510 raise AssertionError('Expected %s to be an exported symbol ' 511 '(while adding a RaggedTensor dispatcher)') 512 513 for op in _UNARY_ELEMENTWISE_OPS: 514 UnaryRaggedElementwiseDispatcher(op).register(op) 515 516 for op in _UNARY_LIST_ELEMENTWISE_OPS: 517 UnaryRaggedElementwiseDispatcher(op, True).register(op) 518 519 for op in _BINARY_ELEMENTWISE_OPS: 520 BinaryRaggedElementwiseDispatcher(op).register(op) 521 522 for (original_op, ragged_op, args) in _RAGGED_DISPATCH_OPS: 523 RaggedDispatcher(original_op, ragged_op, args).register(original_op) 524 525 526def _ragged_op_signature(op, ragged_args): 527 """Returns a signature for the given op, marking ragged args in bold.""" 528 op_name = tf_export.get_canonical_name_for_symbol(op) 529 argspec = tf_inspect.getfullargspec(op) 530 arg_names = argspec.args 531 532 # Mark ragged arguments in bold. 533 for pos in ragged_args: 534 arg_names[pos] = '**' + arg_names[pos] + '**' 535 536 # Add argument defaults. 537 for pos in range(-1, -len(argspec.defaults) - 1, -1): 538 arg_names[pos] += '=`{!r}`'.format(argspec.defaults[pos]) 539 540 # Add varargs and keyword args 541 if argspec.varargs: 542 arg_names.append('*' + argspec.varargs) 543 if argspec.varkw: 544 arg_names.append('**' + argspec.varkw) 545 546 return '* `tf.{}`({})'.format(op_name, ', '.join(arg_names)) 547 548 549def _op_is_in_tf_version(op, version): 550 if version == 1: 551 return (tf_export.get_v1_names(tf_decorator.unwrap(op)[1]) or 552 op in _V2_OPS_THAT_ARE_DELEGATED_TO_FROM_V1_OPS) 553 elif version == 2: 554 return tf_export.get_v2_names(tf_decorator.unwrap(op)[1]) 555 else: 556 raise ValueError('Expected version 1 or 2.') 557 558 559def ragged_op_list(tf_version=1): 560 """Returns a string listing operators that have dispathers registered.""" 561 lines = [] 562 for op in _UNARY_ELEMENTWISE_OPS + _UNARY_LIST_ELEMENTWISE_OPS: 563 if _op_is_in_tf_version(op, tf_version): 564 lines.append(_ragged_op_signature(op, [0])) 565 for op in _BINARY_ELEMENTWISE_OPS: 566 if _op_is_in_tf_version(op, tf_version): 567 lines.append(_ragged_op_signature(op, [0, 1])) 568 for op, _, ragged_args in _RAGGED_DISPATCH_OPS: 569 if _op_is_in_tf_version(op, tf_version): 570 arginfos = _get_arg_infos(op, ragged_args) 571 ragged_args = [arginfo.position for arginfo in arginfos] 572 lines.append(_ragged_op_signature(op, ragged_args)) 573 return ('\n\n### Additional ops that support `RaggedTensor`\n\n' 574 'Arguments that accept `RaggedTensor`s are marked in **bold**.\n\n' + 575 '\n'.join(sorted(lines)) + 'n') 576 577 578register_dispatchers() 579