• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Shapes & broadcasting for RaggedTensors."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import constant_op
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import tensor_shape
25from tensorflow.python.framework import tensor_util
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import control_flow_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops.ragged import ragged_array_ops
30from tensorflow.python.ops.ragged import ragged_config
31from tensorflow.python.ops.ragged import ragged_tensor
32from tensorflow.python.ops.ragged import ragged_util
33
34
35class RaggedTensorDynamicShape(object):
36  """A collection of tensors encoding the shape of a potentially ragged tensor.
37
38  Each `RaggedTensorDynamicShape` consists of an ordered list of dimension
39  sizes.  There are two dimension types:
40
41    * "Uniform dimensions" are dimensions where all slices have the same
42      length.  `RaggedTensorDynamicShape` records the size of each uniform
43      dimension using a single scalar integer.
44
45    * "Ragged dimensions" are dimensions whose slices may have different
46      lengths.  `RaggedTensorDynamicShape` records the size of each ragged
47      dimension using an integer vector containing the slice lengths for all
48      the slices across that dimension.
49
50  Furthermore, there are two ways a dimension might be encoded:
51
52    * "Partitioned dimensions" are dimensions that are encoded using a
53      `RaggedTensor`'s `nested_row_splits`.  The outermostmost partitioned
54      dimension must be uniform, and the innermost partitioned dimension must
55      be ragged.
56
57    * "Inner dimensions" are dimensions that are encoded using a
58      `RaggedTensor`'s `flat_values`.  Inner dimensions are always uniform.
59
60  The sizes of partitioned dimensions are recorded using `partitioned_dim_sizes`
61  and `inner_dim_sizes`:
62
63    * `partitioned_dim_sizes` is a list of tensors (one for each partitioned
64      dimension).
65
66      * For uniform dimensions, the tensor is an integer scalar specifying the
67        size of all slices across that dimension.
68      * For ragged dimensions, the tensor is an integer vector specifying the
69        size of each slice across that dimension.
70
71    * `inner_dim_sizes` is a single integer vector, where each element
72      specifies the size of a single inner dimension.
73
74  Examples:
75
76  Tensor                         | Ragged | Partitioned Dim Sizes  | Inner Dim
77                                 : Rank   :                        : Sizes
78  ------------------------------ | ------ | ---------------------- | ----------
79  `[[1, 2, 3], [4, 5, 6]]`       |      0 |                        | `2, 3`
80  `[[1, 2], [], [3, 4, 5]]`      |      1 | `3, (2, 0, 3)`         |
81  `[[[1, 2], [3, 4]], [[5, 6]]]` |      1 | `2, (2, 1)`            | 2
82  `[[[1, 2], [3]], [[4, 5]]]`    |      2 | `2, (2, 1), (2, 1, 2)` |
83  """
84
85  def __init__(self, partitioned_dim_sizes, inner_dim_sizes,
86               dim_size_dtype=None):
87    """Creates a RaggedTensorDynamicShape.
88
89    Args:
90      partitioned_dim_sizes: A `list` of 0-D or 1-D integer `Tensor`, one for
91        each partitioned dimension.  If dimension `d` is uniform, then
92        `partitioned_dim_sizes[d]` must be an integer scalar, specifying the
93        size of all slices across dimension `d`.  If dimension `d` is ragged,
94        then `partitioned_dim_sizes[d]` must be an integer vector, specifying
95        the size of each slice across dimension `d`.
96      inner_dim_sizes: A 1-D integer `Tensor`, whose length is equal to the
97        number of inner dimensions.  `inner_dim_sizes[n]` is the size of all
98        slices across the `n`th inner dimension (which is the
99        `(len(partitioned_dim_sizes)+n)`th dimension in the overall tensor.
100      dim_size_dtype: dtype for dimension sizes.  If not specified, then it
101        is chosen based on the dtypes of `partitioned_dim_sizes` and
102        `inner_dim_sizes`.
103    """
104    assert isinstance(partitioned_dim_sizes, (list, tuple))
105
106    with ops.name_scope(None, 'RaggedTensorDynamicShape',
107                        (partitioned_dim_sizes, inner_dim_sizes)):
108      partitioned_dim_sizes = tuple(
109          ops.convert_to_tensor(size, name='partitioned_dimension_size_%d' % i)
110          for (i, size) in enumerate(partitioned_dim_sizes))
111      inner_dim_sizes = ops.convert_to_tensor(
112          inner_dim_sizes, name='inner_dim_sizes')
113
114      # Validate shapes.
115      if partitioned_dim_sizes:
116        for axis, dimension_size in enumerate(partitioned_dim_sizes):
117          if dimension_size.shape.ndims is None:
118            raise ValueError(
119                'rank of partitioned_dim_sizes[%d] is unknown' % axis)
120          dimension_size.shape.with_rank_at_most(1)
121        if partitioned_dim_sizes[0].shape.ndims == 1:
122          raise ValueError('outermost partitioned dimension must be uniform')
123        if partitioned_dim_sizes[-1].shape.ndims == 0:
124          raise ValueError('innermost partitioned dimension must be ragged')
125      inner_dim_sizes.shape.assert_has_rank(1)
126
127      # Convert dimension size tensors to a single dtype.
128      if dim_size_dtype is None:
129        dim_size_dtypes = set(
130            p.dtype for p in partitioned_dim_sizes if p.shape.ndims == 1)
131        if not dim_size_dtypes:
132          dim_size_dtype = dtypes.int64
133        elif len(dim_size_dtypes) == 1:
134          dim_size_dtype = dim_size_dtypes.pop()
135        else:
136          if not ragged_config.auto_cast_partition_dtype():
137            raise ValueError('partitioned_dim_sizes must have matching dtypes')
138          dim_size_dtype = dtypes.int64
139      partitioned_dim_sizes = tuple(math_ops.cast(p, dim_size_dtype)
140                                    for p in partitioned_dim_sizes)
141      inner_dim_sizes = math_ops.cast(inner_dim_sizes, dim_size_dtype)
142
143      self._partitioned_dim_sizes = partitioned_dim_sizes
144      self._inner_dim_sizes = inner_dim_sizes
145
146  def __repr__(self):
147    return ('RaggedTensorDynamicShape'
148            '(partitioned_dim_sizes=%r, inner_dim_sizes=%r)' %
149            (self._partitioned_dim_sizes, self._inner_dim_sizes))
150
151  @staticmethod
152  def from_dim_sizes(dim_sizes):
153    """Constructs a ragged shape from a list of dimension sizes.
154
155    This list contains a single tensor for each dimension, where the tensor
156    is a scalar if the dimension is uniform, or a vector if the dimension is
157    ragged.
158
159    Args:
160      dim_sizes: List of int32 or int64 scalars or vectors.
161
162    Returns:
163      A RaggedTensorDynamicShape.
164    """
165    with ops.name_scope(None, 'RaggedTensorDynamicShapeFromDimensionSizes',
166                        [dim_sizes]):
167      dim_sizes = tuple(
168          ops.convert_to_tensor(size, preferred_dtype=dtypes.int64,
169                                name='dim_sizes') for size in dim_sizes)
170      # Split the dimensions into partitioned & inner dimensions.
171      inner_split = 0
172      for dim, dim_size in enumerate(dim_sizes):
173        if dim_size.shape.ndims == 1:
174          inner_split = dim + 1
175        elif dim_size.shape.ndims != 0:
176          raise ValueError('Each dim_size must be a scalar or a vector')
177      return RaggedTensorDynamicShape(dim_sizes[:inner_split],
178                                      dim_sizes[inner_split:])
179
180  @classmethod
181  def from_tensor(cls, rt_input, dim_size_dtype=None):
182    """Constructs a ragged shape for a potentially ragged tensor."""
183    with ops.name_scope(None, 'RaggedTensorDynamicShapeFromTensor', [rt_input]):
184      rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt_input)
185      if not ragged_tensor.is_ragged(rt_input):
186        return cls([], array_ops.shape(rt_input))
187      else:
188        partitioned_dim_sizes = (
189            (rt_input.nrows(),) + rt_input.nested_row_lengths())
190        return RaggedTensorDynamicShape(
191            partitioned_dim_sizes,
192            array_ops.shape(rt_input.flat_values)[1:],
193            dim_size_dtype=dim_size_dtype)
194
195  def dimension_size(self, axis):
196    """Returns the size of slices across the specified dimension."""
197    if not isinstance(axis, int):
198      raise TypeError('axis must be an integer')
199    partitioned_ndims = len(self._partitioned_dim_sizes)
200    if axis < partitioned_ndims:
201      return self._partitioned_dim_sizes[axis]
202    else:
203      return self._inner_dim_sizes[axis - partitioned_ndims]
204
205  def is_ragged(self, axis):
206    """Returns true if the indicated dimension is ragged."""
207    if not isinstance(axis, int):
208      raise TypeError('axis must be an integer')
209    rank = self.rank
210    if axis < 0:
211      raise ValueError('Negative axis values are not supported')
212    elif rank is not None and axis >= rank:
213      raise ValueError('Expected axis=%s < rank=%s' % (axis, rank))
214    else:
215      return (axis > 0 and axis < len(self._partitioned_dim_sizes) and
216              self._partitioned_dim_sizes[axis].shape.ndims == 1)
217
218  @property
219  def rank(self):
220    """The number of dimensions in this shape, or None if unknown."""
221    inner_ndims = tensor_shape.dimension_value(self._inner_dim_sizes.shape[0])
222    if inner_ndims is None:
223      return None
224    else:
225      return len(self._partitioned_dim_sizes) + inner_ndims
226
227  @property
228  def partitioned_dim_sizes(self):
229    """The partitioned dimension sizes for this shape.
230
231    Returns:
232      A `list` of 0-D or 1-D integer `Tensor`.
233    """
234    return self._partitioned_dim_sizes
235
236  @property
237  def inner_dim_sizes(self):
238    """The inner dimension sizes for this shape.
239
240    Returns:
241      A 1-D integer `Tensor`.
242    """
243    return self._inner_dim_sizes
244
245  @property
246  def num_partitioned_dimensions(self):
247    """The number of partitioned dimensions in this shape."""
248    return len(self._partitioned_dim_sizes)
249
250  @property
251  def num_inner_dimensions(self):
252    """The number of inner dimensions, or `None` if not statically known."""
253    return tensor_shape.dimension_value(self._inner_dim_sizes.shape[0])
254
255  @property
256  def dim_size_dtype(self):
257    """DType used by this shape for dimension sizes."""
258    return self._inner_dim_sizes.dtype
259
260  def broadcast_to_rank(self, rank):
261    """Adds leading size-1 dimensions to broadcast `self` to the given rank.
262
263    E.g., if `shape1` is `[3, (D2), 4]`, then `shape1.broadcast_to_rank(5)`
264    is `[1, 1, 3, (D2), 4]`.
265
266    Args:
267      rank: The rank for the returned shape.
268
269    Returns:
270      A RaggedTensorDynamicShape with `rank` dimensions, whose inner dimensions
271      have the same size as `self` and whose outer dimensions have size `1`.
272
273    Raises:
274      ValueError: If `self.rank` is unknown or greater than `rank`.
275    """
276    if self.rank is None:
277      raise ValueError('Unable to broadcast: self.rank is unknown')
278    dims_to_add = rank - self.rank
279    if dims_to_add < 0:
280      raise ValueError('Unable to broadcast: rank=%d must be greater than '
281                       'self.rank=%d.' % (rank, self.rank))
282    elif dims_to_add == 0:
283      return self
284    elif self._partitioned_dim_sizes:
285      partitioned_dims = (1,) * dims_to_add + self._partitioned_dim_sizes
286      return RaggedTensorDynamicShape(partitioned_dims, self._inner_dim_sizes)
287    else:
288      inner_dims = array_ops.concat(
289          [array_ops.ones([dims_to_add], self.dim_size_dtype),
290           self.inner_dim_sizes],
291          axis=0)
292      return RaggedTensorDynamicShape([], inner_dims)
293
294  def broadcast_dimension(self, axis, lengths):
295    """Returns a shape that is broadcast-compatible with self & lengths.
296
297    * If dimension[axis] is uniform and lengths is a scalar, the check
298      that either lengths==1 or axis==1 or lengths==axis, and tile
299      dimension[axis] with tf.where(lengths==axis, 1, axis) repeats.
300
301    * If dimension[axis] is uniform and lengths is a vector, then check
302      that dimension[axis]==1, and raggedly tile dimension[axis] with
303      lengths repeats.  (we can skip tiling if we statically know that
304      slice_lengths == 1??)
305
306    * If dimension[axis] is ragged and lengths is a scalar, then check
307      that lengths==1.
308
309    * If dimension[axis] is ragged and lengths is a vector, then check
310      that self.dimension_size(axis) == lengths.
311
312    Args:
313      axis: `int`.  The dimension to broadcast.
314      lengths: 0-D or 1-D integer `Tensor`.
315
316    Returns:
317      A `RaggedTensorDynamicShape`.
318    """
319    lengths = ragged_util.convert_to_int_tensor(
320        lengths, name='lengths', dtype=self.dim_size_dtype)
321    # Check whether lengths is a scalar (for uniform dimensions) or
322    # vector (for ragged dimensions).
323    if lengths.shape.ndims is None:
324      raise ValueError('lengths must have a known rank.')
325    elif lengths.shape.ndims > 1:
326      raise ValueError('lengths must be a scalar or vector')
327    else:
328      lengths_is_scalar = (lengths.shape.ndims == 0)
329
330    # Verify that the shapes are compatible.
331    if self.is_ragged(axis):
332      if lengths_is_scalar:
333        condition = math_ops.equal(lengths, 1)
334      else:
335        condition = math_ops.reduce_all(
336            math_ops.equal(lengths, self.dimension_size(axis)))
337    else:
338      axis_dim_size = self.dimension_size(axis)
339      if lengths_is_scalar:
340        condition = (
341            math_ops.equal(lengths, 1) | math_ops.equal(axis_dim_size, 1)
342            | math_ops.equal(axis_dim_size, lengths))
343      else:
344        condition = math_ops.equal(axis_dim_size, 1)
345    broadcast_err = [
346        'Unable to broadcast: dimension size mismatch in dimension', axis,
347        'lengths=', lengths, 'dim_size=',
348        self.dimension_size(axis)
349    ]
350    broadcast_check = control_flow_ops.Assert(
351        condition, data=broadcast_err, summarize=10)
352
353    with ops.control_dependencies([broadcast_check]):
354      # Partitioned dimensions:
355      if axis < self.num_partitioned_dimensions:
356        if self.is_ragged(axis):
357          # Use an identity op to make sure the check actually gets run.
358          return RaggedTensorDynamicShape(
359              self._partitioned_dim_sizes,
360              array_ops.identity(self.inner_dim_sizes))
361        else:
362          return self._broadcast_uniform_partitioned_dimension(axis, lengths)
363
364      # Inner dimensions:
365      else:
366        if lengths_is_scalar:
367          return self._broadcast_inner_dimension_to_uniform(axis, lengths)
368        else:
369          if axis == 0:
370            raise ValueError('Unable to broadcast: '
371                             'outermost dimension must be uniform.')
372          return self._broadcast_inner_dimension_to_ragged(axis, lengths)
373
374  def num_slices_in_dimension(self, axis):
375    """Returns the total number of slices across the indicated dimension."""
376    if axis < 0:
377      return constant_op.constant(1, dtype=self.dim_size_dtype)
378    elif self.is_ragged(axis):
379      return math_ops.reduce_sum(self._partitioned_dim_sizes[axis])
380    else:
381      return self.dimension_size(axis) * self.num_slices_in_dimension(axis - 1)
382
383  def _broadcast_uniform_partitioned_dimension(self, axis, lengths):
384    """Broadcasts the partitioned dimension `axis` to match `lengths`."""
385    axis_dim_size = self.dimension_size(axis)
386    partitioned_sizes = list(self._partitioned_dim_sizes[:axis])
387
388    if lengths.shape.ndims == 0:
389      lengths = array_ops.where(
390          math_ops.equal(axis_dim_size, 1), lengths, axis_dim_size)
391      repeats = array_ops.where(math_ops.equal(axis_dim_size, 1), lengths, 1)
392      splits = array_ops.stack([0, self.num_slices_in_dimension(axis)])
393    else:
394      splits = math_ops.range(
395          array_ops.size(lengths, out_type=self.dim_size_dtype) + 1)
396      repeats = lengths
397
398    partitioned_sizes.append(lengths)
399
400    for dim_size in self._partitioned_dim_sizes[axis + 1:]:
401      if dim_size.shape.ndims == 0:
402        partitioned_sizes.append(dim_size)
403        splits *= dim_size
404      else:
405        partitioned_sizes.append(
406            ragged_util.repeat_ranges(dim_size, splits, repeats))
407        splits = array_ops.gather(
408            ragged_util.lengths_to_splits(dim_size), splits)
409    inner_sizes = self._inner_dim_sizes
410    return RaggedTensorDynamicShape(partitioned_sizes, inner_sizes)
411
412  def _broadcast_inner_dimension_to_uniform(self, axis, length):
413    """Broadcasts the inner dimension `axis` to match `lengths`."""
414    dim_size = self.dimension_size(axis)
415    axis_in_inner_dims = axis - self.num_partitioned_dimensions
416    partitioned_sizes = self._partitioned_dim_sizes
417    inner_sizes = array_ops.concat([
418        self._inner_dim_sizes[:axis_in_inner_dims],
419        [array_ops.where(math_ops.equal(dim_size, 1), length, dim_size)],
420        self._inner_dim_sizes[axis_in_inner_dims + 1:]
421    ],
422                                   axis=0)
423    return RaggedTensorDynamicShape(partitioned_sizes, inner_sizes)
424
425  def _broadcast_inner_dimension_to_ragged(self, axis, lengths):
426    axis_in_inner_dims = axis - self.num_partitioned_dimensions
427    partitioned_sizes = (
428        self._partitioned_dim_sizes + tuple([
429            self._inner_dim_sizes[i] for i in range(axis_in_inner_dims)
430        ]) + (lengths,))
431    inner_sizes = self._inner_dim_sizes[axis_in_inner_dims + 1:]
432    return RaggedTensorDynamicShape(partitioned_sizes, inner_sizes)
433
434  def with_dim_size_dtype(self, dtype):
435    if dtype not in (dtypes.int32, dtypes.int64):
436      raise ValueError('dtype must be int32 or int64')
437    if self.dim_size_dtype == dtype:
438      return self
439    return RaggedTensorDynamicShape(
440        [math_ops.cast(p, dtype) for p in self._partitioned_dim_sizes],
441        math_ops.cast(self._inner_dim_sizes, dtype))
442
443
444def broadcast_dynamic_shape(shape_x, shape_y):
445  """Returns the shape formed by broadcasting two shapes to be compatible.
446
447  Args:
448    shape_x: A `RaggedTensorDynamicShape`
449    shape_y: A `RaggedTensorDynamicShape`
450
451  Returns:
452    A `RaggedTensorDynamicShape`.
453  Raises:
454    ValueError: If `shape_x` and `shape_y` are not broadcast-compatible.
455  """
456  if not isinstance(shape_x, RaggedTensorDynamicShape):
457    raise TypeError('shape_x must be a RaggedTensorDynamicShape')
458  if not isinstance(shape_y, RaggedTensorDynamicShape):
459    raise TypeError('shape_y must be a RaggedTensorDynamicShape')
460
461  # Broadcast both shapes to have the same rank.
462  if shape_x.rank is None or shape_y.rank is None:
463    raise ValueError('Unable to broadcast: unknown rank')
464  broadcast_rank = max(shape_x.rank, shape_y.rank)
465  shape_x = shape_x.broadcast_to_rank(broadcast_rank)
466  shape_y = shape_y.broadcast_to_rank(broadcast_rank)
467
468  # Broadcast dimensions one at a time, starting from the outermost dimension.
469  for axis in range(broadcast_rank):
470    shape_x = shape_x.broadcast_dimension(axis, shape_y.dimension_size(axis))
471    shape_y = shape_y.broadcast_dimension(axis, shape_x.dimension_size(axis))
472
473  return shape_x
474
475
476def broadcast_to(rt_input, shape, broadcast_inner_dimensions=True):
477  """Broadcasts a potentially ragged tensor to a ragged shape.
478
479  Tiles `rt_input` as necessary to match the given shape.
480
481  Behavior is undefined if `rt_input` is not broadcast-compatible with `shape`.
482
483  Args:
484    rt_input: The potentially ragged tensor to broadcast.
485    shape: A `RaggedTensorDynamicShape`
486    broadcast_inner_dimensions: If false, then inner dimensions will not be
487      tiled.
488
489  Returns:
490    A potentially ragged tensor whose values are taken from
491    `rt_input`, and whose shape matches `shape`.
492  """
493  if not isinstance(shape, RaggedTensorDynamicShape):
494    raise TypeError('shape must be a RaggedTensorDynamicShape')
495  rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt_input)
496
497  # Broadcasting to a uniform shape.
498  if shape.num_partitioned_dimensions == 0:
499    return _broadcast_to_uniform_shape(rt_input, shape,
500                                       broadcast_inner_dimensions)
501  else:
502    return _broadcast_to_ragged_shape(rt_input, shape,
503                                      broadcast_inner_dimensions)
504
505
506def _broadcast_to_uniform_shape(rt_input, shape, broadcast_inner_dimensions):
507  """Broadcasts rt_input to the uniform shape `shape`."""
508  if isinstance(rt_input, ragged_tensor.RaggedTensor):
509    raise ValueError('Incompatible with shape: ragged rank mismatch')
510  if broadcast_inner_dimensions:
511    return array_ops.broadcast_to(rt_input, shape.inner_dim_sizes)
512  else:
513    return rt_input
514
515
516def _broadcast_to_ragged_shape(rt_input, dst_shape, broadcast_inner_dimensions):
517  """Broadcasts rt_input to the ragged shape `dst_shape`."""
518  # Check that rt_input and dst_shape have the same row_splits dtype.
519  if (isinstance(rt_input, ragged_tensor.RaggedTensor) and
520      rt_input.row_splits.dtype != dst_shape.dim_size_dtype):
521    if not ragged_config.auto_cast_partition_dtype():
522      raise ValueError('rt_input and dst_shape have different row_split '
523                       'dtypes; use RaggedTensor.with_row_splits_dtype() or '
524                       'RaggedTensorDynamicShape.with_dim_size_dtype() to '
525                       'convert to a compatible dtype.')
526    rt_input = rt_input.with_row_splits_dtype(dtypes.int64)
527    dst_shape = dst_shape.with_dim_size_dtype(dtypes.int64)
528
529  # dst_shape's rank and ragged_rank must be greater than or equal to rt_input's
530  if rt_input.shape.ndims is None or dst_shape.rank is None:
531    raise ValueError('Unable to broadcast: unknown rank')
532  if rt_input.shape.ndims > dst_shape.rank:
533    raise ValueError('Incompatible with shape: rank mismatch')
534  if (isinstance(rt_input, ragged_tensor.RaggedTensor) and
535      rt_input.ragged_rank >= dst_shape.num_partitioned_dimensions):
536    raise ValueError('Incompatible with shape: ragged rank mismatch')
537
538  src_shape = RaggedTensorDynamicShape.from_tensor(rt_input)
539  src_shape = src_shape.broadcast_to_rank(dst_shape.rank)
540
541  # Add dimensions to rt_input so its rank and ragged_rank matches dst_shape.
542  if dst_shape.rank > rt_input.shape.ndims:
543    if rt_input.shape.ndims < dst_shape.num_inner_dimensions + 1:
544      rt_input = array_ops.reshape(
545          rt_input, array_ops.concat([[-1], dst_shape.inner_dim_sizes], axis=0))
546    for _ in range(dst_shape.rank - rt_input.shape.ndims):
547      if ragged_tensor.is_ragged(rt_input):
548        nrows = rt_input.nrows()
549      else:
550        nrows = array_ops.shape(rt_input,
551                                out_type=dst_shape.dim_size_dtype)[0]
552      rt_input = ragged_tensor.RaggedTensor.from_row_lengths(rt_input, [nrows],
553                                                             validate=False)
554
555  # Add ragged dimensions to match dst_shape.
556  if ragged_tensor.is_ragged(rt_input):
557    inner_rank_diff = (
558        rt_input.flat_values.shape.ndims - 1 - dst_shape.num_inner_dimensions)
559    if inner_rank_diff > 0:
560      rt_input = rt_input.with_flat_values(
561          ragged_tensor.RaggedTensor.from_tensor(
562              rt_input.flat_values, ragged_rank=inner_rank_diff,
563              row_splits_dtype=dst_shape.dim_size_dtype))
564  else:
565    rt_input = ragged_tensor.RaggedTensor.from_tensor(
566        rt_input, ragged_rank=dst_shape.num_partitioned_dimensions - 1,
567        row_splits_dtype=dst_shape.dim_size_dtype)
568
569  # Do broadcasting for any dimensions that will remain uniform.  We can do
570  # these all at once, since they're independent of one another.
571  multiples = [1] * dst_shape.rank
572  for axis in range(dst_shape.num_partitioned_dimensions):
573    if not src_shape.is_ragged(axis) and not dst_shape.is_ragged(axis):
574      src_size = src_shape.dimension_size(axis)
575      dst_size = dst_shape.dimension_size(axis)
576      if ((tensor_util.constant_value(src_size) in (1, None)) and
577          (tensor_util.constant_value(dst_size) != 1)):
578        multiples[axis] = array_ops.where(
579            math_ops.equal(src_size, 1), dst_size, 1)
580  if not all(isinstance(v, int) and v == 1 for v in multiples):
581    multiples = array_ops.stack(multiples, axis=0)
582    rt_input = ragged_array_ops.tile(rt_input, multiples)
583
584  if broadcast_inner_dimensions:
585    new_shape = array_ops.broadcast_dynamic_shape(
586        array_ops.shape(
587            rt_input.flat_values, out_type=dst_shape.dim_size_dtype),
588        array_ops.concat([[1], dst_shape.inner_dim_sizes], axis=0))
589    rt_input = rt_input.with_flat_values(
590        array_ops.broadcast_to(rt_input.flat_values, new_shape))
591
592  # Do broadcasting for dimensions that become ragged.  We must do these from
593  # outermost to innermost.
594  for axis in range(dst_shape.num_partitioned_dimensions):
595    if not src_shape.is_ragged(axis) and dst_shape.is_ragged(axis):
596      dst_size = dst_shape.dimension_size(axis)
597      rt_input = _ragged_tile_axis(rt_input, axis, dst_size,
598                                   dst_shape.dim_size_dtype)
599
600  return rt_input
601
602
603def _ragged_tile_axis(rt_input, axis, repeats, row_splits_dtype):
604  """Tile a dimension of a RaggedTensor to match a ragged shape."""
605  assert axis > 0  # Outermost dimension may not be ragged.
606
607  if not ragged_tensor.is_ragged(rt_input):
608    rt_input = ragged_tensor.RaggedTensor.from_tensor(
609        rt_input, ragged_rank=1, row_splits_dtype=row_splits_dtype)
610
611  if axis > 1:
612    return rt_input.with_values(
613        _ragged_tile_axis(rt_input.values, axis - 1, repeats,
614                          row_splits_dtype))
615  else:
616    src_row_splits = rt_input.nested_row_splits
617    src_row_lengths = rt_input.nested_row_lengths()
618    splits = src_row_splits[0]
619
620    dst_row_lengths = [repeats]
621    for i in range(1, len(src_row_lengths)):
622      dst_row_lengths.append(
623          ragged_util.repeat_ranges(src_row_lengths[i], splits, repeats))
624      splits = array_ops.gather(src_row_splits[i], splits)
625    dst_values = ragged_util.repeat_ranges(rt_input.flat_values, splits,
626                                           repeats)
627    return ragged_tensor.RaggedTensor.from_nested_row_lengths(
628        dst_values, dst_row_lengths, validate=False)
629