• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Operator dispatch for RaggedTensors."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22
23import numpy as np
24
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import sparse_tensor
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import clip_ops
30from tensorflow.python.ops import data_flow_ops
31from tensorflow.python.ops import gen_bitwise_ops
32from tensorflow.python.ops import math_ops
33from tensorflow.python.ops import parsing_ops
34from tensorflow.python.ops import string_ops
35from tensorflow.python.ops import variables
36from tensorflow.python.ops.ragged import ragged_array_ops
37from tensorflow.python.ops.ragged import ragged_batch_gather_ops
38from tensorflow.python.ops.ragged import ragged_concat_ops
39from tensorflow.python.ops.ragged import ragged_gather_ops
40from tensorflow.python.ops.ragged import ragged_math_ops
41from tensorflow.python.ops.ragged import ragged_squeeze_op
42from tensorflow.python.ops.ragged import ragged_string_ops
43from tensorflow.python.ops.ragged import ragged_tensor
44from tensorflow.python.ops.ragged import ragged_tensor_shape
45from tensorflow.python.ops.ragged import ragged_util
46from tensorflow.python.ops.ragged import ragged_where_op
47from tensorflow.python.util import deprecation
48from tensorflow.python.util import dispatch
49from tensorflow.python.util import tf_decorator
50from tensorflow.python.util import tf_export
51from tensorflow.python.util import tf_inspect
52
53# @TODO(edloper): Set this to True in the CL that exports RaggedTensors.
54_UPDATE_DOCSTRINGS = False
55
56# Information about an argument to an operation: The name of the argument, its
57# position in the argument list, and a boolean flag indicating whether it
58# expects a list of tensors.
59_ArgInfo = collections.namedtuple('ArgInfo', ['name', 'position', 'is_list'])
60
61
62def _get_arg_infos(func, arg_names):
63  """Returns an `_ArgInfo` for each argument of `func` specified by `arg_names`.
64
65  Args:
66    func: The function whose arguments should be described.
67    arg_names: The names of the arguments to get info for.
68
69  Returns:
70    A tuple of `_ArgInfo`s.
71  """
72  arg_infos = []
73
74  # Inspect the func's argspec to find the position of each arg.
75  arg_spec = tf_inspect.getargspec(func)
76  for argname in arg_names:
77    assert isinstance(argname, str)
78    is_list = argname.startswith('[') and argname.endswith(']')
79    if is_list:
80      argname = argname[1:-1]
81    if argname not in arg_spec.args:
82      raise ValueError('Argument %r not found function in %s.  Args=%s' %
83                       (argname, func, arg_spec.args))
84    arg_infos.append(_ArgInfo(argname, arg_spec.args.index(argname), is_list))
85  return arg_infos
86
87
88def _is_convertible_to_tensor(value):
89  """Returns true if `value` is convertible to a `Tensor`."""
90  if value is None:
91    return True
92  if isinstance(value,
93                (ops.Tensor, variables.Variable, np.ndarray, int, float, str)):
94    return True
95  elif isinstance(value, (sparse_tensor.SparseTensor,)):
96    return False
97  else:
98    try:
99      ops.convert_to_tensor(value)
100      return True
101    except (TypeError, ValueError):
102      return False
103
104
105class UnaryRaggedElementwiseDispatcher(dispatch.OpDispatcher):
106  """OpDispatcher for unary ops that map a base op across ragged values."""
107
108  def __init__(self, original_op, arg_is_list=False):
109    self._original_op = original_op
110    self._arg_is_list = arg_is_list
111    arg_names = tf_inspect.getfullargspec(original_op)[0]
112    self._x = arg_names[0]
113    if _UPDATE_DOCSTRINGS:
114      original_op.__doc__ = (
115          original_op.__doc__.rstrip() + '\n\n' +
116          '    `{x}` may be a `tf.RaggedTensor`.\n'.format(x=self._x))
117
118  def handle(self, args, kwargs):
119    if args:
120      x, args = args[0], args[1:]
121    else:
122      kwargs = kwargs.copy()
123      x = kwargs.pop(self._x, None)
124    if x is None:
125      return self.NOT_SUPPORTED
126    if self._arg_is_list:
127      found_ragged = False
128      for elt in x:
129        if ragged_tensor.is_ragged(elt):
130          found_ragged = True
131        elif not _is_convertible_to_tensor(elt):
132          return self.NOT_SUPPORTED
133      if found_ragged:
134        x = ragged_tensor.match_row_splits_dtypes(*x)
135        nested_splits_lists = [
136            elt.nested_row_splits for elt in x if ragged_tensor.is_ragged(elt)
137        ]
138        flat_values = [
139            elt.flat_values if ragged_tensor.is_ragged(elt) else elt
140            for elt in x
141        ]
142        with ops.control_dependencies(
143            ragged_util.assert_splits_match(nested_splits_lists)):
144          return ragged_tensor.RaggedTensor.from_nested_row_splits(
145              self._original_op(flat_values, *args, **kwargs),
146              nested_splits_lists[0], validate=False)
147      else:
148        return self.NOT_SUPPORTED
149    else:
150      found_ragged = ragged_tensor.is_ragged(x)
151      if found_ragged:
152        mapped_values = self._original_op(x.flat_values, *args, **kwargs)
153        return x.with_flat_values(mapped_values)
154      else:
155        return self.NOT_SUPPORTED
156
157
158class BinaryRaggedElementwiseDispatcher(dispatch.OpDispatcher):
159  """OpDispatcher for binary ops that map a base op across ragged values.
160
161  Supports broadcasting.
162  """
163
164  def __init__(self, original_op):
165    self._original_op = original_op
166    arg_names = tf_inspect.getfullargspec(original_op)[0]
167    self._x = arg_names[0]
168    self._y = arg_names[1]
169    if _UPDATE_DOCSTRINGS:
170      original_op.__doc__ = (
171          original_op.__doc__.rstrip() + '\n\n' +
172          '    `{x}` and `{y}` may be a `tf.RaggedTensor`.\n'.format(
173              x=self._x, y=self._y))
174
175  def handle(self, args, kwargs):
176    # Extract the binary args.
177    if len(args) > 1:
178      x = args[0]
179      y = args[1]
180      args = args[2:]
181    elif args:
182      kwargs = kwargs.copy()
183      x = args[0]
184      y = kwargs.pop(self._y, None)
185      args = args[1:]
186    else:
187      kwargs = kwargs.copy()
188      x = kwargs.pop(self._x, None)
189      y = kwargs.pop(self._y, None)
190
191    # Bail if we don't have at least one ragged argument.
192    x_is_ragged = ragged_tensor.is_ragged(x)
193    y_is_ragged = ragged_tensor.is_ragged(y)
194    if not (x_is_ragged or y_is_ragged):
195      return self.NOT_SUPPORTED
196
197    # Convert args to tensors.  Bail if conversion fails.
198    try:
199      if not x_is_ragged:
200        x = ops.convert_to_tensor(x, name=self._x, preferred_dtype=y.dtype)
201      if not y_is_ragged:
202        y = ops.convert_to_tensor(y, name=self._y, preferred_dtype=x.dtype)
203    except (TypeError, ValueError):
204      return self.NOT_SUPPORTED
205
206    if x_is_ragged and y_is_ragged:
207      x, y = ragged_tensor.match_row_splits_dtypes(x, y)
208
209    if ((x_is_ragged and y_is_ragged) or
210        (x_is_ragged and x.flat_values.shape.ndims <= y.shape.ndims) or
211        (y_is_ragged and y.flat_values.shape.ndims <= x.shape.ndims)):
212      bcast_shape = ragged_tensor_shape.broadcast_dynamic_shape(
213          ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(x),
214          ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(y))
215      x = ragged_tensor_shape.broadcast_to(
216          x, bcast_shape, broadcast_inner_dimensions=False)
217      y = ragged_tensor_shape.broadcast_to(
218          y, bcast_shape, broadcast_inner_dimensions=False)
219
220    x_values = x.flat_values if ragged_tensor.is_ragged(x) else x
221    y_values = y.flat_values if ragged_tensor.is_ragged(y) else y
222    mapped_values = self._original_op(x_values, y_values, *args, **kwargs)
223    if ragged_tensor.is_ragged(x):
224      return x.with_flat_values(mapped_values)
225    else:
226      return y.with_flat_values(mapped_values)
227
228
229class RaggedDispatcher(dispatch.OpDispatcher):
230  """OpDispatcher for ragged ops.
231
232  Dispatches to a wrapped op-handler if at least one of the `tensor_args`
233  arguments is a RaggedTensor or a RaggedTensorValue; and all of the
234  `tensor_args` arguments are convertible to Tensor or RaggedTensor.
235  """
236
237  def __init__(self, original_op, ragged_op, ragged_args):
238    op_arg_names = tf_inspect.getfullargspec(original_op)[0]
239    ragged_arg_names = tf_inspect.getfullargspec(ragged_op)[0]
240    if op_arg_names != ragged_arg_names:
241      raise AssertionError(
242          'Signature must exactly match when overriding %s with %s: %s vs %s' %
243          (original_op, ragged_op, op_arg_names, ragged_arg_names))
244    self._ragged_op = ragged_op
245    self._ragged_args = _get_arg_infos(ragged_op, ragged_args)
246    if _UPDATE_DOCSTRINGS:
247      arg_list = ' and '.join('`%s`' % arg for arg in ragged_args)
248      original_op.__doc__ = (
249          original_op.__doc__.rstrip() + '\n\n' +
250          '    {0} may be a `tf.RaggedTensor`.\n'.format(arg_list))
251
252  def handle(self, args, kwargs):
253    if self.is_supported(args, kwargs):
254      return self._ragged_op(*args, **kwargs)
255    else:
256      return self.NOT_SUPPORTED
257
258  def is_supported(self, args, kwargs):
259    found_ragged = False
260    for arg_info in self._ragged_args:
261      if arg_info.position < len(args):
262        arg = args[arg_info.position]
263      else:
264        arg = kwargs.get(arg_info.name, None)
265
266      if arg_info.is_list:
267        if not isinstance(arg, (list, tuple)):
268          return False
269        for elt in arg:
270          if ragged_tensor.is_ragged(elt):
271            found_ragged = True
272          elif not _is_convertible_to_tensor(elt):
273            return False
274      else:
275        if ragged_tensor.is_ragged(arg):
276          found_ragged = True
277        elif not _is_convertible_to_tensor(arg):
278          return False
279    return found_ragged
280
281
282_UNARY_ELEMENTWISE_OPS = [
283    array_ops.check_numerics,
284    array_ops.identity,
285    array_ops.ones_like,
286    array_ops.ones_like_v2,
287    array_ops.zeros_like,
288    array_ops.zeros_like_v2,
289    clip_ops.clip_by_value,
290    gen_bitwise_ops.invert,
291    math_ops.abs,
292    math_ops.acos,
293    math_ops.acosh,
294    math_ops.angle,
295    math_ops.asin,
296    math_ops.asinh,
297    math_ops.atan,
298    math_ops.atanh,
299    math_ops.cast,
300    math_ops.ceil,
301    math_ops.conj,
302    math_ops.cos,
303    math_ops.cosh,
304    math_ops.digamma,
305    math_ops.erf,
306    math_ops.erfc,
307    math_ops.erfinv,
308    math_ops.exp,
309    math_ops.expm1,
310    math_ops.floor,
311    math_ops.imag,
312    math_ops.is_finite,
313    math_ops.is_inf,
314    math_ops.is_nan,
315    math_ops.lgamma,
316    math_ops.log,
317    math_ops.log1p,
318    math_ops.log_sigmoid,
319    math_ops.logical_not,
320    math_ops.ndtri,
321    math_ops.negative,
322    math_ops.real,
323    math_ops.reciprocal,
324    math_ops.rint,
325    math_ops.round,
326    math_ops.rsqrt,
327    math_ops.saturate_cast,
328    math_ops.sign,
329    math_ops.sin,
330    math_ops.sinh,
331    math_ops.sqrt,
332    math_ops.square,
333    math_ops.tan,
334    parsing_ops.decode_compressed,
335    string_ops.string_to_number,
336    string_ops.string_to_hash_bucket,
337    string_ops.as_string,
338    string_ops.decode_base64,
339    string_ops.encode_base64,
340    string_ops.regex_full_match,
341    string_ops.regex_replace,
342    string_ops.string_strip,
343    string_ops.string_to_hash_bucket,
344    string_ops.string_to_hash_bucket_fast,
345    string_ops.string_to_hash_bucket_strong,
346    string_ops.substr,
347    string_ops.substr_v2,
348    string_ops.string_length,
349    string_ops.string_length_v2,
350    string_ops.unicode_script,
351]
352
353_UNARY_LIST_ELEMENTWISE_OPS = [
354    math_ops.add_n,
355    string_ops.string_join,
356]
357
358_BINARY_ELEMENTWISE_OPS = [
359    gen_bitwise_ops.bitwise_and,
360    gen_bitwise_ops.bitwise_or,
361    gen_bitwise_ops.bitwise_xor,
362    gen_bitwise_ops.left_shift,
363    gen_bitwise_ops.right_shift,
364    math_ops.add,
365    math_ops.atan2,
366    math_ops.complex,
367    math_ops.div_no_nan,
368    math_ops.divide,
369    math_ops.equal,
370    math_ops.floordiv,
371    math_ops.floormod,
372    math_ops.greater,
373    math_ops.greater_equal,
374    math_ops.less,
375    math_ops.less_equal,
376    math_ops.logical_and,
377    math_ops.logical_or,
378    math_ops.logical_xor,
379    math_ops.maximum,
380    math_ops.minimum,
381    math_ops.multiply,
382    math_ops.not_equal,
383    math_ops.pow,
384    math_ops.realdiv,
385    math_ops.squared_difference,
386    math_ops.subtract,
387    math_ops.truediv,
388    math_ops.truncatediv,
389    math_ops.truncatemod,
390]
391
392
393# We don't need to register a separate delegation handler for these v1 ops,
394# since they delegate to the v2 ops (which already have a handler).  But we
395# still want to include them in the ragged_op_list() output.
396_V2_OPS_THAT_ARE_DELEGATED_TO_FROM_V1_OPS = [
397    math_ops.reduce_sum,
398    math_ops.reduce_prod,
399    math_ops.reduce_min,
400    math_ops.reduce_max,
401    math_ops.reduce_mean,
402    math_ops.reduce_any,
403    math_ops.reduce_all,
404    string_ops.string_to_number,
405    string_ops.string_to_hash_bucket,
406    string_ops.reduce_join_v2,
407]
408
409
410def _ragged_gather_v1(params, indices, validate_indices=None, name=None,
411                      axis=0, batch_dims=0):
412  return ragged_gather_ops.gather(
413      params=params,
414      indices=indices,
415      validate_indices=validate_indices,
416      axis=axis,
417      batch_dims=batch_dims,
418      name=name)
419
420
421def _ragged_gather_nd_v1(params, indices, name=None, batch_dims=0):
422  return ragged_gather_ops.gather_nd(
423      params=params,
424      indices=indices,
425      batch_dims=batch_dims,
426      name=name)
427
428
429def _ragged_expand_dims_v1(input, axis=None, name=None, dim=None):  # pylint: disable=redefined-builtin
430  if dim is not None:
431    axis = dim
432  return ragged_array_ops.expand_dims(input=input, axis=axis, name=name)
433
434
435def _ragged_size_v1(input, name=None, out_type=dtypes.int32):  # pylint: disable=redefined-builtin
436  return ragged_array_ops.size(input=input, out_type=out_type, name=name)
437
438
439def _ragged_squeeze_v1(input, axis=None, name=None, squeeze_dims=None):  # pylint: disable=redefined-builtin
440  axis = deprecation.deprecated_argument_lookup('axis', axis, 'squeeze_dims',
441                                                squeeze_dims)
442  return ragged_squeeze_op.squeeze(input, axis, name)
443
444
445def _ragged_dynamic_partition(data, partitions, num_partitions, name=None):
446  """RaggedTensor Dispatch override for tf.dynamic_partition."""
447  if not isinstance(num_partitions, int) or num_partitions < 0:
448    raise TypeError('num_partitions must be a non-negative integer')
449  result = ragged_array_ops.stack_dynamic_partitions(data, partitions,
450                                                     num_partitions, name)
451  return [result[i] for i in range(num_partitions)]
452
453# (original_op, ragged_op, ragged_args)
454_RAGGED_DISPATCH_OPS = [
455    (array_ops.batch_gather, ragged_batch_gather_ops.batch_gather,
456     ['params', 'indices']),
457    (array_ops.concat, ragged_concat_ops.concat, ['[values]']),
458    (array_ops.expand_dims, _ragged_expand_dims_v1, ['input']),
459    (array_ops.expand_dims_v2, ragged_array_ops.expand_dims, ['input']),
460    (array_ops.gather, _ragged_gather_v1, ['params', 'indices']),
461    (array_ops.gather_v2, ragged_gather_ops.gather, ['params', 'indices']),
462    (array_ops.gather_nd, _ragged_gather_nd_v1, ['params', 'indices']),
463    (array_ops.gather_nd_v2, ragged_gather_ops.gather_nd, ['params',
464                                                           'indices']),
465    (array_ops.one_hot, ragged_array_ops.ragged_one_hot, ['indices']),
466    (array_ops.rank, ragged_array_ops.rank, ['input']),
467    (array_ops.reverse, ragged_array_ops.reverse, ['tensor']),
468    (array_ops.size, _ragged_size_v1, ['input']),
469    (array_ops.size_v2, ragged_array_ops.size, ['input']),
470    (array_ops.squeeze, _ragged_squeeze_v1, ['input']),
471    (array_ops.squeeze_v2, ragged_squeeze_op.squeeze, ['input']),
472    (array_ops.stack, ragged_concat_ops.stack, ['[values]']),
473    (array_ops.tile, ragged_array_ops.tile, ['input']),
474    (array_ops.where, ragged_where_op.where, ['condition', 'x', 'y']),
475    (data_flow_ops.dynamic_partition, _ragged_dynamic_partition,
476     ['data', 'partitions']),
477    (math_ops.unsorted_segment_sum, ragged_math_ops.segment_sum,
478     ['data', 'segment_ids']),
479    (math_ops.unsorted_segment_prod, ragged_math_ops.segment_prod,
480     ['data', 'segment_ids']),
481    (math_ops.unsorted_segment_min, ragged_math_ops.segment_min,
482     ['data', 'segment_ids']),
483    (math_ops.unsorted_segment_max, ragged_math_ops.segment_max,
484     ['data', 'segment_ids']),
485    (math_ops.unsorted_segment_mean, ragged_math_ops.segment_mean,
486     ['data', 'segment_ids']),
487    (math_ops.unsorted_segment_sqrt_n, ragged_math_ops.segment_sqrt_n,
488     ['data', 'segment_ids']),
489    (string_ops.reduce_join_v2, ragged_string_ops.reduce_join, ['inputs']),
490    (math_ops.reduce_sum, ragged_math_ops.reduce_sum, ['input_tensor']),
491    (math_ops.reduce_prod, ragged_math_ops.reduce_prod, ['input_tensor']),
492    (math_ops.reduce_min, ragged_math_ops.reduce_min, ['input_tensor']),
493    (math_ops.reduce_max, ragged_math_ops.reduce_max, ['input_tensor']),
494    (math_ops.reduce_mean, ragged_math_ops.reduce_mean, ['input_tensor']),
495    (math_ops.reduce_any, ragged_math_ops.reduce_any, ['input_tensor']),
496    (math_ops.reduce_all, ragged_math_ops.reduce_all, ['input_tensor']),
497]
498
499
500def register_dispatchers():
501  """Constructs & registers OpDispatchers for ragged ops."""
502
503  op_list = (
504      _UNARY_ELEMENTWISE_OPS + _UNARY_LIST_ELEMENTWISE_OPS +
505      _BINARY_ELEMENTWISE_OPS + [x[0] for x in _RAGGED_DISPATCH_OPS])
506  for op in op_list:
507    _, undecorated_op = tf_decorator.unwrap(op)
508    if not hasattr(undecorated_op,
509                   tf_export.API_ATTRS[tf_export.TENSORFLOW_API_NAME].names):
510      raise AssertionError('Expected %s to be an exported symbol '
511                           '(while adding a RaggedTensor dispatcher)')
512
513  for op in _UNARY_ELEMENTWISE_OPS:
514    UnaryRaggedElementwiseDispatcher(op).register(op)
515
516  for op in _UNARY_LIST_ELEMENTWISE_OPS:
517    UnaryRaggedElementwiseDispatcher(op, True).register(op)
518
519  for op in _BINARY_ELEMENTWISE_OPS:
520    BinaryRaggedElementwiseDispatcher(op).register(op)
521
522  for (original_op, ragged_op, args) in _RAGGED_DISPATCH_OPS:
523    RaggedDispatcher(original_op, ragged_op, args).register(original_op)
524
525
526def _ragged_op_signature(op, ragged_args):
527  """Returns a signature for the given op, marking ragged args in bold."""
528  op_name = tf_export.get_canonical_name_for_symbol(op)
529  argspec = tf_inspect.getfullargspec(op)
530  arg_names = argspec.args
531
532  # Mark ragged arguments in bold.
533  for pos in ragged_args:
534    arg_names[pos] = '**' + arg_names[pos] + '**'
535
536  # Add argument defaults.
537  for pos in range(-1, -len(argspec.defaults) - 1, -1):
538    arg_names[pos] += '=`{!r}`'.format(argspec.defaults[pos])
539
540  # Add varargs and keyword args
541  if argspec.varargs:
542    arg_names.append('*' + argspec.varargs)
543  if argspec.varkw:
544    arg_names.append('**' + argspec.varkw)
545
546  return '* `tf.{}`({})'.format(op_name, ', '.join(arg_names))
547
548
549def _op_is_in_tf_version(op, version):
550  if version == 1:
551    return (tf_export.get_v1_names(tf_decorator.unwrap(op)[1]) or
552            op in _V2_OPS_THAT_ARE_DELEGATED_TO_FROM_V1_OPS)
553  elif version == 2:
554    return tf_export.get_v2_names(tf_decorator.unwrap(op)[1])
555  else:
556    raise ValueError('Expected version 1 or 2.')
557
558
559def ragged_op_list(tf_version=1):
560  """Returns a string listing operators that have dispathers registered."""
561  lines = []
562  for op in _UNARY_ELEMENTWISE_OPS + _UNARY_LIST_ELEMENTWISE_OPS:
563    if _op_is_in_tf_version(op, tf_version):
564      lines.append(_ragged_op_signature(op, [0]))
565  for op in _BINARY_ELEMENTWISE_OPS:
566    if _op_is_in_tf_version(op, tf_version):
567      lines.append(_ragged_op_signature(op, [0, 1]))
568  for op, _, ragged_args in _RAGGED_DISPATCH_OPS:
569    if _op_is_in_tf_version(op, tf_version):
570      arginfos = _get_arg_infos(op, ragged_args)
571      ragged_args = [arginfo.position for arginfo in arginfos]
572      lines.append(_ragged_op_signature(op, ragged_args))
573  return ('\n\n### Additional ops that support `RaggedTensor`\n\n'
574          'Arguments that accept `RaggedTensor`s are marked in **bold**.\n\n' +
575          '\n'.join(sorted(lines)) + 'n')
576
577
578register_dispatchers()
579