• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Support for ragged tensors."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import tensor_util
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import check_ops
28from tensorflow.python.ops import gen_ragged_math_ops
29from tensorflow.python.ops import map_fn
30from tensorflow.python.ops import math_ops
31from tensorflow.python.ops.ragged import ragged_functional_ops
32from tensorflow.python.ops.ragged import ragged_tensor
33from tensorflow.python.ops.ragged import segment_id_ops
34from tensorflow.python.util import dispatch
35from tensorflow.python.util.tf_export import tf_export
36
37
38#===============================================================================
39# ragged.range
40#===============================================================================
41# pylint: disable=redefined-builtin
42@tf_export('ragged.range')
43@dispatch.add_dispatch_support
44def range(starts,
45          limits=None,
46          deltas=1,
47          dtype=None,
48          name=None,
49          row_splits_dtype=dtypes.int64):
50  """Returns a `RaggedTensor` containing the specified sequences of numbers.
51
52  Each row of the returned `RaggedTensor` contains a single sequence:
53
54  ```python
55  ragged.range(starts, limits, deltas)[i] ==
56      tf.range(starts[i], limits[i], deltas[i])
57  ```
58
59  If `start[i] < limits[i] and deltas[i] > 0`, then `output[i]` will be an
60  empty list.  Similarly, if `start[i] > limits[i] and deltas[i] < 0`, then
61  `output[i]` will be an empty list.  This behavior is consistent with the
62  Python `range` function, but differs from the `tf.range` op, which returns
63  an error for these cases.
64
65  Examples:
66
67  >>> tf.ragged.range([3, 5, 2]).to_list()
68  [[0, 1, 2], [0, 1, 2, 3, 4], [0, 1]]
69  >>> tf.ragged.range([0, 5, 8], [3, 3, 12]).to_list()
70  [[0, 1, 2], [], [8, 9, 10, 11]]
71  >>> tf.ragged.range([0, 5, 8], [3, 3, 12], 2).to_list()
72  [[0, 2], [], [8, 10]]
73
74  The input tensors `starts`, `limits`, and `deltas` may be scalars or vectors.
75  The vector inputs must all have the same size.  Scalar inputs are broadcast
76  to match the size of the vector inputs.
77
78  Args:
79    starts: Vector or scalar `Tensor`.  Specifies the first entry for each range
80      if `limits` is not `None`; otherwise, specifies the range limits, and the
81      first entries default to `0`.
82    limits: Vector or scalar `Tensor`.  Specifies the exclusive upper limits for
83      each range.
84    deltas: Vector or scalar `Tensor`.  Specifies the increment for each range.
85      Defaults to `1`.
86    dtype: The type of the elements of the resulting tensor.  If not specified,
87      then a value is chosen based on the other args.
88    name: A name for the operation.
89    row_splits_dtype: `dtype` for the returned `RaggedTensor`'s `row_splits`
90      tensor.  One of `tf.int32` or `tf.int64`.
91
92  Returns:
93    A `RaggedTensor` of type `dtype` with `ragged_rank=1`.
94  """
95  row_splits_dtype = dtypes.as_dtype(row_splits_dtype)
96  if limits is None:
97    starts, limits = 0, starts
98
99  with ops.name_scope(name, 'RaggedRange', [starts, limits, deltas]) as name:
100    starts = ops.convert_to_tensor(starts, dtype=dtype, name='starts')
101    limits = ops.convert_to_tensor(limits, dtype=dtype, name='limits')
102    deltas = ops.convert_to_tensor(deltas, dtype=dtype, name='deltas')
103
104    # infer dtype if not explicitly provided
105    if dtype is None:
106      starts, limits, deltas = _infer_matching_dtype(
107          [starts, limits, deltas],
108          [dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64])
109
110    result = gen_ragged_math_ops.ragged_range(
111        starts, limits, deltas, Tsplits=row_splits_dtype, name=name)
112    return ragged_tensor.RaggedTensor.from_row_splits(
113        result.rt_dense_values, result.rt_nested_splits, validate=False)
114
115
116def _infer_matching_dtype(tensors, dtype_hierarchy):
117  """Infers a matching dtype for tensors, and casts them to that dtype."""
118  assert all(t.dtype in dtype_hierarchy for t in tensors)
119  inferred_dtype = max([t.dtype for t in tensors], key=dtype_hierarchy.index)
120  return [math_ops.cast(t, inferred_dtype) for t in tensors]
121
122
123ops.no_gradient('RaggedRange')
124
125#===============================================================================
126# ragged_segment_<AGGREGATE>
127#===============================================================================
128
129# Docstring template used for the raggged_segment_<AGGREGATE> ops.
130_RAGGED_SEGMENT_DOCSTRING = """\
131Computes the %(combination)s along segments of a RaggedTensor.
132
133  Returns a RaggedTensor `output` with `num_segments` rows, where the row
134  `output[i]` is formed by taking the %(combination)s of all rows of `data`
135  whose corresponding `segment_id` is `i`.
136
137  The length of the row `output[i]` will be the maximum of the lengths of
138  all rows of `data` whose corresponding `segment_id` is `i`.  If no `data`
139  rows correspond to a given segment ID, then the output row for that segment
140  ID will be empty.
141
142  Args:
143    data: A `RaggedTensor` containing the values to combine.
144    segment_ids: A `Tensor` or `RaggedTensor`.  Must have type `int64` or
145      `int32`.  `segment_ids.shape` must be a prefix of `data.shape`.
146      Must be greater than or equal to zero, and less than `num_segments`.
147      `segment_ids` is not required to be sorted.
148    num_segments: An `int32` or `int64` scalar specifying the number of
149      distinct segment ids.
150    name: A name prefix for the returned tensor (optional).
151  Returns:
152    A `RaggedTensor` containing the %(combined)s values.  The returned tensor
153    has the same dtype as `data`, and its shape is
154    `[num_segments] + data.shape[segment_ids.rank:]`.
155  Raises:
156    ValueError: If `segment_ids.shape` is not a prefix of `data.shape`.
157"""
158
159
160def _ragged_segment_aggregate(unsorted_segment_op,
161                              data,
162                              segment_ids,
163                              num_segments,
164                              separator=None,
165                              name=None):
166  """Aggregates along segments of a RaggedTensor using `unsorted_segment_op`.
167
168  Returns a RaggedTensor `output` with `num_segments` rows, where the row
169  `output[i]` is formed by combining all rows of `data` whose corresponding
170  `segment_id` is `i`.  The values in each row are combined using
171  `unsorted_segment_op`.
172
173  The length of the row `output[i]` will be the maximum of the lengths of
174  all rows of `data` whose corresponding `segment_id` is `i`.  If no `data`
175  rows correspond to a given segment ID, then the output row for that segment
176  ID will be empty.
177
178  Args:
179    unsorted_segment_op: The tensorflow `op` that should be used to combine
180      values in each row.  Must have the same signature and basic behavior as
181      `unsorted_segment_sum`, `unsorted_segment_max`, etc.
182    data: A `RaggedTensor` containing the values to be combined.
183    segment_ids: A `Tensor` or `RaggedTensor`.  Must have type `int64` or
184      `int32`.  `segment_ids.shape` must be a prefix of `data.shape`.
185      `segment_ids` is not required to be sorted.
186    num_segments: An `int32` or `int64` scalar.
187    separator: An optional string. Defaults to None. The separator to use when
188      joining. Only used for string types.
189    name: A name prefix for the returned tensor (optional).
190
191  Returns:
192    A `RaggedTensor` containing the aggregated values.  The returned tensor
193    has the same dtype as `data`, and its shape is
194    `[num_segments] + data.shape[segment_ids.rank:]`.
195  Raises:
196    ValueError: If segment_ids.shape is not a prefix of data.shape.
197  """
198  if not (ragged_tensor.is_ragged(data) or
199          ragged_tensor.is_ragged(segment_ids)):
200    if separator is not None:
201      # It uses unsorted_segment_join.
202      return unsorted_segment_op(data, segment_ids, num_segments, separator,
203                                 name)
204    else:
205      return unsorted_segment_op(data, segment_ids, num_segments, name)
206
207  with ops.name_scope(name, 'RaggedSegment',
208                      [data, segment_ids, num_segments]) as name:
209    data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data')
210    segment_ids = ragged_tensor.convert_to_tensor_or_ragged_tensor(
211        segment_ids, name='segment_ids')
212    data, segment_ids = ragged_tensor.match_row_splits_dtypes(data, segment_ids)
213    if segment_ids.dtype not in (dtypes.int32, dtypes.int64):
214      raise ValueError('segment_ids must have dtype int32 or int64.')
215
216    if ragged_tensor.is_ragged(segment_ids):
217      if not ragged_tensor.is_ragged(data):
218        raise ValueError('segment_ids.shape must be a prefix of data.shape, '
219                         'but segment_ids is ragged and data is not.')
220      check_splits = check_ops.assert_equal(
221          segment_ids.row_splits,
222          data.row_splits,
223          message='segment_ids.shape must be a prefix of data.shape')
224      with ops.control_dependencies([check_splits]):
225        return _ragged_segment_aggregate(unsorted_segment_op, data.values,
226                                         segment_ids.values, num_segments,
227                                         separator)
228
229    # Find the length of each row in data.  (shape=[data_nrows])
230    data_row_lengths = data.row_splits[1:] - data.row_splits[:-1]
231
232    # Find the length that each output row will have.  The length of the row
233    # corresponding to segment `id` is `max(data_row_lengths[i])` where
234    # `segment_ids[i]=id`.  (shape=[output_nrows])
235    output_row_lengths = math_ops.maximum(
236        math_ops.unsorted_segment_max(data_row_lengths, segment_ids,
237                                      num_segments), 0)
238
239    # Build the splits tensor for the output RaggedTensor.
240    output_splits = array_ops.concat([
241        array_ops.zeros([1], output_row_lengths.dtype),
242        math_ops.cumsum(output_row_lengths)
243    ],
244                                     axis=0)
245
246    # For each row in `data`, find the start & limit position where that row's
247    # values will be aggregated in output.values.
248    data_row_to_out_row_start = array_ops.gather(output_splits, segment_ids)
249    data_row_to_out_row_limit = data_row_to_out_row_start + data_row_lengths
250
251    # For each value in `data.values`, find the position where it will
252    # aggregated in `output.values`.
253    # Get the target output values index for each data values index.
254    data_val_to_out_val_index = range(data_row_to_out_row_start,
255                                      data_row_to_out_row_limit).values
256
257    # Recursively aggregate the values.
258    output_values = _ragged_segment_aggregate(unsorted_segment_op, data.values,
259                                              data_val_to_out_val_index,
260                                              output_splits[-1], separator)
261    return ragged_tensor.RaggedTensor.from_row_splits(
262        output_values, output_splits, validate=False)
263
264
265def segment_sum(data, segment_ids, num_segments, name=None):
266  # For docs, see: _RAGGED_SEGMENT_DOCSTRING
267  return _ragged_segment_aggregate(
268      math_ops.unsorted_segment_sum,
269      data=data,
270      segment_ids=segment_ids,
271      num_segments=num_segments,
272      name=(name or 'RaggedSegmentSum'))
273
274
275def segment_prod(data, segment_ids, num_segments, name=None):
276  # For docs, see: _RAGGED_SEGMENT_DOCSTRING
277  return _ragged_segment_aggregate(
278      math_ops.unsorted_segment_prod,
279      data=data,
280      segment_ids=segment_ids,
281      num_segments=num_segments,
282      name=(name or 'RaggedSegmentProd'))
283
284
285def segment_min(data, segment_ids, num_segments, name=None):
286  # For docs, see: _RAGGED_SEGMENT_DOCSTRING
287  return _ragged_segment_aggregate(
288      math_ops.unsorted_segment_min,
289      data=data,
290      segment_ids=segment_ids,
291      num_segments=num_segments,
292      name=(name or 'RaggedSegmentMin'))
293
294
295def segment_max(data, segment_ids, num_segments, name=None):
296  # For docs, see: _RAGGED_SEGMENT_DOCSTRING
297  return _ragged_segment_aggregate(
298      math_ops.unsorted_segment_max,
299      data=data,
300      segment_ids=segment_ids,
301      num_segments=num_segments,
302      name=(name or 'RaggedSegmentMax'))
303
304
305def segment_mean(data, segment_ids, num_segments, name=None):
306  """For docs, see: _RAGGED_SEGMENT_DOCSTRING."""
307  with ops.name_scope(name, 'RaggedSegmentMean',
308                      [data, segment_ids, num_segments]):
309    total = segment_sum(data, segment_ids, num_segments)
310    ones = ragged_tensor.RaggedTensor.from_nested_row_splits(
311        array_ops.ones_like(data.flat_values),
312        data.nested_row_splits,
313        validate=False)
314    count = segment_sum(ones, segment_ids, num_segments)
315    if ragged_tensor.is_ragged(total):
316      return total.with_flat_values(total.flat_values / count.flat_values)
317    else:
318      return total / count
319
320
321def segment_sqrt_n(data, segment_ids, num_segments, name=None):
322  """For docs, see: _RAGGED_SEGMENT_DOCSTRING."""
323  with ops.name_scope(name, 'RaggedSegmentSqrtN',
324                      [data, segment_ids, num_segments]):
325    total = segment_sum(data, segment_ids, num_segments)
326    ones = ragged_tensor.RaggedTensor.from_nested_row_splits(
327        array_ops.ones_like(data.flat_values),
328        data.nested_row_splits,
329        validate=False)
330    count = segment_sum(ones, segment_ids, num_segments)
331    if ragged_tensor.is_ragged(total):
332      return total.with_flat_values(total.flat_values /
333                                    math_ops.sqrt(count.flat_values))
334    else:
335      return total / math_ops.sqrt(count)
336
337
338def _set_ragged_segment_docstring(func, combination, combined):
339  func.__doc__ = _RAGGED_SEGMENT_DOCSTRING % dict(
340      combination=combination, combined=combined)
341
342
343_set_ragged_segment_docstring(segment_sum, 'sum', 'summed')
344_set_ragged_segment_docstring(segment_prod, 'product', 'multiplied')
345_set_ragged_segment_docstring(segment_min, 'minimum', 'minimized')
346_set_ragged_segment_docstring(segment_max, 'maximum', 'maximized')
347_set_ragged_segment_docstring(segment_mean, 'mean', 'averaged')
348_set_ragged_segment_docstring(segment_sqrt_n, 'sum divided by sqrt(N)',
349                              'summed')
350
351#===============================================================================
352# ragged_reduce_<AGGREGATE>
353#===============================================================================
354
355# Docstring template used for ragged_reduce_<AGGREGATE> ops.
356_RAGGED_REDUCE_DOCSTRING = """\
357Computes the %(combination)s of elements across dimensions of a `RaggedTensor`.
358
359  Reduces `input_tensor` along the dimensions given in `axis` by taking the
360  %(combination)s of values.  If a reduced dimension has no elements for
361  some index, then the value for that index will be %(default)s.
362
363  The rank of the tensor is reduced by `1` for each entry in `axis`.  If
364  `axis` is not specified, then all dimensions are reduced, and a scalar
365  value is returned.
366  Args:
367    input_tensor: A `RaggedTensor` containing the values to be %(combined)s.
368    axis: The dimensions to reduce.  May be `None` (to reduce all axes), an
369      `int` (to reduce a single axis), a `list` or `tuple` of `int` (to reduce
370      a given set of axes), or a `Tensor` with a constant value.  Must be in
371      the range `[0, input_tensor.rank]`.
372    name: A name prefix for the returned tensor (optional).
373  Returns:
374    A `RaggedTensor` containing the %(combined)s values.  The returned tensor
375    has the same dtype as `data`, and its shape is given by removing the
376    dimensions specified in `axis` from `input_tensor.shape`.  The `ragged_rank`
377    of the returned tensor is given by substracting any ragged dimensions
378    specified in `axis` from `input_tensor.ragged_rank`.
379  Raises:
380    ValueError: If `axis` contains a `Tensor` whose value is not constant.
381  ####Example:
382    %(example)s
383"""
384_RAGGED_REDUCE_SUM_EXAMPLE = """
385    >>> rt = tf.ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]])
386    >>> tf.reduce_sum(rt, axis=0).numpy()  # = [3+1+9+2, 1+5+6, 4]
387    array([15, 12, 4], dtype=int32)
388    >>> tf.reduce_sum(rt, axis=1).numpy()  # = [3+1+4, 1+5, 9, 2+6]
389    array([8, 6, 9, 8], dtype=int32)
390"""
391_RAGGED_REDUCE_PROD_EXAMPLE = """
392    >>> rt = tf.ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]])
393    >>> tf.reduce_prod(rt, axis=0).numpy()  # = [3*1*9*2, 1*5*6, 4]
394    array([54, 30, 4], dtype=int32)
395    >>> tf.reduce_prod(rt, axis=1).numpy()  # = [3*1*4, 1*5, 9, 2*6]
396    array([12, 5, 9, 12], dtype=int32)
397"""
398_RAGGED_REDUCE_MIN_EXAMPLE = """
399    >>> rt = tf.ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]])
400    >>> tf.reduce_min(rt, axis=0).numpy()
401    array([1, 1, 4], dtype=int32)
402    >>> tf.reduce_min(rt, axis=1).numpy()
403    array([1, 1, 9, 2], dtype=int32)
404"""
405_RAGGED_REDUCE_MAX_EXAMPLE = """
406    >>> rt = tf.ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]])
407    >>> tf.reduce_max(rt, axis=0).numpy()
408    array([9, 6, 4], dtype=int32)
409    >>> tf.reduce_max(rt, axis=1).numpy()
410    array([4, 5, 9, 6], dtype=int32)
411"""
412_RAGGED_REDUCE_MEAN_EXAMPLE = """
413    >>> rt = tf.ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]])
414    >>> tf.reduce_mean(rt, axis=0).numpy()
415    array([3.75, 4.  , 4. ])
416    >>> tf.reduce_mean(rt, axis=1).numpy()
417    array([2.66666667, 3.  , 9.  , 4.  ])
418"""
419_RAGGED_REDUCE_VARIANCE_EXAMPLE = """
420    >>> rt = tf.ragged.constant([[1, 1, 4], [2, 1], [3], [4, 1]],
421    ...                         dtype=tf.float64)
422    >>> tf.math.reduce_variance(rt, axis=0).numpy()
423    array([1.25, 0., 0.])
424    >>> tf.math.reduce_variance(rt, axis=1).numpy()
425    array([2., 0.25, 0., 2.25])
426"""
427_RAGGED_REDUCE_STD_EXAMPLE = """
428    >>> rt = tf.ragged.constant([[1, 0], [2, 1], [3], [4, 1]],
429    ...                         dtype=tf.float64)
430    >>> tf.math.reduce_std(rt, axis=0).numpy()
431    array([1.11803399, 0.47140452])
432    >>> tf.math.reduce_std(rt, axis=1).numpy()
433    array([0.5, 0.5, 0., 1.5])
434"""
435_RAGGED_REDUCE_ALL_EXAMPLE = """
436    >>> rt = tf.ragged.constant([[True, True], [True, True, False, True], [False, True]])
437    >>> tf.reduce_all(rt, axis=0).numpy()
438    array([False,  True, False,  True])
439    >>> tf.reduce_all(rt, axis=1).numpy()
440    array([ True, False, False])
441"""
442_RAGGED_REDUCE_ANY_EXAMPLE = """
443    >>> rt = tf.ragged.constant([[True, True], [True, True, False, True], [False, True]])
444    >>> tf.reduce_any(rt, axis=0).numpy()
445    array([ True,  True, False,  True])
446    >>> tf.reduce_any(rt, axis=1).numpy()
447    array([ True,  True,  True])
448"""
449
450
451def ragged_reduce_aggregate(reduce_op,
452                            unsorted_segment_op,
453                            rt_input,
454                            axis,
455                            keepdims,
456                            separator=None,
457                            name=None):
458  """Aggregates across axes of a RaggedTensor using the given `Tensor` ops.
459
460  Reduces `rt_input` along the dimensions given in `axis`.  The rank of the
461  tensor is reduced by 1 for each entry in `axis`.  If `axis` is not specified,
462  then all dimensions are reduced, and a scalar value is returned.
463
464  This op assumes that `reduce_op` and `unsorted_segment_op` are associative;
465  if not, then reducing multiple axes will return incorrect results.  (In
466  particular, reducing multiple axes is currently implemented by reducing the
467  axes one at a time.)
468
469  Args:
470    reduce_op: The tensorflow `op` that should be used to reduce values in
471      uniform dimensions.  Must have the same signature and basic behavior as
472      `reduce_sum`, `reduce_max`, etc.
473    unsorted_segment_op: The tensorflow `op` that should be used to combine
474      values in ragged dimensions.  Must have the same signature and basic
475      behavior as `unsorted_segment_sum`, `unsorted_segment_max`, etc.
476    rt_input: A `Tensor` or `RaggedTensor` containing the values to be reduced.
477    axis: The axis or axes to reduce.  May be `None` (to reduce all axes), an
478      `int` (to reduce a single axis), a `list` or `tuple` of `int` (to reduce a
479      given set of axes), or a `Tensor` with a constant value.  Must be in the
480      range `[0, rt_input.rank)`.
481    keepdims: If true, retains reduced dimensions with length 1.
482    separator: An optional string. Defaults to None. The separator to use when
483      joining. The separator must not be set for non-string data types. (i.e. if
484      separator is not None then it uses string ops)
485    name: A name prefix for the returned tensor (optional).
486
487  Returns:
488    A `RaggedTensor` containing the reduced values.  The returned tensor
489    has the same dtype as `data`, and its shape is given by removing the
490    dimensions specified in `axis` from `rt_input.shape`.  The `ragged_rank`
491    of the returned tensor is given by substracting any ragged dimensions
492    specified in `axis` from `rt_input.ragged_rank`.
493  Raises:
494    ValueError: If `axis` contains a `Tensor` whose value is not constant.
495  """
496  if not ragged_tensor.is_ragged(rt_input):
497    if separator is None:
498      return reduce_op(rt_input, axis, keepdims=keepdims, name=name)
499    else:
500      # When separator is not None, We infer that dtype is string and
501      # reduce_join will be called.
502      return reduce_op(
503          rt_input, axis, keepdims=keepdims, name=name, separator=separator)
504
505  if isinstance(axis, ops.Tensor):
506    axis = tensor_util.constant_value(axis)
507    if axis is None:
508      raise ValueError('axis must be known at graph construction time.')
509    if isinstance(axis, np.ndarray):
510      axis = axis.tolist()
511
512  # When reducing all axes, just ignore splits & reduce the inner values.
513  if axis is None:
514    result = reduce_op(rt_input.flat_values, None, keepdims=keepdims, name=name)
515    if keepdims:
516      # Expand the result to the input number of dimensions.
517      for _ in rt_input.shape[1:]:
518        result = array_ops.expand_dims(result, axis=0)
519    return result
520
521  with ops.name_scope(name, 'RaggedReduce', [rt_input, axis]):
522    if isinstance(axis, (tuple, list)):
523      if not axis:
524        return rt_input
525      elif len(axis) == 1:
526        axis = axis[0]
527      else:
528        # When reducing multiple axes, as we reduce one at a time (see below),
529        # the negative axis has to be converted to positive at the first run
530        # as the sort with negative axis will have different orders.
531        # See GitHub issue 27497.
532        axis = [
533            array_ops.get_positive_axis(a, rt_input.shape.ndims, 'axis[%s]' % i,
534                                        'rank(input_tensor)')
535            for i, a in enumerate(axis)
536        ]
537        # When reducing multiple axes, just reduce one at a time.  This is less
538        # efficient, and only works for associative ops.  (In particular, it
539        # does not work for reduce_mean.)  However, reducing multiple axes at
540        # once will probably require a nontrivial c++ op.
541        axis = sorted(axis)
542        inner_reduced = ragged_reduce_aggregate(reduce_op, unsorted_segment_op,
543                                                rt_input, axis[-1], keepdims,
544                                                separator)
545        return ragged_reduce_aggregate(reduce_op, unsorted_segment_op,
546                                       inner_reduced, axis[:-1], keepdims,
547                                       separator)
548
549    rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(
550        rt_input, name='rt_input')
551
552    axis = array_ops.get_positive_axis(
553        axis, rt_input.shape.ndims, ndims_name='rank(input_tensor)')
554
555    if axis == 0:
556      # out[i_1, i_2, ..., i_N] = sum_{j} rt_input[j, i_1, i_2, ..., i_N]
557      row_lengths = rt_input.row_splits[1:] - rt_input.row_splits[:-1]
558      num_segments = math_ops.maximum(math_ops.reduce_max(row_lengths), 0)
559      segment_ids = range(row_lengths).values
560      result = _ragged_segment_aggregate(unsorted_segment_op, rt_input.values,
561                                         segment_ids, num_segments, separator)
562      if keepdims:
563        result = array_ops.expand_dims(result, axis=0)
564      return result
565    elif axis == 1:
566      # out[i_0, i_1, i_2, ..., i_N] = sum_{j} rt_input[i_0, j, i_2, ..., i_N]
567      num_segments = array_ops.shape(rt_input.row_splits)[0] - 1
568      segment_ids = segment_id_ops.row_splits_to_segment_ids(
569          rt_input.row_splits)
570      result = _ragged_segment_aggregate(unsorted_segment_op, rt_input.values,
571                                         segment_ids, num_segments, separator)
572      if keepdims:
573        result = array_ops.expand_dims(result, axis=1)
574      return result
575    else:
576      # out[i_0, ..., i_[axis-1], i_axis+1], ..., i_N] =
577      #     sum_{j} rt_input [i_0, ..., i_[axis-1], j, i_axis+1], ..., i_N]
578      return rt_input.with_values(
579          ragged_reduce_aggregate(reduce_op, unsorted_segment_op,
580                                  rt_input.values, axis - 1, keepdims,
581                                  separator))
582
583
584def reduce_sum(input_tensor, axis=None, keepdims=None, name=None):
585  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
586
587  return ragged_reduce_aggregate(
588      reduce_op=math_ops.reduce_sum,
589      unsorted_segment_op=math_ops.unsorted_segment_sum,
590      rt_input=input_tensor,
591      axis=axis,
592      keepdims=keepdims,
593      name=(name or 'RaggedReduceSum'))
594
595
596def reduce_prod(input_tensor, axis=None, keepdims=None, name=None):
597  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
598  return ragged_reduce_aggregate(
599      reduce_op=math_ops.reduce_prod,
600      unsorted_segment_op=math_ops.unsorted_segment_prod,
601      rt_input=input_tensor,
602      axis=axis,
603      keepdims=keepdims,
604      name=(name or 'RaggedReduceProd'))
605
606
607def reduce_min(input_tensor, axis=None, keepdims=None, name=None):
608  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
609  return ragged_reduce_aggregate(
610      reduce_op=math_ops.reduce_min,
611      unsorted_segment_op=math_ops.unsorted_segment_min,
612      rt_input=input_tensor,
613      axis=axis,
614      keepdims=keepdims,
615      name=(name or 'RaggedReduceMin'))
616
617
618def reduce_max(input_tensor, axis=None, keepdims=None, name=None):
619  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
620  return ragged_reduce_aggregate(
621      reduce_op=math_ops.reduce_max,
622      unsorted_segment_op=math_ops.unsorted_segment_max,
623      rt_input=input_tensor,
624      axis=axis,
625      keepdims=keepdims,
626      name=(name or 'RaggedReduceMax'))
627
628
629def reduce_mean(input_tensor, axis=None, keepdims=None, name=None):
630  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
631  with ops.name_scope(name, 'RaggedReduceMean', [input_tensor, axis]):
632    total = reduce_sum(input_tensor, axis, keepdims)
633    if ragged_tensor.is_ragged(input_tensor):
634      ones = ragged_tensor.RaggedTensor.from_nested_row_splits(
635          array_ops.ones_like(input_tensor.flat_values),
636          input_tensor.nested_row_splits,
637          validate=False)
638    else:
639      ones = array_ops.ones_like(input_tensor)
640    count = reduce_sum(ones, axis, keepdims)
641    if ragged_tensor.is_ragged(total):
642      return ragged_tensor.RaggedTensor.from_nested_row_splits(
643          total.flat_values / count.flat_values,
644          total.nested_row_splits,
645          validate=False)
646    else:
647      return total / count
648
649
650def reduce_variance(input_tensor, axis=None, keepdims=False, name=None):
651  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
652  with ops.name_scope(name, 'RaggedReduceVariance', [input_tensor, axis]):
653    square_of_input = math_ops.square(input_tensor)
654    mean_of_square = reduce_mean(square_of_input, axis=axis, keepdims=keepdims)
655    mean = reduce_mean(input_tensor, axis=axis, keepdims=keepdims)
656    square_of_mean = math_ops.square(mean)
657    return mean_of_square - square_of_mean
658
659
660def reduce_std(input_tensor, axis=None, keepdims=False, name=None):
661  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
662  with ops.name_scope(name, 'RaggedReduceStd', [input_tensor, axis]):
663    variance = reduce_variance(input_tensor, axis=axis, keepdims=keepdims)
664    return math_ops.sqrt(variance)
665
666
667def _cast(input_tensor, dtype):
668  return ragged_functional_ops.map_flat_values(math_ops.cast, input_tensor,
669                                               dtype)
670
671
672def reduce_all(input_tensor, axis=None, keepdims=None, name=None):
673  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
674  with ops.name_scope(name, 'RaggedReduceAll', [input_tensor, axis]):
675    return _cast(
676        reduce_prod(_cast(input_tensor, dtypes.int32), axis, keepdims),
677        dtypes.bool)
678
679
680def reduce_any(input_tensor, axis=None, keepdims=None, name=None):
681  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
682  with ops.name_scope(name, 'RaggedReduceAny', [input_tensor, axis]):
683    return _cast(
684        reduce_sum(_cast(input_tensor, dtypes.int32), axis, keepdims),
685        dtypes.bool)
686
687
688def _set_ragged_reduce_docstring(func, combination, combined, default, example):
689  func.__doc__ = _RAGGED_REDUCE_DOCSTRING % dict(
690      combination=combination,
691      combined=combined,
692      default=default,
693      example=example)
694
695
696_set_ragged_reduce_docstring(reduce_sum, 'sum', 'summed', '0',
697                             _RAGGED_REDUCE_SUM_EXAMPLE)
698_set_ragged_reduce_docstring(reduce_prod, 'product', 'multiplied', '1',
699                             _RAGGED_REDUCE_PROD_EXAMPLE)
700_set_ragged_reduce_docstring(reduce_min, 'minimum', 'minimized',
701                             '`input_tensor.dtype.min`',
702                             _RAGGED_REDUCE_MIN_EXAMPLE)
703_set_ragged_reduce_docstring(reduce_max, 'maximum', 'maximized',
704                             '`input_tensor.dtype.max`',
705                             _RAGGED_REDUCE_MAX_EXAMPLE)
706_set_ragged_reduce_docstring(reduce_mean, 'mean', 'averaged', 'NaN',
707                             _RAGGED_REDUCE_MEAN_EXAMPLE)
708_set_ragged_reduce_docstring(reduce_variance, 'variance', 'averaged', 'NaN',
709                             _RAGGED_REDUCE_VARIANCE_EXAMPLE)
710_set_ragged_reduce_docstring(reduce_std, 'std', 'averaged', 'NaN',
711                             _RAGGED_REDUCE_STD_EXAMPLE)
712_set_ragged_reduce_docstring(reduce_all, 'logical and', 'and-ed', 'True',
713                             _RAGGED_REDUCE_ALL_EXAMPLE)
714_set_ragged_reduce_docstring(reduce_any, 'logical or', 'or-ed', 'False',
715                             _RAGGED_REDUCE_ANY_EXAMPLE)
716
717
718#===============================================================================
719# ragged.matmul
720#===============================================================================
721def matmul(a,
722           b,
723           transpose_a=False,
724           transpose_b=False,
725           adjoint_a=False,
726           adjoint_b=False,
727           a_is_sparse=False,
728           b_is_sparse=False,
729           output_type=None,
730           name=None):
731  """Multiplies matrix `a` by matrix `b`.
732
733  If all transpose or adjoint attributes are `False` then:
734
735  ```
736  output[..., i, j] = sum_k (a[..., i, k] * b[..., k, j]), for all indices i, j.
737  ```
738
739  The inputs `a` and `b` must have `rank >= 2`, where the outermost `rank - 2`
740  dimensions are batch dimensions.  The inputs must have the same dtype.  See
741  `tf.matmul` for more information.
742
743  Args:
744    a: `tf.Tensor` or `RaggedTensor` with `rank > 1`.
745    b: `tf.Tensor` or `RaggedTensor` with same type and rank as `a`.
746    transpose_a: If `True`, `a` is transposed before multiplication.
747    transpose_b: If `True`, `b` is transposed before multiplication.
748    adjoint_a: If `True`, `a` is conjugated & transposed before multiplication.
749    adjoint_b: If `True`, `b` is conjugated & transposed before multiplication.
750    a_is_sparse: If `True`, optimize assuming `a` is mostly zero.
751    b_is_sparse: If `True`, optimize assuming `b` is mostly zero.
752    output_type: The output datatype (optional).
753    name: Name for the operation (optional).
754
755  Returns:
756    A `Tensor` or `RaggedTensor` with the same rank and shape as `a`, where
757    each inner-most matrix is the product of the corresponding matrices in `a`
758    and `b`.
759  """
760  if transpose_a and adjoint_a:
761    raise ValueError('Only one of transpose_a and adjoint_a can be True.')
762  if transpose_b and adjoint_b:
763    raise ValueError('Only one of transpose_b and adjoint_b can be True.')
764
765  kwargs = dict(
766      transpose_a=transpose_a,
767      transpose_b=transpose_b,
768      adjoint_a=adjoint_a,
769      adjoint_b=adjoint_b,
770      a_is_sparse=a_is_sparse,
771      b_is_sparse=b_is_sparse,
772      output_type=output_type)
773
774  with ops.name_scope(name, 'RaggedMatMul', [a, b]) as name:
775    a = ragged_tensor.convert_to_tensor_or_ragged_tensor(a, name='a')
776    b = ragged_tensor.convert_to_tensor_or_ragged_tensor(b, name='b')
777
778    a_is_ragged = isinstance(a, ragged_tensor.RaggedTensor)
779    b_is_ragged = isinstance(b, ragged_tensor.RaggedTensor)
780    if not (a_is_ragged or b_is_ragged):
781      return math_ops.matmul(a, b, **kwargs)
782
783    if a.dtype != b.dtype:
784      raise ValueError('`a` and `b` must have the same dtype.')
785
786    # TODO(edloper): Support broadcasting inputs.  (Broadcast support is not
787    # documented by https://www.tensorflow.org/api_docs/python/tf/linalg/matmul,
788    # but it is supported by the op.)
789
790    # Find the rank of the input tensors.
791    if a.shape.rank is None:
792      if b.shape.rank is None:
793        raise ValueError('matmul requires at least one input to have known '
794                         'rank if either input is ragged.')
795      rank = b.shape.rank
796    else:
797      if b.shape.rank is not None and a.shape.rank != b.shape.rank:
798        raise ValueError('`a` and `b` must have the same rank.')
799      rank = a.shape.rank
800
801    # At least one of `a` and `b` is ragged; and ragged tensors always have
802    # rank>=2.
803    if rank < 2:
804      # This can happen if e.g. `a` is a 1D dense tensor and `b` is a
805      # ragged tensor with unknown rank.  Since ragged tensors always have
806      # `rank>=2`, this implies that `a` and `b` have different ranks.
807      raise ValueError('`a` and `b` must have the same rank.')
808
809    # Rank>3: We have multiple batch dimensions.  Merge them into a single
810    # batch dimension, recursively call `matmul`, and then restore the original
811    # batch dimension (using a.row_splits).
812    if rank > 3:
813      shape_err = 'Batch dimensions of `a` and `b` do not have the same size.'
814      if not a_is_ragged:
815        a = ragged_tensor.RaggedTensor.from_tensor(a, ragged_rank=1)
816      if not b_is_ragged:
817        b = ragged_tensor.RaggedTensor.from_tensor(b, ragged_rank=1)
818      with ops.control_dependencies([
819          check_ops.assert_equal(a.row_splits, b.row_splits, message=shape_err)
820      ]):
821        flat_result = matmul(a.values, b.values, **kwargs)
822        return a.with_values(flat_result)
823
824    if rank == 2:
825      return _matmul_2d(a, b, **kwargs)
826
827    assert rank == 3  # I.e., we have a single batch dimension.
828
829    a_ragged_rank = a.ragged_rank if a_is_ragged else 0
830    if a_ragged_rank == 1 and not (b_is_ragged or transpose_a or adjoint_a):
831      # If `a.shape=[B, (I), J]` and `b.shape=[B, J, K], then we can compute
832      # the result with a single dense `matmul`.
833      return _matmul_3d_with_batch_dim_folding(a, b, **kwargs)
834    else:
835      # Otherwie, fall back on using `map_fn`.
836      return _matmul_3d_with_map_fn(a, b, **kwargs)
837
838
839def _matmul_2d(a, b, **kwargs):
840  """Multiplies potentially ragged 2D tensors.
841
842  Args:
843    a: A 2D Tensor or RaggedTensor with `shape=[I, J]`
844    b: A 2D Tensor or RaggedTensor with `shape=[J, K]`
845    **kwargs: Additional arguments for `tf.matmul` (e.g. transpose_a).
846
847  Returns:
848    A 2D Tensor with `shape=[I, K]`.
849  """
850  # multiplying `a` and `b` is only well-defined if `a` and `b` are
851  # actually uniform (and just happened to be stored as ragged tensors).
852  # Check that they're uniform, convert them to tf.Tensor.
853  ragged_err = ('The matrices in `a` and `b` may not be '
854                'ragged in their innermost dimension.')
855  checks = []
856  if isinstance(a, ragged_tensor.RaggedTensor):
857    original_size = array_ops.size(a.flat_values)
858    a = a.to_tensor()
859    checks.append(
860        check_ops.assert_equal(
861            original_size, array_ops.size(a), message=ragged_err))
862  if isinstance(b, ragged_tensor.RaggedTensor):
863    original_size = array_ops.size(b.flat_values)
864    b = b.to_tensor()
865    checks.append(
866        check_ops.assert_equal(
867            original_size, array_ops.size(b), message=ragged_err))
868  with ops.control_dependencies(checks):
869    return math_ops.matmul(a, b, **kwargs)
870
871
872def _matmul_3d_with_map_fn(a, b, **kwargs):
873  """Multiplies batches of 2D matrices using map_fn.
874
875  `output[n, i, k]` = sum_j (a[n, i, j] * b[n, j, k])` (for all `n`, `i`, `k`).
876
877  Requires that `a[n, i].nrows()` == `b[n].nrows()` (for all `n` and `i`).
878
879  Args:
880    a: A 3D Tensor or RaggedTensor with `shape=[B, I, J]`, where dimensions `I`
881      and `J` may be ragged.
882    b: A 3D Tensor or RaggedTensor with `shape=[B, J, K]`, where dimensions `J`
883      and `K` may be ragged.
884    **kwargs: Additional arguments for `tf.matmul` (e.g. transpose_a).
885
886  Returns:
887    A 3D RaggedTensor with `shape=[B, (I), (K)]`.
888  """
889  if isinstance(b, ragged_tensor.RaggedTensor) and b.ragged_rank == 2:
890    output_ragged_rank = 2
891  else:
892    output_ragged_rank = 1
893
894  def single_batch_matmul(x):
895    out = _matmul_2d(x[0], x[1], **kwargs)
896    if output_ragged_rank == 2:
897      out = ragged_tensor.RaggedTensor.from_tensor(out)
898    return out
899
900  fn_out_shape = None  # Figure out proper shape.
901  row_splits_dtype = (
902      a.row_splits.dtype
903      if isinstance(a, ragged_tensor.RaggedTensor) else b.row_splits.dtype)
904  output_type = kwargs['output_type']
905  if output_type is None:
906    output_type = a.dtype
907  spec = ragged_tensor.RaggedTensorSpec(
908      shape=fn_out_shape,
909      dtype=output_type,
910      ragged_rank=output_ragged_rank - 1,
911      row_splits_dtype=row_splits_dtype)
912  result = map_fn.map_fn(
913      single_batch_matmul, elems=(a, b), fn_output_signature=spec)
914
915  # map_fn loses shape information; restore it, where possible.
916  # pylint: disable=protected-access
917  if kwargs.get('transpose_a') or kwargs.get('adjoint_a'):
918    result._set_shape(a.shape[:-2] + a.shape[-1:] + [None])
919  else:
920    result._set_shape(a.shape[:-2] + a.shape[-2:-1] + [None])
921  if kwargs.get('transpose_b') or kwargs.get('adjoint_b'):
922    result._set_shape(b.shape[:-2] + [None] + b.shape[-2:-1])
923  else:
924    result._set_shape(b.shape[:-2] + [None] + b.shape[-1:])
925
926  return result
927
928
929def _matmul_3d_with_batch_dim_folding(a, b, **kwargs):
930  """Multiply batches of 2D matrices where only `a.shape[1]` is ragged.
931
932  Args:
933    a: A RaggedTensor with `shape=[B, (I), J]`.  (ragged_rank must be 1.)
934    b: A Tensor with `shape=[B, J, K]`
935    **kwargs: Additional arguments for `tf.matmul` (e.g. transpose_a).
936      transpose_a and adjoint_a must not be true.
937
938  Returns:
939    A RaggedTensor with `shape=[B, (I), K].
940  """
941  # reshaped_a.shape = [sum(i_1, i_2, ..., i_B), 1, J]
942  reshaped_a = array_ops.expand_dims(a.values, 1)
943  # reshaped_b.shape = [sum(i_1, i_2, ..., i_B), J, K]
944  reshaped_b = array_ops.repeat(b, a.row_lengths(), axis=0)
945  # flat_result.shape = [sum(i_1, i_2, ..., i_B), 1, K]
946  flat_result = math_ops.matmul(reshaped_a, reshaped_b, **kwargs)
947  # result.shape = [B, (I), K]
948  return a.with_values(array_ops.squeeze(flat_result, axis=1))
949
950
951#===============================================================================
952# ragged.softmax
953#===============================================================================
954def softmax(logits, axis=None, name=None):
955  """Computes softmax activations.
956
957  Used for multi-class predictions. The sum of all outputs generated by softmax
958  is 1.
959
960  This function performs the equivalent of
961
962      softmax = tf.exp(logits) / tf.reduce_sum(tf.exp(logits), axis)
963
964  Example usage:
965
966  >>> softmax = tf.nn.softmax([-1, 0., 1.])
967  >>> softmax
968  <tf.Tensor: shape=(3,), dtype=float32,
969  numpy=array([0.09003057, 0.24472848, 0.66524094], dtype=float32)>
970  >>> sum(softmax)
971  <tf.Tensor: shape=(), dtype=float32, numpy=1.0>
972
973  Args:
974    logits: A non-empty `Tensor`. Must be one of the following types: `half`,
975      `float32`, `float64`.
976    axis: The dimension softmax would be performed on. The default is -1 which
977      indicates the last dimension.
978    name: A name for the operation (optional).
979
980  Returns:
981    A `Tensor`. Has the same type and shape as `logits`.
982
983  Raises:
984    InvalidArgumentError: if `logits` is empty or `axis` is beyond the last
985      dimension of `logits`.
986  """
987  if axis is None:
988    axis = -1
989
990  with ops.name_scope(name, 'RaggedSoftmax', [logits]) as name:
991    logits_exp = math_ops.exp(logits)
992    denominator = reduce_sum(logits_exp, axis=axis, keepdims=True)
993    return math_ops.divide(logits_exp, denominator)
994