1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Array operations for RaggedTensors.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import ops 23from tensorflow.python.framework import sparse_tensor 24from tensorflow.python.framework import tensor_util 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import check_ops 27from tensorflow.python.ops import gen_ragged_array_ops 28from tensorflow.python.ops import math_ops 29from tensorflow.python.ops import sort_ops 30from tensorflow.python.ops.ragged import ragged_functional_ops 31from tensorflow.python.ops.ragged import ragged_math_ops 32from tensorflow.python.ops.ragged import ragged_tensor 33from tensorflow.python.ops.ragged import ragged_util 34from tensorflow.python.ops.ragged import segment_id_ops 35from tensorflow.python.util import dispatch 36from tensorflow.python.util.tf_export import tf_export 37 38#=============================================================================== 39# Masking 40#=============================================================================== 41 42 43@tf_export('ragged.boolean_mask') 44@dispatch.add_dispatch_support 45def boolean_mask(data, mask, name=None): 46 """Applies a boolean mask to `data` without flattening the mask dimensions. 47 48 Returns a potentially ragged tensor that is formed by retaining the elements 49 in `data` where the corresponding value in `mask` is `True`. 50 51 * `output[a1...aA, i, b1...bB] = data[a1...aA, j, b1...bB]` 52 53 Where `j` is the `i`th `True` entry of `mask[a1...aA]`. 54 55 Note that `output` preserves the mask dimensions `a1...aA`; this differs 56 from `tf.boolean_mask`, which flattens those dimensions. 57 58 Args: 59 data: A potentially ragged tensor. 60 mask: A potentially ragged boolean tensor. `mask`'s shape must be a prefix 61 of `data`'s shape. `rank(mask)` must be known statically. 62 name: A name prefix for the returned tensor (optional). 63 64 Returns: 65 A potentially ragged tensor that is formed by retaining the elements in 66 `data` where the corresponding value in `mask` is `True`. 67 68 * `rank(output) = rank(data)`. 69 * `output.ragged_rank = max(data.ragged_rank, rank(mask) - 1)`. 70 71 Raises: 72 ValueError: if `rank(mask)` is not known statically; or if `mask.shape` is 73 not a prefix of `data.shape`. 74 75 #### Examples: 76 77 >>> # Aliases for True & False so data and mask line up. 78 >>> T, F = (True, False) 79 80 >>> tf.ragged.boolean_mask( # Mask a 2D Tensor. 81 ... data=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], 82 ... mask=[[T, F, T], [F, F, F], [T, F, F]]).to_list() 83 [[1, 3], [], [7]] 84 85 >>> tf.ragged.boolean_mask( # Mask a 2D RaggedTensor. 86 ... tf.ragged.constant([[1, 2, 3], [4], [5, 6]]), 87 ... tf.ragged.constant([[F, F, T], [F], [T, T]])).to_list() 88 [[3], [], [5, 6]] 89 90 >>> tf.ragged.boolean_mask( # Mask rows of a 2D RaggedTensor. 91 ... tf.ragged.constant([[1, 2, 3], [4], [5, 6]]), 92 ... tf.ragged.constant([True, False, True])).to_list() 93 [[1, 2, 3], [5, 6]] 94 """ 95 with ops.name_scope(name, 'RaggedMask', [data, mask]): 96 # Convert inputs to tensors. 97 data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data') 98 mask = ragged_tensor.convert_to_tensor_or_ragged_tensor( 99 mask, dtypes.bool, name='mask') 100 row_splits_dtype, (data, mask) = ragged_tensor.match_row_splits_dtypes( 101 data, mask, return_dtype=True) 102 103 # Get static rank of mask. 104 if mask.shape.ndims is None: 105 raise ValueError('mask.shape.ndims must be known statically.') 106 elif mask.shape.ndims == 0: 107 raise ValueError('mask cannot be scalar.') 108 109 # If mask is ragged, then recurse with a non-ragged mask. 110 if ragged_tensor.is_ragged(mask): 111 if not ragged_tensor.is_ragged(data): 112 data = ragged_tensor.RaggedTensor.from_tensor( 113 data, 114 ragged_rank=mask.ragged_rank, 115 row_splits_dtype=mask.row_splits.dtype) 116 # Check that mask.nested_row_splits is a prefix of 117 # data.nested_row_splits. 118 splits_list = [ 119 mask.nested_row_splits, data.nested_row_splits[:mask.ragged_rank] 120 ] 121 with ops.control_dependencies( 122 ragged_util.assert_splits_match(splits_list)): 123 # Strip off ragged `splits` until `mask` is non-ragged. Keep the splits 124 # that we strip off in `splits`, so we can add them back on after 125 # we recursively mask the non-ragged data. 126 splits = [] 127 while ragged_tensor.is_ragged(mask): 128 if mask.shape.ndims > 2: 129 splits.append(mask.row_splits) 130 else: 131 # Count the number of True mask values in each row to find the 132 # lengths of the filtered rows; then convert to splits. 133 int_mask = ragged_functional_ops.map_flat_values( 134 math_ops.cast, mask, dtype=row_splits_dtype) 135 masked_row_lengths = ragged_math_ops.reduce_sum(int_mask, axis=1) 136 splits.append(ragged_util.lengths_to_splits(masked_row_lengths)) 137 mask = mask.values 138 data = data.values 139 140 # Recursively apply the nested non-ragged mask to the nested data. 141 masked_values = boolean_mask(data, mask) 142 143 # Add the ragged `splits` back to the result. 144 masked_values = ragged_tensor.RaggedTensor.from_nested_row_splits( 145 masked_values, splits, validate=False) 146 147 return masked_values 148 149 # If mask is non-ragged and has rank 1, and data is ragged, then build a 150 # ragged tensor with the indicated rows. 151 elif ragged_tensor.is_ragged(data) and mask.shape.ndims == 1: 152 # Get the masked splits: first get the length of each row, then filter 153 # out the rows that we are deleting, and convert that filtered set of 154 # masks back to a splits tensor. 155 lengths = data.row_lengths() 156 masked_lengths = array_ops.boolean_mask(lengths, mask) 157 masked_splits = ragged_util.lengths_to_splits(masked_lengths) 158 159 # Get the masked values: first get row ids corresponding to each 160 # value, then use tf.gather to build a boolean mask that's false for 161 # values that come from rows that we are deleting, and use that mask to 162 # construct the masked values tensor. 163 segment_ids = segment_id_ops.row_splits_to_segment_ids(data.row_splits) 164 segment_mask = array_ops.gather(mask, segment_ids) 165 masked_values = boolean_mask(data.values, segment_mask) 166 167 return ragged_tensor.RaggedTensor.from_row_splits( 168 masked_values, masked_splits, validate=False) 169 170 # If mask is non-ragged and has rank>1, then convert it to be ragged, 171 # with a ragged rank matching data. 172 if ragged_tensor.is_ragged(data): 173 mask = ragged_tensor.RaggedTensor.from_tensor( 174 mask, 175 ragged_rank=min(data.ragged_rank, mask.shape.ndims - 1), 176 row_splits_dtype=data.row_splits.dtype) 177 return boolean_mask(data, mask) 178 179 # Otherwise, data and mask are both `Tensor`s. 180 else: 181 # Apply `boolean_mask` to get the masked values. 182 masked_values = array_ops.boolean_mask(data, mask) 183 184 if mask.shape.ndims >= 2: 185 # Add the innermost ragged dimension. For each innermost cell, get the 186 # number of values it contains. Then flatten that to get a list of 187 # cell lengths, and convert it to splits. Finally, combine the splits 188 # and values to get the innermost ragged tensor. 189 masked_lengths = math_ops.count_nonzero( 190 mask, axis=-1, dtype=row_splits_dtype) 191 flattened_masked_lengths = array_ops.reshape(masked_lengths, [-1]) 192 masked_values = ragged_tensor.RaggedTensor.from_row_lengths( 193 masked_values, flattened_masked_lengths, validate=False) 194 195 # Wrap remaining ragged dimensions. 196 if mask.shape.ndims > 2: 197 mask_shape = array_ops.shape(mask, out_type=row_splits_dtype) 198 split_size = math_ops.cumprod(mask_shape) + 1 199 for dim in range(mask.shape.ndims - 3, -1, -1): 200 elt_size = mask_shape[dim + 1] 201 masked_splits = math_ops.range(split_size[dim]) * elt_size 202 masked_values = ragged_tensor.RaggedTensor.from_row_splits( 203 masked_values, masked_splits, validate=False) 204 205 return masked_values 206 207 208#=============================================================================== 209# Tiling 210#=============================================================================== 211def tile(input, multiples, name=None): # pylint: disable=redefined-builtin 212 """Constructs a `RaggedTensor` by tiling a given `RaggedTensor`. 213 214 The values of `input` are replicated `multiples[i]` times along the 215 `i`th dimension (for each dimension `i`). For every dimension `axis` in 216 `input`, the length of each output element in that dimension is the 217 length of corresponding input element multiplied by `multiples[axis]`. 218 219 Args: 220 input: A `RaggedTensor`. 221 multiples: A 1-D integer `Tensor`. Length must be the same as the number of 222 dimensions in `input`. 223 name: A name for the operation (optional). 224 225 Returns: 226 A `RaggedTensor` with the same type, rank, and ragged_rank as `input`. 227 228 #### Example: 229 230 >>> rt = tf.ragged.constant([[1, 2], [3]]) 231 >>> tf.tile(rt, [3, 2]).to_list() 232 [[1, 2, 1, 2], [3, 3], [1, 2, 1, 2], [3, 3], [1, 2, 1, 2], [3, 3]] 233 """ 234 with ops.name_scope(name, 'RaggedTile', [input, multiples]): 235 input = ragged_tensor.convert_to_tensor_or_ragged_tensor( 236 input, name='input') 237 if not ragged_tensor.is_ragged(input): 238 return array_ops.tile(input, multiples, name) 239 multiples = ragged_util.convert_to_int_tensor( 240 multiples, name='multiples', dtype=input.row_splits.dtype) 241 multiples.shape.assert_has_rank(1) 242 243 # If the constant value of `multiples` is available, then we can use it 244 # to skip tiling dimensions where `multiples=1`. 245 const_multiples = tensor_util.constant_value(multiples) 246 247 return ragged_tensor.RaggedTensor.from_nested_row_splits( 248 _tile_ragged_values(input, multiples, const_multiples), 249 _tile_ragged_splits(input, multiples, const_multiples), 250 validate=False) 251 252 253def _tile_ragged_values(rt_input, multiples, const_multiples=None): 254 """Builds flat_values tensor for a tiled `RaggedTensor`. 255 256 Returns a tensor that repeats the values in 257 `rt_input.flat_values` in the 258 appropriate pattern to construct a `RaggedTensor` that tiles `rt_input` as 259 specified by `multiples`. 260 261 Args: 262 rt_input: The `RaggedTensor` whose values should be repeated. 263 multiples: A 1-D integer `tensor`, indicating how many times each dimension 264 should be repeated. 265 const_multiples: Optional constant value for multiples. Used to skip tiling 266 dimensions where `multiples=1`. 267 268 Returns: 269 A `Tensor` with the same type and rank as `rt_input.flat_values`. 270 271 #### Example: 272 273 >>> rt = tf.ragged.constant([[1, 2], [3]]) 274 >>> _tile_ragged_values(rt, tf.constant([3, 2])).numpy() 275 array([1, 2, 1, 2, 3, 3, 1, 2, 1, 2, 3, 3, 1, 2, 1, 2, 3, 3], dtype=int32) 276 """ 277 ragged_rank = rt_input.ragged_rank 278 nested_splits = rt_input.nested_row_splits 279 280 # Pointers to the values in `rt_input.flat_values`. 281 inner_value_ids = math_ops.range(nested_splits[-1][-1]) 282 283 # For each ragged dimension (working from the innermost to outermost), 284 # expand `inner_value_ids` as necessary to tile that dimension. 285 prev_splits = None 286 for axis in range(ragged_rank, 0, -1): 287 # Ragged splits for this dimension. 288 splits = nested_splits[axis - 1] 289 290 # Adjust splits so they point into `inner_value_ids` (instead of just 291 # pointing into the next dimension's values). 292 if prev_splits is not None: # Not the first pass through the loop. 293 splits = array_ops.gather(prev_splits * multiples[axis + 1], splits) 294 295 # Repeat each element in this ragged dimension `multiples[axis]` times. 296 if const_multiples is None or const_multiples[axis] != 1: 297 inner_value_ids = ragged_util.repeat_ranges(inner_value_ids, splits, 298 multiples[axis]) 299 300 prev_splits = splits 301 302 # Gather the tiled inner values. 303 ragged_tiled_values = array_ops.gather(rt_input.flat_values, inner_value_ids) 304 305 # Tile the flat_values for the uniform dimensions (i.e., for `axis=0` plus 306 # `axis=range(ragged_rank, rank)`). 307 inner_repeats = array_ops.concat([multiples[:1], multiples[ragged_rank + 1:]], 308 axis=0) 309 return array_ops.tile(ragged_tiled_values, inner_repeats) 310 311 312def _tile_ragged_splits(rt_input, multiples, const_multiples=None): 313 """Builds nested_split tensors for a tiled `RaggedTensor`. 314 315 Returns a list of split tensors that can be used to construct the 316 `RaggedTensor` that tiles `rt_input` as specified by `multiples`. 317 318 Args: 319 rt_input: The `RaggedTensor` that is being tiled. 320 multiples: A 1-D integer `tensor`, indicating how many times each dimension 321 should be repeated. 322 const_multiples: Optional constant value for multiples. Used to skip tiling 323 dimensions where `multiples=1`. 324 325 Returns: 326 A list of 1-D integer `Tensor`s (one for each ragged dimension in 327 `rt_input`). 328 329 #### Example: 330 331 >>> rt = tf.ragged.constant([[1, 2], [3]]) 332 >>> _tile_ragged_splits(rt, [3, 2]) 333 [<tf.Tensor: shape=(7,), dtype=int64, 334 numpy=array([ 0, 4, 6, 10, 12, 16, 18])>] 335 """ 336 ragged_rank = rt_input.ragged_rank 337 nested_splits = rt_input.nested_row_splits 338 339 # projected_splits[src_axis, dst_axis] contains the split points that divide 340 # the rows from src_axis in the list of dst_axis values. E.g., 341 # projected_splits[i, i] = nested_splits[i], and 342 # projected_splits[i, i+1] = gather(nested_splits[i+1], nested_splits[i]). 343 projected_splits = [{i: nested_splits[i]} for i in range(ragged_rank)] 344 for src_axis in range(ragged_rank): 345 for dst_axis in range(src_axis + 1, ragged_rank - 1): 346 projected_splits[src_axis][dst_axis] = array_ops.gather( 347 nested_splits[dst_axis], projected_splits[src_axis][dst_axis - 1]) 348 349 # For each ragged dimension: nested_splits[axis] -> result_splits[axis]. 350 result_splits = [] 351 for axis in range(ragged_rank): 352 # Get the length of each row for the input tensor for this dimension. 353 input_lengths = nested_splits[axis][1:] - nested_splits[axis][:-1] 354 355 # Multiply those lengths by the `multiples` of dimension axis+1, since 356 # each value will be repeated that number of times. 357 output_lengths = input_lengths * multiples[axis + 1] 358 359 # Repeat ranges of the row lengths as necessary for them to be tiled in 360 # each ragged dimension `d < axis`. (Start with dimension d=axis-1, and 361 # work our way up to dimension d=0.) 362 repeats = 1 363 for d in range(axis - 1, -1, -1): 364 if const_multiples is None or const_multiples[d + 1] != 1: 365 splits = projected_splits[d][axis - 1] * repeats 366 output_lengths = ragged_util.repeat_ranges(output_lengths, splits, 367 multiples[d + 1]) 368 repeats *= multiples[d + 1] 369 370 # Tile splits for the outermost (uniform) dimension. 371 output_lengths = array_ops.tile(output_lengths, multiples[:1]) 372 373 # Convert to splits. 374 result_splits.append(ragged_util.lengths_to_splits(output_lengths)) 375 376 return result_splits 377 378 379#=============================================================================== 380# Reshaping 381#=============================================================================== 382 383 384def expand_dims(input, axis, name=None): # pylint: disable=redefined-builtin 385 """Inserts a dimension with shape 1 into a potentially ragged tensor's shape. 386 387 Given a potentially ragged tenor `input`, this operation inserts a 388 dimension with size 1 at the dimension `axis` of `input`'s shape. 389 390 The following table gives some examples showing how `ragged.expand_dims` 391 impacts the shapes of different input tensors. Ragged dimensions are 392 indicated by enclosing them in parentheses. 393 394 input.shape | axis | result.shape 395 ----------------------- | ---- | ----------------------------- 396 `[D1, D2]` | `0` | `[1, D1, D2]` 397 `[D1, D2]` | `1` | `[D1, 1, D2]` 398 `[D1, D2]` | `2` | `[D1, D2, 1]` 399 `[D1, (D2), (D3), D4]` | `0` | `[1, D1, (D2), (D3), D4]` 400 `[D1, (D2), (D3), D4]` | `1` | `[D1, 1, (D2), (D3), D4]` 401 `[D1, (D2), (D3), D4]` | `2` | `[D1, (D2), 1, (D3), D4]` 402 `[D1, (D2), (D3), D4]` | `3` | `[D1, (D2), (D3), 1, D4]` 403 `[D1, (D2), (D3), D4]` | `4` | `[D1, (D2), (D3), D4, 1]` 404 405 Args: 406 input: The potentially tensor that should be expanded with a new dimension. 407 axis: An integer constant indicating where the new dimension should be 408 inserted. 409 name: A name for the operation (optional). 410 411 Returns: 412 A tensor with the same values as `input`, with an added dimension of 413 size 1 at `axis`. 414 415 #### Examples: 416 417 >>> rt = tf.ragged.constant([[1, 2], [3]]) 418 >>> print(rt.shape) 419 (2, None) 420 421 >>> expanded = tf.expand_dims(rt, axis=0) 422 >>> print(expanded.shape, expanded) 423 (1, 2, None) <tf.RaggedTensor [[[1, 2], [3]]]> 424 425 >>> expanded = tf.expand_dims(rt, axis=1) 426 >>> print(expanded.shape, expanded) 427 (2, 1, None) <tf.RaggedTensor [[[1, 2]], [[3]]]> 428 429 >>> expanded = tf.expand_dims(rt, axis=2) 430 >>> print(expanded.shape, expanded) 431 (2, None, 1) <tf.RaggedTensor [[[1], [2]], [[3]]]> 432 """ 433 with ops.name_scope(name, 'RaggedExpandDims', [input]): 434 input = ragged_tensor.convert_to_tensor_or_ragged_tensor( 435 input, name='input') 436 437 if not ragged_tensor.is_ragged(input): 438 return array_ops.expand_dims(input, axis) 439 440 ndims = None if input.shape.ndims is None else input.shape.ndims + 1 441 axis = array_ops.get_positive_axis(axis, ndims, ndims_name='rank(input)') 442 443 if axis == 0: 444 return ragged_tensor.RaggedTensor.from_uniform_row_length( 445 input, uniform_row_length=input.nrows(), nrows=1, validate=False) 446 elif axis == 1: 447 return ragged_tensor.RaggedTensor.from_uniform_row_length( 448 input, uniform_row_length=1, nrows=input.nrows(), validate=False) 449 else: 450 if ragged_tensor.is_ragged(input.values): 451 return input.with_values(expand_dims(input.values, axis - 1)) 452 else: 453 return input.with_values(array_ops.expand_dims(input.values, axis - 1)) 454 455 456#=============================================================================== 457# RaggedTensor Size 458#=============================================================================== 459 460 461def size(input, out_type=dtypes.int32, name=None): # pylint: disable=redefined-builtin 462 """Returns the size of a potentially ragged tensor. 463 464 The size of a ragged tensor is the size of its inner values. 465 466 #### Example: 467 468 >>> tf.size(tf.ragged.constant([[1, 2], [3]])).numpy() 469 3 470 471 Args: 472 input: A potentially ragged `Tensor`. 473 out_type: The numeric output type for the operation. 474 name: A name for the operation (optional). 475 476 Returns: 477 A Tensor of type `out_type`. 478 """ 479 if ragged_tensor.is_ragged(input): 480 return array_ops.size(input.flat_values, out_type=out_type, name=name) 481 else: 482 return array_ops.size(input, out_type=out_type, name=name) 483 484 485#=============================================================================== 486# ragged.rank 487#=============================================================================== 488def rank(input, name=None): # pylint: disable=redefined-builtin 489 """Returns the rank of a RaggedTensor. 490 491 Returns a 0-D `int32` `Tensor` representing the rank of `input`. 492 493 #### Example: 494 495 >>> # shape of tensor 't' is [2, None, None] 496 >>> t = tf.ragged.constant([[[1], [2, 2]], [[3, 3, 3], [4, 4, 4, 4]]]) 497 >>> tf.rank(t).numpy() 498 3 499 500 Args: 501 input: A `RaggedTensor` 502 name: A name for the operation (optional). 503 504 Returns: 505 A `Tensor` of type `int32`. 506 """ 507 with ops.name_scope(name, 'RaggedRank', [input]) as name: 508 if not ragged_tensor.is_ragged(input): 509 return array_ops.rank(input, name) 510 511 return input.ragged_rank + array_ops.rank(input.flat_values) 512 513 514#=============================================================================== 515# ragged.one_hot 516#=============================================================================== 517def ragged_one_hot(indices, 518 depth, 519 on_value=None, 520 off_value=None, 521 axis=None, 522 dtype=None, 523 name=None): 524 """Applies tf.one_hot along the values of a RaggedTensor.""" 525 # Get the adjusted axis value for the call to array_ops.one_hot. 526 # Note: the only negative `axis` value supported by array_ops.one_hot is -1. 527 if isinstance(axis, int) and axis >= 0: 528 if axis <= indices.ragged_rank: 529 raise ValueError('axis (%d) must be greater than indices.ragged_rank ' 530 '(%d).' % (axis, indices.ragged_rank)) 531 axis -= indices.ragged_rank 532 533 with ops.name_scope(name, 'RaggedOneHot', 534 [indices, depth, on_value, off_value, axis]): 535 indices = ragged_tensor.convert_to_tensor_or_ragged_tensor( 536 indices, name='indices') 537 return indices.with_flat_values( 538 array_ops.one_hot(indices.flat_values, depth, on_value, off_value, axis, 539 dtype, name)) 540 541 542#=============================================================================== 543# ragged.stack_dynamic_partitions 544#=============================================================================== 545@tf_export('ragged.stack_dynamic_partitions') 546@dispatch.add_dispatch_support 547def stack_dynamic_partitions(data, partitions, num_partitions, name=None): 548 """Stacks dynamic partitions of a Tensor or RaggedTensor. 549 550 Returns a RaggedTensor `output` with `num_partitions` rows, where the row 551 `output[i]` is formed by stacking all slices `data[j1...jN]` such that 552 `partitions[j1...jN] = i`. Slices of `data` are stacked in row-major 553 order. 554 555 If `num_partitions` is an `int` (not a `Tensor`), then this is equivalent to 556 `tf.ragged.stack(tf.dynamic_partition(data, partitions, num_partitions))`. 557 558 #### Example: 559 560 >>> data = ['a', 'b', 'c', 'd', 'e'] 561 >>> partitions = [ 3, 0, 2, 2, 3] 562 >>> num_partitions = 5 563 >>> tf.ragged.stack_dynamic_partitions(data, partitions, num_partitions) 564 <tf.RaggedTensor [[b'b'], [], [b'c', b'd'], [b'a', b'e'], []]> 565 566 Args: 567 data: A `Tensor` or `RaggedTensor` containing the values to stack. 568 partitions: An `int32` or `int64` `Tensor` or `RaggedTensor` specifying the 569 partition that each slice of `data` should be added to. `partitions.shape` 570 must be a prefix of `data.shape`. Values must be greater than or equal to 571 zero, and less than `num_partitions`. `partitions` is not required to be 572 sorted. 573 num_partitions: An `int32` or `int64` scalar specifying the number of 574 partitions to output. This determines the number of rows in `output`. 575 name: A name prefix for the returned tensor (optional). 576 577 Returns: 578 A `RaggedTensor` containing the stacked partitions. The returned tensor 579 has the same dtype as `data`, and its shape is 580 `[num_partitions, (D)] + data.shape[partitions.rank:]`, where `(D)` is a 581 ragged dimension whose length is the number of data slices stacked for 582 each `partition`. 583 """ 584 with ops.name_scope(name, 'SegmentStack', [data, partitions, num_partitions]): 585 # Convert inputs to tensors. 586 data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data') 587 row_splits_dtype = ( 588 data.row_splits.dtype 589 if isinstance(data, ragged_tensor.RaggedTensor) else None) 590 partitions = ragged_tensor.convert_to_tensor_or_ragged_tensor( 591 partitions, name='partitions', preferred_dtype=row_splits_dtype) 592 num_partitions = ops.convert_to_tensor( 593 num_partitions, name='num_partitions', preferred_dtype=partitions.dtype) 594 if row_splits_dtype is not None: 595 partitions = math_ops.cast(partitions, row_splits_dtype) 596 num_partitions = math_ops.cast(num_partitions, partitions.dtype) 597 598 # Sanity-checks for shapes. 599 partitions_rank = partitions.shape.ndims 600 if partitions_rank is None: 601 raise ValueError('partitions must have known rank.') 602 num_partitions.shape.assert_has_rank(0) 603 partitions.shape.assert_is_compatible_with(data.shape[:partitions_rank]) 604 605 if partitions_rank == 0: 606 # If partitions is a scalar, then just create a RaggedTensor containing 607 # that single the complete `data` value in the specified row. 608 return ragged_tensor.RaggedTensor.from_value_rowids( 609 values=array_ops.stack([data]), 610 value_rowids=array_ops.stack([partitions]), 611 nrows=num_partitions, 612 validate=False) 613 614 elif partitions_rank == 1: 615 # If partitions is a vector (the typical case): we can just use data and 616 # partitions as the `values` and `value_rowids` for `from_value_rowids`, 617 # as long as we sort them first. 618 permutation = sort_ops.argsort(partitions, stable=True) 619 value_rowids = array_ops.gather(partitions, permutation) 620 values = array_ops.gather(data, permutation) 621 check = check_ops.assert_less( 622 value_rowids[-1:], 623 num_partitions, 624 message='partitions must be less than num_partitions') 625 with ops.control_dependencies([check]): 626 return ragged_tensor.RaggedTensor.from_value_rowids( 627 values, value_rowids, nrows=num_partitions, validate=False) 628 629 else: 630 # Handle higher-dimensional partitions via recursion. 631 if not isinstance(data, ragged_tensor.RaggedTensor): 632 data = ragged_tensor.RaggedTensor.from_tensor( 633 data, row_splits_dtype=partitions.dtype, ragged_rank=1) 634 if not isinstance(partitions, ragged_tensor.RaggedTensor): 635 partitions = ragged_tensor.RaggedTensor.from_tensor( 636 partitions, 637 row_splits_dtype=partitions.dtype, 638 ragged_rank=max(data.ragged_rank, partitions_rank - 1)) 639 check = check_ops.assert_equal( 640 data.row_splits, 641 partitions.row_splits, 642 message='data and partitions have incompatible ragged shapes') 643 with ops.control_dependencies([check]): 644 return stack_dynamic_partitions(data.values, partitions.values, 645 num_partitions) 646 647 648#=============================================================================== 649# Reverse 650#=============================================================================== 651def reverse(tensor, axis, name=None): 652 """Reverses a RaggedTensor along the specified axes. 653 654 #### Example: 655 656 >>> data = tf.ragged.constant([ 657 ... [[1, 2], [3, 4]], [[5, 6]], [[7, 8], [9, 10], [11, 12]]]) 658 >>> tf.reverse(data, axis=[0, 2]) 659 <tf.RaggedTensor [[[8, 7], [10, 9], [12, 11]], [[6, 5]], [[2, 1], [4, 3]]]> 660 661 Args: 662 tensor: A 'RaggedTensor' to reverse. 663 axis: A list or tuple of 'int' or a constant 1D 'tf.Tensor'. The indices of 664 the axes to reverse. 665 name: A name prefix for the returned tensor (optional). 666 667 Returns: 668 A 'RaggedTensor'. 669 """ 670 type_error_msg = ('`axis` must be a list of int or a constant tensor' 671 'when reversing axes in a ragged tensor') 672 673 with ops.name_scope(name, 'Reverse', [tensor, axis]): 674 if isinstance(axis, ops.Tensor): 675 axis = tensor_util.constant_value(axis) 676 if axis is None: 677 raise TypeError(type_error_msg) 678 elif not (isinstance(axis, (list, tuple)) and 679 all(isinstance(dim, int) for dim in axis)): 680 raise TypeError(type_error_msg) 681 682 tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor( 683 tensor, name='tensor') 684 685 # Allow usage of negative values to specify innermost axes. 686 axis = [ 687 array_ops.get_positive_axis(dim, tensor.shape.rank, 'axis[%d]' % i, 688 'rank(tensor)') 689 for i, dim in enumerate(axis) 690 ] 691 692 # We only need to slice up to the max axis. If the axis list 693 # is empty, it should be 0. 694 slices = [slice(None)] * (max(axis) + 1 if axis else 0) 695 696 for dim in axis: 697 slices[dim] = slice(None, None, -1) 698 699 return tensor[tuple(slices)] 700 701 702#=============================================================================== 703# Cross 704#=============================================================================== 705 706 707@tf_export('ragged.cross') 708@dispatch.add_dispatch_support 709def cross(inputs, name=None): 710 """Generates feature cross from a list of tensors. 711 712 The input tensors must have `rank=2`, and must all have the same number of 713 rows. The result is a `RaggedTensor` with the same number of rows as the 714 inputs, where `result[row]` contains a list of all combinations of values 715 formed by taking a single value from each input's corresponding row 716 (`inputs[i][row]`). Values are combined by joining their strings with '_X_'. 717 E.g.: 718 719 >>> tf.ragged.cross([tf.ragged.constant([['a'], ['b', 'c']]), 720 ... tf.ragged.constant([['d'], ['e']]), 721 ... tf.ragged.constant([['f'], ['g']])]) 722 <tf.RaggedTensor [[b'a_X_d_X_f'], [b'b_X_e_X_g', b'c_X_e_X_g']]> 723 724 Args: 725 inputs: A list of `RaggedTensor` or `Tensor` or `SparseTensor`. 726 name: Optional name for the op. 727 728 Returns: 729 A 2D `RaggedTensor` of type `string`. 730 """ 731 return _cross_internal(inputs=inputs, hashed_output=False, name=name) 732 733 734@tf_export('ragged.cross_hashed') 735@dispatch.add_dispatch_support 736def cross_hashed(inputs, num_buckets=0, hash_key=None, name=None): 737 """Generates hashed feature cross from a list of tensors. 738 739 The input tensors must have `rank=2`, and must all have the same number of 740 rows. The result is a `RaggedTensor` with the same number of rows as the 741 inputs, where `result[row]` contains a list of all combinations of values 742 formed by taking a single value from each input's corresponding row 743 (`inputs[i][row]`). Values are combined by hashing together their 744 fingerprints. E.g.: 745 746 >>> tf.ragged.cross_hashed([tf.ragged.constant([['a'], ['b', 'c']]), 747 ... tf.ragged.constant([['d'], ['e']]), 748 ... tf.ragged.constant([['f'], ['g']])], 749 ... num_buckets=100) 750 <tf.RaggedTensor [[78], [66, 74]]> 751 752 Args: 753 inputs: A list of `RaggedTensor` or `Tensor` or `SparseTensor`. 754 num_buckets: A non-negative `int` that used to bucket the hashed values. If 755 `num_buckets != 0`, then `output = hashed_value % num_buckets`. 756 hash_key: Integer hash_key that will be used by the `FingerprintCat64` 757 function. If not given, a default key is used. 758 name: Optional name for the op. 759 760 Returns: 761 A 2D `RaggedTensor` of type `int64`. 762 """ 763 return _cross_internal( 764 inputs=inputs, 765 hashed_output=True, 766 num_buckets=num_buckets, 767 hash_key=hash_key, 768 name=name) 769 770 771_DEFAULT_CROSS_HASH_KEY = 0xDECAFCAFFE 772 773 774def _cross_internal(inputs, 775 hashed_output=False, 776 num_buckets=0, 777 hash_key=None, 778 name=None): 779 """Generates feature cross from a list of ragged and dense tensors.""" 780 if not isinstance(inputs, (tuple, list)): 781 raise TypeError('Inputs must be a list') 782 783 if hash_key is None: 784 hash_key = _DEFAULT_CROSS_HASH_KEY 785 786 ragged_inputs = [] 787 sparse_inputs = [] 788 dense_inputs = [] 789 input_order = [] 790 with ops.name_scope(name, 'RaggedCross', inputs): 791 for i, t in enumerate(inputs): 792 if sparse_tensor.is_sparse(t): 793 t = sparse_tensor.SparseTensor.from_value(t) 794 else: 795 t = ragged_tensor.convert_to_tensor_or_ragged_tensor(t) 796 if t.dtype.is_integer: 797 t = math_ops.cast(t, dtypes.int64) 798 elif t.dtype != dtypes.string: 799 raise ValueError('Unexpected dtype for inputs[%d]: %s' % (i, t.dtype)) 800 if isinstance(t, ragged_tensor.RaggedTensor): 801 if t.ragged_rank != 1: 802 raise ValueError('tf.ragged.cross only supports inputs with rank=2') 803 ragged_inputs.append(t) 804 input_order.append('R') 805 elif isinstance(t, sparse_tensor.SparseTensor): 806 sparse_inputs.append(t) 807 input_order.append('S') 808 else: 809 dense_inputs.append(t) 810 input_order.append('D') 811 812 out_values_type = dtypes.int64 if hashed_output else dtypes.string 813 if ragged_inputs and all( 814 t.row_splits.dtype == dtypes.int32 for t in ragged_inputs): 815 out_row_splits_type = dtypes.int32 816 else: 817 out_row_splits_type = dtypes.int64 818 819 # Convert hash_key from uint64 -> int64, since we need to pass it via 820 # an int64 attr. 821 if hash_key > 2**63: 822 hash_key -= 2**64 823 824 values_out, splits_out = gen_ragged_array_ops.ragged_cross( 825 ragged_values=[rt.values for rt in ragged_inputs], 826 ragged_row_splits=[rt.row_splits for rt in ragged_inputs], 827 sparse_indices=[st.indices for st in sparse_inputs], 828 sparse_values=[st.values for st in sparse_inputs], 829 sparse_shape=[st.dense_shape for st in sparse_inputs], 830 dense_inputs=dense_inputs, 831 input_order=''.join(input_order), 832 hashed_output=hashed_output, 833 num_buckets=num_buckets, 834 hash_key=hash_key, 835 out_values_type=out_values_type.as_datatype_enum, 836 out_row_splits_type=out_row_splits_type.as_datatype_enum, 837 name=name) 838 839 return ragged_tensor.RaggedTensor.from_row_splits( 840 values_out, splits_out, validate=False) 841