• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Array operations for RaggedTensors."""
16
17from typing import Optional
18from typing import Union
19
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import sparse_tensor
23from tensorflow.python.framework import tensor_shape
24from tensorflow.python.framework import tensor_util
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import check_ops
27from tensorflow.python.ops import control_flow_ops
28from tensorflow.python.ops import data_flow_ops
29from tensorflow.python.ops import gen_ragged_array_ops
30from tensorflow.python.ops import math_ops
31from tensorflow.python.ops import sort_ops
32from tensorflow.python.ops.ragged import dynamic_ragged_shape
33from tensorflow.python.ops.ragged import ragged_functional_ops
34from tensorflow.python.ops.ragged import ragged_math_ops
35from tensorflow.python.ops.ragged import ragged_tensor
36from tensorflow.python.ops.ragged import ragged_util
37from tensorflow.python.ops.ragged import segment_id_ops
38from tensorflow.python.types import core as core_types
39from tensorflow.python.util import dispatch
40from tensorflow.python.util.tf_export import tf_export
41
42#===============================================================================
43# Masking
44#===============================================================================
45
46
47@tf_export('ragged.boolean_mask')
48@dispatch.add_dispatch_support
49def boolean_mask(data, mask, name=None):
50  """Applies a boolean mask to `data` without flattening the mask dimensions.
51
52  Returns a potentially ragged tensor that is formed by retaining the elements
53  in `data` where the corresponding value in `mask` is `True`.
54
55  * `output[a1...aA, i, b1...bB] = data[a1...aA, j, b1...bB]`
56
57     Where `j` is the `i`th `True` entry of `mask[a1...aA]`.
58
59  Note that `output` preserves the mask dimensions `a1...aA`; this differs
60  from `tf.boolean_mask`, which flattens those dimensions.
61
62  Args:
63    data: A potentially ragged tensor.
64    mask: A potentially ragged boolean tensor.  `mask`'s shape must be a prefix
65      of `data`'s shape.  `rank(mask)` must be known statically.
66    name: A name prefix for the returned tensor (optional).
67
68  Returns:
69    A potentially ragged tensor that is formed by retaining the elements in
70    `data` where the corresponding value in `mask` is `True`.
71
72    * `rank(output) = rank(data)`.
73    * `output.ragged_rank = max(data.ragged_rank, rank(mask) - 1)`.
74
75  Raises:
76    ValueError: if `rank(mask)` is not known statically; or if `mask.shape` is
77      not a prefix of `data.shape`.
78
79  #### Examples:
80
81  >>> # Aliases for True & False so data and mask line up.
82  >>> T, F = (True, False)
83
84  >>> tf.ragged.boolean_mask(  # Mask a 2D Tensor.
85  ...     data=[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
86  ...     mask=[[T, F, T], [F, F, F], [T, F, F]]).to_list()
87  [[1, 3], [], [7]]
88
89  >>> tf.ragged.boolean_mask(  # Mask a 2D RaggedTensor.
90  ...     tf.ragged.constant([[1, 2, 3], [4], [5, 6]]),
91  ...     tf.ragged.constant([[F, F, T], [F], [T, T]])).to_list()
92  [[3], [], [5, 6]]
93
94  >>> tf.ragged.boolean_mask(  # Mask rows of a 2D RaggedTensor.
95  ...     tf.ragged.constant([[1, 2, 3], [4], [5, 6]]),
96  ...     tf.ragged.constant([True, False, True])).to_list()
97  [[1, 2, 3], [5, 6]]
98  """
99  with ops.name_scope(name, 'RaggedMask', [data, mask]):
100    # Convert inputs to tensors.
101    data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data')
102    mask = ragged_tensor.convert_to_tensor_or_ragged_tensor(
103        mask, dtypes.bool, name='mask')
104    row_splits_dtype, (data, mask) = ragged_tensor.match_row_splits_dtypes(
105        data, mask, return_dtype=True)
106
107    # Get static rank of mask.
108    if mask.shape.ndims is None:
109      raise ValueError('mask.shape.ndims must be known statically.')
110    elif mask.shape.ndims == 0:
111      raise ValueError('mask cannot be scalar.')
112
113    # If mask is ragged, then recurse with a non-ragged mask.
114    if ragged_tensor.is_ragged(mask):
115      if not ragged_tensor.is_ragged(data):
116        data = ragged_tensor.RaggedTensor.from_tensor(
117            data,
118            ragged_rank=mask.ragged_rank,
119            row_splits_dtype=mask.row_splits.dtype)
120      # Check that mask.nested_row_splits is a prefix of
121      # data.nested_row_splits.
122      splits_list = [
123          mask.nested_row_splits, data.nested_row_splits[:mask.ragged_rank]
124      ]
125      with ops.control_dependencies(
126          ragged_util.assert_splits_match(splits_list)):
127        # Strip off ragged `splits` until `mask` is non-ragged.  Keep the splits
128        # that we strip off in `splits`, so we can add them back on after
129        # we recursively mask the non-ragged data.
130        splits = []
131        while ragged_tensor.is_ragged(mask):
132          if mask.shape.ndims > 2:
133            splits.append(mask.row_splits)
134          else:
135            # Count the number of True mask values in each row to find the
136            # lengths of the filtered rows; then convert to splits.
137            int_mask = ragged_functional_ops.map_flat_values(
138                math_ops.cast, mask, dtype=row_splits_dtype)
139            masked_row_lengths = ragged_math_ops.reduce_sum(int_mask, axis=1)
140            splits.append(ragged_util.lengths_to_splits(masked_row_lengths))
141          mask = mask.values
142          data = data.values
143
144        # Recursively apply the nested non-ragged mask to the nested data.
145        masked_values = boolean_mask(data, mask)
146
147        # Add the ragged `splits` back to the result.
148        masked_values = ragged_tensor.RaggedTensor.from_nested_row_splits(
149            masked_values, splits, validate=False)
150
151        return masked_values
152
153    # If mask is non-ragged and has rank 1, and data is ragged, then build a
154    # ragged tensor with the indicated rows.
155    elif ragged_tensor.is_ragged(data) and mask.shape.ndims == 1:
156      # Get the masked splits: first get the length of each row, then filter
157      # out the rows that we are deleting, and convert that filtered set of
158      # masks back to a splits tensor.
159      lengths = data.row_lengths()
160      masked_lengths = array_ops.boolean_mask(lengths, mask)
161      masked_splits = ragged_util.lengths_to_splits(masked_lengths)
162
163      # Get the masked values: first get row ids corresponding to each
164      # value, then use tf.gather to build a boolean mask that's false for
165      # values that come from rows that we are deleting, and use that mask to
166      # construct the masked values tensor.
167      segment_ids = segment_id_ops.row_splits_to_segment_ids(data.row_splits)
168      segment_mask = array_ops.gather(mask, segment_ids)
169      masked_values = boolean_mask(data.values, segment_mask)
170
171      return ragged_tensor.RaggedTensor.from_row_splits(
172          masked_values, masked_splits, validate=False)
173
174    # If mask is non-ragged and has rank>1, then convert it to be ragged,
175    # with a ragged rank matching data.
176    if ragged_tensor.is_ragged(data):
177      mask = ragged_tensor.RaggedTensor.from_tensor(
178          mask,
179          ragged_rank=min(data.ragged_rank, mask.shape.ndims - 1),
180          row_splits_dtype=data.row_splits.dtype)
181      return boolean_mask(data, mask)
182
183    # Otherwise, data and mask are both `Tensor`s.
184    else:
185      # Apply `boolean_mask` to get the masked values.
186      masked_values = array_ops.boolean_mask(data, mask)
187
188      if mask.shape.ndims >= 2:
189        # Add the innermost ragged dimension.  For each innermost cell, get the
190        # number of values it contains.  Then flatten that to get a list of
191        # cell lengths, and convert it to splits.  Finally, combine the splits
192        # and values to get the innermost ragged tensor.
193        masked_lengths = math_ops.count_nonzero(
194            mask, axis=-1, dtype=row_splits_dtype)
195        flattened_masked_lengths = array_ops.reshape(masked_lengths, [-1])
196        masked_values = ragged_tensor.RaggedTensor.from_row_lengths(
197            masked_values, flattened_masked_lengths, validate=False)
198
199        # Wrap remaining ragged dimensions.
200        if mask.shape.ndims > 2:
201          mask_shape = array_ops.shape(mask, out_type=row_splits_dtype)
202          split_size = math_ops.cumprod(mask_shape) + 1
203          for dim in range(mask.shape.ndims - 3, -1, -1):
204            elt_size = mask_shape[dim + 1]
205            masked_splits = math_ops.range(split_size[dim]) * elt_size
206            masked_values = ragged_tensor.RaggedTensor.from_row_splits(
207                masked_values, masked_splits, validate=False)
208
209      return masked_values
210
211
212#===============================================================================
213# Tiling
214#===============================================================================
215@dispatch.dispatch_for_api(array_ops.tile)
216def tile(input: ragged_tensor.Ragged, multiples, name=None):  # pylint: disable=redefined-builtin
217  """Constructs a `RaggedTensor` by tiling a given `RaggedTensor`.
218
219  The values of `input` are replicated `multiples[i]` times along the
220  `i`th dimension (for each dimension `i`).  For every dimension `axis` in
221  `input`, the length of each output element in that dimension is the
222  length of corresponding input element multiplied by `multiples[axis]`.
223
224  Args:
225    input: A `RaggedTensor`.
226    multiples: A 1-D integer `Tensor`.  Length must be the same as the number of
227      dimensions in `input`.
228    name: A name for the operation (optional).
229
230  Returns:
231    A `RaggedTensor` with the same type, rank, and ragged_rank as `input`.
232
233  #### Example:
234
235  >>> rt = tf.ragged.constant([[1, 2], [3]])
236  >>> tf.tile(rt, [3, 2]).to_list()
237  [[1, 2, 1, 2], [3, 3], [1, 2, 1, 2], [3, 3], [1, 2, 1, 2], [3, 3]]
238  """
239  with ops.name_scope(name, 'RaggedTile', [input, multiples]):
240    input = ragged_tensor.convert_to_tensor_or_ragged_tensor(
241        input, name='input')
242    if not ragged_tensor.is_ragged(input):
243      return array_ops.tile(input, multiples, name)
244    multiples = ragged_util.convert_to_int_tensor(
245        multiples, name='multiples', dtype=input.row_splits.dtype)
246    multiples.shape.assert_has_rank(1)
247
248    # If the constant value of `multiples` is available, then we can use it
249    # to skip tiling dimensions where `multiples=1`.
250    const_multiples = tensor_util.constant_value(multiples)
251
252    return ragged_tensor.RaggedTensor.from_nested_row_splits(
253        _tile_ragged_values(input, multiples, const_multiples),
254        _tile_ragged_splits(input, multiples, const_multiples),
255        validate=False)
256
257
258def _tile_ragged_values(rt_input, multiples, const_multiples=None):
259  """Builds flat_values tensor for a tiled `RaggedTensor`.
260
261  Returns a tensor that repeats the values in
262  `rt_input.flat_values` in the
263  appropriate pattern to construct a `RaggedTensor` that tiles `rt_input` as
264  specified by `multiples`.
265
266  Args:
267    rt_input: The `RaggedTensor` whose values should be repeated.
268    multiples: A 1-D integer `tensor`, indicating how many times each dimension
269      should be repeated.
270    const_multiples: Optional constant value for multiples.  Used to skip tiling
271      dimensions where `multiples=1`.
272
273  Returns:
274    A `Tensor` with the same type and rank as `rt_input.flat_values`.
275
276  #### Example:
277
278  >>> rt = tf.ragged.constant([[1, 2], [3]])
279  >>> _tile_ragged_values(rt, tf.constant([3, 2])).numpy()
280  array([1, 2, 1, 2, 3, 3, 1, 2, 1, 2, 3, 3, 1, 2, 1, 2, 3, 3], dtype=int32)
281  """
282  ragged_rank = rt_input.ragged_rank
283  nested_splits = rt_input.nested_row_splits
284
285  # Pointers to the values in `rt_input.flat_values`.
286  inner_value_ids = math_ops.range(nested_splits[-1][-1])
287
288  # For each ragged dimension (working from the innermost to outermost),
289  # expand `inner_value_ids` as necessary to tile that dimension.
290  prev_splits = None
291  for axis in range(ragged_rank, 0, -1):
292    # Ragged splits for this dimension.
293    splits = nested_splits[axis - 1]
294
295    # Adjust splits so they point into `inner_value_ids` (instead of just
296    # pointing into the next dimension's values).
297    if prev_splits is not None:  # Not the first pass through the loop.
298      splits = array_ops.gather(prev_splits * multiples[axis + 1], splits)
299
300    # Repeat each element in this ragged dimension `multiples[axis]` times.
301    if const_multiples is None or const_multiples[axis] != 1:
302      inner_value_ids = ragged_util.repeat_ranges(inner_value_ids, splits,
303                                                  multiples[axis])
304
305    prev_splits = splits
306
307  # Gather the tiled inner values.
308  ragged_tiled_values = array_ops.gather(rt_input.flat_values, inner_value_ids)
309
310  # Tile the flat_values for the uniform dimensions (i.e., for `axis=0` plus
311  # `axis=range(ragged_rank, rank)`).
312  inner_repeats = array_ops.concat([multiples[:1], multiples[ragged_rank + 1:]],
313                                   axis=0)
314  return array_ops.tile(ragged_tiled_values, inner_repeats)
315
316
317def _tile_ragged_splits(rt_input, multiples, const_multiples=None):
318  """Builds nested_split tensors for a tiled `RaggedTensor`.
319
320  Returns a list of split tensors that can be used to construct the
321  `RaggedTensor` that tiles `rt_input` as specified by `multiples`.
322
323  Args:
324    rt_input: The `RaggedTensor` that is being tiled.
325    multiples: A 1-D integer `tensor`, indicating how many times each dimension
326      should be repeated.
327    const_multiples: Optional constant value for multiples.  Used to skip tiling
328      dimensions where `multiples=1`.
329
330  Returns:
331    A list of 1-D integer `Tensor`s (one for each ragged dimension in
332    `rt_input`).
333
334  #### Example:
335
336  >>> rt = tf.ragged.constant([[1, 2], [3]])
337  >>> _tile_ragged_splits(rt, [3, 2])
338  [<tf.Tensor: shape=(7,), dtype=int64,
339  numpy=array([ 0,  4,  6, 10, 12, 16, 18])>]
340  """
341  ragged_rank = rt_input.ragged_rank
342  nested_splits = rt_input.nested_row_splits
343
344  # projected_splits[src_axis, dst_axis] contains the split points that divide
345  # the rows from src_axis in the list of dst_axis values.  E.g.,
346  # projected_splits[i, i] = nested_splits[i], and
347  # projected_splits[i, i+1] = gather(nested_splits[i+1], nested_splits[i]).
348  projected_splits = [{i: nested_splits[i]} for i in range(ragged_rank)]
349  for src_axis in range(ragged_rank):
350    for dst_axis in range(src_axis + 1, ragged_rank - 1):
351      projected_splits[src_axis][dst_axis] = array_ops.gather(
352          nested_splits[dst_axis], projected_splits[src_axis][dst_axis - 1])
353
354  # For each ragged dimension: nested_splits[axis] -> result_splits[axis].
355  result_splits = []
356  for axis in range(ragged_rank):
357    # Get the length of each row for the input tensor for this dimension.
358    input_lengths = nested_splits[axis][1:] - nested_splits[axis][:-1]
359
360    # Multiply those lengths by the `multiples` of dimension axis+1, since
361    # each value will be repeated that number of times.
362    output_lengths = input_lengths * multiples[axis + 1]
363
364    # Repeat ranges of the row lengths as necessary for them to be tiled in
365    # each ragged dimension `d < axis`.  (Start with dimension d=axis-1, and
366    # work our way up to dimension d=0.)
367    repeats = 1
368    for d in range(axis - 1, -1, -1):
369      if const_multiples is None or const_multiples[d + 1] != 1:
370        splits = projected_splits[d][axis - 1] * repeats
371        output_lengths = ragged_util.repeat_ranges(output_lengths, splits,
372                                                   multiples[d + 1])
373      repeats *= multiples[d + 1]
374
375    # Tile splits for the outermost (uniform) dimension.
376    output_lengths = array_ops.tile(output_lengths, multiples[:1])
377
378    # Convert to splits.
379    result_splits.append(ragged_util.lengths_to_splits(output_lengths))
380
381  return result_splits
382
383
384#===============================================================================
385# Reshaping
386#===============================================================================
387
388
389@dispatch.dispatch_for_api(array_ops.expand_dims_v2)
390def expand_dims(input: ragged_tensor.Ragged, axis, name=None):  # pylint: disable=redefined-builtin
391  """Inserts a dimension with shape 1 into a potentially ragged tensor's shape.
392
393  Given a potentially ragged tenor `input`, this operation inserts a
394  dimension with size 1 at the dimension `axis` of `input`'s shape.
395
396  The following table gives some examples showing how `ragged.expand_dims`
397  impacts the shapes of different input tensors.  Ragged dimensions are
398  indicated by enclosing them in parentheses.
399
400  input.shape             | axis | result.shape
401  ----------------------- | ---- | -----------------------------
402  `[D1, D2]`              |  `0` | `[1, D1, D2]`
403  `[D1, D2]`              |  `1` | `[D1, 1, D2]`
404  `[D1, D2]`              |  `2` | `[D1, D2, 1]`
405  `[D1, (D2), (D3), D4]`  |  `0` | `[1, D1, (D2), (D3), D4]`
406  `[D1, (D2), (D3), D4]`  |  `1` | `[D1, 1, (D2), (D3), D4]`
407  `[D1, (D2), (D3), D4]`  |  `2` | `[D1, (D2), 1, (D3), D4]`
408  `[D1, (D2), (D3), D4]`  |  `3` | `[D1, (D2), (D3), 1, D4]`
409  `[D1, (D2), (D3), D4]`  |  `4` | `[D1, (D2), (D3), D4, 1]`
410
411  Args:
412    input: The potentially tensor that should be expanded with a new dimension.
413    axis: An integer constant indicating where the new dimension should be
414      inserted.
415    name: A name for the operation (optional).
416
417  Returns:
418    A tensor with the same values as `input`, with an added dimension of
419    size 1 at `axis`.
420
421  #### Examples:
422
423  >>> rt = tf.ragged.constant([[1, 2], [3]])
424  >>> print(rt.shape)
425  (2, None)
426
427  >>> expanded = tf.expand_dims(rt, axis=0)
428  >>> print(expanded.shape, expanded)
429  (1, 2, None) <tf.RaggedTensor [[[1, 2], [3]]]>
430
431  >>> expanded = tf.expand_dims(rt, axis=1)
432  >>> print(expanded.shape, expanded)
433  (2, 1, None) <tf.RaggedTensor [[[1, 2]], [[3]]]>
434
435  >>> expanded = tf.expand_dims(rt, axis=2)
436  >>> print(expanded.shape, expanded)
437  (2, None, 1) <tf.RaggedTensor [[[1], [2]], [[3]]]>
438  """
439  with ops.name_scope(name, 'RaggedExpandDims', [input]):
440    input = ragged_tensor.convert_to_tensor_or_ragged_tensor(
441        input, name='input')
442
443    if not ragged_tensor.is_ragged(input):
444      return array_ops.expand_dims(input, axis)
445
446    ndims = None if input.shape.ndims is None else input.shape.ndims + 1
447    axis = array_ops.get_positive_axis(axis, ndims, ndims_name='rank(input)')
448
449    if axis == 0:
450      return ragged_tensor.RaggedTensor.from_uniform_row_length(
451          input, uniform_row_length=input.nrows(), nrows=1, validate=False)
452    elif axis == 1:
453      return ragged_tensor.RaggedTensor.from_uniform_row_length(
454          input, uniform_row_length=1, nrows=input.nrows(), validate=False)
455    else:
456      if ragged_tensor.is_ragged(input.values):
457        return input.with_values(expand_dims(input.values, axis - 1))
458      else:
459        return input.with_values(array_ops.expand_dims(input.values, axis - 1))
460
461
462@dispatch.dispatch_for_api(array_ops.expand_dims)
463def _ragged_expand_dims_v1(
464    input: ragged_tensor.Ragged,  # pylint: disable=redefined-builtin
465    axis=None,
466    name=None,
467    dim=None):
468  if dim is not None:
469    axis = dim
470  return expand_dims(input=input, axis=axis, name=name)
471
472
473#===============================================================================
474# RaggedTensor Size
475#===============================================================================
476
477
478@dispatch.dispatch_for_api(array_ops.size_v2)
479def size(input: ragged_tensor.Ragged, out_type=dtypes.int32, name=None):  # pylint: disable=redefined-builtin
480  """Returns the size of a potentially ragged tensor.
481
482  The size of a ragged tensor is the size of its inner values.
483
484  #### Example:
485
486  >>> tf.size(tf.ragged.constant([[1, 2], [3]])).numpy()
487  3
488
489  Args:
490    input: A potentially ragged `Tensor`.
491    out_type: The numeric output type for the operation.
492    name: A name for the operation (optional).
493
494  Returns:
495    A Tensor of type `out_type`.
496  """
497  if ragged_tensor.is_ragged(input):
498    return array_ops.size(input.flat_values, out_type=out_type, name=name)
499  else:
500    return array_ops.size(input, out_type=out_type, name=name)
501
502
503@dispatch.dispatch_for_api(array_ops.size)
504def _ragged_size_v1(
505    input: ragged_tensor.Ragged,  # pylint: disable=redefined-builtin
506    name=None,
507    out_type=dtypes.int32):
508  return size(input=input, out_type=out_type, name=name)
509
510
511#===============================================================================
512# ragged.rank
513#===============================================================================
514@dispatch.dispatch_for_api(array_ops.rank)
515def rank(input: ragged_tensor.Ragged, name=None):  # pylint: disable=redefined-builtin
516  """Returns the rank of a RaggedTensor.
517
518  Returns a 0-D `int32` `Tensor` representing the rank of `input`.
519
520  #### Example:
521
522  >>> # shape of tensor 't' is [2, None, None]
523  >>> t = tf.ragged.constant([[[1], [2, 2]], [[3, 3, 3], [4, 4, 4, 4]]])
524  >>> tf.rank(t).numpy()
525  3
526
527  Args:
528    input: A `RaggedTensor`
529    name: A name for the operation (optional).
530
531  Returns:
532    A `Tensor` of type `int32`.
533  """
534  with ops.name_scope(name, 'RaggedRank', [input]) as name:
535    if not ragged_tensor.is_ragged(input):
536      return array_ops.rank(input, name)
537
538    return input.ragged_rank + array_ops.rank(input.flat_values)
539
540
541#===============================================================================
542# ragged.one_hot
543#===============================================================================
544@dispatch.dispatch_for_api(array_ops.one_hot)
545def ragged_one_hot(indices: ragged_tensor.Ragged,
546                   depth,
547                   on_value=None,
548                   off_value=None,
549                   axis=None,
550                   dtype=None,
551                   name=None):
552  """Applies tf.one_hot along the values of a RaggedTensor."""
553  # Get the adjusted axis value for the call to array_ops.one_hot.
554  # Note: the only negative `axis` value supported by array_ops.one_hot is -1.
555  if isinstance(axis, int) and axis >= 0:
556    if axis <= indices.ragged_rank:
557      raise ValueError('axis (%d) must be greater than indices.ragged_rank '
558                       '(%d).' % (axis, indices.ragged_rank))
559    axis -= indices.ragged_rank
560
561  with ops.name_scope(name, 'RaggedOneHot',
562                      [indices, depth, on_value, off_value, axis]):
563    indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
564        indices, name='indices')
565    return indices.with_flat_values(
566        array_ops.one_hot(indices.flat_values, depth, on_value, off_value, axis,
567                          dtype, name))
568
569
570#===============================================================================
571# ragged.stack_dynamic_partitions
572#===============================================================================
573@tf_export('ragged.stack_dynamic_partitions')
574@dispatch.add_dispatch_support
575def stack_dynamic_partitions(data, partitions, num_partitions, name=None):
576  """Stacks dynamic partitions of a Tensor or RaggedTensor.
577
578  Returns a RaggedTensor `output` with `num_partitions` rows, where the row
579  `output[i]` is formed by stacking all slices `data[j1...jN]` such that
580  `partitions[j1...jN] = i`.  Slices of `data` are stacked in row-major
581  order.
582
583  If `num_partitions` is an `int` (not a `Tensor`), then this is equivalent to
584  `tf.ragged.stack(tf.dynamic_partition(data, partitions, num_partitions))`.
585
586  #### Example:
587
588  >>> data           = ['a', 'b', 'c', 'd', 'e']
589  >>> partitions     = [  3,   0,   2,   2,   3]
590  >>> num_partitions = 5
591  >>> tf.ragged.stack_dynamic_partitions(data, partitions, num_partitions)
592  <tf.RaggedTensor [[b'b'], [], [b'c', b'd'], [b'a', b'e'], []]>
593
594  Args:
595    data: A `Tensor` or `RaggedTensor` containing the values to stack.
596    partitions: An `int32` or `int64` `Tensor` or `RaggedTensor` specifying the
597      partition that each slice of `data` should be added to. `partitions.shape`
598      must be a prefix of `data.shape`.  Values must be greater than or equal to
599      zero, and less than `num_partitions`. `partitions` is not required to be
600      sorted.
601    num_partitions: An `int32` or `int64` scalar specifying the number of
602      partitions to output.  This determines the number of rows in `output`.
603    name: A name prefix for the returned tensor (optional).
604
605  Returns:
606    A `RaggedTensor` containing the stacked partitions.  The returned tensor
607    has the same dtype as `data`, and its shape is
608    `[num_partitions, (D)] + data.shape[partitions.rank:]`, where `(D)` is a
609    ragged dimension whose length is the number of data slices stacked for
610    each `partition`.
611  """
612  with ops.name_scope(name, 'SegmentStack', [data, partitions, num_partitions]):
613    # Convert inputs to tensors.
614    data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data')
615    row_splits_dtype = (
616        data.row_splits.dtype
617        if isinstance(data, ragged_tensor.RaggedTensor) else None)
618    partitions = ragged_tensor.convert_to_tensor_or_ragged_tensor(
619        partitions, name='partitions', preferred_dtype=row_splits_dtype)
620    num_partitions = ops.convert_to_tensor(
621        num_partitions, name='num_partitions', preferred_dtype=partitions.dtype)
622    if row_splits_dtype is not None:
623      partitions = math_ops.cast(partitions, row_splits_dtype)
624    num_partitions = math_ops.cast(num_partitions, partitions.dtype)
625
626    # Sanity-checks for shapes.
627    partitions_rank = partitions.shape.ndims
628    if partitions_rank is None:
629      raise ValueError('partitions must have known rank.')
630    num_partitions.shape.assert_has_rank(0)
631    partitions.shape.assert_is_compatible_with(data.shape[:partitions_rank])
632
633    if partitions_rank == 0:
634      # If partitions is a scalar, then just create a RaggedTensor containing
635      # that single the complete `data` value in the specified row.
636      return ragged_tensor.RaggedTensor.from_value_rowids(
637          values=array_ops.stack([data]),
638          value_rowids=array_ops.stack([partitions]),
639          nrows=num_partitions,
640          validate=False)
641
642    elif partitions_rank == 1:
643      # If partitions is a vector (the typical case): we can just use data and
644      # partitions as the `values` and `value_rowids` for `from_value_rowids`,
645      # as long as we sort them first.
646      permutation = sort_ops.argsort(partitions, stable=True)
647      value_rowids = array_ops.gather(partitions, permutation)
648      values = array_ops.gather(data, permutation)
649      check = check_ops.assert_less(
650          value_rowids[-1:],
651          num_partitions,
652          message='partitions must be less than num_partitions')
653      with ops.control_dependencies([check]):
654        return ragged_tensor.RaggedTensor.from_value_rowids(
655            values, value_rowids, nrows=num_partitions, validate=False)
656
657    else:
658      # Handle higher-dimensional partitions via recursion.
659      if not isinstance(data, ragged_tensor.RaggedTensor):
660        data = ragged_tensor.RaggedTensor.from_tensor(
661            data, row_splits_dtype=partitions.dtype, ragged_rank=1)
662      if not isinstance(partitions, ragged_tensor.RaggedTensor):
663        partitions = ragged_tensor.RaggedTensor.from_tensor(
664            partitions,
665            row_splits_dtype=partitions.dtype,
666            ragged_rank=max(data.ragged_rank, partitions_rank - 1))
667      check = check_ops.assert_equal(
668          data.row_splits,
669          partitions.row_splits,
670          message='data and partitions have incompatible ragged shapes')
671      with ops.control_dependencies([check]):
672        return stack_dynamic_partitions(data.values, partitions.values,
673                                        num_partitions)
674
675
676#===============================================================================
677# Reverse
678#===============================================================================
679@dispatch.dispatch_for_api(array_ops.reverse)
680def reverse(tensor: ragged_tensor.Ragged, axis, name=None):
681  """Reverses a RaggedTensor along the specified axes.
682
683  #### Example:
684
685  >>> data = tf.ragged.constant([
686  ...   [[1, 2], [3, 4]], [[5, 6]], [[7, 8], [9, 10], [11, 12]]])
687  >>> tf.reverse(data, axis=[0, 2])
688  <tf.RaggedTensor [[[8, 7], [10, 9], [12, 11]], [[6, 5]], [[2, 1], [4, 3]]]>
689
690  Args:
691    tensor: A 'RaggedTensor' to reverse.
692    axis: A list or tuple of 'int' or a constant 1D 'tf.Tensor'. The indices of
693      the axes to reverse.
694    name: A name prefix for the returned tensor (optional).
695
696  Returns:
697    A 'RaggedTensor'.
698  """
699  type_error_msg = ('`axis` must be a list of int or a constant tensor'
700                    'when reversing axes in a ragged tensor')
701
702  with ops.name_scope(name, 'Reverse', [tensor, axis]):
703    if isinstance(axis, ops.Tensor):
704      axis = tensor_util.constant_value(axis)
705      if axis is None:
706        raise TypeError(type_error_msg)
707    elif not (isinstance(axis, (list, tuple)) and
708              all(isinstance(dim, int) for dim in axis)):
709      raise TypeError(type_error_msg)
710
711    tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor(
712        tensor, name='tensor')
713
714    # Allow usage of negative values to specify innermost axes.
715    axis = [
716        array_ops.get_positive_axis(dim, tensor.shape.rank, 'axis[%d]' % i,
717                                    'rank(tensor)')
718        for i, dim in enumerate(axis)
719    ]
720
721    # We only need to slice up to the max axis. If the axis list
722    # is empty, it should be 0.
723    slices = [slice(None)] * (max(axis) + 1 if axis else 0)
724
725    for dim in axis:
726      slices[dim] = slice(None, None, -1)
727
728    return tensor[tuple(slices)]
729
730
731#===============================================================================
732# Cross
733#===============================================================================
734
735
736@tf_export('ragged.cross')
737@dispatch.add_dispatch_support
738def cross(inputs, name=None):
739  """Generates feature cross from a list of tensors.
740
741  The input tensors must have `rank=2`, and must all have the same number of
742  rows.  The result is a `RaggedTensor` with the same number of rows as the
743  inputs, where `result[row]` contains a list of all combinations of values
744  formed by taking a single value from each input's corresponding row
745  (`inputs[i][row]`).  Values are combined by joining their strings with '_X_'.
746  E.g.:
747
748  >>> tf.ragged.cross([tf.ragged.constant([['a'], ['b', 'c']]),
749  ...                  tf.ragged.constant([['d'], ['e']]),
750  ...                  tf.ragged.constant([['f'], ['g']])])
751  <tf.RaggedTensor [[b'a_X_d_X_f'], [b'b_X_e_X_g', b'c_X_e_X_g']]>
752
753  Args:
754    inputs: A list of `RaggedTensor` or `Tensor` or `SparseTensor`.
755    name: Optional name for the op.
756
757  Returns:
758    A 2D `RaggedTensor` of type `string`.
759  """
760  return _cross_internal(inputs=inputs, hashed_output=False, name=name)
761
762
763@tf_export('ragged.cross_hashed')
764@dispatch.add_dispatch_support
765def cross_hashed(inputs, num_buckets=0, hash_key=None, name=None):
766  """Generates hashed feature cross from a list of tensors.
767
768  The input tensors must have `rank=2`, and must all have the same number of
769  rows.  The result is a `RaggedTensor` with the same number of rows as the
770  inputs, where `result[row]` contains a list of all combinations of values
771  formed by taking a single value from each input's corresponding row
772  (`inputs[i][row]`).  Values are combined by hashing together their
773  fingerprints. E.g.:
774
775  >>> tf.ragged.cross_hashed([tf.ragged.constant([['a'], ['b', 'c']]),
776  ...                         tf.ragged.constant([['d'], ['e']]),
777  ...                         tf.ragged.constant([['f'], ['g']])],
778  ...                        num_buckets=100)
779  <tf.RaggedTensor [[78], [66, 74]]>
780
781  Args:
782    inputs: A list of `RaggedTensor` or `Tensor` or `SparseTensor`.
783    num_buckets: A non-negative `int` that used to bucket the hashed values. If
784      `num_buckets != 0`, then `output = hashed_value % num_buckets`.
785    hash_key: Integer hash_key that will be used by the `FingerprintCat64`
786      function. If not given, a default key is used.
787    name: Optional name for the op.
788
789  Returns:
790    A 2D `RaggedTensor` of type `int64`.
791  """
792  return _cross_internal(
793      inputs=inputs,
794      hashed_output=True,
795      num_buckets=num_buckets,
796      hash_key=hash_key,
797      name=name)
798
799
800_DEFAULT_CROSS_HASH_KEY = 0xDECAFCAFFE
801
802
803def _cross_internal(inputs,
804                    hashed_output=False,
805                    num_buckets=0,
806                    hash_key=None,
807                    name=None):
808  """Generates feature cross from a list of ragged and dense tensors."""
809  if not isinstance(inputs, (tuple, list)):
810    raise TypeError('Inputs must be a list')
811
812  if hash_key is None:
813    hash_key = _DEFAULT_CROSS_HASH_KEY
814
815  ragged_inputs = []
816  sparse_inputs = []
817  dense_inputs = []
818  input_order = []
819  with ops.name_scope(name, 'RaggedCross', inputs):
820    for i, t in enumerate(inputs):
821      if sparse_tensor.is_sparse(t):
822        t = sparse_tensor.SparseTensor.from_value(t)
823      else:
824        t = ragged_tensor.convert_to_tensor_or_ragged_tensor(t)
825      if t.dtype.is_integer:
826        t = math_ops.cast(t, dtypes.int64)
827      elif t.dtype != dtypes.string:
828        raise ValueError('Unexpected dtype for inputs[%d]: %s' % (i, t.dtype))
829      if isinstance(t, ragged_tensor.RaggedTensor):
830        if t.ragged_rank != 1:
831          raise ValueError('tf.ragged.cross only supports inputs with rank=2')
832        ragged_inputs.append(t)
833        input_order.append('R')
834      elif isinstance(t, sparse_tensor.SparseTensor):
835        sparse_inputs.append(t)
836        input_order.append('S')
837      else:
838        dense_inputs.append(t)
839        input_order.append('D')
840
841    out_values_type = dtypes.int64 if hashed_output else dtypes.string
842    if ragged_inputs and all(
843        t.row_splits.dtype == dtypes.int32 for t in ragged_inputs):
844      out_row_splits_type = dtypes.int32
845    else:
846      out_row_splits_type = dtypes.int64
847
848    # Convert hash_key from uint64 -> int64, since we need to pass it via
849    # an int64 attr.
850    if hash_key > 2**63:
851      hash_key -= 2**64
852
853    values_out, splits_out = gen_ragged_array_ops.ragged_cross(
854        ragged_values=[rt.values for rt in ragged_inputs],
855        ragged_row_splits=[rt.row_splits for rt in ragged_inputs],
856        sparse_indices=[st.indices for st in sparse_inputs],
857        sparse_values=[st.values for st in sparse_inputs],
858        sparse_shape=[st.dense_shape for st in sparse_inputs],
859        dense_inputs=dense_inputs,
860        input_order=''.join(input_order),
861        hashed_output=hashed_output,
862        num_buckets=num_buckets,
863        hash_key=hash_key,
864        out_values_type=out_values_type.as_datatype_enum,
865        out_row_splits_type=out_row_splits_type.as_datatype_enum,
866        name=name)
867
868    return ragged_tensor.RaggedTensor.from_row_splits(
869        values_out, splits_out, validate=False)
870
871
872#===============================================================================
873# dynamic_partition
874#===============================================================================
875@dispatch.dispatch_for_api(data_flow_ops.dynamic_partition)
876def dynamic_partition(data: ragged_tensor.RaggedOrDense,
877                      partitions: ragged_tensor.RaggedOrDense,
878                      num_partitions,
879                      name=None):
880  """RaggedTensor dispatch override for tf.dynamic_partition."""
881  if not isinstance(num_partitions, int) or num_partitions < 0:
882    raise TypeError('num_partitions must be a non-negative integer')
883  result = stack_dynamic_partitions(data, partitions, num_partitions, name)
884  return [result[i] for i in range(num_partitions)]
885
886
887#===============================================================================
888# split
889#===============================================================================
890@dispatch.dispatch_for_api(array_ops.split)
891def split(value: ragged_tensor.Ragged,
892          num_or_size_splits,
893          axis=0,
894          num=None,
895          name=None):
896  """Splits a RaggedTensor `value` into a list of sub RaggedTensors.
897
898  If `num_or_size_splits` is an `int`,  then it splits `value` along the
899  dimension `axis` into `num_or_size_splits` smaller RaggedTensors. This
900  requires that `value.shape[axis]` is divisible by `num_or_size_splits`.
901
902  If `num_or_size_splits` is a 1-D Tensor (or list), then `value` is split into
903  `len(num_or_size_splits)` elements. The shape of the `i`-th element has the
904  same size as the `value` except along dimension `axis` where the size is
905  `num_or_size_splits[i]`.
906
907  Splits along a ragged dimension is not allowed.
908
909  For example:
910
911  >>> rt = tf.RaggedTensor.from_row_lengths(
912  ...      np.arange(6 * 3).reshape(6, 3), row_lengths=[1, 2, 2, 1])
913  >>> rt.shape
914  TensorShape([4, None, 3])
915  >>>
916  >>> rt1, rt2 = tf.split(rt, 2)  # uniform splits
917  >>> rt1.shape
918  TensorShape([2, None, 3])
919  >>> rt2.shape
920  TensorShape([2, None, 3])
921  >>>
922  >>> rt3, rt4, rt5 = tf.split(rt, [1, 2, 1])  # ragged splits
923  >>> rt3.shape
924  TensorShape([1, None, 3])
925  >>> rt4.shape
926  TensorShape([2, None, 3])
927  >>> rt5.shape
928  TensorShape([1, None, 3])
929  >>>
930  >>> rt6, rt7 = tf.split(rt, [1, 2], axis=2)  # splits along axis 2
931  >>> rt6.shape
932  TensorShape([4, None, 1])
933  >>> rt7.shape
934  TensorShape([4, None, 2])
935
936  Args:
937    value: The `RaggedTensor` to split.
938    num_or_size_splits: Either an `int` indicating the number of splits
939      along `axis` or a 1-D integer `Tensor` or Python list containing the sizes
940      of each output tensor along `axis`. If a Python int, then it must evenly
941      divide `value.shape[axis]`; otherwise the sum of sizes along the split
942      axis must match that of the `value`.
943    axis: An `int` or scalar `int32` `Tensor`. The dimension along which
944      to split. Must be in the range `[-rank(value), rank(value))`. Defaults to
945      0.
946    num: An `int` used to specify the number of outputs when
947      `num_or_size_splits` is a 1-D list or `Tensor` and its length is
948      statically unknown, e.g., specifying `tf.TensorSepc(None)` with
949      the `input_signature` argument of `tf.function` (optional).
950    name: A name for the operation (optional).
951
952  Returns:
953    if `num_or_size_splits` is an `int` returns a list of `num_or_size_splits`
954    `RaggedTensor` objects; if `num_or_size_splits` is a 1-D Tensor returns
955    `num_or_size_splits.get_shape[0]` `RaggedTensor` objects resulting from
956    splitting `value`.
957
958  Raises:
959    ValueError: If the dimension `axis` of `value` is a ragged dimension.
960    ValueError: If `num` is unspecified and cannot be inferred.
961    ValueError: If `num` is specified but doesn't match the length of
962      `num_or_size_splits`.
963    ValueError: If `num_or_size_splits` is an `int` and less than 1.
964    TypeError: If `num_or_size_splits` is not an `int` or 1-D
965      list or 1-D `Tensor`.
966    InvalidArgumentError: If the `axis` of `value` cannot be exactly splitted
967      by `num_or_size_splits`.
968    InvalidArgumentError: If `num_or_size_splits` is contains negative integers.
969    InvalidArgumentError: If `num_or_size_splits`'s static shape is unknown and
970      its dynamic shape is inconsistent `num`.
971    InvalidArgumentError: If `num_or_size_splits`'s static rank is unknown and
972      `axis` is a negative integer.
973  """
974  with ops.name_scope(name, 'RaggedSplit'):
975    value = ragged_tensor.convert_to_tensor_or_ragged_tensor(
976        value, name='value')
977    if isinstance(num_or_size_splits, int) and num_or_size_splits == 1:
978      return [value]
979
980    # static assert
981    check_ops.assert_integer_v2(
982        num_or_size_splits,
983        message=('`num_or_size_splits` must be an `int` or 1-D list or '
984                 '`Tensor` of integers.'))
985    value_shape = dynamic_ragged_shape.DynamicRaggedShape.from_tensor(value)
986    axis = array_ops.get_positive_axis(axis, value_shape.rank)
987    try:
988      dim_size = value_shape[axis]
989    except ValueError:
990      raise ValueError('Cannot split a ragged dimension. Got `value` with '
991                       f'shape {value_shape} and `axis` {axis}.')
992    if isinstance(num_or_size_splits, int):
993      # Uniform split
994      num_splits = num_or_size_splits
995      if num_splits < 1:
996        raise ValueError('`num_or_size_splits` must be >=1 if it is an `int`.'
997                         f'Received {num_or_size_splits}.')
998      split_length = math_ops.floordiv(dim_size, num_splits)
999      split_lengths = array_ops.repeat(split_length, num_splits)
1000    else:
1001      # Ragged split
1002      num_splits = None
1003      split_lengths = ops.convert_to_tensor(num_or_size_splits)
1004      if split_lengths.shape.ndims is not None:
1005        if split_lengths.shape.ndims != 1:
1006          raise TypeError('`num_or_size_splits` must be an `int` or 1-D list '
1007                          f'or `Tensor`. Received {num_or_size_splits}.')
1008        num_splits = tensor_shape.dimension_value(split_lengths.shape[0])
1009
1010      if num_splits is None:
1011        if num is None:
1012          raise ValueError('`num` must be specified as an `int` when the '
1013                           'size of `num_or_size_split` is statically '
1014                           f'unknown. Received `num`: {num} and '
1015                           f'`num_or_size_split`: {num_or_size_splits}.')
1016        num_splits = num
1017      else:
1018        if num is not None and num != num_splits:
1019          raise ValueError('`num` does not match the size of '
1020                           f'`num_or_size_split`. Received `num`: {num} and '
1021                           f'size of `num_or_size_split`: {num_splits}.')
1022
1023    splits = array_ops.concat([[0], math_ops.cumsum(split_lengths)], axis=0)
1024    checks = []
1025    checks.append(
1026        check_ops.assert_non_negative_v2(
1027            num_or_size_splits,
1028            message='`num_or_size_splits` must be non-negative.'))
1029    checks.append(
1030        check_ops.assert_equal_v2(
1031            num_splits,
1032            array_ops.shape(split_lengths)[0],
1033            message='`num` is inconsistent with `num_or_size_split.shape[0]`.'))
1034    checks.append(
1035        check_ops.assert_equal_v2(
1036            math_ops.cast(dim_size, splits.dtype),
1037            splits[-1],
1038            message=('Cannot exactly split the `axis` dimension of `value` '
1039                     'with the given `num_or_size_split`.')))
1040    splits = control_flow_ops.with_dependencies(checks, splits)
1041    splited_rts = []
1042    slices = [slice(None)] * (axis + 1)
1043    for i in range(num_splits):
1044      slices[-1] = slice(splits[i], splits[i + 1])
1045      splited_rts.append(value[tuple(slices)])
1046    return splited_rts
1047
1048
1049#===============================================================================
1050# RaggedTensor shape operations
1051#===============================================================================
1052
1053
1054@dispatch.dispatch_for_api(array_ops.reshape)
1055def ragged_reshape(
1056    tensor: ragged_tensor.RaggedOrDense,
1057    shape: dynamic_ragged_shape.DenseOrRaggedShape
1058) -> Union[ragged_tensor.RaggedTensor, ops.Tensor]:
1059  """Reshapes a tensor or ragged tensor."""
1060  tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor(
1061      tensor, name='tensor')
1062  if isinstance(tensor, ragged_tensor.RaggedTensor):
1063    tensor = tensor.values
1064
1065  if isinstance(shape, dynamic_ragged_shape.DynamicRaggedShape):
1066    flat_values = array_ops.reshape(tensor, shape.inner_shape)
1067    return ragged_tensor.RaggedTensor._from_nested_row_partitions(  # pylint: disable=protected-access
1068        flat_values,
1069        shape.row_partitions,
1070        validate=False)
1071  else:
1072    shape = ops.convert_to_tensor(shape, name='shape')
1073    return array_ops.reshape(tensor, shape)
1074
1075
1076@dispatch.dispatch_for_api(array_ops.broadcast_to)
1077def broadcast_to(
1078    input: ragged_tensor.RaggedOrDense,  # pylint: disable=redefined-builtin
1079    shape: dynamic_ragged_shape.DynamicRaggedShape
1080) -> Union[ragged_tensor.RaggedTensor, ops.Tensor]:
1081  """Broadcasts a potentially ragged tensor to a ragged shape.
1082
1083  Tiles `input` as necessary to match the given shape.
1084
1085  Behavior is undefined if `input` is not broadcast-compatible with `shape`.
1086
1087  Args:
1088    input: The potentially ragged tensor to broadcast.
1089    shape: A `DynamicRaggedShape`
1090
1091  Returns:
1092    A potentially ragged tensor whose values are taken from
1093    `input`, and whose shape matches `shape`.
1094  """
1095  return dynamic_ragged_shape.broadcast_to(input, shape)
1096
1097
1098# Note: default value for out_type needs to be int32, to match the
1099# default for tf.shape's out_type parameter.
1100@dispatch.dispatch_for_api(array_ops.shape)
1101def ragged_shape(
1102    input: ragged_tensor.Ragged,  # pylint: disable=redefined-builtin
1103    name: Optional[str] = None,
1104    out_type=dtypes.int32) -> dynamic_ragged_shape.DynamicRaggedShape:
1105  """Returns the shape of a RaggedTensor.
1106
1107  Args:
1108    input: A `RaggedTensor`
1109    name: A name for the operation (optional).
1110    out_type: dtype used to encode the shape.
1111
1112  Returns:
1113    A `tf.experimental.DynamicRaggedShape`
1114  """
1115  with ops.name_scope(name, 'RaggedShape', [input]):
1116    return dynamic_ragged_shape.DynamicRaggedShape.from_tensor(input, out_type)
1117
1118
1119@dispatch.dispatch_for_api(array_ops.broadcast_dynamic_shape)
1120def broadcast_dynamic_shape(
1121    shape_x: dynamic_ragged_shape.DenseOrRaggedShape,
1122    shape_y: dynamic_ragged_shape.DenseOrRaggedShape
1123) -> dynamic_ragged_shape.DynamicRaggedShape:
1124  """Returns the shape formed by broadcasting two shapes to be compatible.
1125
1126  1. If shape_x and shape_y both have row_partitions, then fail if their dtypes
1127     don't match.
1128  2. If neither has row_partitions and they have different dtypes,
1129     go with int64.
1130  3. If one has row_partitions, go with that dtype.
1131
1132  Args:
1133    shape_x: A `DynamicRaggedShape`
1134    shape_y: A `DynamicRaggedShape`
1135
1136  Returns:
1137    A `DynamicRaggedShape`.
1138  Raises:
1139    ValueError: If `shape_x` and `shape_y` are not broadcast-compatible.
1140  """
1141  if not isinstance(shape_x, dynamic_ragged_shape.DynamicRaggedShape):
1142    shape_x = dynamic_ragged_shape.DynamicRaggedShape([], shape_x)
1143  if not isinstance(shape_y, dynamic_ragged_shape.DynamicRaggedShape):
1144    shape_y = dynamic_ragged_shape.DynamicRaggedShape([], shape_y)
1145  return dynamic_ragged_shape.broadcast_dynamic_shape(shape_x, shape_y)
1146
1147
1148@dispatch.dispatch_for_api(array_ops.ones)
1149def ones(shape: dynamic_ragged_shape.DynamicRaggedShape,
1150         dtype=dtypes.float32,
1151         name=None) -> ragged_tensor.RaggedOrDense:
1152  """Returns ones shaped like x."""
1153  flat_values = array_ops.ones(shape.inner_shape, dtype=dtype, name=name)
1154  return shape._add_row_partitions(flat_values)  # pylint: disable=protected-access
1155
1156
1157@dispatch.dispatch_for_api(array_ops.zeros)
1158def zeros(shape: dynamic_ragged_shape.DynamicRaggedShape,
1159          dtype=dtypes.float32,
1160          name=None) -> ragged_tensor.RaggedOrDense:
1161  """Returns ones shaped like x."""
1162  flat_values = array_ops.zeros(shape.inner_shape, dtype=dtype, name=name)
1163  return shape._add_row_partitions(flat_values)  # pylint: disable=protected-access
1164
1165
1166@dispatch.dispatch_for_api(array_ops.fill)
1167def fill(dims: dynamic_ragged_shape.DynamicRaggedShape,
1168         value: core_types.TensorLike,
1169         name: Optional[str] = None) -> ragged_tensor.RaggedOrDense:
1170  """Creates a tensor with shape `dims` and fills it with `value`."""
1171  flat_values = array_ops.fill(dims.inner_shape, value, name=name)
1172  return dims._add_row_partitions(flat_values)  # pylint: disable=protected-access
1173
1174
1175#===============================================================================
1176# bitcast
1177#===============================================================================
1178@dispatch.dispatch_for_api(array_ops.bitcast)
1179def bitcast(
1180    input: ragged_tensor.RaggedOrDense,  # pylint: disable=redefined-builtin
1181    type,  # pylint: disable=redefined-builtin
1182    name=None) -> ragged_tensor.RaggedOrDense:
1183  """RaggedTensor dispatch override for tf.bitcast."""
1184  type = dtypes.as_dtype(type)
1185  with ops.name_scope(name, 'Bitcast', [input]):
1186    input = ragged_tensor.convert_to_tensor_or_ragged_tensor(
1187        input, name='input')
1188    if (input.dtype.size < type.size and input.flat_values.shape.rank < 2):
1189      raise ValueError('`input.flat_values` is required to have rank >= 2 when '
1190                       'input.dtype.size < type.size. Actual rank: '
1191                       f'{input.flat_values.shape.rank}')
1192    return input.with_flat_values(array_ops.bitcast(input.flat_values, type))
1193