1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Classes for storing ragged tensors and their values.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22import operator 23 24import numpy as np 25 26from tensorflow.python import tf2 27from tensorflow.python.client import session 28from tensorflow.python.framework import composite_tensor 29from tensorflow.python.framework import constant_op 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import ops 32from tensorflow.python.framework import sparse_tensor 33from tensorflow.python.framework import tensor_shape 34from tensorflow.python.framework import tensor_spec 35from tensorflow.python.framework import tensor_util 36from tensorflow.python.framework import type_spec 37from tensorflow.python.ops import array_ops 38from tensorflow.python.ops import check_ops 39from tensorflow.python.ops import control_flow_ops 40from tensorflow.python.ops import gen_ragged_conversion_ops 41from tensorflow.python.ops import math_ops 42from tensorflow.python.ops.ragged import ragged_config 43from tensorflow.python.ops.ragged import ragged_tensor_value 44from tensorflow.python.ops.ragged import ragged_util 45from tensorflow.python.ops.ragged.row_partition import RowPartition 46from tensorflow.python.types import internal as internal_types 47from tensorflow.python.util import dispatch 48from tensorflow.python.util.tf_export import tf_export 49from tensorflow.tools.docs import doc_controls 50 51# pylint: disable=protected-access 52_convert_row_partition = RowPartition._convert_row_partition 53# pylint: enable=protected-access 54 55#=============================================================================== 56# RaggedTensor 57#=============================================================================== 58 59 60@tf_export("RaggedTensor") 61class RaggedTensor(composite_tensor.CompositeTensor, 62 internal_types.NativeObject): 63 """Represents a ragged tensor. 64 65 A `RaggedTensor` is a tensor with one or more *ragged dimensions*, which are 66 dimensions whose slices may have different lengths. For example, the inner 67 (column) dimension of `rt=[[3, 1, 4, 1], [], [5, 9, 2], [6], []]` is ragged, 68 since the column slices (`rt[0, :]`, ..., `rt[4, :]`) have different lengths. 69 Dimensions whose slices all have the same length are called *uniform 70 dimensions*. The outermost dimension of a `RaggedTensor` is always uniform, 71 since it consists of a single slice (and so there is no possibility for 72 differing slice lengths). 73 74 The total number of dimensions in a `RaggedTensor` is called its *rank*, 75 and the number of ragged dimensions in a `RaggedTensor` is called its 76 *ragged-rank*. A `RaggedTensor`'s ragged-rank is fixed at graph creation 77 time: it can't depend on the runtime values of `Tensor`s, and can't vary 78 dynamically for different session runs. 79 80 Note that the `__init__` constructor is private. Please use one of the 81 following methods to construct a `RaggedTensor`: 82 83 * `tf.RaggedTensor.from_row_lengths` 84 * `tf.RaggedTensor.from_value_rowids` 85 * `tf.RaggedTensor.from_row_splits` 86 * `tf.RaggedTensor.from_row_starts` 87 * `tf.RaggedTensor.from_row_limits` 88 * `tf.RaggedTensor.from_nested_row_splits` 89 * `tf.RaggedTensor.from_nested_row_lengths` 90 * `tf.RaggedTensor.from_nested_value_rowids` 91 92 ### Potentially Ragged Tensors 93 94 Many ops support both `Tensor`s and `RaggedTensor`s 95 (see [tf.ragged](https://www.tensorflow.org/api_docs/python/tf/ragged) for a 96 full listing). The term "potentially ragged tensor" may be used to refer to a 97 tensor that might be either a `Tensor` or a `RaggedTensor`. The ragged-rank 98 of a `Tensor` is zero. 99 100 ### Documenting RaggedTensor Shapes 101 102 When documenting the shape of a RaggedTensor, ragged dimensions can be 103 indicated by enclosing them in parentheses. For example, the shape of 104 a 3-D `RaggedTensor` that stores the fixed-size word embedding for each 105 word in a sentence, for each sentence in a batch, could be written as 106 `[num_sentences, (num_words), embedding_size]`. The parentheses around 107 `(num_words)` indicate that dimension is ragged, and that the length 108 of each element list in that dimension may vary for each item. 109 110 ### Component Tensors 111 112 Internally, a `RaggedTensor` consists of a concatenated list of values that 113 are partitioned into variable-length rows. In particular, each `RaggedTensor` 114 consists of: 115 116 * A `values` tensor, which concatenates the variable-length rows into a 117 flattened list. For example, the `values` tensor for 118 `[[3, 1, 4, 1], [], [5, 9, 2], [6], []]` is `[3, 1, 4, 1, 5, 9, 2, 6]`. 119 120 * A `row_splits` vector, which indicates how those flattened values are 121 divided into rows. In particular, the values for row `rt[i]` are stored 122 in the slice `rt.values[rt.row_splits[i]:rt.row_splits[i+1]]`. 123 124 Example: 125 126 >>> print(tf.RaggedTensor.from_row_splits( 127 ... values=[3, 1, 4, 1, 5, 9, 2, 6], 128 ... row_splits=[0, 4, 4, 7, 8, 8])) 129 <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]> 130 131 ### Alternative Row-Partitioning Schemes 132 133 In addition to `row_splits`, ragged tensors provide support for five other 134 row-partitioning schemes: 135 136 * `row_lengths`: a vector with shape `[nrows]`, which specifies the length 137 of each row. 138 139 * `value_rowids` and `nrows`: `value_rowids` is a vector with shape 140 `[nvals]`, corresponding one-to-one with `values`, which specifies 141 each value's row index. In particular, the row `rt[row]` consists of the 142 values `rt.values[j]` where `value_rowids[j]==row`. `nrows` is an 143 integer scalar that specifies the number of rows in the 144 `RaggedTensor`. (`nrows` is used to indicate trailing empty rows.) 145 146 * `row_starts`: a vector with shape `[nrows]`, which specifies the start 147 offset of each row. Equivalent to `row_splits[:-1]`. 148 149 * `row_limits`: a vector with shape `[nrows]`, which specifies the stop 150 offset of each row. Equivalent to `row_splits[1:]`. 151 152 * `uniform_row_length`: A scalar tensor, specifying the length of every 153 row. This row-partitioning scheme may only be used if all rows have 154 the same length. 155 156 Example: The following ragged tensors are equivalent, and all represent the 157 nested list `[[3, 1, 4, 1], [], [5, 9, 2], [6], []]`. 158 159 >>> values = [3, 1, 4, 1, 5, 9, 2, 6] 160 >>> RaggedTensor.from_row_splits(values, row_splits=[0, 4, 4, 7, 8, 8]) 161 <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]> 162 >>> RaggedTensor.from_row_lengths(values, row_lengths=[4, 0, 3, 1, 0]) 163 <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]> 164 >>> RaggedTensor.from_value_rowids( 165 ... values, value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], nrows=5) 166 <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]> 167 >>> RaggedTensor.from_row_starts(values, row_starts=[0, 4, 4, 7, 8]) 168 <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]> 169 >>> RaggedTensor.from_row_limits(values, row_limits=[4, 4, 7, 8, 8]) 170 <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]> 171 >>> RaggedTensor.from_uniform_row_length(values, uniform_row_length=2) 172 <tf.RaggedTensor [[3, 1], [4, 1], [5, 9], [2, 6]]> 173 174 ### Multiple Ragged Dimensions 175 176 `RaggedTensor`s with multiple ragged dimensions can be defined by using 177 a nested `RaggedTensor` for the `values` tensor. Each nested `RaggedTensor` 178 adds a single ragged dimension. 179 180 >>> inner_rt = RaggedTensor.from_row_splits( # =rt1 from above 181 ... values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8]) 182 >>> outer_rt = RaggedTensor.from_row_splits( 183 ... values=inner_rt, row_splits=[0, 3, 3, 5]) 184 >>> print(outer_rt.to_list()) 185 [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]] 186 >>> print(outer_rt.ragged_rank) 187 2 188 189 The factory function `RaggedTensor.from_nested_row_splits` may be used to 190 construct a `RaggedTensor` with multiple ragged dimensions directly, by 191 providing a list of `row_splits` tensors: 192 193 >>> RaggedTensor.from_nested_row_splits( 194 ... flat_values=[3, 1, 4, 1, 5, 9, 2, 6], 195 ... nested_row_splits=([0, 3, 3, 5], [0, 4, 4, 7, 8, 8])).to_list() 196 [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]] 197 198 ### Uniform Inner Dimensions 199 200 `RaggedTensor`s with uniform inner dimensions can be defined 201 by using a multidimensional `Tensor` for `values`. 202 203 >>> rt = RaggedTensor.from_row_splits(values=tf.ones([5, 3], tf.int32), 204 ... row_splits=[0, 2, 5]) 205 >>> print(rt.to_list()) 206 [[[1, 1, 1], [1, 1, 1]], 207 [[1, 1, 1], [1, 1, 1], [1, 1, 1]]] 208 >>> print(rt.shape) 209 (2, None, 3) 210 211 ### Uniform Outer Dimensions 212 213 `RaggedTensor`s with uniform outer dimensions can be defined by using 214 one or more `RaggedTensor` with a `uniform_row_length` row-partitioning 215 tensor. For example, a `RaggedTensor` with shape `[2, 2, None]` can be 216 constructed with this method from a `RaggedTensor` values with shape 217 `[4, None]`: 218 219 >>> values = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]]) 220 >>> print(values.shape) 221 (4, None) 222 >>> rt6 = tf.RaggedTensor.from_uniform_row_length(values, 2) 223 >>> print(rt6) 224 <tf.RaggedTensor [[[1, 2, 3], [4]], [[5, 6], [7, 8, 9, 10]]]> 225 >>> print(rt6.shape) 226 (2, 2, None) 227 228 Note that `rt6` only contains one ragged dimension (the innermost 229 dimension). In contrast, if `from_row_splits` is used to construct a similar 230 `RaggedTensor`, then that `RaggedTensor` will have two ragged dimensions: 231 232 >>> rt7 = tf.RaggedTensor.from_row_splits(values, [0, 2, 4]) 233 >>> print(rt7.shape) 234 (2, None, None) 235 236 Uniform and ragged outer dimensions may be interleaved, meaning that a 237 tensor with any combination of ragged and uniform dimensions may be created. 238 For example, a RaggedTensor `t4` with shape `[3, None, 4, 8, None, 2]` could 239 be constructed as follows: 240 241 ```python 242 t0 = tf.zeros([1000, 2]) # Shape: [1000, 2] 243 t1 = RaggedTensor.from_row_lengths(t0, [...]) # [160, None, 2] 244 t2 = RaggedTensor.from_uniform_row_length(t1, 8) # [20, 8, None, 2] 245 t3 = RaggedTensor.from_uniform_row_length(t2, 4) # [5, 4, 8, None, 2] 246 t4 = RaggedTensor.from_row_lengths(t3, [...]) # [3, None, 4, 8, None, 2] 247 ``` 248 249 """ 250 251 #============================================================================= 252 # Constructor (private) 253 #============================================================================= 254 @doc_controls.do_not_generate_docs 255 def __init__(self, values, row_partition, internal=False): 256 """Creates a `RaggedTensor` with a specified partitioning for `values`. 257 258 This constructor is private -- please use one of the following ops to 259 build `RaggedTensor`s: 260 261 * `tf.RaggedTensor.from_row_lengths` 262 * `tf.RaggedTensor.from_value_rowids` 263 * `tf.RaggedTensor.from_row_splits` 264 * `tf.RaggedTensor.from_row_starts` 265 * `tf.RaggedTensor.from_row_limits` 266 * `tf.RaggedTensor.from_nested_row_splits` 267 * `tf.RaggedTensor.from_nested_row_lengths` 268 * `tf.RaggedTensor.from_nested_value_rowids` 269 270 Args: 271 values: A potentially ragged tensor of any dtype and shape `[nvals, ...]`. 272 row_partition: A `RowPartition` object, representing the arrangement of 273 the lists at the top level. 274 internal: True if the constructor is being called by one of the factory 275 methods. If false, an exception will be raised. 276 277 Raises: 278 ValueError: If internal = False. Note that this method is intended only 279 for internal use. 280 TypeError: If values is not a `RaggedTensor` or `Tensor`, or 281 row_partition is not a `RowPartition`. 282 """ 283 284 if not internal: 285 raise ValueError("RaggedTensor constructor is private; please use one " 286 "of the factory methods instead (e.g., " 287 "RaggedTensor.from_row_lengths())") 288 _assert_is_supported_ragged_values_type(values) 289 if not isinstance(row_partition, RowPartition): 290 raise TypeError(f"Argument `row_partition` must be a RowPartition. " 291 f"Received {row_partition}.") 292 293 # Validate shapes. 294 values.shape.with_rank_at_least(1) 295 if isinstance(values, RaggedTensor): 296 # pylint: disable=protected-access 297 assert row_partition.dtype == values._row_partition.dtype 298 299 self._values = values 300 self._row_partition = row_partition 301 302 #============================================================================= 303 # Factory Methods 304 #============================================================================= 305 306 @classmethod 307 def _from_row_partition(cls, values, row_partition, validate=True): 308 """Creates a `RaggedTensor` with a row partition. 309 310 This is used as a way for RaggedTensors to share row partitions. 311 312 The outer dimension of values must be equal to `partition.nvals()`. 313 314 Args: 315 values: A potentially ragged tensor. 316 row_partition: a `RowPartition`: can be shared between tensors. 317 validate: If true, then use assertions to check that the arguments form a 318 valid `RaggedTensor`. 319 320 Returns: 321 A `RaggedTensor`. `result.rank = values.rank + 1`. 322 `result.ragged_rank = values.ragged_rank + 1`. 323 324 Raises: 325 ValueError: If partition.nvals() != _nrows(values) 326 """ 327 if not isinstance(row_partition, RowPartition): 328 raise TypeError(f"Argument `row_partition` must be a RowPartition. " 329 f"Received {row_partition}.") 330 if not isinstance(validate, bool): 331 raise TypeError(f"Argument `validate` must have type bool. " 332 f"Received {validate}.") 333 values, row_partition = cls._convert_values_and_partition( 334 values, row_partition, "partition") 335 if row_partition.has_precomputed_value_rowids(): 336 value_rowids_shape = row_partition.value_rowids().shape 337 values.shape[:1].assert_is_compatible_with(value_rowids_shape) 338 if validate: 339 msg = "Arguments to _from_row_partition do not form a valid RaggedTensor" 340 nvals = _nrows(values, row_partition.dtype) 341 checks = [ 342 check_ops.assert_equal( 343 row_partition.nvals(out_type=row_partition.dtype), 344 nvals, 345 message=msg), 346 ] 347 if not isinstance(values, RaggedTensor): 348 checks.append(check_ops.assert_rank_at_least(values, 1)) 349 row_partition = row_partition.with_dependencies(checks) 350 return cls(values=values, internal=True, row_partition=row_partition) 351 352 @classmethod 353 @dispatch.add_dispatch_support 354 def from_value_rowids(cls, 355 values, 356 value_rowids, 357 nrows=None, 358 name=None, 359 validate=True): 360 """Creates a `RaggedTensor` with rows partitioned by `value_rowids`. 361 362 The returned `RaggedTensor` corresponds with the python list defined by: 363 364 ```python 365 result = [[values[i] for i in range(len(values)) if value_rowids[i] == row] 366 for row in range(nrows)] 367 ``` 368 369 Args: 370 values: A potentially ragged tensor with shape `[nvals, ...]`. 371 value_rowids: A 1-D integer tensor with shape `[nvals]`, which corresponds 372 one-to-one with `values`, and specifies each value's row index. Must be 373 nonnegative, and must be sorted in ascending order. 374 nrows: An integer scalar specifying the number of rows. This should be 375 specified if the `RaggedTensor` may containing empty training rows. Must 376 be greater than `value_rowids[-1]` (or zero if `value_rowids` is empty). 377 Defaults to `value_rowids[-1]` (or zero if `value_rowids` is empty). 378 name: A name prefix for the RaggedTensor (optional). 379 validate: If true, then use assertions to check that the arguments form 380 a valid `RaggedTensor`. Note: these assertions incur a runtime cost, 381 since they must be checked for each tensor value. 382 383 Returns: 384 A `RaggedTensor`. `result.rank = values.rank + 1`. 385 `result.ragged_rank = values.ragged_rank + 1`. 386 387 Raises: 388 ValueError: If `nrows` is incompatible with `value_rowids`. 389 390 #### Example: 391 392 >>> print(tf.RaggedTensor.from_value_rowids( 393 ... values=[3, 1, 4, 1, 5, 9, 2, 6], 394 ... value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], 395 ... nrows=5)) 396 <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]> 397 398 """ 399 if not isinstance(validate, bool): 400 raise TypeError(f"Argument `validate` must have type bool. " 401 f"Received {validate}.") 402 403 with ops.name_scope(name, "RaggedFromValueRowIds", 404 [values, value_rowids, nrows]): 405 row_partition = RowPartition.from_value_rowids( 406 value_rowids=value_rowids, 407 nrows=nrows, 408 validate=validate, 409 preferred_dtype=_get_optional_partition_dtype(values)) 410 return cls._from_row_partition(values, row_partition, validate=validate) 411 412 @classmethod 413 @dispatch.add_dispatch_support 414 def from_row_splits(cls, values, row_splits, name=None, validate=True): 415 """Creates a `RaggedTensor` with rows partitioned by `row_splits`. 416 417 The returned `RaggedTensor` corresponds with the python list defined by: 418 419 ```python 420 result = [values[row_splits[i]:row_splits[i + 1]] 421 for i in range(len(row_splits) - 1)] 422 ``` 423 424 Args: 425 values: A potentially ragged tensor with shape `[nvals, ...]`. 426 row_splits: A 1-D integer tensor with shape `[nrows+1]`. Must not be 427 empty, and must be sorted in ascending order. `row_splits[0]` must be 428 zero and `row_splits[-1]` must be `nvals`. 429 name: A name prefix for the RaggedTensor (optional). 430 validate: If true, then use assertions to check that the arguments form 431 a valid `RaggedTensor`. Note: these assertions incur a runtime cost, 432 since they must be checked for each tensor value. 433 434 Returns: 435 A `RaggedTensor`. `result.rank = values.rank + 1`. 436 `result.ragged_rank = values.ragged_rank + 1`. 437 438 Raises: 439 ValueError: If `row_splits` is an empty list. 440 441 #### Example: 442 443 >>> print(tf.RaggedTensor.from_row_splits( 444 ... values=[3, 1, 4, 1, 5, 9, 2, 6], 445 ... row_splits=[0, 4, 4, 7, 8, 8])) 446 <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]> 447 448 """ 449 if not isinstance(validate, bool): 450 raise TypeError(f"Argument `validate` must have type bool. " 451 f"Received {validate}.") 452 453 with ops.name_scope(name, "RaggedFromRowSplits", [values, row_splits]): 454 row_partition = RowPartition.from_row_splits( 455 row_splits=row_splits, 456 validate=validate, 457 preferred_dtype=_get_optional_partition_dtype(values)) 458 return cls._from_row_partition(values, row_partition, validate=validate) 459 460 @classmethod 461 @dispatch.add_dispatch_support 462 def from_row_lengths(cls, values, row_lengths, name=None, validate=True): 463 """Creates a `RaggedTensor` with rows partitioned by `row_lengths`. 464 465 The returned `RaggedTensor` corresponds with the python list defined by: 466 467 ```python 468 result = [[values.pop(0) for i in range(length)] 469 for length in row_lengths] 470 ``` 471 472 Args: 473 values: A potentially ragged tensor with shape `[nvals, ...]`. 474 row_lengths: A 1-D integer tensor with shape `[nrows]`. Must be 475 nonnegative. `sum(row_lengths)` must be `nvals`. 476 name: A name prefix for the RaggedTensor (optional). 477 validate: If true, then use assertions to check that the arguments form 478 a valid `RaggedTensor`. Note: these assertions incur a runtime cost, 479 since they must be checked for each tensor value. 480 481 Returns: 482 A `RaggedTensor`. `result.rank = values.rank + 1`. 483 `result.ragged_rank = values.ragged_rank + 1`. 484 485 #### Example: 486 487 >>> print(tf.RaggedTensor.from_row_lengths( 488 ... values=[3, 1, 4, 1, 5, 9, 2, 6], 489 ... row_lengths=[4, 0, 3, 1, 0])) 490 <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]> 491 492 """ 493 if not isinstance(validate, bool): 494 raise TypeError(f"Argument `validate` must have type bool. " 495 f"Received {validate}.") 496 497 with ops.name_scope(name, "RaggedFromRowLengths", [values, row_lengths]): 498 row_partition = RowPartition.from_row_lengths( 499 row_lengths=row_lengths, 500 validate=validate, 501 preferred_dtype=_get_optional_partition_dtype(values)) 502 return cls._from_row_partition(values, row_partition, validate=validate) 503 504 @classmethod 505 @dispatch.add_dispatch_support 506 def from_row_starts(cls, values, row_starts, name=None, validate=True): 507 """Creates a `RaggedTensor` with rows partitioned by `row_starts`. 508 509 Equivalent to: `from_row_splits(values, concat([row_starts, nvals]))`. 510 511 Args: 512 values: A potentially ragged tensor with shape `[nvals, ...]`. 513 row_starts: A 1-D integer tensor with shape `[nrows]`. Must be 514 nonnegative and sorted in ascending order. If `nrows>0`, then 515 `row_starts[0]` must be zero. 516 name: A name prefix for the RaggedTensor (optional). 517 validate: If true, then use assertions to check that the arguments form 518 a valid `RaggedTensor`. Note: these assertions incur a runtime cost, 519 since they must be checked for each tensor value. 520 521 Returns: 522 A `RaggedTensor`. `result.rank = values.rank + 1`. 523 `result.ragged_rank = values.ragged_rank + 1`. 524 525 #### Example: 526 527 >>> print(tf.RaggedTensor.from_row_starts( 528 ... values=[3, 1, 4, 1, 5, 9, 2, 6], 529 ... row_starts=[0, 4, 4, 7, 8])) 530 <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]> 531 532 """ 533 if not isinstance(validate, bool): 534 raise TypeError(f"Argument `validate` must have type bool. " 535 f"Received {validate}.") 536 with ops.name_scope(name, "RaggedFromRowStarts", [values, row_starts]): 537 values = _convert_to_ragged_tensor_values(values) 538 row_partition = RowPartition.from_row_starts( 539 row_starts=row_starts, 540 nvals=_nrows(values), 541 validate=validate, 542 preferred_dtype=_get_optional_partition_dtype(values)) 543 return cls._from_row_partition(values, row_partition, validate=validate) 544 545 @classmethod 546 @dispatch.add_dispatch_support 547 def from_row_limits(cls, values, row_limits, name=None, validate=True): 548 """Creates a `RaggedTensor` with rows partitioned by `row_limits`. 549 550 Equivalent to: `from_row_splits(values, concat([0, row_limits]))`. 551 552 Args: 553 values: A potentially ragged tensor with shape `[nvals, ...]`. 554 row_limits: A 1-D integer tensor with shape `[nrows]`. Must be sorted in 555 ascending order. If `nrows>0`, then `row_limits[-1]` must be `nvals`. 556 name: A name prefix for the RaggedTensor (optional). 557 validate: If true, then use assertions to check that the arguments form 558 a valid `RaggedTensor`. Note: these assertions incur a runtime cost, 559 since they must be checked for each tensor value. 560 561 Returns: 562 A `RaggedTensor`. `result.rank = values.rank + 1`. 563 `result.ragged_rank = values.ragged_rank + 1`. 564 565 #### Example: 566 567 >>> print(tf.RaggedTensor.from_row_limits( 568 ... values=[3, 1, 4, 1, 5, 9, 2, 6], 569 ... row_limits=[4, 4, 7, 8, 8])) 570 <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]> 571 572 """ 573 if not isinstance(validate, bool): 574 raise TypeError(f"Argument `validate` must have type bool. " 575 f"Received {validate}.") 576 with ops.name_scope(name, "RaggedFromRowLimits", [values, row_limits]): 577 row_partition = RowPartition.from_row_limits( 578 row_limits=row_limits, 579 validate=validate, 580 preferred_dtype=_get_optional_partition_dtype(values)) 581 return cls._from_row_partition(values, row_partition, validate=validate) 582 583 @classmethod 584 @dispatch.add_dispatch_support 585 def from_uniform_row_length(cls, 586 values, 587 uniform_row_length, 588 nrows=None, 589 validate=True, 590 name=None): 591 """Creates a `RaggedTensor` with rows partitioned by `uniform_row_length`. 592 593 This method can be used to create `RaggedTensor`s with multiple uniform 594 outer dimensions. For example, a `RaggedTensor` with shape `[2, 2, None]` 595 can be constructed with this method from a `RaggedTensor` values with shape 596 `[4, None]`: 597 598 >>> values = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]]) 599 >>> print(values.shape) 600 (4, None) 601 >>> rt1 = tf.RaggedTensor.from_uniform_row_length(values, 2) 602 >>> print(rt1) 603 <tf.RaggedTensor [[[1, 2, 3], [4]], [[5, 6], [7, 8, 9, 10]]]> 604 >>> print(rt1.shape) 605 (2, 2, None) 606 607 Note that `rt1` only contains one ragged dimension (the innermost 608 dimension). In contrast, if `from_row_splits` is used to construct a similar 609 `RaggedTensor`, then that `RaggedTensor` will have two ragged dimensions: 610 611 >>> rt2 = tf.RaggedTensor.from_row_splits(values, [0, 2, 4]) 612 >>> print(rt2.shape) 613 (2, None, None) 614 615 Args: 616 values: A potentially ragged tensor with shape `[nvals, ...]`. 617 uniform_row_length: A scalar integer tensor. Must be nonnegative. The 618 size of the outer axis of `values` must be evenly divisible by 619 `uniform_row_length`. 620 nrows: The number of rows in the constructed RaggedTensor. If not 621 specified, then it defaults to `nvals/uniform_row_length` (or `0` if 622 `uniform_row_length==0`). `nrows` only needs to be specified if 623 `uniform_row_length` might be zero. `uniform_row_length*nrows` must be 624 `nvals`. 625 validate: If true, then use assertions to check that the arguments form 626 a valid `RaggedTensor`. Note: these assertions incur a runtime cost, 627 since they must be checked for each tensor value. 628 name: A name prefix for the RaggedTensor (optional). 629 630 Returns: 631 A `RaggedTensor` that corresponds with the python list defined by: 632 633 ```python 634 result = [[values.pop(0) for i in range(uniform_row_length)] 635 for _ in range(nrows)] 636 ``` 637 638 `result.rank = values.rank + 1`. 639 `result.ragged_rank = values.ragged_rank + 1`. 640 """ 641 if not isinstance(validate, bool): 642 raise TypeError(f"Argument `validate` must have type bool. " 643 f"Received {validate}.") 644 with ops.name_scope(name, "RaggedFromUniformRowLength", 645 [values, uniform_row_length, nrows]): 646 values = _convert_to_ragged_tensor_values(values) 647 uniform_row_length = _convert_row_partition( 648 uniform_row_length, "UniformRowLength", 649 _get_optional_partition_dtype(values)) 650 nvals = _nvals_uniform_row_length(values, uniform_row_length) 651 row_partition = RowPartition.from_uniform_row_length( 652 uniform_row_length=uniform_row_length, 653 nvals=nvals, 654 nrows=nrows, 655 validate=validate, 656 preferred_dtype=_get_optional_partition_dtype(values)) 657 return cls._from_row_partition(values, row_partition, validate=validate) 658 659 @classmethod 660 @dispatch.add_dispatch_support 661 def from_nested_value_rowids(cls, 662 flat_values, 663 nested_value_rowids, 664 nested_nrows=None, 665 name=None, 666 validate=True): 667 """Creates a `RaggedTensor` from a nested list of `value_rowids` tensors. 668 669 Equivalent to: 670 671 ```python 672 result = flat_values 673 for (rowids, nrows) in reversed(zip(nested_value_rowids, nested_nrows)): 674 result = from_value_rowids(result, rowids, nrows) 675 ``` 676 677 Args: 678 flat_values: A potentially ragged tensor. 679 nested_value_rowids: A list of 1-D integer tensors. The `i`th tensor is 680 used as the `value_rowids` for the `i`th ragged dimension. 681 nested_nrows: A list of integer scalars. The `i`th scalar is used as the 682 `nrows` for the `i`th ragged dimension. 683 name: A name prefix for the RaggedTensor (optional). 684 validate: If true, then use assertions to check that the arguments form 685 a valid `RaggedTensor`. Note: these assertions incur a runtime cost, 686 since they must be checked for each tensor value. 687 688 Returns: 689 A `RaggedTensor` (or `flat_values` if `nested_value_rowids` is empty). 690 691 Raises: 692 ValueError: If `len(nested_values_rowids) != len(nested_nrows)`. 693 """ 694 if not isinstance(validate, bool): 695 raise TypeError(f"Argument `validate` must have type bool. " 696 f"Received {validate}.") 697 if isinstance(nested_value_rowids, ops.Tensor): 698 raise TypeError(f"Argument `nested_value_rowids` must be a list of " 699 f"Tensors. Received {nested_value_rowids}.") 700 if nested_nrows is None: 701 nested_nrows = [None] * len(nested_value_rowids) 702 else: 703 if isinstance(nested_nrows, ops.Tensor): 704 raise TypeError(f"Argument `nested_nrows` must be a list of " 705 f"Tensors. Received {nested_nrows}.") 706 if len(nested_nrows) != len(nested_value_rowids): 707 raise ValueError( 708 f"Argument `nested_nrows` must have the same length as " 709 f"argument `nested_value_rowids`. len(nested_nrows) = " 710 f"{len(nested_nrows)} vs. len(nested_values_rowids) = " 711 f"{len(nested_value_rowids)}.") 712 713 with ops.name_scope(name, "RaggedFromNestedValueRowIds", [flat_values] + 714 list(nested_value_rowids) + list(nested_nrows)): 715 result = flat_values 716 for value_rowids, nrows in reversed( 717 list(zip(nested_value_rowids, nested_nrows))): 718 result = cls.from_value_rowids( 719 result, value_rowids, nrows, validate=validate) 720 return result 721 722 @classmethod 723 @dispatch.add_dispatch_support 724 def from_nested_row_splits(cls, 725 flat_values, 726 nested_row_splits, 727 name=None, 728 validate=True): 729 """Creates a `RaggedTensor` from a nested list of `row_splits` tensors. 730 731 Equivalent to: 732 733 ```python 734 result = flat_values 735 for row_splits in reversed(nested_row_splits): 736 result = from_row_splits(result, row_splits) 737 ``` 738 739 Args: 740 flat_values: A potentially ragged tensor. 741 nested_row_splits: A list of 1-D integer tensors. The `i`th tensor is 742 used as the `row_splits` for the `i`th ragged dimension. 743 name: A name prefix for the RaggedTensor (optional). 744 validate: If true, then use assertions to check that the arguments form 745 a valid `RaggedTensor`. Note: these assertions incur a runtime cost, 746 since they must be checked for each tensor value. 747 748 Returns: 749 A `RaggedTensor` (or `flat_values` if `nested_row_splits` is empty). 750 """ 751 if not isinstance(validate, bool): 752 raise TypeError(f"Argument `validate` must have type bool. " 753 f"Received {validate}.") 754 if isinstance(nested_row_splits, ops.Tensor): 755 raise TypeError(f"Argument `nested_row_splits` must be a list of " 756 f"Tensors. Received {nested_row_splits}.") 757 with ops.name_scope(name, "RaggedFromNestedRowSplits", 758 [flat_values] + list(nested_row_splits)): 759 result = flat_values 760 for splits in reversed(nested_row_splits): 761 result = cls.from_row_splits(result, splits, validate=validate) 762 return result 763 764 @classmethod 765 @dispatch.add_dispatch_support 766 def from_nested_row_lengths(cls, 767 flat_values, 768 nested_row_lengths, 769 name=None, 770 validate=True): 771 """Creates a `RaggedTensor` from a nested list of `row_lengths` tensors. 772 773 Equivalent to: 774 775 ```python 776 result = flat_values 777 for row_lengths in reversed(nested_row_lengths): 778 result = from_row_lengths(result, row_lengths) 779 ``` 780 781 Args: 782 flat_values: A potentially ragged tensor. 783 nested_row_lengths: A list of 1-D integer tensors. The `i`th tensor is 784 used as the `row_lengths` for the `i`th ragged dimension. 785 name: A name prefix for the RaggedTensor (optional). 786 validate: If true, then use assertions to check that the arguments form 787 a valid `RaggedTensor`. Note: these assertions incur a runtime cost, 788 since they must be checked for each tensor value. 789 790 Returns: 791 A `RaggedTensor` (or `flat_values` if `nested_row_lengths` is empty). 792 """ 793 if not isinstance(validate, bool): 794 raise TypeError(f"Argument `validate` must have type bool. " 795 f"Received {validate}.") 796 if isinstance(nested_row_lengths, ops.Tensor): 797 raise TypeError(f"Argument `nested_row_lengths` must be a list of " 798 f"Tensors. Received {nested_row_lengths}.") 799 with ops.name_scope(name, "RaggedFromNestedRowlengths", 800 [flat_values] + list(nested_row_lengths)): 801 result = flat_values 802 for lengths in reversed(nested_row_lengths): 803 result = cls.from_row_lengths(result, lengths, validate=validate) 804 return result 805 806 @classmethod 807 def _from_nested_row_partitions(cls, 808 flat_values, 809 nested_row_partitions, 810 name=None, 811 validate=True): 812 """Creates a `RaggedTensor` from a nested list of row partitions. 813 814 Equivalent to: 815 816 ```python 817 result = flat_values 818 for row_partition in reversed(nested_row_partitions): 819 result = _from_row_partition(result, row_partition) 820 ``` 821 822 Args: 823 flat_values: A potentially ragged tensor. 824 nested_row_partitions: A list of row partitions. The `i`th element is 825 used as the row partition for the `i`th ragged dimension. 826 name: A name prefix for the RaggedTensor (optional). 827 validate: If true, then use assertions to check that the arguments form 828 a valid `RaggedTensor`. Note: these assertions incur a runtime cost, 829 since they must be checked for each tensor value. 830 831 Returns: 832 A `RaggedTensor` (or `flat_values` if `nested_row_lengths` is empty). 833 """ 834 if not isinstance(validate, bool): 835 raise TypeError(f"Argument `validate` must have type bool. " 836 f"Received {validate}.") 837 if isinstance(nested_row_partitions, RowPartition): 838 raise TypeError(f"Argument `nested_row_partitions` must be a list of " 839 f"RowPartitions. Received {nested_row_partitions}.") 840 if isinstance(nested_row_partitions, ops.Tensor): 841 raise TypeError(f"Argument `nested_row_partitions` must be a list of " 842 f"RowPartitions. Received {nested_row_partitions}.") 843 with ops.name_scope(name, "RaggedFromNestedRowPartitions", 844 [flat_values] + list(nested_row_partitions)): 845 result = flat_values 846 for partition in reversed(nested_row_partitions): 847 result = cls._from_row_partition(result, partition, validate=validate) 848 return result 849 850 @classmethod 851 def _convert_values_and_partition(cls, values, row_partition, name): 852 """Converts `values` and `partition` to Tensors. 853 854 If `values` is a `RaggedTensor`, then converts `values` and `partition` 855 to have compatible row-partitioning dtypes. In particular, if any of the 856 row partitioning tensors are `int64`, then all of the other row 857 partitioning tensors wil be cast to `int64` (if auto_cast_partition_dtype() 858 is true) or an error will be raised (if auto_cast_partition_dtype() is 859 false). 860 861 Args: 862 values: The `values` for the `RaggedTensor` being constructed. 863 row_partition: A RowPartition object for the `RaggedTensor` being 864 constructed. 865 name: The name of the RowPartition object. 866 867 Returns: 868 A tuple (values, partition). 869 """ 870 if not isinstance(row_partition, RowPartition): 871 raise TypeError(f"Argument `row_partition` must be a RowPartition. " 872 f"Received {row_partition}.") 873 if isinstance(values, RaggedTensor): 874 # pylint: disable=protected-access 875 if values._row_partition.dtype != row_partition.dtype: 876 if not ragged_config.auto_cast_partition_dtype(): 877 # pylint: disable=protected-access 878 # TODO(edloper): get rid of the `name` parameter. 879 raise ValueError( 880 f"Argument `row_partition` of RaggedTensor with name: {name} " 881 f"must have same dtype as Argument `values`. " 882 f"({row_partition.dtype} vs. {values._row_partition.dtype}).") 883 values = values.with_row_splits_dtype(row_partition.dtype) 884 else: 885 values = _convert_to_ragged_tensor_values(values) 886 887 return (values, row_partition) 888 889 #============================================================================= 890 # Accessors 891 #============================================================================= 892 893 @property 894 def dtype(self): 895 """The `DType` of values in this tensor.""" 896 return self._values.dtype 897 898 @property 899 def shape(self): 900 """The statically known shape of this ragged tensor. 901 902 Returns: 903 A `TensorShape` containing the statically known shape of this ragged 904 tensor. Ragged dimensions have a size of `None`. 905 906 Examples: 907 908 >>> tf.ragged.constant([[0], [1, 2]]).shape 909 TensorShape([2, None]) 910 911 >>> tf.ragged.constant([[[0, 1]], [[1, 2], [3, 4]]], ragged_rank=1).shape 912 TensorShape([2, None, 2]) 913 914 """ 915 nrows = self._row_partition.static_nrows 916 ncols = self._row_partition.static_uniform_row_length 917 value_shape = self._values.shape[1:] 918 return tensor_shape.TensorShape([nrows, ncols]).concatenate(value_shape) 919 920 def get_shape(self): 921 """The statically known shape of this ragged tensor. 922 923 Returns: 924 A `TensorShape` containing the statically known shape of this ragged 925 tensor. Ragged dimensions have a size of `None`. 926 927 Alias for `shape` property. 928 929 Examples: 930 931 >>> tf.ragged.constant([[0], [1, 2]]).get_shape() 932 TensorShape([2, None]) 933 934 >>> tf.ragged.constant( 935 ... [[[0, 1]], [[1, 2], [3, 4]]], ragged_rank=1).get_shape() 936 TensorShape([2, None, 2]) 937 938 """ 939 return self.shape 940 941 @property 942 def ragged_rank(self): 943 """The number of times the RaggedTensor's flat_values is partitioned. 944 945 Examples: 946 947 >>> values = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]]) 948 >>> values.ragged_rank 949 1 950 951 >>> rt = tf.RaggedTensor.from_uniform_row_length(values, 2) 952 >>> rt.ragged_rank 953 2 954 955 Returns: 956 A Python `int` indicating the number of times the underlying `flat_values` 957 Tensor has been partitioned to add a new dimension. 958 I.e., `tf.rank(rt) = tf.rank(rt.flat_values) + rt.ragged_rank`. 959 """ 960 values_is_ragged = isinstance(self._values, RaggedTensor) 961 return self._values.ragged_rank + 1 if values_is_ragged else 1 962 963 @property 964 def values(self): 965 """The concatenated rows for this ragged tensor. 966 967 `rt.values` is a potentially ragged tensor formed by flattening the two 968 outermost dimensions of `rt` into a single dimension. 969 970 `rt.values.shape = [nvals] + rt.shape[2:]` (where `nvals` is the 971 number of items in the outer two dimensions of `rt`). 972 973 `rt.ragged_rank = self.ragged_rank - 1` 974 975 Returns: 976 A potentially ragged tensor. 977 978 #### Example: 979 980 >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []]) 981 >>> print(rt.values) 982 tf.Tensor([3 1 4 1 5 9 2 6], shape=(8,), dtype=int32) 983 984 """ 985 return self._values 986 987 @property 988 def _nested_row_partitions(self): 989 """Returns the row partitions for this `RaggedTensor`.""" 990 partitions = [self._row_partition] 991 rt_values = self.values 992 while isinstance(rt_values, RaggedTensor): 993 # pylint: disable=protected-access 994 partitions.append(rt_values._row_partition) 995 rt_values = rt_values.values 996 return tuple(partitions) 997 998 @property 999 def row_splits(self): 1000 """The row-split indices for this ragged tensor's `values`. 1001 1002 `rt.row_splits` specifies where the values for each row begin and end in 1003 `rt.values`. In particular, the values for row `rt[i]` are stored in 1004 the slice `rt.values[rt.row_splits[i]:rt.row_splits[i+1]]`. 1005 1006 Returns: 1007 A 1-D integer `Tensor` with shape `[self.nrows+1]`. 1008 The returned tensor is non-empty, and is sorted in ascending order. 1009 `self.row_splits[0]` is zero, and `self.row_splits[-1]` is equal to 1010 `self.values.shape[0]`. 1011 1012 #### Example: 1013 1014 >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []]) 1015 >>> print(rt.row_splits) # indices of row splits in rt.values 1016 tf.Tensor([0 4 4 7 8 8], shape=(6,), dtype=int64) 1017 1018 """ 1019 return self._row_partition.row_splits() 1020 1021 @property 1022 def uniform_row_length(self): 1023 """The length of each row in this ragged tensor, or None if rows are ragged. 1024 1025 >>> rt1 = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]]) 1026 >>> print(rt1.uniform_row_length) # rows are ragged. 1027 None 1028 1029 >>> rt2 = tf.RaggedTensor.from_uniform_row_length( 1030 ... values=rt1, uniform_row_length=2) 1031 >>> print(rt2) 1032 <tf.RaggedTensor [[[1, 2, 3], [4]], [[5, 6], [7, 8, 9, 10]]]> 1033 >>> print(rt2.uniform_row_length) # rows are not ragged (all have size 2). 1034 tf.Tensor(2, shape=(), dtype=int64) 1035 1036 A RaggedTensor's rows are only considered to be uniform (i.e. non-ragged) 1037 if it can be determined statically (at graph construction time) that the 1038 rows all have the same length. 1039 1040 Returns: 1041 A scalar integer `Tensor`, specifying the length of every row in this 1042 ragged tensor (for ragged tensors whose rows are uniform); or `None` 1043 (for ragged tensors whose rows are ragged). 1044 """ 1045 return self._row_partition.uniform_row_length() 1046 1047 @property 1048 def flat_values(self): 1049 """The innermost `values` tensor for this ragged tensor. 1050 1051 Concretely, if `rt.values` is a `Tensor`, then `rt.flat_values` is 1052 `rt.values`; otherwise, `rt.flat_values` is `rt.values.flat_values`. 1053 1054 Conceptually, `flat_values` is the tensor formed by flattening the 1055 outermost dimension and all of the ragged dimensions into a single 1056 dimension. 1057 1058 `rt.flat_values.shape = [nvals] + rt.shape[rt.ragged_rank + 1:]` 1059 (where `nvals` is the number of items in the flattened dimensions). 1060 1061 Returns: 1062 A `Tensor`. 1063 1064 #### Example: 1065 1066 >>> rt = tf.ragged.constant([[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]]) 1067 >>> print(rt.flat_values) 1068 tf.Tensor([3 1 4 1 5 9 2 6], shape=(8,), dtype=int32) 1069 1070 """ 1071 rt_values = self.values 1072 while isinstance(rt_values, RaggedTensor): 1073 rt_values = rt_values.values 1074 return rt_values 1075 1076 @property 1077 def nested_row_splits(self): 1078 """A tuple containing the row_splits for all ragged dimensions. 1079 1080 `rt.nested_row_splits` is a tuple containing the `row_splits` tensors for 1081 all ragged dimensions in `rt`, ordered from outermost to innermost. In 1082 particular, `rt.nested_row_splits = (rt.row_splits,) + value_splits` where: 1083 1084 * `value_splits = ()` if `rt.values` is a `Tensor`. 1085 * `value_splits = rt.values.nested_row_splits` otherwise. 1086 1087 Returns: 1088 A `tuple` of 1-D integer `Tensor`s. 1089 1090 #### Example: 1091 1092 >>> rt = tf.ragged.constant( 1093 ... [[[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]]]) 1094 >>> for i, splits in enumerate(rt.nested_row_splits): 1095 ... print('Splits for dimension %d: %s' % (i+1, splits.numpy())) 1096 Splits for dimension 1: [0 3] 1097 Splits for dimension 2: [0 3 3 5] 1098 Splits for dimension 3: [0 4 4 7 8 8] 1099 1100 """ 1101 rt_nested_splits = [self.row_splits] 1102 rt_values = self.values 1103 while isinstance(rt_values, RaggedTensor): 1104 rt_nested_splits.append(rt_values.row_splits) 1105 rt_values = rt_values.values 1106 return tuple(rt_nested_splits) 1107 1108 def value_rowids(self, name=None): 1109 """Returns the row indices for the `values` in this ragged tensor. 1110 1111 `rt.value_rowids()` corresponds one-to-one with the outermost dimension of 1112 `rt.values`, and specifies the row containing each value. In particular, 1113 the row `rt[row]` consists of the values `rt.values[j]` where 1114 `rt.value_rowids()[j] == row`. 1115 1116 Args: 1117 name: A name prefix for the returned tensor (optional). 1118 1119 Returns: 1120 A 1-D integer `Tensor` with shape `self.values.shape[:1]`. 1121 The returned tensor is nonnegative, and is sorted in ascending order. 1122 1123 #### Example: 1124 1125 >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []]) 1126 >>> print(rt.values) 1127 tf.Tensor([3 1 4 1 5 9 2 6], shape=(8,), dtype=int32) 1128 >>> print(rt.value_rowids()) # corresponds 1:1 with rt.values 1129 tf.Tensor([0 0 0 0 2 2 2 3], shape=(8,), dtype=int64) 1130 1131 """ 1132 with ops.name_scope(name, "RaggedValueRowIds", [self]): 1133 return self._row_partition.value_rowids() 1134 1135 def nested_value_rowids(self, name=None): 1136 """Returns a tuple containing the value_rowids for all ragged dimensions. 1137 1138 `rt.nested_value_rowids` is a tuple containing the `value_rowids` tensors 1139 for 1140 all ragged dimensions in `rt`, ordered from outermost to innermost. In 1141 particular, `rt.nested_value_rowids = (rt.value_rowids(),) + value_ids` 1142 where: 1143 1144 * `value_ids = ()` if `rt.values` is a `Tensor`. 1145 * `value_ids = rt.values.nested_value_rowids` otherwise. 1146 1147 Args: 1148 name: A name prefix for the returned tensors (optional). 1149 1150 Returns: 1151 A `tuple` of 1-D integer `Tensor`s. 1152 1153 #### Example: 1154 1155 >>> rt = tf.ragged.constant( 1156 ... [[[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]]]) 1157 >>> for i, ids in enumerate(rt.nested_value_rowids()): 1158 ... print('row ids for dimension %d: %s' % (i+1, ids.numpy())) 1159 row ids for dimension 1: [0 0 0] 1160 row ids for dimension 2: [0 0 0 2 2] 1161 row ids for dimension 3: [0 0 0 0 2 2 2 3] 1162 1163 """ 1164 with ops.name_scope(name, "RaggedNestedValueRowIds", [self]): 1165 rt_nested_ids = [self.value_rowids()] 1166 rt_values = self.values 1167 while isinstance(rt_values, RaggedTensor): 1168 rt_nested_ids.append(rt_values.value_rowids()) 1169 rt_values = rt_values.values 1170 return tuple(rt_nested_ids) 1171 1172 def nrows(self, out_type=None, name=None): 1173 """Returns the number of rows in this ragged tensor. 1174 1175 I.e., the size of the outermost dimension of the tensor. 1176 1177 Args: 1178 out_type: `dtype` for the returned tensor. Defaults to 1179 `self.row_splits.dtype`. 1180 name: A name prefix for the returned tensor (optional). 1181 1182 Returns: 1183 A scalar `Tensor` with dtype `out_type`. 1184 1185 #### Example: 1186 1187 >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []]) 1188 >>> print(rt.nrows()) # rt has 5 rows. 1189 tf.Tensor(5, shape=(), dtype=int64) 1190 1191 """ 1192 with ops.name_scope(name, "RaggedNRows", [self]): 1193 return self._row_partition.nrows(out_type=out_type) 1194 1195 def row_starts(self, name=None): 1196 """Returns the start indices for rows in this ragged tensor. 1197 1198 These indices specify where the values for each row begin in 1199 `self.values`. `rt.row_starts()` is equal to `rt.row_splits[:-1]`. 1200 1201 Args: 1202 name: A name prefix for the returned tensor (optional). 1203 1204 Returns: 1205 A 1-D integer Tensor with shape `[nrows]`. 1206 The returned tensor is nonnegative, and is sorted in ascending order. 1207 1208 #### Example: 1209 1210 >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []]) 1211 >>> print(rt.values) 1212 tf.Tensor([3 1 4 1 5 9 2 6], shape=(8,), dtype=int32) 1213 >>> print(rt.row_starts()) # indices of row starts in rt.values 1214 tf.Tensor([0 4 4 7 8], shape=(5,), dtype=int64) 1215 1216 """ 1217 with ops.name_scope(name, "RaggedRowStarts", [self]): 1218 return self._row_partition.row_starts() 1219 1220 def row_limits(self, name=None): 1221 """Returns the limit indices for rows in this ragged tensor. 1222 1223 These indices specify where the values for each row end in 1224 `self.values`. `rt.row_limits(self)` is equal to `rt.row_splits[:-1]`. 1225 1226 Args: 1227 name: A name prefix for the returned tensor (optional). 1228 1229 Returns: 1230 A 1-D integer Tensor with shape `[nrows]`. 1231 The returned tensor is nonnegative, and is sorted in ascending order. 1232 1233 #### Example: 1234 1235 >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []]) 1236 >>> print(rt.values) 1237 tf.Tensor([3 1 4 1 5 9 2 6], shape=(8,), dtype=int32) 1238 >>> print(rt.row_limits()) # indices of row limits in rt.values 1239 tf.Tensor([4 4 7 8 8], shape=(5,), dtype=int64) 1240 1241 """ 1242 with ops.name_scope(name, "RaggedRowLimits", [self]): 1243 return self._row_partition.row_limits() 1244 1245 def row_lengths(self, axis=1, name=None): 1246 """Returns the lengths of the rows in this ragged tensor. 1247 1248 `rt.row_lengths()[i]` indicates the number of values in the 1249 `i`th row of `rt`. 1250 1251 Args: 1252 axis: An integer constant indicating the axis whose row lengths should be 1253 returned. 1254 name: A name prefix for the returned tensor (optional). 1255 1256 Returns: 1257 A potentially ragged integer Tensor with shape `self.shape[:axis]`. 1258 1259 Raises: 1260 ValueError: If `axis` is out of bounds. 1261 1262 #### Example: 1263 1264 >>> rt = tf.ragged.constant( 1265 ... [[[3, 1, 4], [1]], [], [[5, 9], [2]], [[6]], []]) 1266 >>> print(rt.row_lengths()) # lengths of rows in rt 1267 tf.Tensor([2 0 2 1 0], shape=(5,), dtype=int64) 1268 >>> print(rt.row_lengths(axis=2)) # lengths of axis=2 rows. 1269 <tf.RaggedTensor [[3, 1], [], [2, 1], [1], []]> 1270 1271 """ 1272 if axis == 0: 1273 return self._row_partition.nrows() 1274 1275 if axis == 1: 1276 return self._row_partition.row_lengths() 1277 1278 with ops.name_scope(name, "RaggedRowLengths", [self]): 1279 axis = array_ops.get_positive_axis( 1280 axis, self.shape.rank, ndims_name="rank(self)") 1281 if axis == 0: 1282 return self.nrows() 1283 elif axis == 1: 1284 splits = self.row_splits 1285 return splits[1:] - splits[:-1] 1286 elif isinstance(self.values, RaggedTensor): 1287 return self.with_values(self.values.row_lengths(axis - 1)) 1288 else: 1289 shape = array_ops.shape(self.values, out_type=self._row_partition.dtype) 1290 return self.with_values( 1291 array_ops.ones(shape[:axis - 1], self._row_partition.dtype) * 1292 shape[axis - 1]) 1293 1294 def nested_row_lengths(self, name=None): 1295 """Returns a tuple containing the row_lengths for all ragged dimensions. 1296 1297 `rt.nested_row_lengths()` is a tuple containing the `row_lengths` tensors 1298 for all ragged dimensions in `rt`, ordered from outermost to innermost. 1299 1300 Args: 1301 name: A name prefix for the returned tensors (optional). 1302 1303 Returns: 1304 A `tuple` of 1-D integer `Tensors`. The length of the tuple is equal to 1305 `self.ragged_rank`. 1306 """ 1307 with ops.name_scope(name, "RaggedNestedRowLengths", [self]): 1308 rt_nested_row_lengths = [] 1309 rt = self 1310 while isinstance(rt, RaggedTensor): 1311 rt_nested_row_lengths.append(rt.row_lengths()) 1312 rt = rt.values 1313 return tuple(rt_nested_row_lengths) 1314 1315 def bounding_shape(self, axis=None, name=None, out_type=None): 1316 """Returns the tight bounding box shape for this `RaggedTensor`. 1317 1318 Args: 1319 axis: An integer scalar or vector indicating which axes to return the 1320 bounding box for. If not specified, then the full bounding box is 1321 returned. 1322 name: A name prefix for the returned tensor (optional). 1323 out_type: `dtype` for the returned tensor. Defaults to 1324 `self.row_splits.dtype`. 1325 1326 Returns: 1327 An integer `Tensor` (`dtype=self.row_splits.dtype`). If `axis` is not 1328 specified, then `output` is a vector with 1329 `output.shape=[self.shape.ndims]`. If `axis` is a scalar, then the 1330 `output` is a scalar. If `axis` is a vector, then `output` is a vector, 1331 where `output[i]` is the bounding size for dimension `axis[i]`. 1332 1333 #### Example: 1334 1335 >>> rt = tf.ragged.constant([[1, 2, 3, 4], [5], [], [6, 7, 8, 9], [10]]) 1336 >>> rt.bounding_shape().numpy() 1337 array([5, 4]) 1338 1339 """ 1340 if out_type is None: 1341 out_type = self._row_partition.dtype 1342 else: 1343 out_type = dtypes.as_dtype(out_type) 1344 with ops.name_scope(name, "RaggedBoundingBox", [self, axis]): 1345 nested_splits = self.nested_row_splits 1346 rt_flat_values = self.flat_values 1347 1348 # Optimized special cases for when axis=0 or axis=1: 1349 if isinstance(axis, int): 1350 if axis == 0: 1351 return array_ops.shape(nested_splits[0], out_type=out_type)[0] - 1 1352 elif axis == 1: 1353 result = math_ops.maximum(math_ops.reduce_max(self.row_lengths()), 0) 1354 if out_type != self._row_partition.dtype: 1355 result = math_ops.cast(result, out_type) 1356 return result 1357 1358 splits_shape = array_ops.shape(self.row_splits, out_type=out_type) 1359 flat_values_shape = array_ops.shape(rt_flat_values, out_type=out_type) 1360 1361 ragged_dimensions = [splits_shape[0] - 1] + [ 1362 math_ops.maximum(math_ops.reduce_max(splits[1:] - splits[:-1]), 0) 1363 for splits in nested_splits 1364 ] 1365 inner_dimensions = flat_values_shape[1:] 1366 1367 if out_type != self._row_partition.dtype: 1368 ragged_dimensions = [ 1369 math_ops.cast(d, out_type) for d in ragged_dimensions 1370 ] 1371 bbox = array_ops.concat( 1372 [array_ops.stack(ragged_dimensions), inner_dimensions], axis=0) 1373 return bbox if axis is None else array_ops.gather(bbox, axis) 1374 1375 #============================================================================= 1376 # Transformation 1377 #============================================================================= 1378 1379 def with_values(self, new_values): 1380 """Returns a copy of `self` with `values` replaced by `new_value`. 1381 1382 Preserves cached row-partitioning tensors such as `self.cached_nrows` and 1383 `self.cached_value_rowids` if they have values. 1384 1385 Args: 1386 new_values: Potentially ragged tensor to use as the `values` for the 1387 returned `RaggedTensor`. Must have `rank > 0`, and must have the same 1388 number of rows as `self.values`. 1389 1390 Returns: 1391 A `RaggedTensor`. `result.rank = 1 + new_values.rank`. 1392 `result.ragged_rank = 1 + new_values.ragged_rank` 1393 """ 1394 new_values = _convert_to_ragged_tensor_values(new_values) 1395 new_values.shape.with_rank_at_least(1) 1396 self.values.shape[:1].assert_is_compatible_with(new_values.shape[:1]) 1397 if (isinstance(new_values, RaggedTensor) and 1398 self._row_partition.dtype != new_values.row_splits.dtype): 1399 if not ragged_config.auto_cast_partition_dtype(): 1400 raise ValueError("self and new_values have mismatched row_splits " 1401 "dtypes; use RaggedTensor.with_row_splits_dtype() to " 1402 "convert them to compatible dtypes.") 1403 new_values = new_values.with_row_splits_dtype(dtypes.int64) 1404 return self.with_row_splits_dtype(dtypes.int64).with_values(new_values) 1405 return RaggedTensor( 1406 values=new_values, row_partition=self._row_partition, internal=True) 1407 1408 def with_flat_values(self, new_values): 1409 """Returns a copy of `self` with `flat_values` replaced by `new_value`. 1410 1411 Preserves cached row-partitioning tensors such as `self.cached_nrows` and 1412 `self.cached_value_rowids` if they have values. 1413 1414 Args: 1415 new_values: Potentially ragged tensor that should replace 1416 `self.flat_values`. Must have `rank > 0`, and must have the same number 1417 of rows as `self.flat_values`. 1418 1419 Returns: 1420 A `RaggedTensor`. 1421 `result.rank = self.ragged_rank + new_values.rank`. 1422 `result.ragged_rank = self.ragged_rank + new_values.ragged_rank`. 1423 """ 1424 if isinstance(self._values, RaggedTensor): 1425 return self.with_values(self.values.with_flat_values(new_values)) 1426 else: 1427 new_values = _convert_to_ragged_tensor_values(new_values) 1428 return self.with_values(new_values) 1429 1430 def with_row_splits_dtype(self, dtype): 1431 """Returns a copy of this RaggedTensor with the given `row_splits` dtype. 1432 1433 For RaggedTensors with multiple ragged dimensions, the `row_splits` for all 1434 nested `RaggedTensor` objects are cast to the given dtype. 1435 1436 Args: 1437 dtype: The dtype for `row_splits`. One of `tf.int32` or `tf.int64`. 1438 1439 Returns: 1440 A copy of this RaggedTensor, with the `row_splits` cast to the given 1441 type. 1442 """ 1443 dtype = dtypes.as_dtype(dtype) 1444 if dtype not in (dtypes.int32, dtypes.int64): 1445 raise ValueError(f"Argument `row_splits` dtype must be int32 or int64. " 1446 f"Received {dtype}.") 1447 if self._row_partition.dtype == dtype: 1448 return self 1449 current_values = self._values 1450 if isinstance(current_values, RaggedTensor): 1451 return RaggedTensor( 1452 values=current_values.with_row_splits_dtype(dtype), 1453 row_partition=self._row_partition.with_row_splits_dtype(dtype), 1454 internal=True) 1455 else: 1456 return RaggedTensor( 1457 values=current_values, 1458 row_partition=self._row_partition.with_row_splits_dtype(dtype), 1459 internal=True) 1460 1461 def merge_dims(self, outer_axis, inner_axis): 1462 """Merges outer_axis...inner_axis into a single dimension. 1463 1464 Returns a copy of this RaggedTensor with the specified range of dimensions 1465 flattened into a single dimension, with elements in row-major order. 1466 1467 #### Examples: 1468 1469 >>> rt = tf.ragged.constant([[[1, 2], [3]], [[4, 5, 6]]]) 1470 >>> print(rt.merge_dims(0, 1)) 1471 <tf.RaggedTensor [[1, 2], [3], [4, 5, 6]]> 1472 >>> print(rt.merge_dims(1, 2)) 1473 <tf.RaggedTensor [[1, 2, 3], [4, 5, 6]]> 1474 >>> print(rt.merge_dims(0, 2)) 1475 tf.Tensor([1 2 3 4 5 6], shape=(6,), dtype=int32) 1476 1477 To mimic the behavior of `np.flatten` (which flattens all dimensions), use 1478 `rt.merge_dims(0, -1). To mimic the behavior of `tf.layers.Flatten` (which 1479 flattens all dimensions except the outermost batch dimension), use 1480 `rt.merge_dims(1, -1)`. 1481 1482 Args: 1483 outer_axis: `int`: The first dimension in the range of dimensions to 1484 merge. May be negative if `self.shape.rank` is statically known. 1485 inner_axis: `int`: The last dimension in the range of dimensions to merge. 1486 May be negative if `self.shape.rank` is statically known. 1487 1488 Returns: 1489 A copy of this tensor, with the specified dimensions merged into a 1490 single dimension. The shape of the returned tensor will be 1491 `self.shape[:outer_axis] + [N] + self.shape[inner_axis + 1:]`, where `N` 1492 is the total number of slices in the merged dimensions. 1493 """ 1494 outer_axis = array_ops.get_positive_axis( 1495 outer_axis, 1496 self.shape.rank, 1497 axis_name="outer_axis", 1498 ndims_name="rank(self)") 1499 inner_axis = array_ops.get_positive_axis( 1500 inner_axis, 1501 self.shape.rank, 1502 axis_name="inner_axis", 1503 ndims_name="rank(self)") 1504 if not outer_axis <= inner_axis: 1505 raise ValueError(f"Expected outer_axis ({outer_axis}) to be less than or " 1506 f"equal to inner_axis ({inner_axis}).") 1507 return merge_dims(self, outer_axis, inner_axis) 1508 1509 def _set_shape(self, shape): 1510 """Updates the static shape of `self` to be `shape`. 1511 1512 * If a dimension of `shape` has known rank, and is encoded via 1513 partitioning, then this will update the corresponding partition to 1514 define `_uniform_row_length` and `nrows`. 1515 * If a dimension of `shape` has a known rank, and is encoded as one 1516 of the `flat_values` dimensions, then `flat_values.set_shape()` will 1517 be used to update its shape. 1518 1519 Warning: Using this method to assert an incorrect shape for a RaggedTensor 1520 (i.e., one that's not consistent with its actual shape) can cause 1521 segmentation faults and very difficult-to-diagnose behavior. Only use this 1522 method if you are certain that the shape is correct. 1523 1524 Args: 1525 shape: `tf.TensorShape` specifying the shape for this `RaggedTensor`. 1526 """ 1527 # TODO(edloper): Refactor this to not directly access private members 1528 # of RowPartition. 1529 # pylint: disable=protected-access 1530 1531 shape = tensor_shape.as_shape(shape) 1532 if shape.rank is None: 1533 return # Nothing to do. 1534 1535 shape = shape.as_list() 1536 1537 # Outermost dimension 1538 if shape[0] is not None: 1539 self._row_partition._row_splits.set_shape(shape[0] + 1) 1540 1541 # Partitioned dimensions 1542 dtype = self._row_partition.dtype 1543 for i, partition in enumerate(self._nested_row_partitions): 1544 size = shape[i + 1] 1545 if size is not None: 1546 if partition._uniform_row_length is not None: 1547 old_row_length = tensor_util.constant_value( 1548 partition._uniform_row_length) 1549 if old_row_length is not None: 1550 if size == old_row_length: 1551 continue # already have shape info for this axis. 1552 else: 1553 raise ValueError(f"Inconsistent size for axis {i + 1}: " 1554 f"{old_row_length} vs. {size}.") 1555 partition._uniform_row_length = ops.convert_to_tensor(size, dtype) 1556 if partition._nrows is None: 1557 partition._nrows = array_ops.size(partition._row_splits) - 1 1558 1559 # Inner dimensions 1560 flat_shape = tensor_shape.as_shape([None] + shape[self.ragged_rank + 1:]) 1561 self.flat_values.set_shape(flat_shape) 1562 1563 #============================================================================= 1564 # Tensor Type Conversions 1565 #============================================================================= 1566 1567 @classmethod 1568 @dispatch.add_dispatch_support 1569 def from_tensor(cls, 1570 tensor, 1571 lengths=None, 1572 padding=None, 1573 ragged_rank=1, 1574 name=None, 1575 row_splits_dtype=dtypes.int64): 1576 """Converts a `tf.Tensor` into a `RaggedTensor`. 1577 1578 The set of absent/default values may be specified using a vector of lengths 1579 or a padding value (but not both). If `lengths` is specified, then the 1580 output tensor will satisfy `output[row] = tensor[row][:lengths[row]]`. If 1581 'lengths' is a list of lists or tuple of lists, those lists will be used 1582 as nested row lengths. If `padding` is specified, then any row *suffix* 1583 consisting entirely of `padding` will be excluded from the returned 1584 `RaggedTensor`. If neither `lengths` nor `padding` is specified, then the 1585 returned `RaggedTensor` will have no absent/default values. 1586 1587 Examples: 1588 1589 >>> dt = tf.constant([[5, 7, 0], [0, 3, 0], [6, 0, 0]]) 1590 >>> tf.RaggedTensor.from_tensor(dt) 1591 <tf.RaggedTensor [[5, 7, 0], [0, 3, 0], [6, 0, 0]]> 1592 >>> tf.RaggedTensor.from_tensor(dt, lengths=[1, 0, 3]) 1593 <tf.RaggedTensor [[5], [], [6, 0, 0]]> 1594 1595 >>> tf.RaggedTensor.from_tensor(dt, padding=0) 1596 <tf.RaggedTensor [[5, 7], [0, 3], [6]]> 1597 1598 >>> dt = tf.constant([[[5, 0], [7, 0], [0, 0]], 1599 ... [[0, 0], [3, 0], [0, 0]], 1600 ... [[6, 0], [0, 0], [0, 0]]]) 1601 >>> tf.RaggedTensor.from_tensor(dt, lengths=([2, 0, 3], [1, 1, 2, 0, 1])) 1602 <tf.RaggedTensor [[[5], [7]], [], [[6, 0], [], [0]]]> 1603 1604 Args: 1605 tensor: The `Tensor` to convert. Must have rank `ragged_rank + 1` or 1606 higher. 1607 lengths: An optional set of row lengths, specified using a 1-D integer 1608 `Tensor` whose length is equal to `tensor.shape[0]` (the number of rows 1609 in `tensor`). If specified, then `output[row]` will contain 1610 `tensor[row][:lengths[row]]`. Negative lengths are treated as zero. You 1611 may optionally pass a list or tuple of lengths to this argument, which 1612 will be used as nested row lengths to construct a ragged tensor with 1613 multiple ragged dimensions. 1614 padding: An optional padding value. If specified, then any row suffix 1615 consisting entirely of `padding` will be excluded from the returned 1616 RaggedTensor. `padding` is a `Tensor` with the same dtype as `tensor` 1617 and with `shape=tensor.shape[ragged_rank + 1:]`. 1618 ragged_rank: Integer specifying the ragged rank for the returned 1619 `RaggedTensor`. Must be greater than zero. 1620 name: A name prefix for the returned tensors (optional). 1621 row_splits_dtype: `dtype` for the returned `RaggedTensor`'s `row_splits` 1622 tensor. One of `tf.int32` or `tf.int64`. 1623 1624 Returns: 1625 A `RaggedTensor` with the specified `ragged_rank`. The shape of the 1626 returned ragged tensor is compatible with the shape of `tensor`. 1627 Raises: 1628 ValueError: If both `lengths` and `padding` are specified. 1629 """ 1630 row_splits_dtype = dtypes.as_dtype(row_splits_dtype) 1631 if lengths is not None and padding is not None: 1632 raise ValueError("Specify argument `lengths` or `padding`, but not both.") 1633 if not isinstance(ragged_rank, int): 1634 raise TypeError(f"Argument `ragged_rank` must be an int. " 1635 f"Received {ragged_rank}.") 1636 if ragged_rank <= 0: 1637 raise ValueError(f"Argument `ragged_rank` must be greater than 0. " 1638 f"Received {ragged_rank}.") 1639 1640 with ops.name_scope(name, "RaggedFromTensor", [tensor, lengths, padding]): 1641 tensor = ops.convert_to_tensor(tensor, name="tensor") 1642 tensor.shape.with_rank_at_least(ragged_rank + 1) 1643 input_shape = array_ops.shape(tensor, out_type=row_splits_dtype) 1644 ncols = input_shape[1] 1645 1646 # Handle nested row lengths. 1647 if (lengths is not None and isinstance(lengths, (list, tuple)) and 1648 len(lengths) and not isinstance(lengths[0], (int, float))): 1649 if ragged_rank not in (1, len(lengths)): 1650 # Note: we accept `ragged_rank=1` here because it's the default value; 1651 # i.e., if the user passes in a tuple of lengths, but doesn't specify 1652 # ragged_rank, then we should use that tuple to determine ragged_rank. 1653 # We only want to complain if they pass in an explicit ragged_rank 1654 # that doesn't match len(lengths). 1655 raise ValueError(f"If Argument `lengths` is a tuple of row_lengths, " 1656 f"argument `ragged_rank` must be " 1657 f"len(lengths): {len(lengths)}. Received " 1658 f"ragged_rank: {ragged_rank}.") 1659 # Rather than reconstructing the tensor mask directly, we can 1660 # recreate it as a boolean RaggedTensor, then densify that and use 1661 # that as the mask to clear out the unused data in the passed tensor. 1662 tensor.shape.with_rank_at_least(len(lengths) + 1) 1663 num_tokens = math_ops.reduce_sum(lengths[-1]) 1664 ones_mask = array_ops.ones([num_tokens], dtype=dtypes.bool) 1665 ragged_mask = cls.from_nested_row_lengths( 1666 ones_mask, lengths, validate=False) 1667 dense_ragged_mask = ragged_mask.to_tensor(default_value=False) 1668 masked_data = array_ops.boolean_mask(tensor, dense_ragged_mask) 1669 return cls.from_nested_row_lengths(masked_data, lengths, validate=False) 1670 1671 # Handle ragged_rank>1 via recursion: 1672 # If the output should have multiple ragged dimensions, then first 1673 # flatten the tensor to eliminate all but the last ragged dimension, 1674 # and recursively convert that flattened tensor. Then add on the splits 1675 # for the dimensions that we flattened out. 1676 if ragged_rank > 1: 1677 if tensor.shape.is_fully_defined(): 1678 input_shape = tensor.shape.as_list() 1679 # The total number of elements in each dimension. E.g., if 1680 # input_shape=[3, 4, 5, 6], then dim[2] has 3*4*5 elements in total. 1681 dim_size = np.cumprod(input_shape) 1682 new_shape = [dim_size[ragged_rank - 1]] + input_shape[ragged_rank:] 1683 else: 1684 dim_size = math_ops.cumprod(input_shape) 1685 new_shape = array_ops.concat( 1686 [[dim_size[ragged_rank - 1]], input_shape[ragged_rank:]], axis=0) 1687 flattened = array_ops.reshape(tensor, new_shape) 1688 result = cls.from_tensor( 1689 flattened, lengths, padding, row_splits_dtype=row_splits_dtype) 1690 1691 for axis in range(ragged_rank - 1, 0, -1): 1692 dim_len = tensor_shape.dimension_at_index(tensor.shape, axis).value 1693 if dim_len is None: 1694 dim_len = input_shape[axis] 1695 else: 1696 dim_len = constant_op.constant(dim_len, row_splits_dtype) 1697 result = RaggedTensor.from_uniform_row_length( 1698 values=result, 1699 uniform_row_length=dim_len, 1700 nrows=dim_size[axis - 1], 1701 validate=False) 1702 return result 1703 1704 # If padding was specified, then use it to find row lengths. 1705 if padding is not None: 1706 padding = ops.convert_to_tensor( 1707 padding, name="padding", dtype=tensor.dtype) 1708 padding.shape.assert_is_compatible_with(tensor.shape[2:]) 1709 1710 # Find places where the padding is equal to the tensor. (This will 1711 # broadcast `padding` across the outermost 2 dimensions of `tensor`, 1712 # so `has_default_value.shape = tensor.shape`.) 1713 has_default_value = math_ops.equal(padding, tensor) 1714 1715 # If the padding isn't a scalar, then require that all values in the 1716 # padding match each item in the tensor. After this block of code, 1717 # `has_default.shape = tensor.shape[:2]`. (Unfortunately, we can't just 1718 # use reduce_all for both cases, becaue when you pass an empty `axis` 1719 # list to reduce_all, it reduces all axes; but we want it to reduce no 1720 # axes -- i.e., to be a no-op.) 1721 tensor_rank = array_ops.rank(tensor) 1722 reduce_axis = math_ops.range(2, tensor_rank) 1723 has_default = control_flow_ops.cond( 1724 tensor_rank > 2, 1725 lambda: math_ops.reduce_all(has_default_value, axis=reduce_axis), 1726 lambda: has_default_value) 1727 has_default.set_shape(tensor_shape.TensorShape([None, None])) 1728 has_default.set_shape(tensor.shape[:2]) 1729 1730 # Use has_default to find the length of each row: for each 1731 # non-default item in a row, calculate the length that the row needs to 1732 # have to include that item; and then take the max of those values 1733 # (across each row). 1734 has_nondefault = math_ops.logical_not(has_default) 1735 has_nondefault = math_ops.cast(has_nondefault, row_splits_dtype) 1736 length_for_nondefault_value = ( 1737 has_nondefault * 1738 array_ops.expand_dims(math_ops.range(1, ncols + 1), 0)) 1739 lengths = math_ops.reduce_max(length_for_nondefault_value, axis=1) 1740 1741 if lengths is not None: 1742 # If we have lengths (either directly supplied, or computed from 1743 # paddings), then use those to construct splits; and then use masking 1744 # to get the corresponding values. 1745 lengths = ragged_util.convert_to_int_tensor(lengths, "lengths", 1746 row_splits_dtype) 1747 lengths.shape.assert_has_rank(1) 1748 lengths = math_ops.minimum(lengths, ncols) 1749 lengths = math_ops.maximum(lengths, 0) 1750 limits = math_ops.cumsum(lengths) 1751 splits = array_ops.concat( 1752 [array_ops.zeros([1], row_splits_dtype), limits], axis=0) 1753 mask = array_ops.sequence_mask(lengths, maxlen=ncols) 1754 values = array_ops.boolean_mask(tensor, mask) 1755 return cls.from_row_splits(values, splits, validate=False) 1756 1757 # If neither padding nor lengths were specified, then create a splits 1758 # vector that contains no default values, and reshape the input tensor 1759 # to form the values for the RaggedTensor. 1760 values_shape = array_ops.concat( 1761 [[input_shape[0] * input_shape[1]], input_shape[2:]], axis=0) 1762 values = array_ops.reshape(tensor, values_shape) 1763 const_nrows = tensor_shape.dimension_at_index(tensor.shape, 0).value 1764 const_ncols = tensor_shape.dimension_at_index(tensor.shape, 1).value 1765 if const_nrows is not None: 1766 nrows = constant_op.constant(const_nrows, row_splits_dtype) 1767 else: 1768 nrows = input_shape[0] 1769 if const_ncols is not None: 1770 ncols = constant_op.constant(const_ncols, row_splits_dtype) 1771 else: 1772 ncols = input_shape[1] 1773 return RaggedTensor.from_uniform_row_length( 1774 values=values, uniform_row_length=ncols, nrows=nrows, validate=False) 1775 1776 def to_tensor(self, default_value=None, name=None, shape=None): 1777 """Converts this `RaggedTensor` into a `tf.Tensor`. 1778 1779 If `shape` is specified, then the result is padded and/or truncated to 1780 the specified shape. 1781 1782 Examples: 1783 1784 >>> rt = tf.ragged.constant([[9, 8, 7], [], [6, 5], [4]]) 1785 >>> print(rt.to_tensor()) 1786 tf.Tensor( 1787 [[9 8 7] [0 0 0] [6 5 0] [4 0 0]], shape=(4, 3), dtype=int32) 1788 >>> print(rt.to_tensor(shape=[5, 2])) 1789 tf.Tensor( 1790 [[9 8] [0 0] [6 5] [4 0] [0 0]], shape=(5, 2), dtype=int32) 1791 1792 Args: 1793 default_value: Value to set for indices not specified in `self`. Defaults 1794 to zero. `default_value` must be broadcastable to 1795 `self.shape[self.ragged_rank + 1:]`. 1796 name: A name prefix for the returned tensors (optional). 1797 shape: The shape of the resulting dense tensor. In particular, 1798 `result.shape[i]` is `shape[i]` (if `shape[i]` is not None), or 1799 `self.bounding_shape(i)` (otherwise).`shape.rank` must be `None` or 1800 equal to `self.rank`. 1801 1802 Returns: 1803 A `Tensor` with shape `ragged.bounding_shape(self)` and the 1804 values specified by the non-empty values in `self`. Empty values are 1805 assigned `default_value`. 1806 """ 1807 with ops.name_scope(name, "RaggedToTensor", [self, default_value, shape]): 1808 if default_value is not None: 1809 default_value = ops.convert_to_tensor( 1810 default_value, name="default_value", dtype=self.dtype) 1811 type_tensor_pairs = _get_row_partition_type_tensor_pairs(self) 1812 row_partition_types = [x[0] for x in type_tensor_pairs] 1813 row_partition_tensors = [x[1] for x in type_tensor_pairs] 1814 if default_value is None: 1815 default_value = array_ops.zeros((), self.dtype) 1816 1817 if (isinstance(shape, (list, tuple)) and 1818 any(isinstance(v, ops.Tensor) for v in shape) and 1819 all(isinstance(v, (int, ops.Tensor)) for v in shape)): 1820 shape = array_ops.stack(shape) 1821 1822 shape_tensor = _shape_as_tensor(shape, row_partition_tensors[0].dtype) 1823 tensor = gen_ragged_conversion_ops.ragged_tensor_to_tensor( 1824 shape=shape_tensor, 1825 values=self.flat_values, 1826 default_value=default_value, 1827 row_partition_types=row_partition_types, 1828 row_partition_tensors=row_partition_tensors) 1829 1830 ragged_shape = self.shape 1831 1832 if ragged_shape.rank is not None and not isinstance(shape, ops.Tensor): 1833 # Merged self.shape and shape, favoring the second one as it takes 1834 # into account potential padding added to the output. 1835 shape = tensor_shape.as_shape(shape) 1836 if shape.rank is None: 1837 output_shape = ragged_shape 1838 else: 1839 # At this point we can assume that hshape.rank == ragged_shape.rank 1840 # because otherwise it would have failed earlier. 1841 output_shape = [ 1842 s1 if s1 is not None else s2 1843 for (s1, s2) in zip(shape.as_list(), ragged_shape.as_list()) 1844 ] 1845 tensor.set_shape(output_shape) 1846 1847 return tensor 1848 1849 @classmethod 1850 @dispatch.add_dispatch_support 1851 def from_sparse(cls, st_input, name=None, row_splits_dtype=dtypes.int64): 1852 """Converts a 2D `tf.sparse.SparseTensor` to a `RaggedTensor`. 1853 1854 Each row of the `output` `RaggedTensor` will contain the explicit values 1855 from the same row in `st_input`. `st_input` must be ragged-right. If not 1856 it is not ragged-right, then an error will be generated. 1857 1858 Example: 1859 1860 >>> indices = [[0, 0], [0, 1], [0, 2], [1, 0], [3, 0]] 1861 >>> st = tf.sparse.SparseTensor(indices=indices, 1862 ... values=[1, 2, 3, 4, 5], 1863 ... dense_shape=[4, 3]) 1864 >>> tf.RaggedTensor.from_sparse(st).to_list() 1865 [[1, 2, 3], [4], [], [5]] 1866 1867 Currently, only two-dimensional `SparseTensors` are supported. 1868 1869 Args: 1870 st_input: The sparse tensor to convert. Must have rank 2. 1871 name: A name prefix for the returned tensors (optional). 1872 row_splits_dtype: `dtype` for the returned `RaggedTensor`'s `row_splits` 1873 tensor. One of `tf.int32` or `tf.int64`. 1874 1875 Returns: 1876 A `RaggedTensor` with the same values as `st_input`. 1877 `output.ragged_rank = rank(st_input) - 1`. 1878 `output.shape = [st_input.dense_shape[0], None]`. 1879 Raises: 1880 ValueError: If the number of dimensions in `st_input` is not known 1881 statically, or is not two. 1882 """ 1883 row_splits_dtype = dtypes.as_dtype(row_splits_dtype) 1884 if not sparse_tensor.is_sparse(st_input): 1885 raise TypeError(f"Argument `st_input` must be of type SparseTensor, but " 1886 f"is of type {type(st_input).__name__}.") 1887 with ops.name_scope(name, "RaggedFromSparse", [st_input]): 1888 st_input = sparse_tensor.convert_to_tensor_or_sparse_tensor( 1889 st_input, name="st_input") 1890 1891 if st_input.dense_shape.shape.ndims is None: 1892 static_rank_from_dense_shape = None 1893 else: 1894 static_rank_from_dense_shape = st_input.dense_shape.shape.dims[0].value 1895 1896 if st_input.indices.shape.ndims is None: 1897 static_rank_from_indices = None 1898 else: 1899 static_rank_from_indices = st_input.indices.shape.dims[1].value 1900 1901 if static_rank_from_dense_shape != 2 and static_rank_from_indices != 2: 1902 raise ValueError("rank(st_input) must be 2.") 1903 1904 with ops.control_dependencies( 1905 _assert_sparse_indices_are_ragged_right(st_input.indices)): 1906 # Treat sparse row indices as segment ids to generate a splits tensor 1907 # thta we can pair with the sparse tensor values. (Ignore sparse column 1908 # indices.) 1909 segment_ids = math_ops.cast(st_input.indices[:, 0], row_splits_dtype) 1910 num_segments = math_ops.cast(st_input.dense_shape[0], row_splits_dtype) 1911 return cls.from_value_rowids( 1912 st_input.values, segment_ids, num_segments, validate=False) 1913 1914 def to_sparse(self, name=None): 1915 """Converts this `RaggedTensor` into a `tf.sparse.SparseTensor`. 1916 1917 Example: 1918 1919 >>> rt = tf.ragged.constant([[1, 2, 3], [4], [], [5, 6]]) 1920 >>> print(rt.to_sparse()) 1921 SparseTensor(indices=tf.Tensor( 1922 [[0 0] [0 1] [0 2] [1 0] [3 0] [3 1]], 1923 shape=(6, 2), dtype=int64), 1924 values=tf.Tensor([1 2 3 4 5 6], shape=(6,), dtype=int32), 1925 dense_shape=tf.Tensor([4 3], shape=(2,), dtype=int64)) 1926 1927 Args: 1928 name: A name prefix for the returned tensors (optional). 1929 1930 Returns: 1931 A SparseTensor with the same values as `self`. 1932 """ 1933 with ops.name_scope(name, "RaggedToSparse", [self]): 1934 result = gen_ragged_conversion_ops.ragged_tensor_to_sparse( 1935 self.nested_row_splits, self.flat_values, name=name) 1936 return sparse_tensor.SparseTensor(result.sparse_indices, 1937 result.sparse_values, 1938 result.sparse_dense_shape) 1939 1940 @classmethod 1941 def _from_variant(cls, 1942 variant, 1943 dtype, 1944 output_ragged_rank, 1945 input_ragged_rank=None, 1946 row_splits_dtype=dtypes.int64, 1947 name=None): 1948 """Converts a `variant` Tensor into a `RaggedTensor`. 1949 1950 The input `variant` could be a scalar, meaning it encodes a single 1951 `RaggedTensor` with ragged_rank `output_ragged_rank`. Alternatively it could 1952 have an arbitrary rank, in which case each element is decoded into a 1953 `RaggedTensor` with ragged_rank `input_ragged_rank` and these are then 1954 stacked according to the input shape to output a single `RaggedTensor` 1955 with ragged_rank `output_ragged_rank`. If `input_ragged_rank` is not 1956 provided, it is inferred dynamically as `output_ragged_rank` - 1957 `rank(variant)`. If `input_ragged_rank` is provided, the following must be 1958 true: `output_ragged_rank` = `input_ragged_rank` + `rank(variant)`. 1959 1960 Example: 1961 1962 >>> rt = tf.ragged.constant([[0], [1, 2]]) 1963 >>> et = rt._to_variant() 1964 >>> stacked_et = tf.stack([et, et]) 1965 >>> tf.RaggedTensor._from_variant( # scalar input. 1966 ... et, dtype=tf.int32, output_ragged_rank=1).to_list() 1967 [[0], [1, 2]] 1968 >>> tf.RaggedTensor._from_variant( # batched input. 1969 ... stacked_et, dtype=tf.int32, output_ragged_rank=2).to_list() 1970 [[[0], [1, 2]], [[0], [1, 2]]] 1971 1972 Args: 1973 variant: A `variant` Tensor representing an encoded (possibly 1974 nested-batched) `RaggedTensor`. 1975 dtype: The dtype of the encoded `RaggedTensor`. 1976 output_ragged_rank: The expected ragged rank of the output `RaggedTensor`. 1977 input_ragged_rank: The ragged rank of each encoded `RaggedTensor`. This is 1978 optional and inferred dynamically if not provided. 1979 row_splits_dtype: `dtype` for the RaggedTensor's `row_splits` tensor. One 1980 of `tf.int32` or `tf.int64`. 1981 name: A name prefix for the returned tensors (optional). 1982 1983 Returns: 1984 A `RaggedTensor` of dtype `dtype` and ragged rank `output_ragged_rank`. 1985 1986 Raises: 1987 ValueError: If the input rank is known, `input_ragged_rank` is provided 1988 and `output_ragged_rank` = `input_ragged_rank` + `rank(variant)` does 1989 not hold. 1990 """ 1991 variant = ops.convert_to_tensor( 1992 variant, name="variant", dtype=dtypes.variant) 1993 if (variant.shape.ndims is not None and input_ragged_rank is not None and 1994 output_ragged_rank != input_ragged_rank + variant.shape.ndims): 1995 raise ValueError( 1996 f"Argument `output_ragged_rank` ({output_ragged_rank}) must be equal " 1997 f"to `input_ragged_rank` + `variant.shape.ndims` " 1998 f"({input_ragged_rank} + {variant.shape.ndims}).") 1999 input_ragged_rank = -1 if input_ragged_rank is None else input_ragged_rank 2000 with ops.name_scope( 2001 name, "RaggedFromVariant", 2002 [variant, dtype, input_ragged_rank, output_ragged_rank]): 2003 result = gen_ragged_conversion_ops.ragged_tensor_from_variant( 2004 variant, input_ragged_rank, output_ragged_rank, dtype, 2005 row_splits_dtype, name) 2006 return cls.from_nested_row_splits( 2007 result.output_dense_values, 2008 result.output_nested_splits, 2009 validate=False) 2010 2011 def _to_variant(self, batched_input=False, name=None): 2012 """Converts this `RaggedTensor` into a `variant` Tensor. 2013 2014 If `batched_input` is `True`, then the `RaggedTensor` is unbatched along the 2015 zero-th dimension, each component `RaggedTensor` is encoded into a scalar 2016 `variant` Tensor, and these are stacked to return a 1-D `variant` Tensor. 2017 If `batched_input` is `False`, then the `RaggedTensor` is encoded as is and 2018 a scalar `variant` Tensor is returned. 2019 2020 Example: 2021 >>> rt = tf.ragged.constant([[[0]], [[1]], [[2]]]) 2022 >>> rt._to_variant().shape.as_list() 2023 [] 2024 >>> rt._to_variant(batched_input=True).shape.as_list() 2025 [3] 2026 2027 Args: 2028 batched_input: If `True`, the `RaggedTensor` is unbatched and converted to 2029 a `variant` vector. Set to `False` by default. 2030 name: A name prefix for the returned tensors (optional). 2031 2032 Returns: 2033 A `variant` Tensor that encodes this `RaggedTensor`. 2034 """ 2035 with ops.name_scope(name, "RaggedToVariant", [self, batched_input]): 2036 return gen_ragged_conversion_ops.ragged_tensor_to_variant( 2037 self.nested_row_splits, self.flat_values, batched_input, name) 2038 2039 #============================================================================= 2040 # String Encoding 2041 #============================================================================= 2042 def __repr__(self): 2043 if self._is_eager(): 2044 return "<tf.RaggedTensor %s>" % self.to_list() 2045 else: 2046 return "tf.RaggedTensor(values=%s, row_splits=%s)" % (self.values, 2047 self.row_splits) 2048 2049 #============================================================================= 2050 # Eager Execution Mode 2051 #============================================================================= 2052 2053 def numpy(self): 2054 """Returns a numpy `array` with the values for this `RaggedTensor`. 2055 2056 Requires that this `RaggedTensor` was constructed in eager execution mode. 2057 2058 Ragged dimensions are encoded using numpy `arrays` with `dtype=object` and 2059 `rank=1`, where each element is a single row. 2060 2061 #### Examples 2062 2063 In the following example, the value returned by `RaggedTensor.numpy()` 2064 contains three numpy `array` objects: one for each row (with `rank=1` and 2065 `dtype=int64`), and one to combine them (with `rank=1` and `dtype=object`): 2066 2067 >>> tf.ragged.constant([[1, 2, 3], [4, 5]], dtype=tf.int64).numpy() 2068 array([array([1, 2, 3]), array([4, 5])], dtype=object) 2069 2070 Uniform dimensions are encoded using multidimensional numpy `array`s. In 2071 the following example, the value returned by `RaggedTensor.numpy()` contains 2072 a single numpy `array` object, with `rank=2` and `dtype=int64`: 2073 2074 >>> tf.ragged.constant([[1, 2, 3], [4, 5, 6]], dtype=tf.int64).numpy() 2075 array([[1, 2, 3], [4, 5, 6]]) 2076 2077 Returns: 2078 A numpy `array`. 2079 """ 2080 if not self._is_eager(): 2081 raise ValueError("RaggedTensor.numpy() is only supported in eager mode.") 2082 values = self.values.numpy() 2083 splits = self.row_splits.numpy() 2084 rows = [values[splits[i]:splits[i + 1]] for i in range(len(splits) - 1)] 2085 if not rows: 2086 return np.zeros((0, 0) + values.shape[1:], dtype=values.dtype) 2087 # Note: if `rows` have ragged lengths, then they will be stored in a 2088 # np.ndarray with dtype=object and rank=1. If they have uniform lengths, 2089 # they will be combined into a single np.ndarray with dtype=row.dtype and 2090 # rank=row.rank+1. 2091 return np.array(rows) 2092 2093 def to_list(self): 2094 """Returns a nested Python `list` with the values for this `RaggedTensor`. 2095 2096 Requires that `rt` was constructed in eager execution mode. 2097 2098 Returns: 2099 A nested Python `list`. 2100 """ 2101 if not isinstance(self.row_splits, ops.EagerTensor): 2102 raise ValueError("to_list can only be used in eager mode.") 2103 row_splits = self.row_splits.numpy().tolist() 2104 values = self.values 2105 2106 if isinstance(values, RaggedTensor): 2107 return [ 2108 values[row_splits[i]:row_splits[i + 1]].to_list() 2109 for i in range(len(row_splits) - 1) 2110 ] 2111 else: 2112 # Convert values to a Python list. 2113 if hasattr(values, "numpy"): 2114 values_as_list = values.numpy().tolist() 2115 elif hasattr(values, "to_list"): 2116 values_as_list = values.to_list() 2117 else: 2118 raise ValueError("values must be convertible to a list") 2119 2120 return [ 2121 values_as_list[row_splits[i]:row_splits[i + 1]] 2122 for i in range(len(row_splits) - 1) 2123 ] 2124 2125 def _eager_value(self): 2126 """Returns a RaggedTensorValue for self. Requires self._is_eager()=true.""" 2127 value = self.flat_values.numpy() 2128 for row_splits in reversed(self.nested_row_splits): 2129 value = ragged_tensor_value.RaggedTensorValue(value, row_splits.numpy()) 2130 return value 2131 2132 def _is_eager(self): 2133 """Returns True if values & row_splits Tensors are all `EagerTensor`s.""" 2134 rt = self 2135 while isinstance(rt, RaggedTensor): 2136 if not isinstance(rt.row_splits, ops.EagerTensor): 2137 return False 2138 rt = rt.values 2139 return isinstance(rt, ops.EagerTensor) 2140 2141 #============================================================================= 2142 # Operators 2143 #============================================================================= 2144 # To avoid circular dependencies, we define stub methods for operators here, 2145 # and then override them when the ragged_operators module is imported. 2146 2147 def _overloaded_operator(name): # pylint: disable=no-self-argument 2148 2149 def stub(*args, **kwargs): 2150 del args, kwargs 2151 raise ValueError( 2152 f"You must import 'tensorflow.python.ops.ragged.ragged_ops' " 2153 f"before using RaggedTensor.{name}.") 2154 2155 return stub 2156 2157 __getitem__ = _overloaded_operator("__getitem__") 2158 __ge__ = _overloaded_operator("__ge__") 2159 __gt__ = _overloaded_operator("__gt__") 2160 __le__ = _overloaded_operator("__le__") 2161 __lt__ = _overloaded_operator("__lt__") 2162 __and__ = _overloaded_operator("__and__") 2163 __rand__ = _overloaded_operator("__rand__") 2164 __invert__ = _overloaded_operator("__invert__") 2165 __ror__ = _overloaded_operator("__ror__") 2166 __or__ = _overloaded_operator("__or__") 2167 __xor__ = _overloaded_operator("__xor__") 2168 __rxor__ = _overloaded_operator("__rxor__") 2169 __abs__ = _overloaded_operator("__abs__") 2170 __add__ = _overloaded_operator("__add__") 2171 __radd__ = _overloaded_operator("__radd__") 2172 __div__ = _overloaded_operator("__div__") 2173 __rdiv__ = _overloaded_operator("__rdiv__") 2174 __floordiv__ = _overloaded_operator("__floordiv__") 2175 __rfloordiv__ = _overloaded_operator("__rfloordiv__") 2176 __mod__ = _overloaded_operator("__mod__") 2177 __rmod__ = _overloaded_operator("__rmod__") 2178 __mul__ = _overloaded_operator("__mul__") 2179 __rmul__ = _overloaded_operator("__rmul__") 2180 __neg__ = _overloaded_operator("__neg__") 2181 __pow__ = _overloaded_operator("__pow__") 2182 __rpow__ = _overloaded_operator("__rpow__") 2183 __sub__ = _overloaded_operator("__sub__") 2184 __rsub__ = _overloaded_operator("__rsub__") 2185 __truediv__ = _overloaded_operator("__truediv__") 2186 __rtruediv__ = _overloaded_operator("__rtruediv__") 2187 del _overloaded_operator 2188 2189 #============================================================================= 2190 # Name Scope 2191 #============================================================================= 2192 2193 # This private function is used by ops.name_scope to ensure that all of the 2194 # input tensors for the scope belong to the same graph. Defining this means 2195 # that you may include `RaggedTensor` objects in the name_scope `values` 2196 # list. 2197 def _as_graph_element(self): 2198 """Convert `self` to a graph element.""" 2199 values = self.values 2200 while isinstance(values, RaggedTensor): 2201 values = values.values 2202 return values 2203 2204 #============================================================================= 2205 # Composite Tensor 2206 #============================================================================= 2207 2208 @property 2209 def _type_spec(self): 2210 return RaggedTensorSpec.from_value(self) 2211 2212 def _shape_invariant_to_type_spec(self, shape): 2213 return RaggedTensorSpec(shape, self.dtype, self.ragged_rank, 2214 self.row_splits.dtype) 2215 2216 def consumers(self): 2217 return self._consumers() 2218 2219 2220def is_ragged(value): 2221 """Returns true if `value` is a ragged tensor or ragged tensor value.""" 2222 return isinstance(value, 2223 (RaggedTensor, ragged_tensor_value.RaggedTensorValue)) 2224 2225 2226def match_row_splits_dtypes(*tensors, **kwargs): 2227 """Return a copy of `tensors` with row_splits all having the same dtype. 2228 2229 Args: 2230 *tensors: A list of Tensors or RaggedTensors. 2231 **kwargs: If 'return_dtype=True', then return a tuple (dtype, tensors), 2232 where `dtype` is the data type used by row-splits, and `tensors` is the 2233 converted list of `Tensors` and `RaggedTensors`. 2234 2235 Returns: 2236 The converted list of `Tensors` and `RaggedTensors`. 2237 """ 2238 return_dtype = kwargs.pop("return_dtype", False) 2239 if kwargs: 2240 raise ValueError(f"Unexpected keyword args {kwargs}.") 2241 2242 has_int32 = False 2243 has_int64 = False 2244 for tensor in tensors: 2245 if isinstance(tensor, RaggedTensor): 2246 if tensor.row_splits.dtype == dtypes.int32: 2247 has_int32 = True 2248 else: 2249 has_int64 = True 2250 2251 if has_int32 and has_int64: 2252 if not ragged_config.auto_cast_partition_dtype(): 2253 raise ValueError("Input RaggedTensors have mismatched row_splits dtypes; " 2254 "use RaggedTensor.with_row_splits_dtype() to convert " 2255 "them to compatible dtypes.") 2256 dtype = dtypes.int64 2257 tensors = tuple( 2258 t.with_row_splits_dtype(dtypes.int64) if isinstance(t, RaggedTensor 2259 ) else t 2260 for t in tensors) 2261 2262 elif has_int32: 2263 dtype = dtypes.int32 2264 else: 2265 dtype = dtypes.int64 2266 2267 if return_dtype: 2268 return (dtype, tensors) 2269 else: 2270 return tensors 2271 2272 2273#=============================================================================== 2274# RaggedTensorSpec 2275#=============================================================================== 2276@tf_export("RaggedTensorSpec") 2277@type_spec.register("tf.RaggedTensorSpec") 2278class RaggedTensorSpec(type_spec.BatchableTypeSpec): 2279 """Type specification for a `tf.RaggedTensor`.""" 2280 2281 __slots__ = [ 2282 "_shape", "_dtype", "_ragged_rank", "_row_splits_dtype", 2283 "_flat_values_spec" 2284 ] 2285 2286 @property 2287 def dtype(self): 2288 """The `tf.dtypes.DType` specified by this type for the RaggedTensor. 2289 2290 Examples: 2291 2292 >>> rt = tf.ragged.constant([["a"], ["b", "c"]], dtype=tf.string) 2293 >>> tf.type_spec_from_value(rt).dtype 2294 tf.string 2295 2296 Returns: 2297 A `tf.dtypes.DType` of the values in the RaggedTensor. 2298 """ 2299 return self._dtype 2300 2301 @property 2302 def shape(self): 2303 """The statically known shape of the RaggedTensor. 2304 2305 Examples: 2306 2307 >>> rt = tf.ragged.constant([[0], [1, 2]]) 2308 >>> tf.type_spec_from_value(rt).shape 2309 TensorShape([2, None]) 2310 2311 >>> rt = tf.ragged.constant([[[0, 1]], [[1, 2], [3, 4]]], ragged_rank=1) 2312 >>> tf.type_spec_from_value(rt).shape 2313 TensorShape([2, None, 2]) 2314 2315 Returns: 2316 A `tf.TensorShape` containing the statically known shape of the 2317 RaggedTensor. Ragged dimensions have a size of `None`. 2318 """ 2319 return self._shape 2320 2321 @property 2322 def ragged_rank(self): 2323 """The number of times the RaggedTensor's flat_values is partitioned. 2324 2325 Defaults to `shape.ndims - 1`. 2326 2327 Examples: 2328 2329 >>> values = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]]) 2330 >>> tf.type_spec_from_value(values).ragged_rank 2331 1 2332 2333 >>> rt1 = tf.RaggedTensor.from_uniform_row_length(values, 2) 2334 >>> tf.type_spec_from_value(rt1).ragged_rank 2335 2 2336 2337 Returns: 2338 A Python `int` indicating the number of times the underlying `flat_values` 2339 Tensor has been partitioned to add a new dimension. 2340 I.e., `tf.rank(rt) = tf.rank(rt.flat_values) + rt.ragged_rank`. 2341 """ 2342 return self._ragged_rank 2343 2344 @property 2345 def row_splits_dtype(self): 2346 """The `tf.dtypes.DType` of the RaggedTensor's `row_splits`. 2347 2348 Examples: 2349 2350 >>> rt = tf.ragged.constant([[1, 2, 3], [4]], row_splits_dtype=tf.int64) 2351 >>> tf.type_spec_from_value(rt).row_splits_dtype 2352 tf.int64 2353 2354 Returns: 2355 A `tf.dtypes.DType` for the RaggedTensor's `row_splits` tensor. One 2356 of `tf.int32` or `tf.int64`. 2357 """ 2358 return self._row_splits_dtype 2359 2360 @property 2361 def flat_values_spec(self): 2362 """The `TypeSpec` of the flat_values of RaggedTensor. 2363 2364 Returns: 2365 - The TypeSpec of flat_values. 2366 - None when the flat_values is a Tensor. 2367 """ 2368 return self._flat_values_spec 2369 2370 @property 2371 def value_type(self): 2372 return RaggedTensor if self._ragged_rank > 0 else ops.Tensor 2373 2374 def __init__(self, 2375 shape=None, 2376 dtype=dtypes.float32, 2377 ragged_rank=None, 2378 row_splits_dtype=dtypes.int64, 2379 flat_values_spec=None): 2380 """Constructs a type specification for a `tf.RaggedTensor`. 2381 2382 Args: 2383 shape: The shape of the RaggedTensor, or `None` to allow any shape. If a 2384 shape is specified, then all ragged dimensions must have size `None`. 2385 dtype: `tf.DType` of values in the RaggedTensor. 2386 ragged_rank: Python integer, the number of times the RaggedTensor's 2387 flat_values is partitioned. Defaults to `shape.ndims - 1`. 2388 row_splits_dtype: `dtype` for the RaggedTensor's `row_splits` tensor. One 2389 of `tf.int32` or `tf.int64`. 2390 flat_values_spec: TypeSpec for flat_value of the RaggedTensor. It shall be 2391 provided when the flat_values is a CompositeTensor rather then Tensor. 2392 If both `dtype` and `flat_values_spec` and are provided, `dtype` must 2393 be the same as `flat_values_spec.dtype`. (experimental) 2394 """ 2395 self._shape = tensor_shape.as_shape(shape) 2396 self._row_splits_dtype = dtypes.as_dtype(row_splits_dtype) 2397 if flat_values_spec is not None: 2398 if dtype is None: 2399 dtype = flat_values_spec.dtype 2400 elif dtype != flat_values_spec.dtype: 2401 raise ValueError("dtype must be the same as flat_values_spec.dtype") 2402 elif dtype is None: 2403 raise ValueError( 2404 "At least one of dtype or flat_values_spec must be provided") 2405 self._dtype = dtypes.as_dtype(dtype) 2406 self._flat_values_spec = flat_values_spec 2407 2408 rank = self._shape.ndims 2409 if ragged_rank is None: 2410 if rank is None: 2411 raise ValueError("Must specify ragged_rank or " 2412 "a shape with a known rank.") 2413 ragged_rank = rank - 1 2414 self._ragged_rank = ragged_rank 2415 if not isinstance(self._ragged_rank, int): 2416 raise TypeError(f"Argument `ragged_rank` must be an int. " 2417 f"Recieved {ragged_rank}.") 2418 2419 if rank is not None: 2420 if ragged_rank >= rank: 2421 raise ValueError(f"Argument `ragged_rank` ({ragged_rank}) must be less " 2422 f"than rank ({rank}).") 2423 2424 def is_compatible_with(self, spec_or_value): 2425 # RaggedTensor with ragged_rank 0 can be compatible with raw flat_values. 2426 if self._ragged_rank == 0: 2427 if self._flat_values_spec is None: 2428 if isinstance(spec_or_value, (ops.Tensor, tensor_spec.TensorSpec)): 2429 return tensor_spec.TensorSpec( 2430 self._shape, self._dtype).is_compatible_with(spec_or_value) 2431 elif not isinstance(spec_or_value, (RaggedTensor, RaggedTensorSpec)): 2432 return self._flat_values_spec.is_compatible_with(spec_or_value) 2433 return super(RaggedTensorSpec, self).is_compatible_with(spec_or_value) 2434 2435 def _serialize(self): 2436 if self._flat_values_spec is None: 2437 return (self._shape, self._dtype, self._ragged_rank, 2438 self._row_splits_dtype) 2439 else: 2440 return (self._shape, self._dtype, self._ragged_rank, 2441 self._row_splits_dtype, self._flat_values_spec) 2442 2443 @property 2444 def _component_specs(self): 2445 if self._ragged_rank == 0: 2446 if self._flat_values_spec is not None: 2447 return [self._flat_values_spec] 2448 else: 2449 return [tensor_spec.TensorSpec(self._shape, self._dtype)] 2450 2451 flat_values_spec = self._flat_values_spec 2452 if flat_values_spec is None: 2453 flat_values_shape = tensor_shape.TensorShape([None]).concatenate( 2454 self._shape[self._ragged_rank + 1:]) 2455 flat_values_spec = tensor_spec.TensorSpec(flat_values_shape, self._dtype) 2456 outer_dim = tensor_shape.dimension_at_index(self._shape, 0) 2457 outer_splits_shape = [None if outer_dim is None else outer_dim + 1] 2458 inner_splits_spec = tensor_spec.TensorSpec([None], self._row_splits_dtype) 2459 2460 specs = ([ 2461 flat_values_spec, 2462 tensor_spec.TensorSpec(outer_splits_shape, self._row_splits_dtype) 2463 ] + [inner_splits_spec for _ in range(self._ragged_rank - 1)]) 2464 return specs 2465 2466 def _to_components(self, value): 2467 if is_ragged(value): 2468 return [value.flat_values] + list(value.nested_row_splits) 2469 else: 2470 return [value] 2471 2472 def _from_components(self, tensor_list): 2473 result = tensor_list[0] 2474 if (all(isinstance(t, np.ndarray) for t in tensor_list) and 2475 not tf2.enabled()): 2476 for row_splits in reversed(tensor_list[1:]): 2477 result = ragged_tensor_value.RaggedTensorValue(result, row_splits) 2478 else: 2479 if isinstance(tensor_list[0], np.ndarray): 2480 tensor_list = [ops.convert_to_tensor(t) for t in tensor_list] 2481 result = tensor_list[0] 2482 for row_splits in reversed(tensor_list[1:]): 2483 result = RaggedTensor( 2484 result, 2485 RowPartition.from_row_splits(row_splits, validate=False), 2486 internal=True) 2487 return result 2488 2489 # The RaggedTensorSpec tensor_list encoding uses to/from_variant ops 2490 # to (un)box the component tensors in a way that allows for batching & 2491 # unbatching. 2492 @property 2493 def _flat_tensor_specs(self): 2494 # NOTE(mishragaurav): The default flat shape of a boxed `RaggedTensor` is 2495 # `[]` (scalar), but a `RaggedTensorSpec` can also represent a batch of 2496 # boxed `RaggedTensor` objects with shape `(...)` (and batches of batches, 2497 # etc.), so the flat shape must be unknown. 2498 return [tensor_spec.TensorSpec(None, dtypes.variant)] 2499 2500 def _to_tensor_list(self, value): 2501 # TODO(edloper): Update gen_ragged_conversion_ops that convert to and 2502 # from variant to include all of the row-partitioning tensors. 2503 if self._flat_values_spec is not None: 2504 raise ValueError("Customized value_type is not supported.") 2505 ragged_rank = value.ragged_rank if isinstance(value, RaggedTensor) else 0 2506 if ragged_rank != self._ragged_rank: 2507 raise ValueError(f"Ragged rank of value {ragged_rank} does not match " 2508 f"ragged rank of type {self._ragged_rank}.") 2509 if ragged_rank == 0: 2510 return [ 2511 gen_ragged_conversion_ops.ragged_tensor_to_variant( 2512 (), value, batched_input=False) 2513 ] 2514 # pylint: disable=protected-access 2515 return [value._to_variant(batched_input=False)] 2516 2517 def _to_batched_tensor_list(self, value): 2518 if self._flat_values_spec is not None: 2519 raise ValueError("Customized value_type is not supported.") 2520 ragged_rank = value.ragged_rank if isinstance(value, RaggedTensor) else 0 2521 if ragged_rank != self._ragged_rank: 2522 raise ValueError(f"Ragged rank of value {ragged_rank} does not match " 2523 f"ragged rank of type {self._ragged_rank}.") 2524 if ragged_rank == 0: 2525 # TODO(b/141789000) Update this to handle ragged_rank=0. 2526 raise ValueError( 2527 "_to_batched_tensor_list doesn't support ragged_rank=0 yet") 2528 # pylint: disable=protected-access 2529 return [value._to_variant(batched_input=True)] 2530 2531 def _from_compatible_tensor_list(self, tensor_list): 2532 if self._flat_values_spec is not None: 2533 raise ValueError("Customized value_type is not supported.") 2534 if self._ragged_rank < 0: 2535 raise ValueError(f"Argument `ragged_rank` must be non-negative. " 2536 f"Received {self._ragged_rank}.") 2537 result = RaggedTensor._from_variant( # pylint: disable=protected-access 2538 tensor_list[0], 2539 dtype=self._dtype, 2540 row_splits_dtype=self._row_splits_dtype, 2541 output_ragged_rank=self._ragged_rank) 2542 if self._shape.ndims is not None: 2543 if isinstance(result, RaggedTensor): 2544 outer_dim = tensor_shape.dimension_value(self._shape[0]) 2545 if outer_dim is not None: 2546 result.row_splits.set_shape([outer_dim + 1]) 2547 result._set_shape(self._shape) # pylint: disable=protected-access 2548 else: 2549 result.set_shape(self._shape) 2550 return result 2551 2552 def _batch(self, batch_size): 2553 if self._flat_values_spec is not None: 2554 raise ValueError("Customized value_type is not supported.") 2555 return RaggedTensorSpec( 2556 tensor_shape.TensorShape([batch_size]).concatenate(self._shape), 2557 self._dtype, self._ragged_rank + 1, self._row_splits_dtype) 2558 2559 def _unbatch(self): 2560 if self._flat_values_spec is not None: 2561 raise ValueError("Customized value_type is not supported.") 2562 # Note: Negative ragged_rank is allowed here because the dataset could be 2563 # subsequently batched again. If ragged_rank > 1, assume row_splits_dtype is 2564 # consistent. Errors are handled in 2565 # RaggedTensorSpec._from_compatible_tensor_list() 2566 return RaggedTensorSpec(self._shape[1:], self._dtype, self._ragged_rank - 1, 2567 self._row_splits_dtype) 2568 2569 def _to_legacy_output_types(self): 2570 return self._dtype 2571 2572 def _to_legacy_output_shapes(self): 2573 return self._shape 2574 2575 def _to_legacy_output_classes(self): 2576 return self 2577 2578 @classmethod 2579 def from_value(cls, value): 2580 if (isinstance(value, ragged_tensor_value.RaggedTensorValue) or 2581 isinstance(value.flat_values, ops.Tensor)): 2582 return cls( 2583 shape=value.shape, 2584 dtype=value.values.dtype, 2585 ragged_rank=value.ragged_rank, 2586 row_splits_dtype=value.row_splits.dtype) 2587 else: 2588 return cls( 2589 shape=value.shape, 2590 dtype=value.values.dtype, 2591 ragged_rank=value.ragged_rank, 2592 row_splits_dtype=value.row_splits.dtype, 2593 flat_values_spec=type_spec.type_spec_from_value(value.flat_values)) 2594 2595 2596type_spec.register_type_spec_from_value_converter( 2597 ragged_tensor_value.RaggedTensorValue, RaggedTensorSpec.from_value) 2598 2599 2600#=============================================================================== 2601# Convert value -> tensor 2602#=============================================================================== 2603def convert_to_tensor_or_ragged_tensor(value, 2604 dtype=None, 2605 preferred_dtype=None, 2606 name=None): 2607 """Converts value to a `RaggedTensor` or `Tensor`. 2608 2609 * If `value` is a `RaggedTensor`, then return it as-is. 2610 * If `value` is a `RaggedTensorValue`, return a corresponding constant 2611 `RaggedTensor`. 2612 * Otherwise, use `convert_to_tensor` to convert `value` to a `Tensor`. 2613 2614 Args: 2615 value: A `RaggedTensor`, a `RaggedTensorValue`, or an object whose type has 2616 a registered `Tensor` conversion function. 2617 dtype: Optional element type for the returned tensor. If missing the type 2618 is inferred from the type of `value`. 2619 preferred_dtype: Optional element type for the returned tensor, used when 2620 dtype is None. This argument has no effect if `value` is already a 2621 tensor, or when conversion is not possible. 2622 name: Optional name to use if a new `Tensor` is created. 2623 2624 Returns: 2625 A `Tensor` or `RaggedTensor`. 2626 """ 2627 if isinstance(value, RaggedTensor): 2628 if dtype and not dtype.is_compatible_with(value.dtype): 2629 raise ValueError(f"Tensor conversion requested dtype {dtype.name} for " 2630 f"RaggedTensor with dtype {value.dtype.name}: {value}.") 2631 return value 2632 elif isinstance(value, ragged_tensor_value.RaggedTensorValue): 2633 with ops.name_scope(name, "ConvertToTensorOrRaggedTensor", []): 2634 flat_values = ops.convert_to_tensor( 2635 value=value.flat_values, 2636 dtype=dtype, 2637 preferred_dtype=preferred_dtype, 2638 name="flat_values") 2639 return RaggedTensor.from_nested_row_splits( 2640 flat_values, value.nested_row_splits, validate=False) 2641 else: 2642 return ops.convert_to_tensor_v2_with_dispatch( 2643 value=value, dtype=dtype, dtype_hint=preferred_dtype, name=name) 2644 2645 2646def _convert_to_ragged_tensor_values(value): 2647 """Converts value to supported RaggedTensor value. 2648 2649 * If `value` is an object of supported value type, then return it as-is. 2650 * Otherwise convert it to Tensor or RaggedTensor. 2651 2652 Args: 2653 value: An object of `Tensor`, `RaggedTensor` or registerred RaggedTensor 2654 value types, or an object whose type has a registered `Tensor` conversion 2655 function. 2656 2657 Returns: 2658 An object of `Tensor`, `RaggedTensor` or registerred RaggedTensor 2659 value types 2660 """ 2661 if _is_supported_ragged_values_type(value): 2662 return value 2663 else: 2664 return convert_to_tensor_or_ragged_tensor(value, name="values") 2665 2666 2667#=============================================================================== 2668# Register RaggedTensor for use with session.run. 2669#=============================================================================== 2670def _ragged_tensor_value_from_components(components): 2671 components = list(components) 2672 value = components.pop() 2673 while components: 2674 value = ragged_tensor_value.RaggedTensorValue(value, components.pop()) 2675 return value 2676 2677 2678def _ragged_tensor_session_fetch(rt): 2679 components = rt.nested_row_splits + (rt.flat_values,) 2680 return (components, _ragged_tensor_value_from_components) 2681 2682 2683def _ragged_tensor_session_feed(feed_key, feed_val): 2684 key_components = feed_key.nested_row_splits + (feed_key.flat_values,) 2685 val_components = feed_val.nested_row_splits + (feed_val.flat_values,) 2686 return zip(key_components, val_components) 2687 2688 2689def _ragged_tensor_session_feed_for_partial_run(feed_key): 2690 return feed_key.nested_row_splits + (feed_key.flat_values,) 2691 2692 2693session.register_session_run_conversion_functions( 2694 RaggedTensor, _ragged_tensor_session_fetch, _ragged_tensor_session_feed, 2695 _ragged_tensor_session_feed_for_partial_run) 2696 2697 2698#=============================================================================== 2699# RaggedTensorType 2700#=============================================================================== 2701class RaggedTensorType(object): 2702 """Encoding of a static type for a `RaggedTensor`. 2703 2704 Use this type to express/declare that an output must have the type of 2705 `RaggedTensor`. 2706 """ 2707 2708 def __init__(self, dtype, ragged_rank, row_splits_dtype=dtypes.int64): 2709 """Initializes a RaggedTensorType object. 2710 2711 Args: 2712 dtype: data type of the `RaggedTensor`'s inner values. 2713 ragged_rank: ragged_rank of the declared `RaggedTensor`. 2714 row_splits_dtype: data type for the `RaggedTensor`'s row splits. 2715 One of: `tf.int32` or `tf.int64`. 2716 """ 2717 row_splits_dtype = dtypes.as_dtype(row_splits_dtype) 2718 self._dtype = dtype 2719 self._ragged_rank = ragged_rank 2720 self._row_splits_dtype = row_splits_dtype 2721 2722 dtype = property(lambda self: self._dtype) 2723 ragged_rank = property(lambda self: self._ragged_rank) 2724 row_splits_dtype = property(lambda self: self._row_splits_dtype) 2725 2726 def __repr__(self): 2727 return "RaggedTensorType(%r, %r, %r)" % (self.dtype, self.ragged_rank, 2728 self.row_splits_dtype) 2729 2730 2731#=============================================================================== 2732# Helper Functions 2733#=============================================================================== 2734def _assert_sparse_indices_are_ragged_right(indices): 2735 """Checks that the given SparseTensor.indices tensor is ragged-right. 2736 2737 Example: `indices = [[0, 0], [0, 1], [2, 0], [3, 1]]` is not ragged right 2738 because the entry `[3, 1]` skips a cell. 2739 2740 Args: 2741 indices: The SparseTensor indices to check. 2742 2743 Returns: 2744 A list of control dependency op tensors. 2745 """ 2746 index_prefix = indices[:, :-1] 2747 index_suffix = indices[:, -1] 2748 2749 # Check whether each index is starting a new row in the innermost dimension 2750 # (prefix[i] != prefix[i-1]) or continuing a row (prefix[i] == prefix[i-1]). 2751 # (Note: this skips the first index; we will check that separately below.) 2752 index_prefix_changed = math_ops.reduce_any( 2753 math_ops.not_equal(index_prefix[1:], index_prefix[:-1]), axis=1) 2754 2755 # Check two cases: 2756 # * For indices that start a new row: index_suffix[i] must be zero. 2757 # * For indices that continue a row: index_suffix[i] must be equal to 2758 # index_suffix[i-1]+1. 2759 index_ok = array_ops.where( 2760 index_prefix_changed, math_ops.equal(index_suffix[1:], 0), 2761 math_ops.equal(index_suffix[1:], index_suffix[:-1] + 1)) 2762 2763 # Also check that the very first index didn't skip any cells. The first 2764 # index starts a new row (by definition), so its suffix should be zero. 2765 sparse_indices_are_ragged_right = math_ops.logical_and( 2766 math_ops.reduce_all(math_ops.equal(index_suffix[:1], 0)), 2767 math_ops.reduce_all(index_ok)) 2768 2769 message = [ 2770 "SparseTensor is not right-ragged", "SparseTensor.indices =", indices 2771 ] 2772 return [control_flow_ops.Assert(sparse_indices_are_ragged_right, message)] 2773 2774 2775@ops.RegisterGradient("RaggedTensorToSparse") 2776def _ragged_tensor_to_sparse_gradient(op, unused_sparse_indices_grad, 2777 sparse_values_grad, 2778 unused_sparse_shape_grad): 2779 """Gradient for RaggedTensorToSparse.""" 2780 op_inputs_nested_row_splits = op.inputs[:-1] 2781 op_inputs_flat_values = op.inputs[-1] 2782 2783 # No gradient for the RaggedTensor's nested_row_splits. 2784 nested_row_splits_gradient = [None] * len(op_inputs_nested_row_splits) 2785 2786 # Gradient for the RaggedTensor's flat_values is formed by reshaping 2787 # the gradient for the SparseTensor's values. 2788 flat_values_shape = array_ops.shape(op_inputs_flat_values) 2789 flat_values_gradient = array_ops.reshape(sparse_values_grad, 2790 flat_values_shape) 2791 2792 return nested_row_splits_gradient + [flat_values_gradient] 2793 2794 2795def _assert_monotonic_increasing(tensor, message=None): 2796 return check_ops.assert_non_negative( 2797 tensor[1:] - tensor[:-1], message=message) 2798 2799 2800def _assert_zero(tensor, message=None): 2801 return check_ops.assert_equal( 2802 tensor, constant_op.constant(0, dtype=tensor.dtype), message=message) 2803 2804 2805def _nrows(tensor, out_type=dtypes.int32): 2806 if isinstance(tensor, RaggedTensor): 2807 return tensor.nrows(out_type=out_type) 2808 else: 2809 return array_ops.shape(tensor, out_type=out_type)[0] 2810 2811 2812def merge_dims(value, outer_axis, inner_axis): 2813 """Merges value[outer_axis...inner_axis] into a single dimension. 2814 2815 See `RaggedTensor.merge_dims()` for more details. This helper differs from 2816 `RaggedTensor.merge_dims()` in that `value` may be a dense or ragged tensor. 2817 2818 Args: 2819 value: A `RaggedTensor` or `Tensor` 2820 outer_axis: `int` 2821 inner_axis: `int` 2822 2823 Returns: 2824 A flattened `RaggedTensor` or `Tensor`. 2825 """ 2826 if outer_axis == inner_axis: 2827 return value 2828 2829 # Flatten outer dimensions of a RaggedTensor by just taking its values. 2830 while outer_axis == 0 and isinstance(value, RaggedTensor): 2831 value = value.values 2832 inner_axis -= 1 2833 if inner_axis == 0: 2834 return value 2835 2836 # Flatten non-Ragged tensors using tf.reshape(). 2837 if not isinstance(value, RaggedTensor): 2838 if value.shape.is_fully_defined(): 2839 old_shape = value.shape.as_list() 2840 new_shape = old_shape[:outer_axis] + [-1] + old_shape[inner_axis + 1:] 2841 else: 2842 old_shape = array_ops.shape(value) 2843 new_shape = array_ops.concat( 2844 [old_shape[:outer_axis], [-1], old_shape[inner_axis + 1:]], axis=0) 2845 return array_ops.reshape(value, new_shape) 2846 2847 # Handle outer_axis>1 via recursion. 2848 if outer_axis > 1: 2849 return value.with_values( 2850 merge_dims(value.values, outer_axis - 1, inner_axis - 1)) 2851 2852 # At this point, we know outer_axis == 1, and value is a RaggedTensor. 2853 # So we need to flatten the values and build a corresponding splits tensor. 2854 new_values = value.values 2855 new_splits = value.row_splits 2856 for axis in range(outer_axis, inner_axis): 2857 if isinstance(new_values, RaggedTensor): 2858 # Flatten a single ragged dimension. 2859 new_splits = array_ops.gather(new_values.row_splits, new_splits) 2860 new_values = new_values.values 2861 else: 2862 # Flatten all remaining dense dimensions. 2863 shape_split = inner_axis - axis + 1 2864 if new_values.shape.is_fully_defined(): 2865 old_shape = new_values.shape.as_list() 2866 new_shape = [-1] + old_shape[shape_split:] 2867 flat_size = _prod(old_shape[1:shape_split]) 2868 else: 2869 old_shape = array_ops.shape(new_values) 2870 new_shape = array_ops.concat([[-1], old_shape[shape_split:]], axis=0) 2871 flat_size = math_ops.cast( 2872 math_ops.reduce_prod(old_shape[1:shape_split]), new_splits.dtype) 2873 new_values = array_ops.reshape(new_values, new_shape) 2874 new_splits = new_splits * flat_size 2875 break 2876 return RaggedTensor.from_row_splits(new_values, new_splits) 2877 2878 2879def _prod(lst): 2880 """Returns the product of the numbers in a list.""" 2881 return functools.reduce(operator.mul, lst, 1) 2882 2883 2884def _get_row_partition_type_tensor_pairs_tail(partition): 2885 """Gets a row partition type tensor pair for the tail. 2886 2887 If value_rowid is defined, then it is used. Otherwise, row_splits 2888 are used. 2889 2890 Args: 2891 partition: a RowPartition. 2892 2893 Returns: 2894 A list of (row_partition_type, row_partition_tensor) pairs. 2895 """ 2896 if partition.has_precomputed_value_rowids(): 2897 return ("VALUE_ROWIDS", partition.value_rowids()) 2898 else: 2899 return ("ROW_SPLITS", partition.row_splits()) 2900 2901 2902def _get_row_partition_type_tensor_pairs(rt_input): 2903 """Gets a list of the row partitions for rt_input. 2904 2905 If value_rowids are defined, then they are used. Otherwise, row_splits 2906 are used. If the outermost level has value_rowids defind, then nrows is 2907 also added. 2908 2909 Args: 2910 rt_input: a ragged tensor. 2911 2912 Returns: 2913 A list of (row_partition_type, row_partition_tensor) pairs. 2914 """ 2915 partitions = rt_input._nested_row_partitions # pylint: disable=protected-access 2916 tail = [_get_row_partition_type_tensor_pairs_tail(x) for x in partitions[1:]] 2917 2918 if partitions[0]._value_rowids is not None: # pylint: disable=protected-access 2919 return [("FIRST_DIM_SIZE", partitions[0].nrows()), 2920 ("VALUE_ROWIDS", partitions[0].value_rowids())] + tail 2921 else: 2922 return [("ROW_SPLITS", partitions[0].row_splits())] + tail 2923 2924 2925def _shape_as_tensor(shape, dtype): 2926 """Takes shape and coerces it to a shape as a tensor. 2927 2928 If the object is already a tensor, simply passes it on (result is guaranteed 2929 to be int64 or int32, but not necessarily dtype). 2930 If not, creates a tensor of type dtype. 2931 2932 Result is either a scalar equal to -1 if the shape is unknown_rank. 2933 Otherwise, it is a vector, where unknown dimensions are represented with a 2934 value of -1. 2935 2936 In C++, see TensorShapeFromTensor for parsing shapes in kernels, and 2937 InferenceContext::MakeShapeFromShapeTensorTreatScalarAsUnknownShape, for 2938 use in the shape inference function. 2939 2940 Args: 2941 shape: input to coerce from TensorShape, Tensor, None, List[Optional[Int]], 2942 Tuple[Optional[Int]]. 2943 dtype: tf.int64 or tf.int32 2944 2945 Returns: 2946 a scalar or vector tensor of dtype tf.int32 or tf.int64. 2947 """ 2948 if dtype != dtypes.int64 and dtype != dtypes.int32: 2949 raise ValueError(f"Expected int64 or int32 for dtype: got {dtype}.") 2950 2951 if isinstance(shape, ops.Tensor): 2952 if shape.dtype != dtypes.int64 and shape.dtype != dtypes.int32: 2953 return math_ops.cast(shape, dtype) 2954 return shape 2955 shape = tensor_shape.as_shape(shape) 2956 if not shape: 2957 # Imply rank is unknown using a -1 scalar. 2958 return constant_op.constant(-1, dtype=dtype) 2959 shape = [(-1 if x is None else x) for x in shape.as_list()] 2960 # At this point, shape is List[Int]. 2961 return constant_op.constant(shape, dtype=dtype) 2962 2963 2964def _nvals_uniform_row_length(values, uniform_row_length): 2965 """Get the number of values for uniform row length constructor.""" 2966 const_nvals = tensor_shape.dimension_at_index(values.shape, 0).value 2967 if const_nvals is not None: 2968 nvals = constant_op.constant(const_nvals, uniform_row_length.dtype) 2969 elif isinstance(values, RaggedTensor): 2970 nvals = values.nrows(out_type=uniform_row_length.dtype) 2971 else: 2972 nvals = array_ops.shape(values, out_type=uniform_row_length.dtype)[0] 2973 return nvals 2974 2975 2976def _get_optional_partition_dtype(values): 2977 """Returns the partition dtype, or None if None exists.""" 2978 if isinstance(values, RaggedTensor): 2979 # pylint: disable=protected-access 2980 return values._row_partition.dtype 2981 return None 2982 2983 2984_SUPPORTED_RAGGED_VALUE_TYPES = (ops.Tensor, RaggedTensor) 2985 2986 2987# TODO(edloper): Consider whether we should change the registry to be on 2988# TypeSpecs rather than ValueTypes. 2989def _add_supported_value_type(cls): 2990 """Register the `cls` as supported value type of RaggedTenosr. 2991 2992 The cls must be a subclass of CompositeTensor, and must support: 2993 - Properties: 2994 - x.shape 2995 - x.dtype 2996 - Methods: 2997 - x.__getitem__(idx) (method: returns a supported value type) 2998 - Ops: 2999 - tf.shape(x) -- tf.shape(x)[0] must be a tf.Tensor. 3000 - tf.tile(x) 3001 - assert_rank_at_least(x) 3002 - tf.ones_like(x) 3003 - tf.gather(params=x, indices=Tensor) 3004 - tf.add(x, y) 3005 - tf.boolean_mask(x, ...) 3006 - @TODO(edloper): Complete this list 3007 3008 Note: the following RaggedTensor, RaggedTensorSpec methods & ops are not 3009 currently supported unless `rt.values` is a RaggedTensor or a tf.Tensor: 3010 - rt.to_tensor() 3011 - rt.to_sparse_tensor() 3012 - rt._to_variant() 3013 - rt._from_variant() 3014 - tf.ragged.cross([rt]) 3015 - tf.gather(params=x, indices=rt) # rt used for indices 3016 - RaggedTensorSpec methods: 3017 - _batch 3018 - _unbatch 3019 - _to_tensor_list 3020 - _to_batched_tensor_list 3021 - _from_compatible_tensor_list 3022 3023 Args: 3024 cls: The type to be added to supported value types. 3025 """ 3026 if not issubclass(cls, composite_tensor.CompositeTensor): 3027 raise ValueError(f"cls ({cls}) must be a subclass of CompositeTensor.") 3028 if not hasattr(cls, "shape"): 3029 raise ValueError("cls must support the `shape` property.") 3030 if not hasattr(cls, "dtype"): 3031 raise ValueError("cls must support the `dtype` property.") 3032 global _SUPPORTED_RAGGED_VALUE_TYPES 3033 _SUPPORTED_RAGGED_VALUE_TYPES += (cls,) 3034 3035 3036def _is_supported_ragged_values_type(value): 3037 return isinstance(value, _SUPPORTED_RAGGED_VALUE_TYPES) 3038 3039 3040def _assert_is_supported_ragged_values_type(value): 3041 if not _is_supported_ragged_values_type(value): 3042 ok_types = ", ".join(cls.__name__ for cls in _SUPPORTED_RAGGED_VALUE_TYPES) 3043 raise TypeError(f"type(values) must be one of: {ok_types}, got {value}.") 3044