1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Array operations for RaggedTensors.""" 16 17from typing import Optional 18from typing import Union 19 20from tensorflow.python.framework import dtypes 21from tensorflow.python.framework import ops 22from tensorflow.python.framework import sparse_tensor 23from tensorflow.python.framework import tensor_shape 24from tensorflow.python.framework import tensor_util 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import check_ops 27from tensorflow.python.ops import control_flow_ops 28from tensorflow.python.ops import data_flow_ops 29from tensorflow.python.ops import gen_ragged_array_ops 30from tensorflow.python.ops import math_ops 31from tensorflow.python.ops import sort_ops 32from tensorflow.python.ops.ragged import dynamic_ragged_shape 33from tensorflow.python.ops.ragged import ragged_functional_ops 34from tensorflow.python.ops.ragged import ragged_math_ops 35from tensorflow.python.ops.ragged import ragged_tensor 36from tensorflow.python.ops.ragged import ragged_util 37from tensorflow.python.ops.ragged import segment_id_ops 38from tensorflow.python.types import core as core_types 39from tensorflow.python.util import dispatch 40from tensorflow.python.util.tf_export import tf_export 41 42#=============================================================================== 43# Masking 44#=============================================================================== 45 46 47@tf_export('ragged.boolean_mask') 48@dispatch.add_dispatch_support 49def boolean_mask(data, mask, name=None): 50 """Applies a boolean mask to `data` without flattening the mask dimensions. 51 52 Returns a potentially ragged tensor that is formed by retaining the elements 53 in `data` where the corresponding value in `mask` is `True`. 54 55 * `output[a1...aA, i, b1...bB] = data[a1...aA, j, b1...bB]` 56 57 Where `j` is the `i`th `True` entry of `mask[a1...aA]`. 58 59 Note that `output` preserves the mask dimensions `a1...aA`; this differs 60 from `tf.boolean_mask`, which flattens those dimensions. 61 62 Args: 63 data: A potentially ragged tensor. 64 mask: A potentially ragged boolean tensor. `mask`'s shape must be a prefix 65 of `data`'s shape. `rank(mask)` must be known statically. 66 name: A name prefix for the returned tensor (optional). 67 68 Returns: 69 A potentially ragged tensor that is formed by retaining the elements in 70 `data` where the corresponding value in `mask` is `True`. 71 72 * `rank(output) = rank(data)`. 73 * `output.ragged_rank = max(data.ragged_rank, rank(mask) - 1)`. 74 75 Raises: 76 ValueError: if `rank(mask)` is not known statically; or if `mask.shape` is 77 not a prefix of `data.shape`. 78 79 #### Examples: 80 81 >>> # Aliases for True & False so data and mask line up. 82 >>> T, F = (True, False) 83 84 >>> tf.ragged.boolean_mask( # Mask a 2D Tensor. 85 ... data=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], 86 ... mask=[[T, F, T], [F, F, F], [T, F, F]]).to_list() 87 [[1, 3], [], [7]] 88 89 >>> tf.ragged.boolean_mask( # Mask a 2D RaggedTensor. 90 ... tf.ragged.constant([[1, 2, 3], [4], [5, 6]]), 91 ... tf.ragged.constant([[F, F, T], [F], [T, T]])).to_list() 92 [[3], [], [5, 6]] 93 94 >>> tf.ragged.boolean_mask( # Mask rows of a 2D RaggedTensor. 95 ... tf.ragged.constant([[1, 2, 3], [4], [5, 6]]), 96 ... tf.ragged.constant([True, False, True])).to_list() 97 [[1, 2, 3], [5, 6]] 98 """ 99 with ops.name_scope(name, 'RaggedMask', [data, mask]): 100 # Convert inputs to tensors. 101 data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data') 102 mask = ragged_tensor.convert_to_tensor_or_ragged_tensor( 103 mask, dtypes.bool, name='mask') 104 row_splits_dtype, (data, mask) = ragged_tensor.match_row_splits_dtypes( 105 data, mask, return_dtype=True) 106 107 # Get static rank of mask. 108 if mask.shape.ndims is None: 109 raise ValueError('mask.shape.ndims must be known statically.') 110 elif mask.shape.ndims == 0: 111 raise ValueError('mask cannot be scalar.') 112 113 # If mask is ragged, then recurse with a non-ragged mask. 114 if ragged_tensor.is_ragged(mask): 115 if not ragged_tensor.is_ragged(data): 116 data = ragged_tensor.RaggedTensor.from_tensor( 117 data, 118 ragged_rank=mask.ragged_rank, 119 row_splits_dtype=mask.row_splits.dtype) 120 # Check that mask.nested_row_splits is a prefix of 121 # data.nested_row_splits. 122 splits_list = [ 123 mask.nested_row_splits, data.nested_row_splits[:mask.ragged_rank] 124 ] 125 with ops.control_dependencies( 126 ragged_util.assert_splits_match(splits_list)): 127 # Strip off ragged `splits` until `mask` is non-ragged. Keep the splits 128 # that we strip off in `splits`, so we can add them back on after 129 # we recursively mask the non-ragged data. 130 splits = [] 131 while ragged_tensor.is_ragged(mask): 132 if mask.shape.ndims > 2: 133 splits.append(mask.row_splits) 134 else: 135 # Count the number of True mask values in each row to find the 136 # lengths of the filtered rows; then convert to splits. 137 int_mask = ragged_functional_ops.map_flat_values( 138 math_ops.cast, mask, dtype=row_splits_dtype) 139 masked_row_lengths = ragged_math_ops.reduce_sum(int_mask, axis=1) 140 splits.append(ragged_util.lengths_to_splits(masked_row_lengths)) 141 mask = mask.values 142 data = data.values 143 144 # Recursively apply the nested non-ragged mask to the nested data. 145 masked_values = boolean_mask(data, mask) 146 147 # Add the ragged `splits` back to the result. 148 masked_values = ragged_tensor.RaggedTensor.from_nested_row_splits( 149 masked_values, splits, validate=False) 150 151 return masked_values 152 153 # If mask is non-ragged and has rank 1, and data is ragged, then build a 154 # ragged tensor with the indicated rows. 155 elif ragged_tensor.is_ragged(data) and mask.shape.ndims == 1: 156 # Get the masked splits: first get the length of each row, then filter 157 # out the rows that we are deleting, and convert that filtered set of 158 # masks back to a splits tensor. 159 lengths = data.row_lengths() 160 masked_lengths = array_ops.boolean_mask(lengths, mask) 161 masked_splits = ragged_util.lengths_to_splits(masked_lengths) 162 163 # Get the masked values: first get row ids corresponding to each 164 # value, then use tf.gather to build a boolean mask that's false for 165 # values that come from rows that we are deleting, and use that mask to 166 # construct the masked values tensor. 167 segment_ids = segment_id_ops.row_splits_to_segment_ids(data.row_splits) 168 segment_mask = array_ops.gather(mask, segment_ids) 169 masked_values = boolean_mask(data.values, segment_mask) 170 171 return ragged_tensor.RaggedTensor.from_row_splits( 172 masked_values, masked_splits, validate=False) 173 174 # If mask is non-ragged and has rank>1, then convert it to be ragged, 175 # with a ragged rank matching data. 176 if ragged_tensor.is_ragged(data): 177 mask = ragged_tensor.RaggedTensor.from_tensor( 178 mask, 179 ragged_rank=min(data.ragged_rank, mask.shape.ndims - 1), 180 row_splits_dtype=data.row_splits.dtype) 181 return boolean_mask(data, mask) 182 183 # Otherwise, data and mask are both `Tensor`s. 184 else: 185 # Apply `boolean_mask` to get the masked values. 186 masked_values = array_ops.boolean_mask(data, mask) 187 188 if mask.shape.ndims >= 2: 189 # Add the innermost ragged dimension. For each innermost cell, get the 190 # number of values it contains. Then flatten that to get a list of 191 # cell lengths, and convert it to splits. Finally, combine the splits 192 # and values to get the innermost ragged tensor. 193 masked_lengths = math_ops.count_nonzero( 194 mask, axis=-1, dtype=row_splits_dtype) 195 flattened_masked_lengths = array_ops.reshape(masked_lengths, [-1]) 196 masked_values = ragged_tensor.RaggedTensor.from_row_lengths( 197 masked_values, flattened_masked_lengths, validate=False) 198 199 # Wrap remaining ragged dimensions. 200 if mask.shape.ndims > 2: 201 mask_shape = array_ops.shape(mask, out_type=row_splits_dtype) 202 split_size = math_ops.cumprod(mask_shape) + 1 203 for dim in range(mask.shape.ndims - 3, -1, -1): 204 elt_size = mask_shape[dim + 1] 205 masked_splits = math_ops.range(split_size[dim]) * elt_size 206 masked_values = ragged_tensor.RaggedTensor.from_row_splits( 207 masked_values, masked_splits, validate=False) 208 209 return masked_values 210 211 212#=============================================================================== 213# Tiling 214#=============================================================================== 215@dispatch.dispatch_for_api(array_ops.tile) 216def tile(input: ragged_tensor.Ragged, multiples, name=None): # pylint: disable=redefined-builtin 217 """Constructs a `RaggedTensor` by tiling a given `RaggedTensor`. 218 219 The values of `input` are replicated `multiples[i]` times along the 220 `i`th dimension (for each dimension `i`). For every dimension `axis` in 221 `input`, the length of each output element in that dimension is the 222 length of corresponding input element multiplied by `multiples[axis]`. 223 224 Args: 225 input: A `RaggedTensor`. 226 multiples: A 1-D integer `Tensor`. Length must be the same as the number of 227 dimensions in `input`. 228 name: A name for the operation (optional). 229 230 Returns: 231 A `RaggedTensor` with the same type, rank, and ragged_rank as `input`. 232 233 #### Example: 234 235 >>> rt = tf.ragged.constant([[1, 2], [3]]) 236 >>> tf.tile(rt, [3, 2]).to_list() 237 [[1, 2, 1, 2], [3, 3], [1, 2, 1, 2], [3, 3], [1, 2, 1, 2], [3, 3]] 238 """ 239 with ops.name_scope(name, 'RaggedTile', [input, multiples]): 240 input = ragged_tensor.convert_to_tensor_or_ragged_tensor( 241 input, name='input') 242 if not ragged_tensor.is_ragged(input): 243 return array_ops.tile(input, multiples, name) 244 multiples = ragged_util.convert_to_int_tensor( 245 multiples, name='multiples', dtype=input.row_splits.dtype) 246 multiples.shape.assert_has_rank(1) 247 248 # If the constant value of `multiples` is available, then we can use it 249 # to skip tiling dimensions where `multiples=1`. 250 const_multiples = tensor_util.constant_value(multiples) 251 252 return ragged_tensor.RaggedTensor.from_nested_row_splits( 253 _tile_ragged_values(input, multiples, const_multiples), 254 _tile_ragged_splits(input, multiples, const_multiples), 255 validate=False) 256 257 258def _tile_ragged_values(rt_input, multiples, const_multiples=None): 259 """Builds flat_values tensor for a tiled `RaggedTensor`. 260 261 Returns a tensor that repeats the values in 262 `rt_input.flat_values` in the 263 appropriate pattern to construct a `RaggedTensor` that tiles `rt_input` as 264 specified by `multiples`. 265 266 Args: 267 rt_input: The `RaggedTensor` whose values should be repeated. 268 multiples: A 1-D integer `tensor`, indicating how many times each dimension 269 should be repeated. 270 const_multiples: Optional constant value for multiples. Used to skip tiling 271 dimensions where `multiples=1`. 272 273 Returns: 274 A `Tensor` with the same type and rank as `rt_input.flat_values`. 275 276 #### Example: 277 278 >>> rt = tf.ragged.constant([[1, 2], [3]]) 279 >>> _tile_ragged_values(rt, tf.constant([3, 2])).numpy() 280 array([1, 2, 1, 2, 3, 3, 1, 2, 1, 2, 3, 3, 1, 2, 1, 2, 3, 3], dtype=int32) 281 """ 282 ragged_rank = rt_input.ragged_rank 283 nested_splits = rt_input.nested_row_splits 284 285 # Pointers to the values in `rt_input.flat_values`. 286 inner_value_ids = math_ops.range(nested_splits[-1][-1]) 287 288 # For each ragged dimension (working from the innermost to outermost), 289 # expand `inner_value_ids` as necessary to tile that dimension. 290 prev_splits = None 291 for axis in range(ragged_rank, 0, -1): 292 # Ragged splits for this dimension. 293 splits = nested_splits[axis - 1] 294 295 # Adjust splits so they point into `inner_value_ids` (instead of just 296 # pointing into the next dimension's values). 297 if prev_splits is not None: # Not the first pass through the loop. 298 splits = array_ops.gather(prev_splits * multiples[axis + 1], splits) 299 300 # Repeat each element in this ragged dimension `multiples[axis]` times. 301 if const_multiples is None or const_multiples[axis] != 1: 302 inner_value_ids = ragged_util.repeat_ranges(inner_value_ids, splits, 303 multiples[axis]) 304 305 prev_splits = splits 306 307 # Gather the tiled inner values. 308 ragged_tiled_values = array_ops.gather(rt_input.flat_values, inner_value_ids) 309 310 # Tile the flat_values for the uniform dimensions (i.e., for `axis=0` plus 311 # `axis=range(ragged_rank, rank)`). 312 inner_repeats = array_ops.concat([multiples[:1], multiples[ragged_rank + 1:]], 313 axis=0) 314 return array_ops.tile(ragged_tiled_values, inner_repeats) 315 316 317def _tile_ragged_splits(rt_input, multiples, const_multiples=None): 318 """Builds nested_split tensors for a tiled `RaggedTensor`. 319 320 Returns a list of split tensors that can be used to construct the 321 `RaggedTensor` that tiles `rt_input` as specified by `multiples`. 322 323 Args: 324 rt_input: The `RaggedTensor` that is being tiled. 325 multiples: A 1-D integer `tensor`, indicating how many times each dimension 326 should be repeated. 327 const_multiples: Optional constant value for multiples. Used to skip tiling 328 dimensions where `multiples=1`. 329 330 Returns: 331 A list of 1-D integer `Tensor`s (one for each ragged dimension in 332 `rt_input`). 333 334 #### Example: 335 336 >>> rt = tf.ragged.constant([[1, 2], [3]]) 337 >>> _tile_ragged_splits(rt, [3, 2]) 338 [<tf.Tensor: shape=(7,), dtype=int64, 339 numpy=array([ 0, 4, 6, 10, 12, 16, 18])>] 340 """ 341 ragged_rank = rt_input.ragged_rank 342 nested_splits = rt_input.nested_row_splits 343 344 # projected_splits[src_axis, dst_axis] contains the split points that divide 345 # the rows from src_axis in the list of dst_axis values. E.g., 346 # projected_splits[i, i] = nested_splits[i], and 347 # projected_splits[i, i+1] = gather(nested_splits[i+1], nested_splits[i]). 348 projected_splits = [{i: nested_splits[i]} for i in range(ragged_rank)] 349 for src_axis in range(ragged_rank): 350 for dst_axis in range(src_axis + 1, ragged_rank - 1): 351 projected_splits[src_axis][dst_axis] = array_ops.gather( 352 nested_splits[dst_axis], projected_splits[src_axis][dst_axis - 1]) 353 354 # For each ragged dimension: nested_splits[axis] -> result_splits[axis]. 355 result_splits = [] 356 for axis in range(ragged_rank): 357 # Get the length of each row for the input tensor for this dimension. 358 input_lengths = nested_splits[axis][1:] - nested_splits[axis][:-1] 359 360 # Multiply those lengths by the `multiples` of dimension axis+1, since 361 # each value will be repeated that number of times. 362 output_lengths = input_lengths * multiples[axis + 1] 363 364 # Repeat ranges of the row lengths as necessary for them to be tiled in 365 # each ragged dimension `d < axis`. (Start with dimension d=axis-1, and 366 # work our way up to dimension d=0.) 367 repeats = 1 368 for d in range(axis - 1, -1, -1): 369 if const_multiples is None or const_multiples[d + 1] != 1: 370 splits = projected_splits[d][axis - 1] * repeats 371 output_lengths = ragged_util.repeat_ranges(output_lengths, splits, 372 multiples[d + 1]) 373 repeats *= multiples[d + 1] 374 375 # Tile splits for the outermost (uniform) dimension. 376 output_lengths = array_ops.tile(output_lengths, multiples[:1]) 377 378 # Convert to splits. 379 result_splits.append(ragged_util.lengths_to_splits(output_lengths)) 380 381 return result_splits 382 383 384#=============================================================================== 385# Reshaping 386#=============================================================================== 387 388 389@dispatch.dispatch_for_api(array_ops.expand_dims_v2) 390def expand_dims(input: ragged_tensor.Ragged, axis, name=None): # pylint: disable=redefined-builtin 391 """Inserts a dimension with shape 1 into a potentially ragged tensor's shape. 392 393 Given a potentially ragged tenor `input`, this operation inserts a 394 dimension with size 1 at the dimension `axis` of `input`'s shape. 395 396 The following table gives some examples showing how `ragged.expand_dims` 397 impacts the shapes of different input tensors. Ragged dimensions are 398 indicated by enclosing them in parentheses. 399 400 input.shape | axis | result.shape 401 ----------------------- | ---- | ----------------------------- 402 `[D1, D2]` | `0` | `[1, D1, D2]` 403 `[D1, D2]` | `1` | `[D1, 1, D2]` 404 `[D1, D2]` | `2` | `[D1, D2, 1]` 405 `[D1, (D2), (D3), D4]` | `0` | `[1, D1, (D2), (D3), D4]` 406 `[D1, (D2), (D3), D4]` | `1` | `[D1, 1, (D2), (D3), D4]` 407 `[D1, (D2), (D3), D4]` | `2` | `[D1, (D2), 1, (D3), D4]` 408 `[D1, (D2), (D3), D4]` | `3` | `[D1, (D2), (D3), 1, D4]` 409 `[D1, (D2), (D3), D4]` | `4` | `[D1, (D2), (D3), D4, 1]` 410 411 Args: 412 input: The potentially tensor that should be expanded with a new dimension. 413 axis: An integer constant indicating where the new dimension should be 414 inserted. 415 name: A name for the operation (optional). 416 417 Returns: 418 A tensor with the same values as `input`, with an added dimension of 419 size 1 at `axis`. 420 421 #### Examples: 422 423 >>> rt = tf.ragged.constant([[1, 2], [3]]) 424 >>> print(rt.shape) 425 (2, None) 426 427 >>> expanded = tf.expand_dims(rt, axis=0) 428 >>> print(expanded.shape, expanded) 429 (1, 2, None) <tf.RaggedTensor [[[1, 2], [3]]]> 430 431 >>> expanded = tf.expand_dims(rt, axis=1) 432 >>> print(expanded.shape, expanded) 433 (2, 1, None) <tf.RaggedTensor [[[1, 2]], [[3]]]> 434 435 >>> expanded = tf.expand_dims(rt, axis=2) 436 >>> print(expanded.shape, expanded) 437 (2, None, 1) <tf.RaggedTensor [[[1], [2]], [[3]]]> 438 """ 439 with ops.name_scope(name, 'RaggedExpandDims', [input]): 440 input = ragged_tensor.convert_to_tensor_or_ragged_tensor( 441 input, name='input') 442 443 if not ragged_tensor.is_ragged(input): 444 return array_ops.expand_dims(input, axis) 445 446 ndims = None if input.shape.ndims is None else input.shape.ndims + 1 447 axis = array_ops.get_positive_axis(axis, ndims, ndims_name='rank(input)') 448 449 if axis == 0: 450 return ragged_tensor.RaggedTensor.from_uniform_row_length( 451 input, uniform_row_length=input.nrows(), nrows=1, validate=False) 452 elif axis == 1: 453 return ragged_tensor.RaggedTensor.from_uniform_row_length( 454 input, uniform_row_length=1, nrows=input.nrows(), validate=False) 455 else: 456 if ragged_tensor.is_ragged(input.values): 457 return input.with_values(expand_dims(input.values, axis - 1)) 458 else: 459 return input.with_values(array_ops.expand_dims(input.values, axis - 1)) 460 461 462@dispatch.dispatch_for_api(array_ops.expand_dims) 463def _ragged_expand_dims_v1( 464 input: ragged_tensor.Ragged, # pylint: disable=redefined-builtin 465 axis=None, 466 name=None, 467 dim=None): 468 if dim is not None: 469 axis = dim 470 return expand_dims(input=input, axis=axis, name=name) 471 472 473#=============================================================================== 474# RaggedTensor Size 475#=============================================================================== 476 477 478@dispatch.dispatch_for_api(array_ops.size_v2) 479def size(input: ragged_tensor.Ragged, out_type=dtypes.int32, name=None): # pylint: disable=redefined-builtin 480 """Returns the size of a potentially ragged tensor. 481 482 The size of a ragged tensor is the size of its inner values. 483 484 #### Example: 485 486 >>> tf.size(tf.ragged.constant([[1, 2], [3]])).numpy() 487 3 488 489 Args: 490 input: A potentially ragged `Tensor`. 491 out_type: The numeric output type for the operation. 492 name: A name for the operation (optional). 493 494 Returns: 495 A Tensor of type `out_type`. 496 """ 497 if ragged_tensor.is_ragged(input): 498 return array_ops.size(input.flat_values, out_type=out_type, name=name) 499 else: 500 return array_ops.size(input, out_type=out_type, name=name) 501 502 503@dispatch.dispatch_for_api(array_ops.size) 504def _ragged_size_v1( 505 input: ragged_tensor.Ragged, # pylint: disable=redefined-builtin 506 name=None, 507 out_type=dtypes.int32): 508 return size(input=input, out_type=out_type, name=name) 509 510 511#=============================================================================== 512# ragged.rank 513#=============================================================================== 514@dispatch.dispatch_for_api(array_ops.rank) 515def rank(input: ragged_tensor.Ragged, name=None): # pylint: disable=redefined-builtin 516 """Returns the rank of a RaggedTensor. 517 518 Returns a 0-D `int32` `Tensor` representing the rank of `input`. 519 520 #### Example: 521 522 >>> # shape of tensor 't' is [2, None, None] 523 >>> t = tf.ragged.constant([[[1], [2, 2]], [[3, 3, 3], [4, 4, 4, 4]]]) 524 >>> tf.rank(t).numpy() 525 3 526 527 Args: 528 input: A `RaggedTensor` 529 name: A name for the operation (optional). 530 531 Returns: 532 A `Tensor` of type `int32`. 533 """ 534 with ops.name_scope(name, 'RaggedRank', [input]) as name: 535 if not ragged_tensor.is_ragged(input): 536 return array_ops.rank(input, name) 537 538 return input.ragged_rank + array_ops.rank(input.flat_values) 539 540 541#=============================================================================== 542# ragged.one_hot 543#=============================================================================== 544@dispatch.dispatch_for_api(array_ops.one_hot) 545def ragged_one_hot(indices: ragged_tensor.Ragged, 546 depth, 547 on_value=None, 548 off_value=None, 549 axis=None, 550 dtype=None, 551 name=None): 552 """Applies tf.one_hot along the values of a RaggedTensor.""" 553 # Get the adjusted axis value for the call to array_ops.one_hot. 554 # Note: the only negative `axis` value supported by array_ops.one_hot is -1. 555 if isinstance(axis, int) and axis >= 0: 556 if axis <= indices.ragged_rank: 557 raise ValueError('axis (%d) must be greater than indices.ragged_rank ' 558 '(%d).' % (axis, indices.ragged_rank)) 559 axis -= indices.ragged_rank 560 561 with ops.name_scope(name, 'RaggedOneHot', 562 [indices, depth, on_value, off_value, axis]): 563 indices = ragged_tensor.convert_to_tensor_or_ragged_tensor( 564 indices, name='indices') 565 return indices.with_flat_values( 566 array_ops.one_hot(indices.flat_values, depth, on_value, off_value, axis, 567 dtype, name)) 568 569 570#=============================================================================== 571# ragged.stack_dynamic_partitions 572#=============================================================================== 573@tf_export('ragged.stack_dynamic_partitions') 574@dispatch.add_dispatch_support 575def stack_dynamic_partitions(data, partitions, num_partitions, name=None): 576 """Stacks dynamic partitions of a Tensor or RaggedTensor. 577 578 Returns a RaggedTensor `output` with `num_partitions` rows, where the row 579 `output[i]` is formed by stacking all slices `data[j1...jN]` such that 580 `partitions[j1...jN] = i`. Slices of `data` are stacked in row-major 581 order. 582 583 If `num_partitions` is an `int` (not a `Tensor`), then this is equivalent to 584 `tf.ragged.stack(tf.dynamic_partition(data, partitions, num_partitions))`. 585 586 #### Example: 587 588 >>> data = ['a', 'b', 'c', 'd', 'e'] 589 >>> partitions = [ 3, 0, 2, 2, 3] 590 >>> num_partitions = 5 591 >>> tf.ragged.stack_dynamic_partitions(data, partitions, num_partitions) 592 <tf.RaggedTensor [[b'b'], [], [b'c', b'd'], [b'a', b'e'], []]> 593 594 Args: 595 data: A `Tensor` or `RaggedTensor` containing the values to stack. 596 partitions: An `int32` or `int64` `Tensor` or `RaggedTensor` specifying the 597 partition that each slice of `data` should be added to. `partitions.shape` 598 must be a prefix of `data.shape`. Values must be greater than or equal to 599 zero, and less than `num_partitions`. `partitions` is not required to be 600 sorted. 601 num_partitions: An `int32` or `int64` scalar specifying the number of 602 partitions to output. This determines the number of rows in `output`. 603 name: A name prefix for the returned tensor (optional). 604 605 Returns: 606 A `RaggedTensor` containing the stacked partitions. The returned tensor 607 has the same dtype as `data`, and its shape is 608 `[num_partitions, (D)] + data.shape[partitions.rank:]`, where `(D)` is a 609 ragged dimension whose length is the number of data slices stacked for 610 each `partition`. 611 """ 612 with ops.name_scope(name, 'SegmentStack', [data, partitions, num_partitions]): 613 # Convert inputs to tensors. 614 data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data') 615 row_splits_dtype = ( 616 data.row_splits.dtype 617 if isinstance(data, ragged_tensor.RaggedTensor) else None) 618 partitions = ragged_tensor.convert_to_tensor_or_ragged_tensor( 619 partitions, name='partitions', preferred_dtype=row_splits_dtype) 620 num_partitions = ops.convert_to_tensor( 621 num_partitions, name='num_partitions', preferred_dtype=partitions.dtype) 622 if row_splits_dtype is not None: 623 partitions = math_ops.cast(partitions, row_splits_dtype) 624 num_partitions = math_ops.cast(num_partitions, partitions.dtype) 625 626 # Sanity-checks for shapes. 627 partitions_rank = partitions.shape.ndims 628 if partitions_rank is None: 629 raise ValueError('partitions must have known rank.') 630 num_partitions.shape.assert_has_rank(0) 631 partitions.shape.assert_is_compatible_with(data.shape[:partitions_rank]) 632 633 if partitions_rank == 0: 634 # If partitions is a scalar, then just create a RaggedTensor containing 635 # that single the complete `data` value in the specified row. 636 return ragged_tensor.RaggedTensor.from_value_rowids( 637 values=array_ops.stack([data]), 638 value_rowids=array_ops.stack([partitions]), 639 nrows=num_partitions, 640 validate=False) 641 642 elif partitions_rank == 1: 643 # If partitions is a vector (the typical case): we can just use data and 644 # partitions as the `values` and `value_rowids` for `from_value_rowids`, 645 # as long as we sort them first. 646 permutation = sort_ops.argsort(partitions, stable=True) 647 value_rowids = array_ops.gather(partitions, permutation) 648 values = array_ops.gather(data, permutation) 649 check = check_ops.assert_less( 650 value_rowids[-1:], 651 num_partitions, 652 message='partitions must be less than num_partitions') 653 with ops.control_dependencies([check]): 654 return ragged_tensor.RaggedTensor.from_value_rowids( 655 values, value_rowids, nrows=num_partitions, validate=False) 656 657 else: 658 # Handle higher-dimensional partitions via recursion. 659 if not isinstance(data, ragged_tensor.RaggedTensor): 660 data = ragged_tensor.RaggedTensor.from_tensor( 661 data, row_splits_dtype=partitions.dtype, ragged_rank=1) 662 if not isinstance(partitions, ragged_tensor.RaggedTensor): 663 partitions = ragged_tensor.RaggedTensor.from_tensor( 664 partitions, 665 row_splits_dtype=partitions.dtype, 666 ragged_rank=max(data.ragged_rank, partitions_rank - 1)) 667 check = check_ops.assert_equal( 668 data.row_splits, 669 partitions.row_splits, 670 message='data and partitions have incompatible ragged shapes') 671 with ops.control_dependencies([check]): 672 return stack_dynamic_partitions(data.values, partitions.values, 673 num_partitions) 674 675 676#=============================================================================== 677# Reverse 678#=============================================================================== 679@dispatch.dispatch_for_api(array_ops.reverse) 680def reverse(tensor: ragged_tensor.Ragged, axis, name=None): 681 """Reverses a RaggedTensor along the specified axes. 682 683 #### Example: 684 685 >>> data = tf.ragged.constant([ 686 ... [[1, 2], [3, 4]], [[5, 6]], [[7, 8], [9, 10], [11, 12]]]) 687 >>> tf.reverse(data, axis=[0, 2]) 688 <tf.RaggedTensor [[[8, 7], [10, 9], [12, 11]], [[6, 5]], [[2, 1], [4, 3]]]> 689 690 Args: 691 tensor: A 'RaggedTensor' to reverse. 692 axis: A list or tuple of 'int' or a constant 1D 'tf.Tensor'. The indices of 693 the axes to reverse. 694 name: A name prefix for the returned tensor (optional). 695 696 Returns: 697 A 'RaggedTensor'. 698 """ 699 type_error_msg = ('`axis` must be a list of int or a constant tensor' 700 'when reversing axes in a ragged tensor') 701 702 with ops.name_scope(name, 'Reverse', [tensor, axis]): 703 if isinstance(axis, ops.Tensor): 704 axis = tensor_util.constant_value(axis) 705 if axis is None: 706 raise TypeError(type_error_msg) 707 elif not (isinstance(axis, (list, tuple)) and 708 all(isinstance(dim, int) for dim in axis)): 709 raise TypeError(type_error_msg) 710 711 tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor( 712 tensor, name='tensor') 713 714 # Allow usage of negative values to specify innermost axes. 715 axis = [ 716 array_ops.get_positive_axis(dim, tensor.shape.rank, 'axis[%d]' % i, 717 'rank(tensor)') 718 for i, dim in enumerate(axis) 719 ] 720 721 # We only need to slice up to the max axis. If the axis list 722 # is empty, it should be 0. 723 slices = [slice(None)] * (max(axis) + 1 if axis else 0) 724 725 for dim in axis: 726 slices[dim] = slice(None, None, -1) 727 728 return tensor[tuple(slices)] 729 730 731#=============================================================================== 732# Cross 733#=============================================================================== 734 735 736@tf_export('ragged.cross') 737@dispatch.add_dispatch_support 738def cross(inputs, name=None): 739 """Generates feature cross from a list of tensors. 740 741 The input tensors must have `rank=2`, and must all have the same number of 742 rows. The result is a `RaggedTensor` with the same number of rows as the 743 inputs, where `result[row]` contains a list of all combinations of values 744 formed by taking a single value from each input's corresponding row 745 (`inputs[i][row]`). Values are combined by joining their strings with '_X_'. 746 E.g.: 747 748 >>> tf.ragged.cross([tf.ragged.constant([['a'], ['b', 'c']]), 749 ... tf.ragged.constant([['d'], ['e']]), 750 ... tf.ragged.constant([['f'], ['g']])]) 751 <tf.RaggedTensor [[b'a_X_d_X_f'], [b'b_X_e_X_g', b'c_X_e_X_g']]> 752 753 Args: 754 inputs: A list of `RaggedTensor` or `Tensor` or `SparseTensor`. 755 name: Optional name for the op. 756 757 Returns: 758 A 2D `RaggedTensor` of type `string`. 759 """ 760 return _cross_internal(inputs=inputs, hashed_output=False, name=name) 761 762 763@tf_export('ragged.cross_hashed') 764@dispatch.add_dispatch_support 765def cross_hashed(inputs, num_buckets=0, hash_key=None, name=None): 766 """Generates hashed feature cross from a list of tensors. 767 768 The input tensors must have `rank=2`, and must all have the same number of 769 rows. The result is a `RaggedTensor` with the same number of rows as the 770 inputs, where `result[row]` contains a list of all combinations of values 771 formed by taking a single value from each input's corresponding row 772 (`inputs[i][row]`). Values are combined by hashing together their 773 fingerprints. E.g.: 774 775 >>> tf.ragged.cross_hashed([tf.ragged.constant([['a'], ['b', 'c']]), 776 ... tf.ragged.constant([['d'], ['e']]), 777 ... tf.ragged.constant([['f'], ['g']])], 778 ... num_buckets=100) 779 <tf.RaggedTensor [[78], [66, 74]]> 780 781 Args: 782 inputs: A list of `RaggedTensor` or `Tensor` or `SparseTensor`. 783 num_buckets: A non-negative `int` that used to bucket the hashed values. If 784 `num_buckets != 0`, then `output = hashed_value % num_buckets`. 785 hash_key: Integer hash_key that will be used by the `FingerprintCat64` 786 function. If not given, a default key is used. 787 name: Optional name for the op. 788 789 Returns: 790 A 2D `RaggedTensor` of type `int64`. 791 """ 792 return _cross_internal( 793 inputs=inputs, 794 hashed_output=True, 795 num_buckets=num_buckets, 796 hash_key=hash_key, 797 name=name) 798 799 800_DEFAULT_CROSS_HASH_KEY = 0xDECAFCAFFE 801 802 803def _cross_internal(inputs, 804 hashed_output=False, 805 num_buckets=0, 806 hash_key=None, 807 name=None): 808 """Generates feature cross from a list of ragged and dense tensors.""" 809 if not isinstance(inputs, (tuple, list)): 810 raise TypeError('Inputs must be a list') 811 812 if hash_key is None: 813 hash_key = _DEFAULT_CROSS_HASH_KEY 814 815 ragged_inputs = [] 816 sparse_inputs = [] 817 dense_inputs = [] 818 input_order = [] 819 with ops.name_scope(name, 'RaggedCross', inputs): 820 for i, t in enumerate(inputs): 821 if sparse_tensor.is_sparse(t): 822 t = sparse_tensor.SparseTensor.from_value(t) 823 else: 824 t = ragged_tensor.convert_to_tensor_or_ragged_tensor(t) 825 if t.dtype.is_integer: 826 t = math_ops.cast(t, dtypes.int64) 827 elif t.dtype != dtypes.string: 828 raise ValueError('Unexpected dtype for inputs[%d]: %s' % (i, t.dtype)) 829 if isinstance(t, ragged_tensor.RaggedTensor): 830 if t.ragged_rank != 1: 831 raise ValueError('tf.ragged.cross only supports inputs with rank=2') 832 ragged_inputs.append(t) 833 input_order.append('R') 834 elif isinstance(t, sparse_tensor.SparseTensor): 835 sparse_inputs.append(t) 836 input_order.append('S') 837 else: 838 dense_inputs.append(t) 839 input_order.append('D') 840 841 out_values_type = dtypes.int64 if hashed_output else dtypes.string 842 if ragged_inputs and all( 843 t.row_splits.dtype == dtypes.int32 for t in ragged_inputs): 844 out_row_splits_type = dtypes.int32 845 else: 846 out_row_splits_type = dtypes.int64 847 848 # Convert hash_key from uint64 -> int64, since we need to pass it via 849 # an int64 attr. 850 if hash_key > 2**63: 851 hash_key -= 2**64 852 853 values_out, splits_out = gen_ragged_array_ops.ragged_cross( 854 ragged_values=[rt.values for rt in ragged_inputs], 855 ragged_row_splits=[rt.row_splits for rt in ragged_inputs], 856 sparse_indices=[st.indices for st in sparse_inputs], 857 sparse_values=[st.values for st in sparse_inputs], 858 sparse_shape=[st.dense_shape for st in sparse_inputs], 859 dense_inputs=dense_inputs, 860 input_order=''.join(input_order), 861 hashed_output=hashed_output, 862 num_buckets=num_buckets, 863 hash_key=hash_key, 864 out_values_type=out_values_type.as_datatype_enum, 865 out_row_splits_type=out_row_splits_type.as_datatype_enum, 866 name=name) 867 868 return ragged_tensor.RaggedTensor.from_row_splits( 869 values_out, splits_out, validate=False) 870 871 872#=============================================================================== 873# dynamic_partition 874#=============================================================================== 875@dispatch.dispatch_for_api(data_flow_ops.dynamic_partition) 876def dynamic_partition(data: ragged_tensor.RaggedOrDense, 877 partitions: ragged_tensor.RaggedOrDense, 878 num_partitions, 879 name=None): 880 """RaggedTensor dispatch override for tf.dynamic_partition.""" 881 if not isinstance(num_partitions, int) or num_partitions < 0: 882 raise TypeError('num_partitions must be a non-negative integer') 883 result = stack_dynamic_partitions(data, partitions, num_partitions, name) 884 return [result[i] for i in range(num_partitions)] 885 886 887#=============================================================================== 888# split 889#=============================================================================== 890@dispatch.dispatch_for_api(array_ops.split) 891def split(value: ragged_tensor.Ragged, 892 num_or_size_splits, 893 axis=0, 894 num=None, 895 name=None): 896 """Splits a RaggedTensor `value` into a list of sub RaggedTensors. 897 898 If `num_or_size_splits` is an `int`, then it splits `value` along the 899 dimension `axis` into `num_or_size_splits` smaller RaggedTensors. This 900 requires that `value.shape[axis]` is divisible by `num_or_size_splits`. 901 902 If `num_or_size_splits` is a 1-D Tensor (or list), then `value` is split into 903 `len(num_or_size_splits)` elements. The shape of the `i`-th element has the 904 same size as the `value` except along dimension `axis` where the size is 905 `num_or_size_splits[i]`. 906 907 Splits along a ragged dimension is not allowed. 908 909 For example: 910 911 >>> rt = tf.RaggedTensor.from_row_lengths( 912 ... np.arange(6 * 3).reshape(6, 3), row_lengths=[1, 2, 2, 1]) 913 >>> rt.shape 914 TensorShape([4, None, 3]) 915 >>> 916 >>> rt1, rt2 = tf.split(rt, 2) # uniform splits 917 >>> rt1.shape 918 TensorShape([2, None, 3]) 919 >>> rt2.shape 920 TensorShape([2, None, 3]) 921 >>> 922 >>> rt3, rt4, rt5 = tf.split(rt, [1, 2, 1]) # ragged splits 923 >>> rt3.shape 924 TensorShape([1, None, 3]) 925 >>> rt4.shape 926 TensorShape([2, None, 3]) 927 >>> rt5.shape 928 TensorShape([1, None, 3]) 929 >>> 930 >>> rt6, rt7 = tf.split(rt, [1, 2], axis=2) # splits along axis 2 931 >>> rt6.shape 932 TensorShape([4, None, 1]) 933 >>> rt7.shape 934 TensorShape([4, None, 2]) 935 936 Args: 937 value: The `RaggedTensor` to split. 938 num_or_size_splits: Either an `int` indicating the number of splits 939 along `axis` or a 1-D integer `Tensor` or Python list containing the sizes 940 of each output tensor along `axis`. If a Python int, then it must evenly 941 divide `value.shape[axis]`; otherwise the sum of sizes along the split 942 axis must match that of the `value`. 943 axis: An `int` or scalar `int32` `Tensor`. The dimension along which 944 to split. Must be in the range `[-rank(value), rank(value))`. Defaults to 945 0. 946 num: An `int` used to specify the number of outputs when 947 `num_or_size_splits` is a 1-D list or `Tensor` and its length is 948 statically unknown, e.g., specifying `tf.TensorSepc(None)` with 949 the `input_signature` argument of `tf.function` (optional). 950 name: A name for the operation (optional). 951 952 Returns: 953 if `num_or_size_splits` is an `int` returns a list of `num_or_size_splits` 954 `RaggedTensor` objects; if `num_or_size_splits` is a 1-D Tensor returns 955 `num_or_size_splits.get_shape[0]` `RaggedTensor` objects resulting from 956 splitting `value`. 957 958 Raises: 959 ValueError: If the dimension `axis` of `value` is a ragged dimension. 960 ValueError: If `num` is unspecified and cannot be inferred. 961 ValueError: If `num` is specified but doesn't match the length of 962 `num_or_size_splits`. 963 ValueError: If `num_or_size_splits` is an `int` and less than 1. 964 TypeError: If `num_or_size_splits` is not an `int` or 1-D 965 list or 1-D `Tensor`. 966 InvalidArgumentError: If the `axis` of `value` cannot be exactly splitted 967 by `num_or_size_splits`. 968 InvalidArgumentError: If `num_or_size_splits` is contains negative integers. 969 InvalidArgumentError: If `num_or_size_splits`'s static shape is unknown and 970 its dynamic shape is inconsistent `num`. 971 InvalidArgumentError: If `num_or_size_splits`'s static rank is unknown and 972 `axis` is a negative integer. 973 """ 974 with ops.name_scope(name, 'RaggedSplit'): 975 value = ragged_tensor.convert_to_tensor_or_ragged_tensor( 976 value, name='value') 977 if isinstance(num_or_size_splits, int) and num_or_size_splits == 1: 978 return [value] 979 980 # static assert 981 check_ops.assert_integer_v2( 982 num_or_size_splits, 983 message=('`num_or_size_splits` must be an `int` or 1-D list or ' 984 '`Tensor` of integers.')) 985 value_shape = dynamic_ragged_shape.DynamicRaggedShape.from_tensor(value) 986 axis = array_ops.get_positive_axis(axis, value_shape.rank) 987 try: 988 dim_size = value_shape[axis] 989 except ValueError: 990 raise ValueError('Cannot split a ragged dimension. Got `value` with ' 991 f'shape {value_shape} and `axis` {axis}.') 992 if isinstance(num_or_size_splits, int): 993 # Uniform split 994 num_splits = num_or_size_splits 995 if num_splits < 1: 996 raise ValueError('`num_or_size_splits` must be >=1 if it is an `int`.' 997 f'Received {num_or_size_splits}.') 998 split_length = math_ops.floordiv(dim_size, num_splits) 999 split_lengths = array_ops.repeat(split_length, num_splits) 1000 else: 1001 # Ragged split 1002 num_splits = None 1003 split_lengths = ops.convert_to_tensor(num_or_size_splits) 1004 if split_lengths.shape.ndims is not None: 1005 if split_lengths.shape.ndims != 1: 1006 raise TypeError('`num_or_size_splits` must be an `int` or 1-D list ' 1007 f'or `Tensor`. Received {num_or_size_splits}.') 1008 num_splits = tensor_shape.dimension_value(split_lengths.shape[0]) 1009 1010 if num_splits is None: 1011 if num is None: 1012 raise ValueError('`num` must be specified as an `int` when the ' 1013 'size of `num_or_size_split` is statically ' 1014 f'unknown. Received `num`: {num} and ' 1015 f'`num_or_size_split`: {num_or_size_splits}.') 1016 num_splits = num 1017 else: 1018 if num is not None and num != num_splits: 1019 raise ValueError('`num` does not match the size of ' 1020 f'`num_or_size_split`. Received `num`: {num} and ' 1021 f'size of `num_or_size_split`: {num_splits}.') 1022 1023 splits = array_ops.concat([[0], math_ops.cumsum(split_lengths)], axis=0) 1024 checks = [] 1025 checks.append( 1026 check_ops.assert_non_negative_v2( 1027 num_or_size_splits, 1028 message='`num_or_size_splits` must be non-negative.')) 1029 checks.append( 1030 check_ops.assert_equal_v2( 1031 num_splits, 1032 array_ops.shape(split_lengths)[0], 1033 message='`num` is inconsistent with `num_or_size_split.shape[0]`.')) 1034 checks.append( 1035 check_ops.assert_equal_v2( 1036 math_ops.cast(dim_size, splits.dtype), 1037 splits[-1], 1038 message=('Cannot exactly split the `axis` dimension of `value` ' 1039 'with the given `num_or_size_split`.'))) 1040 splits = control_flow_ops.with_dependencies(checks, splits) 1041 splited_rts = [] 1042 slices = [slice(None)] * (axis + 1) 1043 for i in range(num_splits): 1044 slices[-1] = slice(splits[i], splits[i + 1]) 1045 splited_rts.append(value[tuple(slices)]) 1046 return splited_rts 1047 1048 1049#=============================================================================== 1050# RaggedTensor shape operations 1051#=============================================================================== 1052 1053 1054@dispatch.dispatch_for_api(array_ops.reshape) 1055def ragged_reshape( 1056 tensor: ragged_tensor.RaggedOrDense, 1057 shape: dynamic_ragged_shape.DenseOrRaggedShape 1058) -> Union[ragged_tensor.RaggedTensor, ops.Tensor]: 1059 """Reshapes a tensor or ragged tensor.""" 1060 tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor( 1061 tensor, name='tensor') 1062 if isinstance(tensor, ragged_tensor.RaggedTensor): 1063 tensor = tensor.values 1064 1065 if isinstance(shape, dynamic_ragged_shape.DynamicRaggedShape): 1066 flat_values = array_ops.reshape(tensor, shape.inner_shape) 1067 return ragged_tensor.RaggedTensor._from_nested_row_partitions( # pylint: disable=protected-access 1068 flat_values, 1069 shape.row_partitions, 1070 validate=False) 1071 else: 1072 shape = ops.convert_to_tensor(shape, name='shape') 1073 return array_ops.reshape(tensor, shape) 1074 1075 1076@dispatch.dispatch_for_api(array_ops.broadcast_to) 1077def broadcast_to( 1078 input: ragged_tensor.RaggedOrDense, # pylint: disable=redefined-builtin 1079 shape: dynamic_ragged_shape.DynamicRaggedShape 1080) -> Union[ragged_tensor.RaggedTensor, ops.Tensor]: 1081 """Broadcasts a potentially ragged tensor to a ragged shape. 1082 1083 Tiles `input` as necessary to match the given shape. 1084 1085 Behavior is undefined if `input` is not broadcast-compatible with `shape`. 1086 1087 Args: 1088 input: The potentially ragged tensor to broadcast. 1089 shape: A `DynamicRaggedShape` 1090 1091 Returns: 1092 A potentially ragged tensor whose values are taken from 1093 `input`, and whose shape matches `shape`. 1094 """ 1095 return dynamic_ragged_shape.broadcast_to(input, shape) 1096 1097 1098# Note: default value for out_type needs to be int32, to match the 1099# default for tf.shape's out_type parameter. 1100@dispatch.dispatch_for_api(array_ops.shape) 1101def ragged_shape( 1102 input: ragged_tensor.Ragged, # pylint: disable=redefined-builtin 1103 name: Optional[str] = None, 1104 out_type=dtypes.int32) -> dynamic_ragged_shape.DynamicRaggedShape: 1105 """Returns the shape of a RaggedTensor. 1106 1107 Args: 1108 input: A `RaggedTensor` 1109 name: A name for the operation (optional). 1110 out_type: dtype used to encode the shape. 1111 1112 Returns: 1113 A `tf.experimental.DynamicRaggedShape` 1114 """ 1115 with ops.name_scope(name, 'RaggedShape', [input]): 1116 return dynamic_ragged_shape.DynamicRaggedShape.from_tensor(input, out_type) 1117 1118 1119@dispatch.dispatch_for_api(array_ops.broadcast_dynamic_shape) 1120def broadcast_dynamic_shape( 1121 shape_x: dynamic_ragged_shape.DenseOrRaggedShape, 1122 shape_y: dynamic_ragged_shape.DenseOrRaggedShape 1123) -> dynamic_ragged_shape.DynamicRaggedShape: 1124 """Returns the shape formed by broadcasting two shapes to be compatible. 1125 1126 1. If shape_x and shape_y both have row_partitions, then fail if their dtypes 1127 don't match. 1128 2. If neither has row_partitions and they have different dtypes, 1129 go with int64. 1130 3. If one has row_partitions, go with that dtype. 1131 1132 Args: 1133 shape_x: A `DynamicRaggedShape` 1134 shape_y: A `DynamicRaggedShape` 1135 1136 Returns: 1137 A `DynamicRaggedShape`. 1138 Raises: 1139 ValueError: If `shape_x` and `shape_y` are not broadcast-compatible. 1140 """ 1141 if not isinstance(shape_x, dynamic_ragged_shape.DynamicRaggedShape): 1142 shape_x = dynamic_ragged_shape.DynamicRaggedShape([], shape_x) 1143 if not isinstance(shape_y, dynamic_ragged_shape.DynamicRaggedShape): 1144 shape_y = dynamic_ragged_shape.DynamicRaggedShape([], shape_y) 1145 return dynamic_ragged_shape.broadcast_dynamic_shape(shape_x, shape_y) 1146 1147 1148@dispatch.dispatch_for_api(array_ops.ones) 1149def ones(shape: dynamic_ragged_shape.DynamicRaggedShape, 1150 dtype=dtypes.float32, 1151 name=None) -> ragged_tensor.RaggedOrDense: 1152 """Returns ones shaped like x.""" 1153 flat_values = array_ops.ones(shape.inner_shape, dtype=dtype, name=name) 1154 return shape._add_row_partitions(flat_values) # pylint: disable=protected-access 1155 1156 1157@dispatch.dispatch_for_api(array_ops.zeros) 1158def zeros(shape: dynamic_ragged_shape.DynamicRaggedShape, 1159 dtype=dtypes.float32, 1160 name=None) -> ragged_tensor.RaggedOrDense: 1161 """Returns ones shaped like x.""" 1162 flat_values = array_ops.zeros(shape.inner_shape, dtype=dtype, name=name) 1163 return shape._add_row_partitions(flat_values) # pylint: disable=protected-access 1164 1165 1166@dispatch.dispatch_for_api(array_ops.fill) 1167def fill(dims: dynamic_ragged_shape.DynamicRaggedShape, 1168 value: core_types.TensorLike, 1169 name: Optional[str] = None) -> ragged_tensor.RaggedOrDense: 1170 """Creates a tensor with shape `dims` and fills it with `value`.""" 1171 flat_values = array_ops.fill(dims.inner_shape, value, name=name) 1172 return dims._add_row_partitions(flat_values) # pylint: disable=protected-access 1173 1174 1175#=============================================================================== 1176# bitcast 1177#=============================================================================== 1178@dispatch.dispatch_for_api(array_ops.bitcast) 1179def bitcast( 1180 input: ragged_tensor.RaggedOrDense, # pylint: disable=redefined-builtin 1181 type, # pylint: disable=redefined-builtin 1182 name=None) -> ragged_tensor.RaggedOrDense: 1183 """RaggedTensor dispatch override for tf.bitcast.""" 1184 type = dtypes.as_dtype(type) 1185 with ops.name_scope(name, 'Bitcast', [input]): 1186 input = ragged_tensor.convert_to_tensor_or_ragged_tensor( 1187 input, name='input') 1188 if (input.dtype.size < type.size and input.flat_values.shape.rank < 2): 1189 raise ValueError('`input.flat_values` is required to have rank >= 2 when ' 1190 'input.dtype.size < type.size. Actual rank: ' 1191 f'{input.flat_values.shape.rank}') 1192 return input.with_flat_values(array_ops.bitcast(input.flat_values, type)) 1193