1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Shapes & broadcasting for RaggedTensors.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import constant_op 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.framework import tensor_shape 25from tensorflow.python.framework import tensor_util 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import control_flow_ops 28from tensorflow.python.ops import math_ops 29from tensorflow.python.ops.ragged import ragged_array_ops 30from tensorflow.python.ops.ragged import ragged_config 31from tensorflow.python.ops.ragged import ragged_tensor 32from tensorflow.python.ops.ragged import ragged_util 33 34 35class RaggedTensorDynamicShape(object): 36 """A collection of tensors encoding the shape of a potentially ragged tensor. 37 38 Each `RaggedTensorDynamicShape` consists of an ordered list of dimension 39 sizes. There are two dimension types: 40 41 * "Uniform dimensions" are dimensions where all slices have the same 42 length. `RaggedTensorDynamicShape` records the size of each uniform 43 dimension using a single scalar integer. 44 45 * "Ragged dimensions" are dimensions whose slices may have different 46 lengths. `RaggedTensorDynamicShape` records the size of each ragged 47 dimension using an integer vector containing the slice lengths for all 48 the slices across that dimension. 49 50 Furthermore, there are two ways a dimension might be encoded: 51 52 * "Partitioned dimensions" are dimensions that are encoded using a 53 `RaggedTensor`'s `nested_row_splits`. The outermostmost partitioned 54 dimension must be uniform, and the innermost partitioned dimension must 55 be ragged. 56 57 * "Inner dimensions" are dimensions that are encoded using a 58 `RaggedTensor`'s `flat_values`. Inner dimensions are always uniform. 59 60 The sizes of partitioned dimensions are recorded using `partitioned_dim_sizes` 61 and `inner_dim_sizes`: 62 63 * `partitioned_dim_sizes` is a list of tensors (one for each partitioned 64 dimension). 65 66 * For uniform dimensions, the tensor is an integer scalar specifying the 67 size of all slices across that dimension. 68 * For ragged dimensions, the tensor is an integer vector specifying the 69 size of each slice across that dimension. 70 71 * `inner_dim_sizes` is a single integer vector, where each element 72 specifies the size of a single inner dimension. 73 74 Examples: 75 76 Tensor | Ragged | Partitioned Dim Sizes | Inner Dim 77 : Rank : : Sizes 78 ------------------------------ | ------ | ---------------------- | ---------- 79 `[[1, 2, 3], [4, 5, 6]]` | 0 | | `2, 3` 80 `[[1, 2], [], [3, 4, 5]]` | 1 | `3, (2, 0, 3)` | 81 `[[[1, 2], [3, 4]], [[5, 6]]]` | 1 | `2, (2, 1)` | 2 82 `[[[1, 2], [3]], [[4, 5]]]` | 2 | `2, (2, 1), (2, 1, 2)` | 83 """ 84 85 def __init__(self, partitioned_dim_sizes, inner_dim_sizes, 86 dim_size_dtype=None): 87 """Creates a RaggedTensorDynamicShape. 88 89 Args: 90 partitioned_dim_sizes: A `list` of 0-D or 1-D integer `Tensor`, one for 91 each partitioned dimension. If dimension `d` is uniform, then 92 `partitioned_dim_sizes[d]` must be an integer scalar, specifying the 93 size of all slices across dimension `d`. If dimension `d` is ragged, 94 then `partitioned_dim_sizes[d]` must be an integer vector, specifying 95 the size of each slice across dimension `d`. 96 inner_dim_sizes: A 1-D integer `Tensor`, whose length is equal to the 97 number of inner dimensions. `inner_dim_sizes[n]` is the size of all 98 slices across the `n`th inner dimension (which is the 99 `(len(partitioned_dim_sizes)+n)`th dimension in the overall tensor. 100 dim_size_dtype: dtype for dimension sizes. If not specified, then it 101 is chosen based on the dtypes of `partitioned_dim_sizes` and 102 `inner_dim_sizes`. 103 """ 104 assert isinstance(partitioned_dim_sizes, (list, tuple)) 105 106 with ops.name_scope(None, 'RaggedTensorDynamicShape', 107 (partitioned_dim_sizes, inner_dim_sizes)): 108 partitioned_dim_sizes = tuple( 109 ops.convert_to_tensor(size, name='partitioned_dimension_size_%d' % i) 110 for (i, size) in enumerate(partitioned_dim_sizes)) 111 inner_dim_sizes = ops.convert_to_tensor( 112 inner_dim_sizes, name='inner_dim_sizes') 113 114 # Validate shapes. 115 if partitioned_dim_sizes: 116 for axis, dimension_size in enumerate(partitioned_dim_sizes): 117 if dimension_size.shape.ndims is None: 118 raise ValueError( 119 'rank of partitioned_dim_sizes[%d] is unknown' % axis) 120 dimension_size.shape.with_rank_at_most(1) 121 if partitioned_dim_sizes[0].shape.ndims == 1: 122 raise ValueError('outermost partitioned dimension must be uniform') 123 if partitioned_dim_sizes[-1].shape.ndims == 0: 124 raise ValueError('innermost partitioned dimension must be ragged') 125 inner_dim_sizes.shape.assert_has_rank(1) 126 127 # Convert dimension size tensors to a single dtype. 128 if dim_size_dtype is None: 129 dim_size_dtypes = set( 130 p.dtype for p in partitioned_dim_sizes if p.shape.ndims == 1) 131 if not dim_size_dtypes: 132 dim_size_dtype = dtypes.int64 133 elif len(dim_size_dtypes) == 1: 134 dim_size_dtype = dim_size_dtypes.pop() 135 else: 136 if not ragged_config.auto_cast_partition_dtype(): 137 raise ValueError('partitioned_dim_sizes must have matching dtypes') 138 dim_size_dtype = dtypes.int64 139 partitioned_dim_sizes = tuple(math_ops.cast(p, dim_size_dtype) 140 for p in partitioned_dim_sizes) 141 inner_dim_sizes = math_ops.cast(inner_dim_sizes, dim_size_dtype) 142 143 self._partitioned_dim_sizes = partitioned_dim_sizes 144 self._inner_dim_sizes = inner_dim_sizes 145 146 def __repr__(self): 147 return ('RaggedTensorDynamicShape' 148 '(partitioned_dim_sizes=%r, inner_dim_sizes=%r)' % 149 (self._partitioned_dim_sizes, self._inner_dim_sizes)) 150 151 @staticmethod 152 def from_dim_sizes(dim_sizes): 153 """Constructs a ragged shape from a list of dimension sizes. 154 155 This list contains a single tensor for each dimension, where the tensor 156 is a scalar if the dimension is uniform, or a vector if the dimension is 157 ragged. 158 159 Args: 160 dim_sizes: List of int32 or int64 scalars or vectors. 161 162 Returns: 163 A RaggedTensorDynamicShape. 164 """ 165 with ops.name_scope(None, 'RaggedTensorDynamicShapeFromDimensionSizes', 166 [dim_sizes]): 167 dim_sizes = tuple( 168 ops.convert_to_tensor(size, preferred_dtype=dtypes.int64, 169 name='dim_sizes') for size in dim_sizes) 170 # Split the dimensions into partitioned & inner dimensions. 171 inner_split = 0 172 for dim, dim_size in enumerate(dim_sizes): 173 if dim_size.shape.ndims == 1: 174 inner_split = dim + 1 175 elif dim_size.shape.ndims != 0: 176 raise ValueError('Each dim_size must be a scalar or a vector') 177 return RaggedTensorDynamicShape(dim_sizes[:inner_split], 178 dim_sizes[inner_split:]) 179 180 @classmethod 181 def from_tensor(cls, rt_input, dim_size_dtype=None): 182 """Constructs a ragged shape for a potentially ragged tensor.""" 183 with ops.name_scope(None, 'RaggedTensorDynamicShapeFromTensor', [rt_input]): 184 rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt_input) 185 if not ragged_tensor.is_ragged(rt_input): 186 return cls([], array_ops.shape(rt_input)) 187 else: 188 partitioned_dim_sizes = ( 189 (rt_input.nrows(),) + rt_input.nested_row_lengths()) 190 return RaggedTensorDynamicShape( 191 partitioned_dim_sizes, 192 array_ops.shape(rt_input.flat_values)[1:], 193 dim_size_dtype=dim_size_dtype) 194 195 def dimension_size(self, axis): 196 """Returns the size of slices across the specified dimension.""" 197 if not isinstance(axis, int): 198 raise TypeError('axis must be an integer') 199 partitioned_ndims = len(self._partitioned_dim_sizes) 200 if axis < partitioned_ndims: 201 return self._partitioned_dim_sizes[axis] 202 else: 203 return self._inner_dim_sizes[axis - partitioned_ndims] 204 205 def is_ragged(self, axis): 206 """Returns true if the indicated dimension is ragged.""" 207 if not isinstance(axis, int): 208 raise TypeError('axis must be an integer') 209 rank = self.rank 210 if axis < 0: 211 raise ValueError('Negative axis values are not supported') 212 elif rank is not None and axis >= rank: 213 raise ValueError('Expected axis=%s < rank=%s' % (axis, rank)) 214 else: 215 return (axis > 0 and axis < len(self._partitioned_dim_sizes) and 216 self._partitioned_dim_sizes[axis].shape.ndims == 1) 217 218 @property 219 def rank(self): 220 """The number of dimensions in this shape, or None if unknown.""" 221 inner_ndims = tensor_shape.dimension_value(self._inner_dim_sizes.shape[0]) 222 if inner_ndims is None: 223 return None 224 else: 225 return len(self._partitioned_dim_sizes) + inner_ndims 226 227 @property 228 def partitioned_dim_sizes(self): 229 """The partitioned dimension sizes for this shape. 230 231 Returns: 232 A `list` of 0-D or 1-D integer `Tensor`. 233 """ 234 return self._partitioned_dim_sizes 235 236 @property 237 def inner_dim_sizes(self): 238 """The inner dimension sizes for this shape. 239 240 Returns: 241 A 1-D integer `Tensor`. 242 """ 243 return self._inner_dim_sizes 244 245 @property 246 def num_partitioned_dimensions(self): 247 """The number of partitioned dimensions in this shape.""" 248 return len(self._partitioned_dim_sizes) 249 250 @property 251 def num_inner_dimensions(self): 252 """The number of inner dimensions, or `None` if not statically known.""" 253 return tensor_shape.dimension_value(self._inner_dim_sizes.shape[0]) 254 255 @property 256 def dim_size_dtype(self): 257 """DType used by this shape for dimension sizes.""" 258 return self._inner_dim_sizes.dtype 259 260 def broadcast_to_rank(self, rank): 261 """Adds leading size-1 dimensions to broadcast `self` to the given rank. 262 263 E.g., if `shape1` is `[3, (D2), 4]`, then `shape1.broadcast_to_rank(5)` 264 is `[1, 1, 3, (D2), 4]`. 265 266 Args: 267 rank: The rank for the returned shape. 268 269 Returns: 270 A RaggedTensorDynamicShape with `rank` dimensions, whose inner dimensions 271 have the same size as `self` and whose outer dimensions have size `1`. 272 273 Raises: 274 ValueError: If `self.rank` is unknown or greater than `rank`. 275 """ 276 if self.rank is None: 277 raise ValueError('Unable to broadcast: self.rank is unknown') 278 dims_to_add = rank - self.rank 279 if dims_to_add < 0: 280 raise ValueError('Unable to broadcast: rank=%d must be greater than ' 281 'self.rank=%d.' % (rank, self.rank)) 282 elif dims_to_add == 0: 283 return self 284 elif self._partitioned_dim_sizes: 285 partitioned_dims = (1,) * dims_to_add + self._partitioned_dim_sizes 286 return RaggedTensorDynamicShape(partitioned_dims, self._inner_dim_sizes) 287 else: 288 inner_dims = array_ops.concat( 289 [array_ops.ones([dims_to_add], self.dim_size_dtype), 290 self.inner_dim_sizes], 291 axis=0) 292 return RaggedTensorDynamicShape([], inner_dims) 293 294 def broadcast_dimension(self, axis, lengths): 295 """Returns a shape that is broadcast-compatible with self & lengths. 296 297 * If dimension[axis] is uniform and lengths is a scalar, the check 298 that either lengths==1 or axis==1 or lengths==axis, and tile 299 dimension[axis] with tf.where(lengths==axis, 1, axis) repeats. 300 301 * If dimension[axis] is uniform and lengths is a vector, then check 302 that dimension[axis]==1, and raggedly tile dimension[axis] with 303 lengths repeats. (we can skip tiling if we statically know that 304 slice_lengths == 1??) 305 306 * If dimension[axis] is ragged and lengths is a scalar, then check 307 that lengths==1. 308 309 * If dimension[axis] is ragged and lengths is a vector, then check 310 that self.dimension_size(axis) == lengths. 311 312 Args: 313 axis: `int`. The dimension to broadcast. 314 lengths: 0-D or 1-D integer `Tensor`. 315 316 Returns: 317 A `RaggedTensorDynamicShape`. 318 """ 319 lengths = ragged_util.convert_to_int_tensor( 320 lengths, name='lengths', dtype=self.dim_size_dtype) 321 # Check whether lengths is a scalar (for uniform dimensions) or 322 # vector (for ragged dimensions). 323 if lengths.shape.ndims is None: 324 raise ValueError('lengths must have a known rank.') 325 elif lengths.shape.ndims > 1: 326 raise ValueError('lengths must be a scalar or vector') 327 else: 328 lengths_is_scalar = (lengths.shape.ndims == 0) 329 330 # Verify that the shapes are compatible. 331 if self.is_ragged(axis): 332 if lengths_is_scalar: 333 condition = math_ops.equal(lengths, 1) 334 else: 335 condition = math_ops.reduce_all( 336 math_ops.equal(lengths, self.dimension_size(axis))) 337 else: 338 axis_dim_size = self.dimension_size(axis) 339 if lengths_is_scalar: 340 condition = ( 341 math_ops.equal(lengths, 1) | math_ops.equal(axis_dim_size, 1) 342 | math_ops.equal(axis_dim_size, lengths)) 343 else: 344 condition = math_ops.equal(axis_dim_size, 1) 345 broadcast_err = [ 346 'Unable to broadcast: dimension size mismatch in dimension', axis, 347 'lengths=', lengths, 'dim_size=', 348 self.dimension_size(axis) 349 ] 350 broadcast_check = control_flow_ops.Assert( 351 condition, data=broadcast_err, summarize=10) 352 353 with ops.control_dependencies([broadcast_check]): 354 # Partitioned dimensions: 355 if axis < self.num_partitioned_dimensions: 356 if self.is_ragged(axis): 357 # Use an identity op to make sure the check actually gets run. 358 return RaggedTensorDynamicShape( 359 self._partitioned_dim_sizes, 360 array_ops.identity(self.inner_dim_sizes)) 361 else: 362 return self._broadcast_uniform_partitioned_dimension(axis, lengths) 363 364 # Inner dimensions: 365 else: 366 if lengths_is_scalar: 367 return self._broadcast_inner_dimension_to_uniform(axis, lengths) 368 else: 369 if axis == 0: 370 raise ValueError('Unable to broadcast: ' 371 'outermost dimension must be uniform.') 372 return self._broadcast_inner_dimension_to_ragged(axis, lengths) 373 374 def num_slices_in_dimension(self, axis): 375 """Returns the total number of slices across the indicated dimension.""" 376 if axis < 0: 377 return constant_op.constant(1, dtype=self.dim_size_dtype) 378 elif self.is_ragged(axis): 379 return math_ops.reduce_sum(self._partitioned_dim_sizes[axis]) 380 else: 381 return self.dimension_size(axis) * self.num_slices_in_dimension(axis - 1) 382 383 def _broadcast_uniform_partitioned_dimension(self, axis, lengths): 384 """Broadcasts the partitioned dimension `axis` to match `lengths`.""" 385 axis_dim_size = self.dimension_size(axis) 386 partitioned_sizes = list(self._partitioned_dim_sizes[:axis]) 387 388 if lengths.shape.ndims == 0: 389 lengths = array_ops.where( 390 math_ops.equal(axis_dim_size, 1), lengths, axis_dim_size) 391 repeats = array_ops.where(math_ops.equal(axis_dim_size, 1), lengths, 1) 392 splits = array_ops.stack([0, self.num_slices_in_dimension(axis)]) 393 else: 394 splits = math_ops.range( 395 array_ops.size(lengths, out_type=self.dim_size_dtype) + 1) 396 repeats = lengths 397 398 partitioned_sizes.append(lengths) 399 400 for dim_size in self._partitioned_dim_sizes[axis + 1:]: 401 if dim_size.shape.ndims == 0: 402 partitioned_sizes.append(dim_size) 403 splits *= dim_size 404 else: 405 partitioned_sizes.append( 406 ragged_util.repeat_ranges(dim_size, splits, repeats)) 407 splits = array_ops.gather( 408 ragged_util.lengths_to_splits(dim_size), splits) 409 inner_sizes = self._inner_dim_sizes 410 return RaggedTensorDynamicShape(partitioned_sizes, inner_sizes) 411 412 def _broadcast_inner_dimension_to_uniform(self, axis, length): 413 """Broadcasts the inner dimension `axis` to match `lengths`.""" 414 dim_size = self.dimension_size(axis) 415 axis_in_inner_dims = axis - self.num_partitioned_dimensions 416 partitioned_sizes = self._partitioned_dim_sizes 417 inner_sizes = array_ops.concat([ 418 self._inner_dim_sizes[:axis_in_inner_dims], 419 [array_ops.where(math_ops.equal(dim_size, 1), length, dim_size)], 420 self._inner_dim_sizes[axis_in_inner_dims + 1:] 421 ], 422 axis=0) 423 return RaggedTensorDynamicShape(partitioned_sizes, inner_sizes) 424 425 def _broadcast_inner_dimension_to_ragged(self, axis, lengths): 426 axis_in_inner_dims = axis - self.num_partitioned_dimensions 427 partitioned_sizes = ( 428 self._partitioned_dim_sizes + tuple([ 429 self._inner_dim_sizes[i] for i in range(axis_in_inner_dims) 430 ]) + (lengths,)) 431 inner_sizes = self._inner_dim_sizes[axis_in_inner_dims + 1:] 432 return RaggedTensorDynamicShape(partitioned_sizes, inner_sizes) 433 434 def with_dim_size_dtype(self, dtype): 435 if dtype not in (dtypes.int32, dtypes.int64): 436 raise ValueError('dtype must be int32 or int64') 437 if self.dim_size_dtype == dtype: 438 return self 439 return RaggedTensorDynamicShape( 440 [math_ops.cast(p, dtype) for p in self._partitioned_dim_sizes], 441 math_ops.cast(self._inner_dim_sizes, dtype)) 442 443 444def broadcast_dynamic_shape(shape_x, shape_y): 445 """Returns the shape formed by broadcasting two shapes to be compatible. 446 447 Args: 448 shape_x: A `RaggedTensorDynamicShape` 449 shape_y: A `RaggedTensorDynamicShape` 450 451 Returns: 452 A `RaggedTensorDynamicShape`. 453 Raises: 454 ValueError: If `shape_x` and `shape_y` are not broadcast-compatible. 455 """ 456 if not isinstance(shape_x, RaggedTensorDynamicShape): 457 raise TypeError('shape_x must be a RaggedTensorDynamicShape') 458 if not isinstance(shape_y, RaggedTensorDynamicShape): 459 raise TypeError('shape_y must be a RaggedTensorDynamicShape') 460 461 # Broadcast both shapes to have the same rank. 462 if shape_x.rank is None or shape_y.rank is None: 463 raise ValueError('Unable to broadcast: unknown rank') 464 broadcast_rank = max(shape_x.rank, shape_y.rank) 465 shape_x = shape_x.broadcast_to_rank(broadcast_rank) 466 shape_y = shape_y.broadcast_to_rank(broadcast_rank) 467 468 # Broadcast dimensions one at a time, starting from the outermost dimension. 469 for axis in range(broadcast_rank): 470 shape_x = shape_x.broadcast_dimension(axis, shape_y.dimension_size(axis)) 471 shape_y = shape_y.broadcast_dimension(axis, shape_x.dimension_size(axis)) 472 473 return shape_x 474 475 476def broadcast_to(rt_input, shape, broadcast_inner_dimensions=True): 477 """Broadcasts a potentially ragged tensor to a ragged shape. 478 479 Tiles `rt_input` as necessary to match the given shape. 480 481 Behavior is undefined if `rt_input` is not broadcast-compatible with `shape`. 482 483 Args: 484 rt_input: The potentially ragged tensor to broadcast. 485 shape: A `RaggedTensorDynamicShape` 486 broadcast_inner_dimensions: If false, then inner dimensions will not be 487 tiled. 488 489 Returns: 490 A potentially ragged tensor whose values are taken from 491 `rt_input`, and whose shape matches `shape`. 492 """ 493 if not isinstance(shape, RaggedTensorDynamicShape): 494 raise TypeError('shape must be a RaggedTensorDynamicShape') 495 rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt_input) 496 497 # Broadcasting to a uniform shape. 498 if shape.num_partitioned_dimensions == 0: 499 return _broadcast_to_uniform_shape(rt_input, shape, 500 broadcast_inner_dimensions) 501 else: 502 return _broadcast_to_ragged_shape(rt_input, shape, 503 broadcast_inner_dimensions) 504 505 506def _broadcast_to_uniform_shape(rt_input, shape, broadcast_inner_dimensions): 507 """Broadcasts rt_input to the uniform shape `shape`.""" 508 if isinstance(rt_input, ragged_tensor.RaggedTensor): 509 raise ValueError('Incompatible with shape: ragged rank mismatch') 510 if broadcast_inner_dimensions: 511 return array_ops.broadcast_to(rt_input, shape.inner_dim_sizes) 512 else: 513 return rt_input 514 515 516def _broadcast_to_ragged_shape(rt_input, dst_shape, broadcast_inner_dimensions): 517 """Broadcasts rt_input to the ragged shape `dst_shape`.""" 518 # Check that rt_input and dst_shape have the same row_splits dtype. 519 if (isinstance(rt_input, ragged_tensor.RaggedTensor) and 520 rt_input.row_splits.dtype != dst_shape.dim_size_dtype): 521 if not ragged_config.auto_cast_partition_dtype(): 522 raise ValueError('rt_input and dst_shape have different row_split ' 523 'dtypes; use RaggedTensor.with_row_splits_dtype() or ' 524 'RaggedTensorDynamicShape.with_dim_size_dtype() to ' 525 'convert to a compatible dtype.') 526 rt_input = rt_input.with_row_splits_dtype(dtypes.int64) 527 dst_shape = dst_shape.with_dim_size_dtype(dtypes.int64) 528 529 # dst_shape's rank and ragged_rank must be greater than or equal to rt_input's 530 if rt_input.shape.ndims is None or dst_shape.rank is None: 531 raise ValueError('Unable to broadcast: unknown rank') 532 if rt_input.shape.ndims > dst_shape.rank: 533 raise ValueError('Incompatible with shape: rank mismatch') 534 if (isinstance(rt_input, ragged_tensor.RaggedTensor) and 535 rt_input.ragged_rank >= dst_shape.num_partitioned_dimensions): 536 raise ValueError('Incompatible with shape: ragged rank mismatch') 537 538 src_shape = RaggedTensorDynamicShape.from_tensor(rt_input) 539 src_shape = src_shape.broadcast_to_rank(dst_shape.rank) 540 541 # Add dimensions to rt_input so its rank and ragged_rank matches dst_shape. 542 if dst_shape.rank > rt_input.shape.ndims: 543 if rt_input.shape.ndims < dst_shape.num_inner_dimensions + 1: 544 rt_input = array_ops.reshape( 545 rt_input, array_ops.concat([[-1], dst_shape.inner_dim_sizes], axis=0)) 546 for _ in range(dst_shape.rank - rt_input.shape.ndims): 547 if ragged_tensor.is_ragged(rt_input): 548 nrows = rt_input.nrows() 549 else: 550 nrows = array_ops.shape(rt_input, 551 out_type=dst_shape.dim_size_dtype)[0] 552 rt_input = ragged_tensor.RaggedTensor.from_row_lengths(rt_input, [nrows], 553 validate=False) 554 555 # Add ragged dimensions to match dst_shape. 556 if ragged_tensor.is_ragged(rt_input): 557 inner_rank_diff = ( 558 rt_input.flat_values.shape.ndims - 1 - dst_shape.num_inner_dimensions) 559 if inner_rank_diff > 0: 560 rt_input = rt_input.with_flat_values( 561 ragged_tensor.RaggedTensor.from_tensor( 562 rt_input.flat_values, ragged_rank=inner_rank_diff, 563 row_splits_dtype=dst_shape.dim_size_dtype)) 564 else: 565 rt_input = ragged_tensor.RaggedTensor.from_tensor( 566 rt_input, ragged_rank=dst_shape.num_partitioned_dimensions - 1, 567 row_splits_dtype=dst_shape.dim_size_dtype) 568 569 # Do broadcasting for any dimensions that will remain uniform. We can do 570 # these all at once, since they're independent of one another. 571 multiples = [1] * dst_shape.rank 572 for axis in range(dst_shape.num_partitioned_dimensions): 573 if not src_shape.is_ragged(axis) and not dst_shape.is_ragged(axis): 574 src_size = src_shape.dimension_size(axis) 575 dst_size = dst_shape.dimension_size(axis) 576 if ((tensor_util.constant_value(src_size) in (1, None)) and 577 (tensor_util.constant_value(dst_size) != 1)): 578 multiples[axis] = array_ops.where( 579 math_ops.equal(src_size, 1), dst_size, 1) 580 if not all(isinstance(v, int) and v == 1 for v in multiples): 581 multiples = array_ops.stack(multiples, axis=0) 582 rt_input = ragged_array_ops.tile(rt_input, multiples) 583 584 if broadcast_inner_dimensions: 585 new_shape = array_ops.broadcast_dynamic_shape( 586 array_ops.shape( 587 rt_input.flat_values, out_type=dst_shape.dim_size_dtype), 588 array_ops.concat([[1], dst_shape.inner_dim_sizes], axis=0)) 589 rt_input = rt_input.with_flat_values( 590 array_ops.broadcast_to(rt_input.flat_values, new_shape)) 591 592 # Do broadcasting for dimensions that become ragged. We must do these from 593 # outermost to innermost. 594 for axis in range(dst_shape.num_partitioned_dimensions): 595 if not src_shape.is_ragged(axis) and dst_shape.is_ragged(axis): 596 dst_size = dst_shape.dimension_size(axis) 597 rt_input = _ragged_tile_axis(rt_input, axis, dst_size, 598 dst_shape.dim_size_dtype) 599 600 return rt_input 601 602 603def _ragged_tile_axis(rt_input, axis, repeats, row_splits_dtype): 604 """Tile a dimension of a RaggedTensor to match a ragged shape.""" 605 assert axis > 0 # Outermost dimension may not be ragged. 606 607 if not ragged_tensor.is_ragged(rt_input): 608 rt_input = ragged_tensor.RaggedTensor.from_tensor( 609 rt_input, ragged_rank=1, row_splits_dtype=row_splits_dtype) 610 611 if axis > 1: 612 return rt_input.with_values( 613 _ragged_tile_axis(rt_input.values, axis - 1, repeats, 614 row_splits_dtype)) 615 else: 616 src_row_splits = rt_input.nested_row_splits 617 src_row_lengths = rt_input.nested_row_lengths() 618 splits = src_row_splits[0] 619 620 dst_row_lengths = [repeats] 621 for i in range(1, len(src_row_lengths)): 622 dst_row_lengths.append( 623 ragged_util.repeat_ranges(src_row_lengths[i], splits, repeats)) 624 splits = array_ops.gather(src_row_splits[i], splits) 625 dst_values = ragged_util.repeat_ranges(rt_input.flat_values, splits, 626 repeats) 627 return ragged_tensor.RaggedTensor.from_nested_row_lengths( 628 dst_values, dst_row_lengths, validate=False) 629