1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Array operations for RaggedTensors.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import ops 23from tensorflow.python.framework import tensor_util 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import check_ops 26from tensorflow.python.ops import math_ops 27from tensorflow.python.ops import sort_ops 28from tensorflow.python.ops.ragged import ragged_functional_ops 29from tensorflow.python.ops.ragged import ragged_math_ops 30from tensorflow.python.ops.ragged import ragged_tensor 31from tensorflow.python.ops.ragged import ragged_util 32from tensorflow.python.ops.ragged import segment_id_ops 33from tensorflow.python.util.tf_export import tf_export 34 35 36#=============================================================================== 37# Masking 38#=============================================================================== 39 40 41@tf_export('ragged.boolean_mask') 42def boolean_mask(data, mask, name=None): 43 """Applies a boolean mask to `data` without flattening the mask dimensions. 44 45 Returns a potentially ragged tensor that is formed by retaining the elements 46 in `data` where the corresponding value in `mask` is `True`. 47 48 * `output[a1...aA, i, b1...bB] = data[a1...aA, j, b1...bB]` 49 50 Where `j` is the `i`th `True` entry of `mask[a1...aA]`. 51 52 Note that `output` preserves the mask dimensions `a1...aA`; this differs 53 from `tf.boolean_mask`, which flattens those dimensions. 54 55 Args: 56 data: A potentially ragged tensor. 57 mask: A potentially ragged boolean tensor. `mask`'s shape must be a prefix 58 of `data`'s shape. `rank(mask)` must be known statically. 59 name: A name prefix for the returned tensor (optional). 60 61 Returns: 62 A potentially ragged tensor that is formed by retaining the elements in 63 `data` where the corresponding value in `mask` is `True`. 64 65 * `rank(output) = rank(data)`. 66 * `output.ragged_rank = max(data.ragged_rank, rank(mask) - 1)`. 67 68 Raises: 69 ValueError: if `rank(mask)` is not known statically; or if `mask.shape` is 70 not a prefix of `data.shape`. 71 72 #### Examples: 73 74 >>> # Aliases for True & False so data and mask line up. 75 >>> T, F = (True, False) 76 77 >>> tf.ragged.boolean_mask( # Mask a 2D Tensor. 78 ... data=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], 79 ... mask=[[T, F, T], [F, F, F], [T, F, F]]).to_list() 80 [[1, 3], [], [7]] 81 82 >>> tf.ragged.boolean_mask( # Mask a 2D RaggedTensor. 83 ... tf.ragged.constant([[1, 2, 3], [4], [5, 6]]), 84 ... tf.ragged.constant([[F, F, T], [F], [T, T]])).to_list() 85 [[3], [], [5, 6]] 86 87 >>> tf.ragged.boolean_mask( # Mask rows of a 2D RaggedTensor. 88 ... tf.ragged.constant([[1, 2, 3], [4], [5, 6]]), 89 ... tf.ragged.constant([True, False, True])).to_list() 90 [[1, 2, 3], [5, 6]] 91 """ 92 with ops.name_scope(name, 'RaggedMask', [data, mask]): 93 # Convert inputs to tensors. 94 data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data') 95 mask = ragged_tensor.convert_to_tensor_or_ragged_tensor( 96 mask, dtypes.bool, name='mask') 97 row_splits_dtype, (data, mask) = ragged_tensor.match_row_splits_dtypes( 98 data, mask, return_dtype=True) 99 100 # Get static rank of mask. 101 if mask.shape.ndims is None: 102 raise ValueError('mask.shape.ndims must be known statically.') 103 elif mask.shape.ndims == 0: 104 raise ValueError('mask cannot be scalar.') 105 106 # If mask is ragged, then recurse with a non-ragged mask. 107 if ragged_tensor.is_ragged(mask): 108 if not ragged_tensor.is_ragged(data): 109 data = ragged_tensor.RaggedTensor.from_tensor( 110 data, ragged_rank=mask.ragged_rank, 111 row_splits_dtype=mask.row_splits.dtype) 112 # Check that mask.nested_row_splits is a prefix of 113 # data.nested_row_splits. 114 splits_list = [ 115 mask.nested_row_splits, data.nested_row_splits[:mask.ragged_rank] 116 ] 117 with ops.control_dependencies( 118 ragged_util.assert_splits_match(splits_list)): 119 # Strip off ragged `splits` until `mask` is non-ragged. Keep the splits 120 # that we strip off in `splits`, so we can add them back on after 121 # we recursively mask the non-ragged data. 122 splits = [] 123 while ragged_tensor.is_ragged(mask): 124 if mask.shape.ndims > 2: 125 splits.append(mask.row_splits) 126 else: 127 # Count the number of True mask values in each row to find the 128 # lengths of the filtered rows; then convert to splits. 129 int_mask = ragged_functional_ops.map_flat_values( 130 math_ops.cast, mask, dtype=row_splits_dtype) 131 masked_row_lengths = ragged_math_ops.reduce_sum(int_mask, axis=1) 132 splits.append(ragged_util.lengths_to_splits(masked_row_lengths)) 133 mask = mask.values 134 data = data.values 135 136 # Recursively apply the nested non-ragged mask to the nested data. 137 masked_values = boolean_mask(data, mask) 138 139 # Add the ragged `splits` back to the result. 140 masked_values = ragged_tensor.RaggedTensor.from_nested_row_splits( 141 masked_values, splits, validate=False) 142 143 return masked_values 144 145 # If mask is non-ragged and has rank 1, and data is ragged, then build a 146 # ragged tensor with the indicated rows. 147 elif ragged_tensor.is_ragged(data) and mask.shape.ndims == 1: 148 # Get the masked splits: first get the length of each row, then filter 149 # out the rows that we are deleting, and convert that filtered set of 150 # masks back to a splits tensor. 151 lengths = data.row_lengths() 152 masked_lengths = array_ops.boolean_mask(lengths, mask) 153 masked_splits = ragged_util.lengths_to_splits(masked_lengths) 154 155 # Get the masked values: first get row ids corresponding to each 156 # value, then use tf.gather to build a boolean mask that's false for 157 # values that come from rows that we are deleting, and use that mask to 158 # construct the masked values tensor. 159 segment_ids = segment_id_ops.row_splits_to_segment_ids(data.row_splits) 160 segment_mask = array_ops.gather(mask, segment_ids) 161 masked_values = boolean_mask(data.values, segment_mask) 162 163 return ragged_tensor.RaggedTensor.from_row_splits(masked_values, 164 masked_splits, 165 validate=False) 166 167 # If mask is non-ragged and has rank>1, then convert it to be ragged, 168 # with a ragged rank matching data. 169 if ragged_tensor.is_ragged(data): 170 mask = ragged_tensor.RaggedTensor.from_tensor( 171 mask, ragged_rank=min(data.ragged_rank, mask.shape.ndims - 1), 172 row_splits_dtype=data.row_splits.dtype) 173 return boolean_mask(data, mask) 174 175 # Otherwise, data and mask are both `Tensor`s. 176 else: 177 # Apply `boolean_mask` to get the masked values. 178 masked_values = array_ops.boolean_mask(data, mask) 179 180 if mask.shape.ndims >= 2: 181 # Add the innermost ragged dimension. For each innermost cell, get the 182 # number of values it contains. Then flatten that to get a list of 183 # cell lengths, and convert it to splits. Finally, combine the splits 184 # and values to get the innermost ragged tensor. 185 masked_lengths = math_ops.count_nonzero(mask, axis=-1, 186 dtype=row_splits_dtype) 187 flattened_masked_lengths = array_ops.reshape(masked_lengths, [-1]) 188 masked_values = ragged_tensor.RaggedTensor.from_row_lengths( 189 masked_values, flattened_masked_lengths, validate=False) 190 191 # Wrap remaining ragged dimensions. 192 if mask.shape.ndims > 2: 193 mask_shape = array_ops.shape(mask, out_type=row_splits_dtype) 194 split_size = math_ops.cumprod(mask_shape) + 1 195 for dim in range(mask.shape.ndims - 3, -1, -1): 196 elt_size = mask_shape[dim + 1] 197 masked_splits = math_ops.range(split_size[dim]) * elt_size 198 masked_values = ragged_tensor.RaggedTensor.from_row_splits( 199 masked_values, masked_splits, validate=False) 200 201 return masked_values 202 203 204#=============================================================================== 205# Tiling 206#=============================================================================== 207def tile(input, multiples, name=None): # pylint: disable=redefined-builtin 208 """Constructs a `RaggedTensor` by tiling a given `RaggedTensor`. 209 210 The values of `input` are replicated `multiples[i]` times along the 211 `i`th dimension (for each dimension `i`). For every dimension `axis` in 212 `input`, the length of each output element in that dimension is the 213 length of corresponding input element multiplied by `multiples[axis]`. 214 215 Args: 216 input: A `RaggedTensor`. 217 multiples: A 1-D integer `Tensor`. Length must be the same as the number of 218 dimensions in `input`. 219 name: A name for the operation (optional). 220 221 Returns: 222 A `RaggedTensor` with the same type, rank, and ragged_rank as `input`. 223 224 #### Example: 225 226 >>> rt = tf.ragged.constant([[1, 2], [3]]) 227 >>> tf.tile(rt, [3, 2]).to_list() 228 [[1, 2, 1, 2], [3, 3], [1, 2, 1, 2], [3, 3], [1, 2, 1, 2], [3, 3]] 229 """ 230 with ops.name_scope(name, 'RaggedTile', [input, multiples]): 231 input = ragged_tensor.convert_to_tensor_or_ragged_tensor( 232 input, name='input') 233 if not ragged_tensor.is_ragged(input): 234 return array_ops.tile(input, multiples, name) 235 multiples = ragged_util.convert_to_int_tensor( 236 multiples, name='multiples', dtype=input.row_splits.dtype) 237 multiples.shape.assert_has_rank(1) 238 239 # If the constant value of `multiples` is available, then we can use it 240 # to skip tiling dimensions where `multiples=1`. 241 const_multiples = tensor_util.constant_value(multiples) 242 243 return ragged_tensor.RaggedTensor.from_nested_row_splits( 244 _tile_ragged_values(input, multiples, const_multiples), 245 _tile_ragged_splits(input, multiples, const_multiples), 246 validate=False) 247 248 249def _tile_ragged_values(rt_input, multiples, const_multiples=None): 250 """Builds flat_values tensor for a tiled `RaggedTensor`. 251 252 Returns a tensor that repeats the values in 253 `rt_input.flat_values` in the 254 appropriate pattern to construct a `RaggedTensor` that tiles `rt_input` as 255 specified by `multiples`. 256 257 Args: 258 rt_input: The `RaggedTensor` whose values should be repeated. 259 multiples: A 1-D integer `tensor`, indicating how many times each dimension 260 should be repeated. 261 const_multiples: Optional constant value for multiples. Used to skip tiling 262 dimensions where `multiples=1`. 263 264 Returns: 265 A `Tensor` with the same type and rank as `rt_input.flat_values`. 266 267 #### Example: 268 269 >>> rt = tf.ragged.constant([[1, 2], [3]]) 270 >>> _tile_ragged_values(rt, tf.constant([3, 2])).numpy() 271 array([1, 2, 1, 2, 3, 3, 1, 2, 1, 2, 3, 3, 1, 2, 1, 2, 3, 3], dtype=int32) 272 """ 273 ragged_rank = rt_input.ragged_rank 274 nested_splits = rt_input.nested_row_splits 275 276 # Pointers to the values in `rt_input.flat_values`. 277 inner_value_ids = math_ops.range(nested_splits[-1][-1]) 278 279 # For each ragged dimension (working from the innermost to outermost), 280 # expand `inner_value_ids` as necessary to tile that dimension. 281 prev_splits = None 282 for axis in range(ragged_rank, 0, -1): 283 # Ragged splits for this dimension. 284 splits = nested_splits[axis - 1] 285 286 # Adjust splits so they point into `inner_value_ids` (instead of just 287 # pointing into the next dimension's values). 288 if prev_splits is not None: # Not the first pass through the loop. 289 splits = array_ops.gather(prev_splits * multiples[axis + 1], splits) 290 291 # Repeat each element in this ragged dimension `multiples[axis]` times. 292 if const_multiples is None or const_multiples[axis] != 1: 293 inner_value_ids = ragged_util.repeat_ranges(inner_value_ids, splits, 294 multiples[axis]) 295 296 prev_splits = splits 297 298 # Gather the tiled inner values. 299 ragged_tiled_values = array_ops.gather(rt_input.flat_values, inner_value_ids) 300 301 # Tile the flat_values for the uniform dimensions (i.e., for `axis=0` plus 302 # `axis=range(ragged_rank, rank)`). 303 inner_repeats = array_ops.concat([multiples[:1], multiples[ragged_rank + 1:]], 304 axis=0) 305 return array_ops.tile(ragged_tiled_values, inner_repeats) 306 307 308def _tile_ragged_splits(rt_input, multiples, const_multiples=None): 309 """Builds nested_split tensors for a tiled `RaggedTensor`. 310 311 Returns a list of split tensors that can be used to construct the 312 `RaggedTensor` that tiles `rt_input` as specified by `multiples`. 313 314 Args: 315 rt_input: The `RaggedTensor` that is being tiled. 316 multiples: A 1-D integer `tensor`, indicating how many times each dimension 317 should be repeated. 318 const_multiples: Optional constant value for multiples. Used to skip tiling 319 dimensions where `multiples=1`. 320 321 Returns: 322 A list of 1-D integer `Tensor`s (one for each ragged dimension in 323 `rt_input`). 324 325 #### Example: 326 327 >>> rt = tf.ragged.constant([[1, 2], [3]]) 328 >>> _tile_ragged_splits(rt, [3, 2]) 329 [<tf.Tensor: shape=(7,), dtype=int64, 330 numpy=array([ 0, 4, 6, 10, 12, 16, 18])>] 331 """ 332 ragged_rank = rt_input.ragged_rank 333 nested_splits = rt_input.nested_row_splits 334 335 # projected_splits[src_axis, dst_axis] contains the split points that divide 336 # the rows from src_axis in the list of dst_axis values. E.g., 337 # projected_splits[i, i] = nested_splits[i], and 338 # projected_splits[i, i+1] = gather(nested_splits[i+1], nested_splits[i]). 339 projected_splits = [{i: nested_splits[i]} for i in range(ragged_rank)] 340 for src_axis in range(ragged_rank): 341 for dst_axis in range(src_axis + 1, ragged_rank - 1): 342 projected_splits[src_axis][dst_axis] = array_ops.gather( 343 nested_splits[dst_axis], 344 projected_splits[src_axis][dst_axis - 1]) 345 346 # For each ragged dimension: nested_splits[axis] -> result_splits[axis]. 347 result_splits = [] 348 for axis in range(ragged_rank): 349 # Get the length of each row for the input tensor for this dimension. 350 input_lengths = nested_splits[axis][1:] - nested_splits[axis][:-1] 351 352 # Multiply those lengths by the `multiples` of dimension axis+1, since 353 # each value will be repeated that number of times. 354 output_lengths = input_lengths * multiples[axis + 1] 355 356 # Repeat ranges of the row lengths as necessary for them to be tiled in 357 # each ragged dimension `d < axis`. (Start with dimension d=axis-1, and 358 # work our way up to dimension d=0.) 359 repeats = 1 360 for d in range(axis - 1, -1, -1): 361 if const_multiples is None or const_multiples[d + 1] != 1: 362 splits = projected_splits[d][axis - 1] * repeats 363 output_lengths = ragged_util.repeat_ranges(output_lengths, splits, 364 multiples[d + 1]) 365 repeats *= multiples[d + 1] 366 367 # Tile splits for the outermost (uniform) dimension. 368 output_lengths = array_ops.tile(output_lengths, multiples[:1]) 369 370 # Convert to splits. 371 result_splits.append(ragged_util.lengths_to_splits(output_lengths)) 372 373 return result_splits 374 375 376#=============================================================================== 377# Reshaping 378#=============================================================================== 379 380 381def expand_dims(input, axis, name=None): # pylint: disable=redefined-builtin 382 """Inserts a dimension with shape 1 into a potentially ragged tensor's shape. 383 384 Given a potentially ragged tenor `input`, this operation inserts a 385 dimension with size 1 at the dimension `axis` of `input`'s shape. 386 387 * If `input` is a `Tensor`, then this is equivalent to 388 `tf.expand_dims`. 389 * If `input` is ragged, and `axis=0`, then the new dimension will be 390 uniform; but the previously outermost dimension will become ragged. 391 * If `input` is ragged, and `0 < axis < input.ragged_rank`, then the 392 new dimension will be ragged. 393 * If `input` is ragged, and axis >= input.ragged_rank`, then the new 394 dimension will be uniform. 395 396 The following table gives some examples showing how `ragged.expand_dims` 397 impacts the shapes of different input tensors. Ragged dimensions are 398 indicated by enclosing them in parentheses. 399 400 input.shape | axis | result.shape 401 ----------------------- | ---- | ----------------------------- 402 `[D1, D2]` | `0` | `[1, D1, D2]` 403 `[D1, D2]` | `1` | `[D1, 1, D2]` 404 `[D1, D2]` | `2` | `[D1, D2, 1]` 405 `[D1, (D2), (D3), D4]` | `0` | `[1, (D1), (D2), (D3), D4]` 406 `[D1, (D2), (D3), D4]` | `1` | `[D1, (1), (D2), (D3), D4]` 407 `[D1, (D2), (D3), D4]` | `2` | `[D1, (D2), (1), (D3), D4]` 408 `[D1, (D2), (D3), D4]` | `3` | `[D1, (D2), (D3), 1, D4]` 409 `[D1, (D2), (D3), D4]` | `4` | `[D1, (D2), (D3), D4, 1]` 410 411 Args: 412 input: The potentially tensor that should be expanded with a new 413 dimension. 414 axis: An integer constant indicating where the new dimension should be 415 inserted. 416 name: A name for the operation (optional). 417 418 Returns: 419 A tensor with the same values as `input`, with an added dimension of 420 size 1 at `axis`. 421 422 #### Examples: 423 424 >>> rt = tf.ragged.constant([[1, 2], [3]]) 425 >>> print(rt.shape) 426 (2, None) 427 428 >>> expanded = tf.expand_dims(rt, axis=0) 429 >>> print(expanded.shape, expanded) 430 (1, None, None) <tf.RaggedTensor [[[1, 2], [3]]]> 431 432 >>> expanded = tf.expand_dims(rt, axis=1) 433 >>> print(expanded.shape, expanded) 434 (2, None, None) <tf.RaggedTensor [[[1, 2]], [[3]]]> 435 436 >>> expanded = tf.expand_dims(rt, axis=2) 437 >>> print(expanded.shape, expanded) 438 (2, None, 1) <tf.RaggedTensor [[[1], [2]], [[3]]]> 439 """ 440 with ops.name_scope(name, 'RaggedExpandDims', [input]): 441 input = ragged_tensor.convert_to_tensor_or_ragged_tensor( 442 input, name='input') 443 444 if not ragged_tensor.is_ragged(input): 445 return array_ops.expand_dims(input, axis) 446 447 ndims = None if input.shape.ndims is None else input.shape.ndims + 1 448 axis = ragged_util.get_positive_axis(axis, ndims) 449 if axis == 0: 450 values = input 451 splits = array_ops.stack([0, input.nrows()]) 452 elif axis == 1: 453 values = input 454 splits = math_ops.range(input.nrows() + 1) 455 else: 456 values = expand_dims(input.values, axis - 1) 457 splits = input.row_splits 458 459 return ragged_tensor.RaggedTensor.from_row_splits(values, splits, 460 validate=False) 461 462 463#=============================================================================== 464# RaggedTensor Size 465#=============================================================================== 466 467 468def size(input, out_type=dtypes.int32, name=None): # pylint: disable=redefined-builtin 469 """Returns the size of a potentially ragged tensor. 470 471 The size of a ragged tensor is the size of its inner values. 472 473 #### Example: 474 475 >>> tf.size(tf.ragged.constant([[1, 2], [3]])).numpy() 476 3 477 478 Args: 479 input: A potentially ragged `Tensor`. 480 out_type: The numeric output type for the operation. 481 name: A name for the operation (optional). 482 483 Returns: 484 A Tensor of type `out_type`. 485 """ 486 if ragged_tensor.is_ragged(input): 487 return array_ops.size(input.flat_values, out_type=out_type, name=name) 488 else: 489 return array_ops.size(input, out_type=out_type, name=name) 490 491 492#=============================================================================== 493# ragged.rank 494#=============================================================================== 495def rank(input, name=None): # pylint: disable=redefined-builtin 496 """Returns the rank of a RaggedTensor. 497 498 Returns a 0-D `int32` `Tensor` representing the rank of `input`. 499 500 #### Example: 501 502 >>> # shape of tensor 't' is [2, None, None] 503 >>> t = tf.ragged.constant([[[1], [2, 2]], [[3, 3, 3], [4, 4, 4, 4]]]) 504 >>> tf.rank(t).numpy() 505 3 506 507 Args: 508 input: A `RaggedTensor` 509 name: A name for the operation (optional). 510 511 Returns: 512 A `Tensor` of type `int32`. 513 """ 514 with ops.name_scope(name, 'RaggedRank', [input]) as name: 515 if not ragged_tensor.is_ragged(input): 516 return array_ops.rank(input, name) 517 518 return input.ragged_rank + array_ops.rank(input.flat_values) 519 520 521#=============================================================================== 522# ragged.one_hot 523#=============================================================================== 524def ragged_one_hot(indices, 525 depth, 526 on_value=None, 527 off_value=None, 528 axis=None, 529 dtype=None, 530 name=None): 531 """Applies tf.one_hot along the values of a RaggedTensor.""" 532 with ops.name_scope(name, 'RaggedOneHot', [indices]): 533 indices = ragged_tensor.convert_to_tensor_or_ragged_tensor( 534 indices, name='indices') 535 if axis is not None: 536 axis = ragged_util.get_positive_axis(axis, indices.shape.ndims) 537 if axis < indices.ragged_rank: 538 raise ValueError('axis may not be less than indices.ragged_rank.') 539 return indices.with_flat_values( 540 array_ops.one_hot(indices.flat_values, depth, on_value, off_value, axis, 541 dtype, name)) 542 543 544#=============================================================================== 545# ragged.stack_dynamic_partitions 546#=============================================================================== 547@tf_export('ragged.stack_dynamic_partitions') 548def stack_dynamic_partitions(data, partitions, num_partitions, name=None): 549 """Stacks dynamic partitions of a Tensor or RaggedTensor. 550 551 Returns a RaggedTensor `output` with `num_partitions` rows, where the row 552 `output[i]` is formed by stacking all slices `data[j1...jN]` such that 553 `partitions[j1...jN] = i`. Slices of `data` are stacked in row-major 554 order. 555 556 If `num_partitions` is an `int` (not a `Tensor`), then this is equivalent to 557 `tf.ragged.stack(tf.dynamic_partition(data, partitions, num_partitions))`. 558 559 #### Example: 560 561 >>> data = ['a', 'b', 'c', 'd', 'e'] 562 >>> partitions = [ 3, 0, 2, 2, 3] 563 >>> num_partitions = 5 564 >>> tf.ragged.stack_dynamic_partitions(data, partitions, num_partitions) 565 <tf.RaggedTensor [[b'b'], [], [b'c', b'd'], [b'a', b'e'], []]> 566 567 Args: 568 data: A `Tensor` or `RaggedTensor` containing the values to stack. 569 partitions: An `int32` or `int64` `Tensor` or `RaggedTensor` specifying the 570 partition that each slice of `data` should be added to. 571 `partitions.shape` must be a prefix of `data.shape`. Values must be 572 greater than or equal to zero, and less than `num_partitions`. 573 `partitions` is not required to be sorted. 574 num_partitions: An `int32` or `int64` scalar specifying the number of 575 partitions to output. This determines the number of rows in `output`. 576 name: A name prefix for the returned tensor (optional). 577 578 Returns: 579 A `RaggedTensor` containing the stacked partitions. The returned tensor 580 has the same dtype as `data`, and its shape is 581 `[num_partitions, (D)] + data.shape[partitions.rank:]`, where `(D)` is a 582 ragged dimension whose length is the number of data slices stacked for 583 each `partition`. 584 """ 585 with ops.name_scope(name, 'SegmentStack', [data, partitions, num_partitions]): 586 # Convert inputs to tensors. 587 data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data') 588 row_splits_dtype = ( 589 data.row_splits.dtype 590 if isinstance(data, ragged_tensor.RaggedTensor) else None) 591 partitions = ragged_tensor.convert_to_tensor_or_ragged_tensor( 592 partitions, name='partitions', preferred_dtype=row_splits_dtype) 593 num_partitions = ops.convert_to_tensor( 594 num_partitions, name='num_partitions', preferred_dtype=partitions.dtype) 595 if row_splits_dtype is not None: 596 partitions = math_ops.cast(partitions, row_splits_dtype) 597 num_partitions = math_ops.cast(num_partitions, partitions.dtype) 598 599 # Sanity-checks for shapes. 600 partitions_rank = partitions.shape.ndims 601 if partitions_rank is None: 602 raise ValueError('partitions must have known rank.') 603 num_partitions.shape.assert_has_rank(0) 604 partitions.shape.assert_is_compatible_with(data.shape[:partitions_rank]) 605 606 if partitions_rank == 0: 607 # If partitions is a scalar, then just create a RaggedTensor containing 608 # that single the complete `data` value in the specified row. 609 return ragged_tensor.RaggedTensor.from_value_rowids( 610 values=array_ops.stack([data]), 611 value_rowids=array_ops.stack([partitions]), 612 nrows=num_partitions, 613 validate=False) 614 615 elif partitions_rank == 1: 616 # If partitions is a vector (the typical case): we can just use data and 617 # partitions as the `values` and `value_rowids` for `from_value_rowids`, 618 # as long as we sort them first. 619 permutation = sort_ops.argsort(partitions, stable=True) 620 value_rowids = array_ops.gather(partitions, permutation) 621 values = array_ops.gather(data, permutation) 622 check = check_ops.assert_less( 623 value_rowids[-1:], 624 num_partitions, 625 message='partitions must be less than num_partitions') 626 with ops.control_dependencies([check]): 627 return ragged_tensor.RaggedTensor.from_value_rowids( 628 values, value_rowids, nrows=num_partitions, validate=False) 629 630 else: 631 # Handle higher-dimensional partitions via recursion. 632 if not isinstance(data, ragged_tensor.RaggedTensor): 633 data = ragged_tensor.RaggedTensor.from_tensor( 634 data, row_splits_dtype=partitions.dtype, ragged_rank=1) 635 if not isinstance(partitions, ragged_tensor.RaggedTensor): 636 partitions = ragged_tensor.RaggedTensor.from_tensor( 637 partitions, 638 row_splits_dtype=partitions.dtype, 639 ragged_rank=max(data.ragged_rank, partitions_rank - 1)) 640 check = check_ops.assert_equal( 641 data.row_splits, 642 partitions.row_splits, 643 message='data and partitions have incompatible ragged shapes') 644 with ops.control_dependencies([check]): 645 return stack_dynamic_partitions(data.values, partitions.values, 646 num_partitions) 647 648 649#=============================================================================== 650# Reverse 651#=============================================================================== 652def reverse(tensor, axis, name=None): 653 """Reverses a RaggedTensor along the specified axes. 654 655 #### Example: 656 657 >>> data = tf.ragged.constant([ 658 ... [[1, 2], [3, 4]], [[5, 6]], [[7, 8], [9, 10], [11, 12]]]) 659 >>> tf.reverse(data, axis=[0, 2]) 660 <tf.RaggedTensor [[[8, 7], [10, 9], [12, 11]], [[6, 5]], [[2, 1], [4, 3]]]> 661 662 Args: 663 tensor: A 'RaggedTensor' to reverse. 664 axis: A list or tuple of 'int' or a constant 1D 'tf.Tensor'. The indices 665 of the axes to reverse. 666 name: A name prefix for the returned tensor (optional). 667 668 Returns: 669 A 'RaggedTensor'. 670 """ 671 type_error_msg = ('`axis` must be a list of int or a constant tensor' 672 'when reversing axes in a ragged tensor') 673 674 with ops.name_scope(name, 'Reverse', [tensor, axis]): 675 if isinstance(axis, ops.Tensor): 676 axis = tensor_util.constant_value(axis) 677 if axis is None: 678 raise TypeError(type_error_msg) 679 elif not (isinstance(axis, (list, tuple)) and 680 all(isinstance(dim, int) for dim in axis)): 681 raise TypeError(type_error_msg) 682 683 tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor( 684 tensor, name='tensor') 685 686 # Allow usage of negative values to specify innermost axes. 687 axis = [ragged_util.get_positive_axis(dim, tensor.shape.rank) 688 for dim in axis] 689 690 # We only need to slice up to the max axis. If the axis list 691 # is empty, it should be 0. 692 slices = [slice(None)] * (max(axis) + 1 if axis else 0) 693 694 for dim in axis: 695 slices[dim] = slice(None, None, -1) 696 697 return tensor[tuple(slices)] 698