• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Operator dispatch for RaggedTensors."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import numpy as np
23
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import sparse_tensor
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import clip_ops
29from tensorflow.python.ops import gen_bitwise_ops
30from tensorflow.python.ops import math_ops
31from tensorflow.python.ops import parsing_ops
32from tensorflow.python.ops import string_ops
33from tensorflow.python.ops import variables
34from tensorflow.python.ops.ragged import ragged_array_ops
35from tensorflow.python.ops.ragged import ragged_batch_gather_ops
36from tensorflow.python.ops.ragged import ragged_concat_ops
37from tensorflow.python.ops.ragged import ragged_gather_ops
38from tensorflow.python.ops.ragged import ragged_math_ops
39from tensorflow.python.ops.ragged import ragged_tensor
40from tensorflow.python.ops.ragged import ragged_tensor_shape
41from tensorflow.python.ops.ragged import ragged_util
42from tensorflow.python.ops.ragged import ragged_where_op
43from tensorflow.python.util import dispatch
44from tensorflow.python.util import tf_decorator
45from tensorflow.python.util import tf_export
46from tensorflow.python.util import tf_inspect
47
48# @TODO(edloper): Set this to True in the CL that exports RaggedTensors.
49_UPDATE_DOCSTRINGS = False
50
51# Information about an argument to an operation: The name of the argument, its
52# position in the argument list, and a boolean flag indicating whether it
53# expects a list of tensors.
54_ArgInfo = collections.namedtuple('ArgInfo', ['name', 'position', 'is_list'])
55
56
57def _get_arg_infos(func, arg_names):
58  """Returns an `_ArgInfo` for each argument of `func` specified by `arg_names`.
59
60  Args:
61    func: The function whose arguments should be described.
62    arg_names: The names of the arguments to get info for.
63
64  Returns:
65    A tuple of `_ArgInfo`s.
66  """
67  arg_infos = []
68
69  # Inspect the func's argspec to find the position of each arg.
70  arg_spec = tf_inspect.getargspec(func)
71  for argname in arg_names:
72    assert isinstance(argname, str)
73    is_list = argname.startswith('[') and argname.endswith(']')
74    if is_list:
75      argname = argname[1:-1]
76    if argname not in arg_spec.args:
77      raise ValueError('Argument %r not found function in %s.  Args=%s' %
78                       (argname, func, arg_spec.args))
79    arg_infos.append(_ArgInfo(argname, arg_spec.args.index(argname), is_list))
80  return arg_infos
81
82
83def _is_convertible_to_tensor(value):
84  """Returns true if `value` is convertible to a `Tensor`."""
85  if value is None:
86    return True
87  if isinstance(value,
88                (ops.Tensor, variables.Variable, np.ndarray, int, float, str)):
89    return True
90  elif isinstance(value, (sparse_tensor.SparseTensor,)):
91    return False
92  else:
93    try:
94      ops.convert_to_tensor(value)
95      return True
96    except (TypeError, ValueError):
97      return False
98
99
100class UnaryRaggedElementwiseDispatcher(dispatch.OpDispatcher):
101  """OpDispatcher for unary ops that map a base op across ragged values."""
102
103  def __init__(self, original_op, arg_is_list=False):
104    self._original_op = original_op
105    self._arg_is_list = arg_is_list
106    arg_names = tf_inspect.getfullargspec(original_op)[0]
107    self._x = arg_names[0]
108    if _UPDATE_DOCSTRINGS:
109      original_op.__doc__ = (
110          original_op.__doc__.rstrip() + '\n\n' +
111          '    `{x}` may be a `tf.RaggedTensor`.\n'.format(x=self._x))
112
113  def handle(self, args, kwargs):
114    if args:
115      x, args = args[0], args[1:]
116    else:
117      kwargs = kwargs.copy()
118      x = kwargs.pop(self._x, None)
119    if x is None:
120      return self.NOT_SUPPORTED
121    if self._arg_is_list:
122      found_ragged = False
123      for elt in x:
124        if ragged_tensor.is_ragged(elt):
125          found_ragged = True
126        elif not _is_convertible_to_tensor(elt):
127          return self.NOT_SUPPORTED
128      if found_ragged:
129        nested_splits_lists = [
130            elt.nested_row_splits for elt in x if ragged_tensor.is_ragged(elt)
131        ]
132        flat_values = [
133            elt.flat_values if ragged_tensor.is_ragged(elt) else elt
134            for elt in x
135        ]
136        with ops.control_dependencies(
137            ragged_util.assert_splits_match(nested_splits_lists)):
138          return ragged_tensor.RaggedTensor.from_nested_row_splits(
139              self._original_op(flat_values, *args, **kwargs),
140              nested_splits_lists[0])
141      else:
142        return self.NOT_SUPPORTED
143    else:
144      found_ragged = ragged_tensor.is_ragged(x)
145      if found_ragged:
146        mapped_values = self._original_op(x.flat_values, *args, **kwargs)
147        return x.with_flat_values(mapped_values)
148      else:
149        return self.NOT_SUPPORTED
150
151
152class BinaryRaggedElementwiseDispatcher(dispatch.OpDispatcher):
153  """OpDispatcher for binary ops that map a base op across ragged values.
154
155  Supports broadcasting.
156  """
157
158  def __init__(self, original_op):
159    self._original_op = original_op
160    arg_names = tf_inspect.getfullargspec(original_op)[0]
161    self._x = arg_names[0]
162    self._y = arg_names[1]
163    if _UPDATE_DOCSTRINGS:
164      original_op.__doc__ = (
165          original_op.__doc__.rstrip() + '\n\n' +
166          '    `{x}` and `{y}` may be a `tf.RaggedTensor`.\n'.format(
167              x=self._x, y=self._y))
168
169  def handle(self, args, kwargs):
170    # Extract the binary args.
171    if len(args) > 1:
172      x = args[0]
173      y = args[1]
174      args = args[2:]
175    elif args:
176      kwargs = kwargs.copy()
177      x = args[0]
178      y = kwargs.pop(self._y, None)
179      args = args[1:]
180    else:
181      kwargs = kwargs.copy()
182      x = kwargs.pop(self._x, None)
183      y = kwargs.pop(self._y, None)
184
185    # Bail if we don't have at least one ragged argument.
186    x_is_ragged = ragged_tensor.is_ragged(x)
187    y_is_ragged = ragged_tensor.is_ragged(y)
188    if not (x_is_ragged or y_is_ragged):
189      return self.NOT_SUPPORTED
190
191    # Convert args to tensors.  Bail if conversion fails.
192    try:
193      if not x_is_ragged:
194        x = ops.convert_to_tensor(x, name=self._x, preferred_dtype=y.dtype)
195      if not y_is_ragged:
196        y = ops.convert_to_tensor(y, name=self._y, preferred_dtype=x.dtype)
197    except (TypeError, ValueError):
198      return self.NOT_SUPPORTED
199
200    if ((x_is_ragged and y_is_ragged) or
201        (x_is_ragged and x.flat_values.shape.ndims <= y.shape.ndims) or
202        (y_is_ragged and y.flat_values.shape.ndims <= x.shape.ndims)):
203      bcast_shape = ragged_tensor_shape.broadcast_dynamic_shape(
204          ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(x),
205          ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(y))
206      x = ragged_tensor_shape.broadcast_to(
207          x, bcast_shape, broadcast_inner_dimensions=False)
208      y = ragged_tensor_shape.broadcast_to(
209          y, bcast_shape, broadcast_inner_dimensions=False)
210
211    x_values = x.flat_values if ragged_tensor.is_ragged(x) else x
212    y_values = y.flat_values if ragged_tensor.is_ragged(y) else y
213    mapped_values = self._original_op(x_values, y_values, *args, **kwargs)
214    if ragged_tensor.is_ragged(x):
215      return x.with_flat_values(mapped_values)
216    else:
217      return y.with_flat_values(mapped_values)
218
219
220class RaggedDispatcher(dispatch.OpDispatcher):
221  """OpDispatcher for ragged ops.
222
223  Dispatches to a wrapped op-handler if at least one of the `tensor_args`
224  arguments is a RaggedTensor or a RaggedTensorValue; and all of the
225  `tensor_args` arguments are convertible to Tensor or RaggedTensor.
226  """
227
228  def __init__(self, original_op, ragged_op, ragged_args):
229    op_arg_names = tf_inspect.getfullargspec(original_op)[0]
230    ragged_arg_names = tf_inspect.getfullargspec(ragged_op)[0]
231    if op_arg_names != ragged_arg_names:
232      raise AssertionError(
233          'Signature must exactly match when overriding %s with %s: %s vs %s' %
234          (original_op, ragged_op, op_arg_names, ragged_arg_names))
235    self._ragged_op = ragged_op
236    self._ragged_args = _get_arg_infos(ragged_op, ragged_args)
237    if _UPDATE_DOCSTRINGS:
238      arg_list = ' and '.join('`%s`' % arg for arg in ragged_args)
239      original_op.__doc__ = (
240          original_op.__doc__.rstrip() + '\n\n' +
241          '    {0} may be a `tf.RaggedTensor`.\n'.format(arg_list))
242
243  def handle(self, args, kwargs):
244    if self.is_supported(args, kwargs):
245      return self._ragged_op(*args, **kwargs)
246    else:
247      return self.NOT_SUPPORTED
248
249  def is_supported(self, args, kwargs):
250    found_ragged = False
251    for arg_info in self._ragged_args:
252      if arg_info.position < len(args):
253        arg = args[arg_info.position]
254      else:
255        arg = kwargs.get(arg_info.name, None)
256
257      if arg_info.is_list:
258        if not isinstance(arg, (list, tuple)):
259          return False
260        for elt in arg:
261          if ragged_tensor.is_ragged(elt):
262            found_ragged = True
263          elif not _is_convertible_to_tensor(elt):
264            return False
265      else:
266        if ragged_tensor.is_ragged(arg):
267          found_ragged = True
268        elif not _is_convertible_to_tensor(arg):
269          return False
270    return found_ragged
271
272
273def ragged_dispatch(original_op, tensor_args):
274
275  def decorator(ragged_op):
276    dispatch.RaggedDispatcher(original_op, ragged_op,
277                              tensor_args).register(original_op)
278    return ragged_op
279
280  return decorator
281
282
283_UNARY_ELEMENTWISE_OPS = [
284    array_ops.check_numerics,
285    array_ops.identity,
286    array_ops.ones_like,
287    array_ops.ones_like_v2,
288    array_ops.zeros_like,
289    array_ops.zeros_like_v2,
290    clip_ops.clip_by_value,
291    gen_bitwise_ops.invert,
292    math_ops.abs,
293    math_ops.acos,
294    math_ops.acosh,
295    math_ops.angle,
296    math_ops.asin,
297    math_ops.asinh,
298    math_ops.atan,
299    math_ops.atanh,
300    math_ops.cast,
301    math_ops.ceil,
302    math_ops.conj,
303    math_ops.cos,
304    math_ops.cosh,
305    math_ops.digamma,
306    math_ops.erf,
307    math_ops.erfc,
308    math_ops.exp,
309    math_ops.expm1,
310    math_ops.floor,
311    math_ops.imag,
312    math_ops.is_finite,
313    math_ops.is_inf,
314    math_ops.is_nan,
315    math_ops.lgamma,
316    math_ops.log,
317    math_ops.log1p,
318    math_ops.log_sigmoid,
319    math_ops.logical_not,
320    math_ops.negative,
321    math_ops.real,
322    math_ops.reciprocal,
323    math_ops.rint,
324    math_ops.round,
325    math_ops.rsqrt,
326    math_ops.saturate_cast,
327    math_ops.sign,
328    math_ops.sin,
329    math_ops.sinh,
330    math_ops.sqrt,
331    math_ops.square,
332    math_ops.tan,
333    parsing_ops.decode_compressed,
334    string_ops.string_to_number,
335    string_ops.string_to_hash_bucket,
336    string_ops.as_string,
337    string_ops.decode_base64,
338    string_ops.encode_base64,
339    string_ops.regex_full_match,
340    string_ops.regex_replace,
341    string_ops.string_strip,
342    string_ops.string_to_hash_bucket,
343    string_ops.string_to_hash_bucket_fast,
344    string_ops.string_to_hash_bucket_strong,
345    string_ops.substr,
346    string_ops.substr_v2,
347    string_ops.string_length,
348    string_ops.string_length_v2,
349    string_ops.unicode_script,
350]
351
352_UNARY_LIST_ELEMENTWISE_OPS = [
353    math_ops.add_n,
354    string_ops.string_join,
355]
356
357_BINARY_ELEMENTWISE_OPS = [
358    gen_bitwise_ops.bitwise_and,
359    gen_bitwise_ops.bitwise_or,
360    gen_bitwise_ops.bitwise_xor,
361    gen_bitwise_ops.left_shift,
362    gen_bitwise_ops.right_shift,
363    math_ops.add,
364    math_ops.atan2,
365    math_ops.complex,
366    math_ops.div_no_nan,
367    math_ops.divide,
368    math_ops.equal,
369    math_ops.floordiv,
370    math_ops.floormod,
371    math_ops.greater,
372    math_ops.greater_equal,
373    math_ops.less,
374    math_ops.less_equal,
375    math_ops.logical_and,
376    math_ops.logical_or,
377    math_ops.logical_xor,
378    math_ops.maximum,
379    math_ops.minimum,
380    math_ops.multiply,
381    math_ops.not_equal,
382    math_ops.pow,
383    math_ops.realdiv,
384    math_ops.squared_difference,
385    math_ops.subtract,
386    math_ops.truediv,
387    math_ops.truncatediv,
388    math_ops.truncatemod,
389]
390
391
392# We don't need to register a separate delegation handler for these v1 ops,
393# since they delegate to the v2 ops (which already have a handler).  But we
394# still want to include them in the ragged_op_list() output.
395_V1_OPS_THAT_DELEGATE_TO_V2_OPS = [
396    math_ops.reduce_sum,
397    math_ops.reduce_prod,
398    math_ops.reduce_min,
399    math_ops.reduce_max,
400    math_ops.reduce_mean,
401    math_ops.reduce_any,
402    math_ops.reduce_all,
403]
404
405
406def _ragged_gather_v1(params, indices, validate_indices=None, name=None,
407                      axis=0, batch_dims=0):
408  return ragged_gather_ops.gather(
409      params=params,
410      indices=indices,
411      validate_indices=validate_indices,
412      axis=axis,
413      batch_dims=batch_dims,
414      name=name)
415
416
417def _ragged_expand_dims_v1(input, axis=None, name=None, dim=None):  # pylint: disable=redefined-builtin
418  if dim is not None:
419    axis = dim
420  return ragged_array_ops.expand_dims(input=input, axis=axis, name=name)
421
422
423def _ragged_size_v1(input, name=None, out_type=dtypes.int32):  # pylint: disable=redefined-builtin
424  return ragged_array_ops.size(input=input, out_type=out_type, name=name)
425
426
427# (original_op, ragged_op, ragged_args)
428_RAGGED_DISPATCH_OPS = [
429    (array_ops.batch_gather, ragged_batch_gather_ops.batch_gather,
430     ['params', 'indices']),
431    (array_ops.concat, ragged_concat_ops.concat, ['[values]']),
432    (array_ops.expand_dims, _ragged_expand_dims_v1, ['input']),
433    (array_ops.expand_dims_v2, ragged_array_ops.expand_dims, ['input']),
434    (array_ops.gather, _ragged_gather_v1, ['params', 'indices']),
435    (array_ops.gather_v2, ragged_gather_ops.gather, ['params', 'indices']),
436    (array_ops.gather_nd, ragged_gather_ops.gather_nd, ['params', 'indices']),
437    (array_ops.rank, ragged_array_ops.rank, ['input']),
438    (array_ops.size, _ragged_size_v1, ['input']),
439    (array_ops.size_v2, ragged_array_ops.size, ['input']),
440    (array_ops.stack, ragged_concat_ops.stack, ['[values]']),
441    (array_ops.tile, ragged_array_ops.tile, ['input']),
442    (array_ops.where, ragged_where_op.where, ['condition', 'x', 'y']),
443    (math_ops.unsorted_segment_sum, ragged_math_ops.segment_sum,
444     ['data', 'segment_ids']),
445    (math_ops.unsorted_segment_prod, ragged_math_ops.segment_prod,
446     ['data', 'segment_ids']),
447    (math_ops.unsorted_segment_min, ragged_math_ops.segment_min,
448     ['data', 'segment_ids']),
449    (math_ops.unsorted_segment_max, ragged_math_ops.segment_max,
450     ['data', 'segment_ids']),
451    (math_ops.unsorted_segment_mean, ragged_math_ops.segment_mean,
452     ['data', 'segment_ids']),
453    (math_ops.unsorted_segment_sqrt_n, ragged_math_ops.segment_sqrt_n,
454     ['data', 'segment_ids']),
455    (math_ops.reduce_sum, ragged_math_ops.reduce_sum, ['input_tensor']),
456    (math_ops.reduce_prod, ragged_math_ops.reduce_prod, ['input_tensor']),
457    (math_ops.reduce_min, ragged_math_ops.reduce_min, ['input_tensor']),
458    (math_ops.reduce_max, ragged_math_ops.reduce_max, ['input_tensor']),
459    (math_ops.reduce_mean, ragged_math_ops.reduce_mean, ['input_tensor']),
460    (math_ops.reduce_any, ragged_math_ops.reduce_any, ['input_tensor']),
461    (math_ops.reduce_all, ragged_math_ops.reduce_all, ['input_tensor']),
462]
463
464
465def register_dispatchers():
466  """Constructs & registers OpDispatchers for ragged ops."""
467
468  op_list = (
469      _UNARY_ELEMENTWISE_OPS + _UNARY_LIST_ELEMENTWISE_OPS +
470      _BINARY_ELEMENTWISE_OPS + [x[0] for x in _RAGGED_DISPATCH_OPS])
471  for op in op_list:
472    _, undecorated_op = tf_decorator.unwrap(op)
473    if not hasattr(undecorated_op,
474                   tf_export.API_ATTRS[tf_export.TENSORFLOW_API_NAME].names):
475      raise AssertionError('Expected %s to be an exported symbol '
476                           '(while adding a RaggedTensor dispatcher)')
477
478  for op in _UNARY_ELEMENTWISE_OPS:
479    UnaryRaggedElementwiseDispatcher(op).register(op)
480
481  for op in _UNARY_LIST_ELEMENTWISE_OPS:
482    UnaryRaggedElementwiseDispatcher(op, True).register(op)
483
484  for op in _BINARY_ELEMENTWISE_OPS:
485    BinaryRaggedElementwiseDispatcher(op).register(op)
486
487  for (original_op, ragged_op, args) in _RAGGED_DISPATCH_OPS:
488    RaggedDispatcher(original_op, ragged_op, args).register(original_op)
489
490
491def _ragged_op_signature(op, ragged_args):
492  """Returns a signature for the given op, marking ragged args in bold."""
493  op_name = tf_export.get_canonical_name_for_symbol(op)
494  argspec = tf_inspect.getfullargspec(op)
495  arg_names = argspec.args
496
497  # Mark ragged arguments in bold.
498  for pos in ragged_args:
499    arg_names[pos] = '**' + arg_names[pos] + '**'
500
501  # Add argument defaults.
502  for pos in range(-1, -len(argspec.defaults) - 1, -1):
503    arg_names[pos] += '=`{!r}`'.format(argspec.defaults[pos])
504
505  # Add varargs and keyword args
506  if argspec.varargs:
507    arg_names.append('*' + argspec.varargs)
508  if argspec.varkw:
509    arg_names.append('**' + argspec.varkw)
510
511  return '* `tf.{}`({})'.format(op_name, ', '.join(arg_names))
512
513
514def _op_is_in_tf_version(op, version):
515  if version == 1:
516    return (tf_export.get_v1_names(tf_decorator.unwrap(op)[1]) or
517            op in _V1_OPS_THAT_DELEGATE_TO_V2_OPS)
518  elif version == 2:
519    return tf_export.get_v2_names(tf_decorator.unwrap(op)[1])
520  else:
521    raise ValueError('Expected version 1 or 2.')
522
523
524def ragged_op_list(tf_version=1):
525  """Returns a string listing operators that have dispathers registered."""
526  lines = []
527  for op in _UNARY_ELEMENTWISE_OPS + _UNARY_LIST_ELEMENTWISE_OPS:
528    if _op_is_in_tf_version(op, tf_version):
529      lines.append(_ragged_op_signature(op, [0]))
530  for op in _BINARY_ELEMENTWISE_OPS:
531    if _op_is_in_tf_version(op, tf_version):
532      lines.append(_ragged_op_signature(op, [0, 1]))
533  for op, _, ragged_args in _RAGGED_DISPATCH_OPS:
534    if _op_is_in_tf_version(op, tf_version):
535      arginfos = _get_arg_infos(op, ragged_args)
536      ragged_args = [arginfo.position for arginfo in arginfos]
537      lines.append(_ragged_op_signature(op, ragged_args))
538  return ('\n\n### Additional ops that support `RaggedTensor`\n\n'
539          'Arguments that accept `RaggedTensor`s are marked in **bold**.\n\n' +
540          '\n'.join(sorted(lines)) + 'n')
541
542
543register_dispatchers()
544