1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Classes for storing ragged tensors and their values.""" 16 17import functools 18import operator 19 20import typing 21import numpy as np 22 23from tensorflow.python import tf2 24from tensorflow.python.client import session 25from tensorflow.python.framework import composite_tensor 26from tensorflow.python.framework import composite_tensor_gradient 27from tensorflow.python.framework import constant_op 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import sparse_tensor 31from tensorflow.python.framework import tensor_shape 32from tensorflow.python.framework import tensor_spec 33from tensorflow.python.framework import tensor_util 34from tensorflow.python.framework import type_spec 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import check_ops 37from tensorflow.python.ops import control_flow_ops 38from tensorflow.python.ops import gen_ragged_conversion_ops 39from tensorflow.python.ops import math_ops 40from tensorflow.python.ops.ragged import ragged_config 41from tensorflow.python.ops.ragged import ragged_tensor_value 42from tensorflow.python.ops.ragged import ragged_util 43from tensorflow.python.ops.ragged.row_partition import RowPartition 44from tensorflow.python.types import core as core_types 45from tensorflow.python.types import internal as internal_types 46from tensorflow.python.util import dispatch 47from tensorflow.python.util.tf_export import tf_export 48from tensorflow.tools.docs import doc_controls 49 50# pylint: disable=protected-access 51_convert_row_partition = RowPartition._convert_row_partition 52# pylint: enable=protected-access 53 54#=============================================================================== 55# RaggedTensor 56#=============================================================================== 57 58 59@tf_export("RaggedTensor") 60class RaggedTensor(composite_tensor.CompositeTensor, 61 internal_types.NativeObject): 62 """Represents a ragged tensor. 63 64 A `RaggedTensor` is a tensor with one or more *ragged dimensions*, which are 65 dimensions whose slices may have different lengths. For example, the inner 66 (column) dimension of `rt=[[3, 1, 4, 1], [], [5, 9, 2], [6], []]` is ragged, 67 since the column slices (`rt[0, :]`, ..., `rt[4, :]`) have different lengths. 68 Dimensions whose slices all have the same length are called *uniform 69 dimensions*. The outermost dimension of a `RaggedTensor` is always uniform, 70 since it consists of a single slice (and so there is no possibility for 71 differing slice lengths). 72 73 The total number of dimensions in a `RaggedTensor` is called its *rank*, 74 and the number of ragged dimensions in a `RaggedTensor` is called its 75 *ragged-rank*. A `RaggedTensor`'s ragged-rank is fixed at graph creation 76 time: it can't depend on the runtime values of `Tensor`s, and can't vary 77 dynamically for different session runs. 78 79 Note that the `__init__` constructor is private. Please use one of the 80 following methods to construct a `RaggedTensor`: 81 82 * `tf.RaggedTensor.from_row_lengths` 83 * `tf.RaggedTensor.from_value_rowids` 84 * `tf.RaggedTensor.from_row_splits` 85 * `tf.RaggedTensor.from_row_starts` 86 * `tf.RaggedTensor.from_row_limits` 87 * `tf.RaggedTensor.from_nested_row_splits` 88 * `tf.RaggedTensor.from_nested_row_lengths` 89 * `tf.RaggedTensor.from_nested_value_rowids` 90 91 ### Potentially Ragged Tensors 92 93 Many ops support both `Tensor`s and `RaggedTensor`s 94 (see [tf.ragged](https://www.tensorflow.org/api_docs/python/tf/ragged) for a 95 full listing). The term "potentially ragged tensor" may be used to refer to a 96 tensor that might be either a `Tensor` or a `RaggedTensor`. The ragged-rank 97 of a `Tensor` is zero. 98 99 ### Documenting RaggedTensor Shapes 100 101 When documenting the shape of a RaggedTensor, ragged dimensions can be 102 indicated by enclosing them in parentheses. For example, the shape of 103 a 3-D `RaggedTensor` that stores the fixed-size word embedding for each 104 word in a sentence, for each sentence in a batch, could be written as 105 `[num_sentences, (num_words), embedding_size]`. The parentheses around 106 `(num_words)` indicate that dimension is ragged, and that the length 107 of each element list in that dimension may vary for each item. 108 109 ### Component Tensors 110 111 Internally, a `RaggedTensor` consists of a concatenated list of values that 112 are partitioned into variable-length rows. In particular, each `RaggedTensor` 113 consists of: 114 115 * A `values` tensor, which concatenates the variable-length rows into a 116 flattened list. For example, the `values` tensor for 117 `[[3, 1, 4, 1], [], [5, 9, 2], [6], []]` is `[3, 1, 4, 1, 5, 9, 2, 6]`. 118 119 * A `row_splits` vector, which indicates how those flattened values are 120 divided into rows. In particular, the values for row `rt[i]` are stored 121 in the slice `rt.values[rt.row_splits[i]:rt.row_splits[i+1]]`. 122 123 Example: 124 125 >>> print(tf.RaggedTensor.from_row_splits( 126 ... values=[3, 1, 4, 1, 5, 9, 2, 6], 127 ... row_splits=[0, 4, 4, 7, 8, 8])) 128 <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]> 129 130 ### Alternative Row-Partitioning Schemes 131 132 In addition to `row_splits`, ragged tensors provide support for five other 133 row-partitioning schemes: 134 135 * `row_lengths`: a vector with shape `[nrows]`, which specifies the length 136 of each row. 137 138 * `value_rowids` and `nrows`: `value_rowids` is a vector with shape 139 `[nvals]`, corresponding one-to-one with `values`, which specifies 140 each value's row index. In particular, the row `rt[row]` consists of the 141 values `rt.values[j]` where `value_rowids[j]==row`. `nrows` is an 142 integer scalar that specifies the number of rows in the 143 `RaggedTensor`. (`nrows` is used to indicate trailing empty rows.) 144 145 * `row_starts`: a vector with shape `[nrows]`, which specifies the start 146 offset of each row. Equivalent to `row_splits[:-1]`. 147 148 * `row_limits`: a vector with shape `[nrows]`, which specifies the stop 149 offset of each row. Equivalent to `row_splits[1:]`. 150 151 * `uniform_row_length`: A scalar tensor, specifying the length of every 152 row. This row-partitioning scheme may only be used if all rows have 153 the same length. 154 155 Example: The following ragged tensors are equivalent, and all represent the 156 nested list `[[3, 1, 4, 1], [], [5, 9, 2], [6], []]`. 157 158 >>> values = [3, 1, 4, 1, 5, 9, 2, 6] 159 >>> RaggedTensor.from_row_splits(values, row_splits=[0, 4, 4, 7, 8, 8]) 160 <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]> 161 >>> RaggedTensor.from_row_lengths(values, row_lengths=[4, 0, 3, 1, 0]) 162 <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]> 163 >>> RaggedTensor.from_value_rowids( 164 ... values, value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], nrows=5) 165 <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]> 166 >>> RaggedTensor.from_row_starts(values, row_starts=[0, 4, 4, 7, 8]) 167 <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]> 168 >>> RaggedTensor.from_row_limits(values, row_limits=[4, 4, 7, 8, 8]) 169 <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]> 170 >>> RaggedTensor.from_uniform_row_length(values, uniform_row_length=2) 171 <tf.RaggedTensor [[3, 1], [4, 1], [5, 9], [2, 6]]> 172 173 ### Multiple Ragged Dimensions 174 175 `RaggedTensor`s with multiple ragged dimensions can be defined by using 176 a nested `RaggedTensor` for the `values` tensor. Each nested `RaggedTensor` 177 adds a single ragged dimension. 178 179 >>> inner_rt = RaggedTensor.from_row_splits( # =rt1 from above 180 ... values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8]) 181 >>> outer_rt = RaggedTensor.from_row_splits( 182 ... values=inner_rt, row_splits=[0, 3, 3, 5]) 183 >>> print(outer_rt.to_list()) 184 [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]] 185 >>> print(outer_rt.ragged_rank) 186 2 187 188 The factory function `RaggedTensor.from_nested_row_splits` may be used to 189 construct a `RaggedTensor` with multiple ragged dimensions directly, by 190 providing a list of `row_splits` tensors: 191 192 >>> RaggedTensor.from_nested_row_splits( 193 ... flat_values=[3, 1, 4, 1, 5, 9, 2, 6], 194 ... nested_row_splits=([0, 3, 3, 5], [0, 4, 4, 7, 8, 8])).to_list() 195 [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]] 196 197 ### Uniform Inner Dimensions 198 199 `RaggedTensor`s with uniform inner dimensions can be defined 200 by using a multidimensional `Tensor` for `values`. 201 202 >>> rt = RaggedTensor.from_row_splits(values=tf.ones([5, 3], tf.int32), 203 ... row_splits=[0, 2, 5]) 204 >>> print(rt.to_list()) 205 [[[1, 1, 1], [1, 1, 1]], 206 [[1, 1, 1], [1, 1, 1], [1, 1, 1]]] 207 >>> print(rt.shape) 208 (2, None, 3) 209 210 ### Uniform Outer Dimensions 211 212 `RaggedTensor`s with uniform outer dimensions can be defined by using 213 one or more `RaggedTensor` with a `uniform_row_length` row-partitioning 214 tensor. For example, a `RaggedTensor` with shape `[2, 2, None]` can be 215 constructed with this method from a `RaggedTensor` values with shape 216 `[4, None]`: 217 218 >>> values = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]]) 219 >>> print(values.shape) 220 (4, None) 221 >>> rt6 = tf.RaggedTensor.from_uniform_row_length(values, 2) 222 >>> print(rt6) 223 <tf.RaggedTensor [[[1, 2, 3], [4]], [[5, 6], [7, 8, 9, 10]]]> 224 >>> print(rt6.shape) 225 (2, 2, None) 226 227 Note that `rt6` only contains one ragged dimension (the innermost 228 dimension). In contrast, if `from_row_splits` is used to construct a similar 229 `RaggedTensor`, then that `RaggedTensor` will have two ragged dimensions: 230 231 >>> rt7 = tf.RaggedTensor.from_row_splits(values, [0, 2, 4]) 232 >>> print(rt7.shape) 233 (2, None, None) 234 235 Uniform and ragged outer dimensions may be interleaved, meaning that a 236 tensor with any combination of ragged and uniform dimensions may be created. 237 For example, a RaggedTensor `t4` with shape `[3, None, 4, 8, None, 2]` could 238 be constructed as follows: 239 240 ```python 241 t0 = tf.zeros([1000, 2]) # Shape: [1000, 2] 242 t1 = RaggedTensor.from_row_lengths(t0, [...]) # [160, None, 2] 243 t2 = RaggedTensor.from_uniform_row_length(t1, 8) # [20, 8, None, 2] 244 t3 = RaggedTensor.from_uniform_row_length(t2, 4) # [5, 4, 8, None, 2] 245 t4 = RaggedTensor.from_row_lengths(t3, [...]) # [3, None, 4, 8, None, 2] 246 ``` 247 248 """ 249 250 #============================================================================= 251 # Constructor (private) 252 #============================================================================= 253 @doc_controls.do_not_generate_docs 254 def __init__(self, values, row_partition, internal=False): 255 """Creates a `RaggedTensor` with a specified partitioning for `values`. 256 257 This constructor is private -- please use one of the following ops to 258 build `RaggedTensor`s: 259 260 * `tf.RaggedTensor.from_row_lengths` 261 * `tf.RaggedTensor.from_value_rowids` 262 * `tf.RaggedTensor.from_row_splits` 263 * `tf.RaggedTensor.from_row_starts` 264 * `tf.RaggedTensor.from_row_limits` 265 * `tf.RaggedTensor.from_nested_row_splits` 266 * `tf.RaggedTensor.from_nested_row_lengths` 267 * `tf.RaggedTensor.from_nested_value_rowids` 268 269 Args: 270 values: A potentially ragged tensor of any dtype and shape `[nvals, ...]`. 271 row_partition: A `RowPartition` object, representing the arrangement of 272 the lists at the top level. 273 internal: True if the constructor is being called by one of the factory 274 methods. If false, an exception will be raised. 275 276 Raises: 277 ValueError: If internal = False. Note that this method is intended only 278 for internal use. 279 TypeError: If values is not a `RaggedTensor` or `Tensor`, or 280 row_partition is not a `RowPartition`. 281 """ 282 283 if not internal: 284 raise ValueError("RaggedTensor constructor is private; please use one " 285 "of the factory methods instead (e.g., " 286 "RaggedTensor.from_row_lengths())") 287 _assert_is_supported_ragged_values_type(values) 288 if not isinstance(row_partition, RowPartition): 289 raise TypeError(f"Argument `row_partition` must be a RowPartition. " 290 f"Received {row_partition}.") 291 292 # Validate shapes. 293 values.shape.with_rank_at_least(1) 294 if isinstance(values, RaggedTensor): 295 # pylint: disable=protected-access 296 assert row_partition.dtype == values._row_partition.dtype 297 298 self._values = values 299 self._row_partition = row_partition 300 301 #============================================================================= 302 # Factory Methods 303 #============================================================================= 304 305 @classmethod 306 def _from_row_partition(cls, values, row_partition, validate=True): 307 """Creates a `RaggedTensor` with a row partition. 308 309 This is used as a way for RaggedTensors to share row partitions. 310 311 The outer dimension of values must be equal to `partition.nvals()`. 312 313 Args: 314 values: A potentially ragged tensor. 315 row_partition: a `RowPartition`: can be shared between tensors. 316 validate: If true, then use assertions to check that the arguments form a 317 valid `RaggedTensor`. 318 319 Returns: 320 A `RaggedTensor`. `result.rank = values.rank + 1`. 321 `result.ragged_rank = values.ragged_rank + 1`. 322 323 Raises: 324 ValueError: If partition.nvals() != _nrows(values) 325 """ 326 if not isinstance(row_partition, RowPartition): 327 raise TypeError(f"Argument `row_partition` must be a RowPartition. " 328 f"Received {row_partition}.") 329 if not isinstance(validate, bool): 330 raise TypeError(f"Argument `validate` must have type bool. " 331 f"Received {validate}.") 332 values, row_partition = cls._convert_values_and_partition( 333 values, row_partition, "partition") 334 if row_partition._has_precomputed_value_rowids(): # pylint: disable=protected-access 335 value_rowids_shape = row_partition.value_rowids().shape 336 values.shape[:1].assert_is_compatible_with(value_rowids_shape) 337 if validate: 338 msg = "Arguments to _from_row_partition do not form a valid RaggedTensor" 339 nvals = _nrows(values, row_partition.dtype) 340 checks = [ 341 check_ops.assert_equal( 342 math_ops.cast(row_partition.nvals(), row_partition.dtype), 343 nvals, 344 message=msg), 345 ] 346 if not isinstance(values, RaggedTensor): 347 checks.append(check_ops.assert_rank_at_least(values, 1)) 348 row_partition = row_partition._with_dependencies(checks) # pylint: disable=protected-access 349 return cls(values=values, internal=True, row_partition=row_partition) 350 351 @classmethod 352 @dispatch.add_dispatch_support 353 def from_value_rowids(cls, 354 values, 355 value_rowids, 356 nrows=None, 357 name=None, 358 validate=True): 359 """Creates a `RaggedTensor` with rows partitioned by `value_rowids`. 360 361 The returned `RaggedTensor` corresponds with the python list defined by: 362 363 ```python 364 result = [[values[i] for i in range(len(values)) if value_rowids[i] == row] 365 for row in range(nrows)] 366 ``` 367 368 Args: 369 values: A potentially ragged tensor with shape `[nvals, ...]`. 370 value_rowids: A 1-D integer tensor with shape `[nvals]`, which corresponds 371 one-to-one with `values`, and specifies each value's row index. Must be 372 nonnegative, and must be sorted in ascending order. 373 nrows: An integer scalar specifying the number of rows. This should be 374 specified if the `RaggedTensor` may containing empty training rows. Must 375 be greater than `value_rowids[-1]` (or zero if `value_rowids` is empty). 376 Defaults to `value_rowids[-1] + 1` (or zero if `value_rowids` is empty). 377 name: A name prefix for the RaggedTensor (optional). 378 validate: If true, then use assertions to check that the arguments form 379 a valid `RaggedTensor`. Note: these assertions incur a runtime cost, 380 since they must be checked for each tensor value. 381 382 Returns: 383 A `RaggedTensor`. `result.rank = values.rank + 1`. 384 `result.ragged_rank = values.ragged_rank + 1`. 385 386 Raises: 387 ValueError: If `nrows` is incompatible with `value_rowids`. 388 389 #### Example: 390 391 >>> print(tf.RaggedTensor.from_value_rowids( 392 ... values=[3, 1, 4, 1, 5, 9, 2, 6], 393 ... value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], 394 ... nrows=5)) 395 <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]> 396 397 """ 398 if not isinstance(validate, bool): 399 raise TypeError(f"Argument `validate` must have type bool. " 400 f"Received {validate}.") 401 402 with ops.name_scope(name, "RaggedFromValueRowIds", 403 [values, value_rowids, nrows]): 404 row_partition = RowPartition.from_value_rowids( 405 value_rowids=value_rowids, 406 nrows=nrows, 407 validate=validate, 408 dtype_hint=_get_optional_partition_dtype(values)) 409 return cls._from_row_partition(values, row_partition, validate=validate) 410 411 @classmethod 412 @dispatch.add_dispatch_support 413 def from_row_splits(cls, values, row_splits, name=None, validate=True): 414 """Creates a `RaggedTensor` with rows partitioned by `row_splits`. 415 416 The returned `RaggedTensor` corresponds with the python list defined by: 417 418 ```python 419 result = [values[row_splits[i]:row_splits[i + 1]] 420 for i in range(len(row_splits) - 1)] 421 ``` 422 423 Args: 424 values: A potentially ragged tensor with shape `[nvals, ...]`. 425 row_splits: A 1-D integer tensor with shape `[nrows+1]`. Must not be 426 empty, and must be sorted in ascending order. `row_splits[0]` must be 427 zero and `row_splits[-1]` must be `nvals`. 428 name: A name prefix for the RaggedTensor (optional). 429 validate: If true, then use assertions to check that the arguments form 430 a valid `RaggedTensor`. Note: these assertions incur a runtime cost, 431 since they must be checked for each tensor value. 432 433 Returns: 434 A `RaggedTensor`. `result.rank = values.rank + 1`. 435 `result.ragged_rank = values.ragged_rank + 1`. 436 437 Raises: 438 ValueError: If `row_splits` is an empty list. 439 440 #### Example: 441 442 >>> print(tf.RaggedTensor.from_row_splits( 443 ... values=[3, 1, 4, 1, 5, 9, 2, 6], 444 ... row_splits=[0, 4, 4, 7, 8, 8])) 445 <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]> 446 447 """ 448 if not isinstance(validate, bool): 449 raise TypeError(f"Argument `validate` must have type bool. " 450 f"Received {validate}.") 451 452 with ops.name_scope(name, "RaggedFromRowSplits", [values, row_splits]): 453 row_partition = RowPartition.from_row_splits( 454 row_splits=row_splits, 455 validate=validate, 456 dtype_hint=_get_optional_partition_dtype(values)) 457 return cls._from_row_partition(values, row_partition, validate=validate) 458 459 @classmethod 460 @dispatch.add_dispatch_support 461 def from_row_lengths(cls, values, row_lengths, name=None, validate=True): 462 """Creates a `RaggedTensor` with rows partitioned by `row_lengths`. 463 464 The returned `RaggedTensor` corresponds with the python list defined by: 465 466 ```python 467 result = [[values.pop(0) for i in range(length)] 468 for length in row_lengths] 469 ``` 470 471 Args: 472 values: A potentially ragged tensor with shape `[nvals, ...]`. 473 row_lengths: A 1-D integer tensor with shape `[nrows]`. Must be 474 nonnegative. `sum(row_lengths)` must be `nvals`. 475 name: A name prefix for the RaggedTensor (optional). 476 validate: If true, then use assertions to check that the arguments form 477 a valid `RaggedTensor`. Note: these assertions incur a runtime cost, 478 since they must be checked for each tensor value. 479 480 Returns: 481 A `RaggedTensor`. `result.rank = values.rank + 1`. 482 `result.ragged_rank = values.ragged_rank + 1`. 483 484 #### Example: 485 486 >>> print(tf.RaggedTensor.from_row_lengths( 487 ... values=[3, 1, 4, 1, 5, 9, 2, 6], 488 ... row_lengths=[4, 0, 3, 1, 0])) 489 <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]> 490 491 """ 492 if not isinstance(validate, bool): 493 raise TypeError(f"Argument `validate` must have type bool. " 494 f"Received {validate}.") 495 496 with ops.name_scope(name, "RaggedFromRowLengths", [values, row_lengths]): 497 row_partition = RowPartition.from_row_lengths( 498 row_lengths=row_lengths, 499 validate=validate, 500 dtype_hint=_get_optional_partition_dtype(values)) 501 return cls._from_row_partition(values, row_partition, validate=validate) 502 503 @classmethod 504 @dispatch.add_dispatch_support 505 def from_row_starts(cls, values, row_starts, name=None, validate=True): 506 """Creates a `RaggedTensor` with rows partitioned by `row_starts`. 507 508 Equivalent to: `from_row_splits(values, concat([row_starts, nvals]))`. 509 510 Args: 511 values: A potentially ragged tensor with shape `[nvals, ...]`. 512 row_starts: A 1-D integer tensor with shape `[nrows]`. Must be 513 nonnegative and sorted in ascending order. If `nrows>0`, then 514 `row_starts[0]` must be zero. 515 name: A name prefix for the RaggedTensor (optional). 516 validate: If true, then use assertions to check that the arguments form 517 a valid `RaggedTensor`. Note: these assertions incur a runtime cost, 518 since they must be checked for each tensor value. 519 520 Returns: 521 A `RaggedTensor`. `result.rank = values.rank + 1`. 522 `result.ragged_rank = values.ragged_rank + 1`. 523 524 #### Example: 525 526 >>> print(tf.RaggedTensor.from_row_starts( 527 ... values=[3, 1, 4, 1, 5, 9, 2, 6], 528 ... row_starts=[0, 4, 4, 7, 8])) 529 <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]> 530 531 """ 532 if not isinstance(validate, bool): 533 raise TypeError(f"Argument `validate` must have type bool. " 534 f"Received {validate}.") 535 with ops.name_scope(name, "RaggedFromRowStarts", [values, row_starts]): 536 values = _convert_to_ragged_tensor_values(values) 537 row_partition = RowPartition.from_row_starts( 538 row_starts=row_starts, 539 nvals=_nrows(values), 540 validate=validate, 541 dtype_hint=_get_optional_partition_dtype(values)) 542 return cls._from_row_partition(values, row_partition, validate=validate) 543 544 @classmethod 545 @dispatch.add_dispatch_support 546 def from_row_limits(cls, values, row_limits, name=None, validate=True): 547 """Creates a `RaggedTensor` with rows partitioned by `row_limits`. 548 549 Equivalent to: `from_row_splits(values, concat([0, row_limits]))`. 550 551 Args: 552 values: A potentially ragged tensor with shape `[nvals, ...]`. 553 row_limits: A 1-D integer tensor with shape `[nrows]`. Must be sorted in 554 ascending order. If `nrows>0`, then `row_limits[-1]` must be `nvals`. 555 name: A name prefix for the RaggedTensor (optional). 556 validate: If true, then use assertions to check that the arguments form 557 a valid `RaggedTensor`. Note: these assertions incur a runtime cost, 558 since they must be checked for each tensor value. 559 560 Returns: 561 A `RaggedTensor`. `result.rank = values.rank + 1`. 562 `result.ragged_rank = values.ragged_rank + 1`. 563 564 #### Example: 565 566 >>> print(tf.RaggedTensor.from_row_limits( 567 ... values=[3, 1, 4, 1, 5, 9, 2, 6], 568 ... row_limits=[4, 4, 7, 8, 8])) 569 <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]> 570 571 """ 572 if not isinstance(validate, bool): 573 raise TypeError(f"Argument `validate` must have type bool. " 574 f"Received {validate}.") 575 with ops.name_scope(name, "RaggedFromRowLimits", [values, row_limits]): 576 values = _convert_to_ragged_tensor_values(values) 577 row_partition = RowPartition.from_row_limits( 578 row_limits=row_limits, 579 validate=validate, 580 dtype_hint=_get_optional_partition_dtype(values)) 581 return cls._from_row_partition(values, row_partition, validate=validate) 582 583 @classmethod 584 @dispatch.add_dispatch_support 585 def from_uniform_row_length(cls, 586 values, 587 uniform_row_length, 588 nrows=None, 589 validate=True, 590 name=None): 591 """Creates a `RaggedTensor` with rows partitioned by `uniform_row_length`. 592 593 This method can be used to create `RaggedTensor`s with multiple uniform 594 outer dimensions. For example, a `RaggedTensor` with shape `[2, 2, None]` 595 can be constructed with this method from a `RaggedTensor` values with shape 596 `[4, None]`: 597 598 >>> values = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]]) 599 >>> print(values.shape) 600 (4, None) 601 >>> rt1 = tf.RaggedTensor.from_uniform_row_length(values, 2) 602 >>> print(rt1) 603 <tf.RaggedTensor [[[1, 2, 3], [4]], [[5, 6], [7, 8, 9, 10]]]> 604 >>> print(rt1.shape) 605 (2, 2, None) 606 607 Note that `rt1` only contains one ragged dimension (the innermost 608 dimension). In contrast, if `from_row_splits` is used to construct a similar 609 `RaggedTensor`, then that `RaggedTensor` will have two ragged dimensions: 610 611 >>> rt2 = tf.RaggedTensor.from_row_splits(values, [0, 2, 4]) 612 >>> print(rt2.shape) 613 (2, None, None) 614 615 Args: 616 values: A potentially ragged tensor with shape `[nvals, ...]`. 617 uniform_row_length: A scalar integer tensor. Must be nonnegative. The 618 size of the outer axis of `values` must be evenly divisible by 619 `uniform_row_length`. 620 nrows: The number of rows in the constructed RaggedTensor. If not 621 specified, then it defaults to `nvals/uniform_row_length` (or `0` if 622 `uniform_row_length==0`). `nrows` only needs to be specified if 623 `uniform_row_length` might be zero. `uniform_row_length*nrows` must be 624 `nvals`. 625 validate: If true, then use assertions to check that the arguments form 626 a valid `RaggedTensor`. Note: these assertions incur a runtime cost, 627 since they must be checked for each tensor value. 628 name: A name prefix for the RaggedTensor (optional). 629 630 Returns: 631 A `RaggedTensor` that corresponds with the python list defined by: 632 633 ```python 634 result = [[values.pop(0) for i in range(uniform_row_length)] 635 for _ in range(nrows)] 636 ``` 637 638 `result.rank = values.rank + 1`. 639 `result.ragged_rank = values.ragged_rank + 1`. 640 """ 641 if not isinstance(validate, bool): 642 raise TypeError(f"Argument `validate` must have type bool. " 643 f"Received {validate}.") 644 with ops.name_scope(name, "RaggedFromUniformRowLength", 645 [values, uniform_row_length, nrows]): 646 values = _convert_to_ragged_tensor_values(values) 647 uniform_row_length = _convert_row_partition( 648 uniform_row_length, "UniformRowLength", 649 _get_optional_partition_dtype(values)) 650 nvals = _nvals_uniform_row_length(values, uniform_row_length) 651 row_partition = RowPartition.from_uniform_row_length( 652 uniform_row_length=uniform_row_length, 653 nvals=nvals, 654 nrows=nrows, 655 validate=validate, 656 dtype_hint=_get_optional_partition_dtype(values)) 657 return cls._from_row_partition(values, row_partition, validate=validate) 658 659 @classmethod 660 @dispatch.add_dispatch_support 661 def from_nested_value_rowids(cls, 662 flat_values, 663 nested_value_rowids, 664 nested_nrows=None, 665 name=None, 666 validate=True): 667 """Creates a `RaggedTensor` from a nested list of `value_rowids` tensors. 668 669 Equivalent to: 670 671 ```python 672 result = flat_values 673 for (rowids, nrows) in reversed(zip(nested_value_rowids, nested_nrows)): 674 result = from_value_rowids(result, rowids, nrows) 675 ``` 676 677 Args: 678 flat_values: A potentially ragged tensor. 679 nested_value_rowids: A list of 1-D integer tensors. The `i`th tensor is 680 used as the `value_rowids` for the `i`th ragged dimension. 681 nested_nrows: A list of integer scalars. The `i`th scalar is used as the 682 `nrows` for the `i`th ragged dimension. 683 name: A name prefix for the RaggedTensor (optional). 684 validate: If true, then use assertions to check that the arguments form 685 a valid `RaggedTensor`. Note: these assertions incur a runtime cost, 686 since they must be checked for each tensor value. 687 688 Returns: 689 A `RaggedTensor` (or `flat_values` if `nested_value_rowids` is empty). 690 691 Raises: 692 ValueError: If `len(nested_values_rowids) != len(nested_nrows)`. 693 """ 694 if not isinstance(validate, bool): 695 raise TypeError(f"Argument `validate` must have type bool. " 696 f"Received {validate}.") 697 if isinstance(nested_value_rowids, ops.Tensor): 698 raise TypeError(f"Argument `nested_value_rowids` must be a list of " 699 f"Tensors. Received {nested_value_rowids}.") 700 if nested_nrows is None: 701 nested_nrows = [None] * len(nested_value_rowids) 702 else: 703 if isinstance(nested_nrows, ops.Tensor): 704 raise TypeError(f"Argument `nested_nrows` must be a list of " 705 f"Tensors. Received {nested_nrows}.") 706 if len(nested_nrows) != len(nested_value_rowids): 707 raise ValueError( 708 f"Argument `nested_nrows` must have the same length as " 709 f"argument `nested_value_rowids`. len(nested_nrows) = " 710 f"{len(nested_nrows)} vs. len(nested_values_rowids) = " 711 f"{len(nested_value_rowids)}.") 712 713 with ops.name_scope(name, "RaggedFromNestedValueRowIds", [flat_values] + 714 list(nested_value_rowids) + list(nested_nrows)): 715 result = flat_values 716 for value_rowids, nrows in reversed( 717 list(zip(nested_value_rowids, nested_nrows))): 718 result = cls.from_value_rowids( 719 result, value_rowids, nrows, validate=validate) 720 return result 721 722 @classmethod 723 @dispatch.add_dispatch_support 724 def from_nested_row_splits(cls, 725 flat_values, 726 nested_row_splits, 727 name=None, 728 validate=True): 729 """Creates a `RaggedTensor` from a nested list of `row_splits` tensors. 730 731 Equivalent to: 732 733 ```python 734 result = flat_values 735 for row_splits in reversed(nested_row_splits): 736 result = from_row_splits(result, row_splits) 737 ``` 738 739 Args: 740 flat_values: A potentially ragged tensor. 741 nested_row_splits: A list of 1-D integer tensors. The `i`th tensor is 742 used as the `row_splits` for the `i`th ragged dimension. 743 name: A name prefix for the RaggedTensor (optional). 744 validate: If true, then use assertions to check that the arguments form 745 a valid `RaggedTensor`. Note: these assertions incur a runtime cost, 746 since they must be checked for each tensor value. 747 748 Returns: 749 A `RaggedTensor` (or `flat_values` if `nested_row_splits` is empty). 750 """ 751 if not isinstance(validate, bool): 752 raise TypeError(f"Argument `validate` must have type bool. " 753 f"Received {validate}.") 754 if isinstance(nested_row_splits, ops.Tensor): 755 raise TypeError(f"Argument `nested_row_splits` must be a list of " 756 f"Tensors. Received {nested_row_splits}.") 757 with ops.name_scope(name, "RaggedFromNestedRowSplits", 758 [flat_values] + list(nested_row_splits)): 759 result = flat_values 760 for splits in reversed(nested_row_splits): 761 result = cls.from_row_splits(result, splits, validate=validate) 762 return result 763 764 @classmethod 765 @dispatch.add_dispatch_support 766 def from_nested_row_lengths(cls, 767 flat_values, 768 nested_row_lengths, 769 name=None, 770 validate=True): 771 """Creates a `RaggedTensor` from a nested list of `row_lengths` tensors. 772 773 Equivalent to: 774 775 ```python 776 result = flat_values 777 for row_lengths in reversed(nested_row_lengths): 778 result = from_row_lengths(result, row_lengths) 779 ``` 780 781 Args: 782 flat_values: A potentially ragged tensor. 783 nested_row_lengths: A list of 1-D integer tensors. The `i`th tensor is 784 used as the `row_lengths` for the `i`th ragged dimension. 785 name: A name prefix for the RaggedTensor (optional). 786 validate: If true, then use assertions to check that the arguments form 787 a valid `RaggedTensor`. Note: these assertions incur a runtime cost, 788 since they must be checked for each tensor value. 789 790 Returns: 791 A `RaggedTensor` (or `flat_values` if `nested_row_lengths` is empty). 792 """ 793 if not isinstance(validate, bool): 794 raise TypeError(f"Argument `validate` must have type bool. " 795 f"Received {validate}.") 796 if isinstance(nested_row_lengths, ops.Tensor): 797 raise TypeError(f"Argument `nested_row_lengths` must be a list of " 798 f"Tensors. Received {nested_row_lengths}.") 799 with ops.name_scope(name, "RaggedFromNestedRowlengths", 800 [flat_values] + list(nested_row_lengths)): 801 result = flat_values 802 for lengths in reversed(nested_row_lengths): 803 result = cls.from_row_lengths(result, lengths, validate=validate) 804 return result 805 806 @classmethod 807 def _from_nested_row_partitions(cls, 808 flat_values, 809 nested_row_partitions, 810 name=None, 811 validate=True): 812 """Creates a `RaggedTensor` from a nested list of row partitions. 813 814 Equivalent to: 815 816 ```python 817 result = flat_values 818 for row_partition in reversed(nested_row_partitions): 819 result = _from_row_partition(result, row_partition) 820 ``` 821 822 Args: 823 flat_values: A potentially ragged tensor. 824 nested_row_partitions: A list of row partitions. The `i`th element is 825 used as the row partition for the `i`th ragged dimension. 826 name: A name prefix for the RaggedTensor (optional). 827 validate: If true, then use assertions to check that the arguments form 828 a valid `RaggedTensor`. Note: these assertions incur a runtime cost, 829 since they must be checked for each tensor value. 830 831 Returns: 832 A `RaggedTensor` (or `flat_values` if `nested_row_lengths` is empty). 833 """ 834 if not isinstance(validate, bool): 835 raise TypeError(f"Argument `validate` must have type bool. " 836 f"Received {validate}.") 837 if isinstance(nested_row_partitions, RowPartition): 838 raise TypeError(f"Argument `nested_row_partitions` must be a list of " 839 f"RowPartitions. Received {nested_row_partitions}.") 840 if isinstance(nested_row_partitions, ops.Tensor): 841 raise TypeError(f"Argument `nested_row_partitions` must be a list of " 842 f"RowPartitions. Received {nested_row_partitions}.") 843 with ops.name_scope(name, "RaggedFromNestedRowPartitions", 844 [flat_values] + list(nested_row_partitions)): 845 result = flat_values 846 for partition in reversed(nested_row_partitions): 847 result = cls._from_row_partition(result, partition, validate=validate) 848 return result 849 850 @classmethod 851 def _convert_values_and_partition(cls, values, row_partition, name): 852 """Converts `values` and `partition` to Tensors. 853 854 If `values` is a `RaggedTensor`, then converts `values` and `partition` 855 to have compatible row-partitioning dtypes. In particular, if any of the 856 row partitioning tensors are `int64`, then all of the other row 857 partitioning tensors wil be cast to `int64` (if auto_cast_partition_dtype() 858 is true) or an error will be raised (if auto_cast_partition_dtype() is 859 false). 860 861 Args: 862 values: The `values` for the `RaggedTensor` being constructed. 863 row_partition: A RowPartition object for the `RaggedTensor` being 864 constructed. 865 name: The name of the RowPartition object. 866 867 Returns: 868 A tuple (values, partition). 869 """ 870 if not isinstance(row_partition, RowPartition): 871 raise TypeError(f"Argument `row_partition` must be a RowPartition. " 872 f"Received {row_partition}.") 873 if isinstance(values, RaggedTensor): 874 # pylint: disable=protected-access 875 if values._row_partition.dtype != row_partition.dtype: 876 if not ragged_config.auto_cast_partition_dtype(): 877 # pylint: disable=protected-access 878 # TODO(edloper): get rid of the `name` parameter. 879 raise ValueError( 880 f"Argument `row_partition` of RaggedTensor with name: {name} " 881 f"must have same dtype as Argument `values`. " 882 f"({row_partition.dtype} vs. {values._row_partition.dtype}).") 883 values = values.with_row_splits_dtype(row_partition.dtype) 884 else: 885 values = _convert_to_ragged_tensor_values(values) 886 887 return (values, row_partition) 888 889 #============================================================================= 890 # Accessors 891 #============================================================================= 892 893 @property 894 def dtype(self): 895 """The `DType` of values in this tensor.""" 896 return self._values.dtype 897 898 @property 899 def shape(self): 900 """The statically known shape of this ragged tensor. 901 902 Returns: 903 A `TensorShape` containing the statically known shape of this ragged 904 tensor. Ragged dimensions have a size of `None`. 905 906 Examples: 907 908 >>> tf.ragged.constant([[0], [1, 2]]).shape 909 TensorShape([2, None]) 910 911 >>> tf.ragged.constant([[[0, 1]], [[1, 2], [3, 4]]], ragged_rank=1).shape 912 TensorShape([2, None, 2]) 913 914 """ 915 nrows = self._row_partition.static_nrows 916 ncols = self._row_partition.static_uniform_row_length 917 value_shape = self._values.shape[1:] 918 return tensor_shape.TensorShape([nrows, ncols]).concatenate(value_shape) 919 920 def get_shape(self): 921 """The statically known shape of this ragged tensor. 922 923 Returns: 924 A `TensorShape` containing the statically known shape of this ragged 925 tensor. Ragged dimensions have a size of `None`. 926 927 Alias for `shape` property. 928 929 Examples: 930 931 >>> tf.ragged.constant([[0], [1, 2]]).get_shape() 932 TensorShape([2, None]) 933 934 >>> tf.ragged.constant( 935 ... [[[0, 1]], [[1, 2], [3, 4]]], ragged_rank=1).get_shape() 936 TensorShape([2, None, 2]) 937 938 """ 939 return self.shape 940 941 @property 942 def ragged_rank(self): 943 """The number of times the RaggedTensor's flat_values is partitioned. 944 945 Examples: 946 947 >>> values = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]]) 948 >>> values.ragged_rank 949 1 950 951 >>> rt = tf.RaggedTensor.from_uniform_row_length(values, 2) 952 >>> rt.ragged_rank 953 2 954 955 Returns: 956 A Python `int` indicating the number of times the underlying `flat_values` 957 Tensor has been partitioned to add a new dimension. 958 I.e., `tf.rank(rt) = tf.rank(rt.flat_values) + rt.ragged_rank`. 959 """ 960 values_is_ragged = isinstance(self._values, RaggedTensor) 961 return self._values.ragged_rank + 1 if values_is_ragged else 1 962 963 @property 964 def values(self): 965 """The concatenated rows for this ragged tensor. 966 967 `rt.values` is a potentially ragged tensor formed by flattening the two 968 outermost dimensions of `rt` into a single dimension. 969 970 `rt.values.shape = [nvals] + rt.shape[2:]` (where `nvals` is the 971 number of items in the outer two dimensions of `rt`). 972 973 `rt.ragged_rank = self.ragged_rank - 1` 974 975 Returns: 976 A potentially ragged tensor. 977 978 #### Example: 979 980 >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []]) 981 >>> print(rt.values) 982 tf.Tensor([3 1 4 1 5 9 2 6], shape=(8,), dtype=int32) 983 984 """ 985 return self._values 986 987 @property 988 def _nested_row_partitions(self): 989 """Returns the row partitions for this `RaggedTensor`.""" 990 partitions = [self._row_partition] 991 rt_values = self.values 992 while isinstance(rt_values, RaggedTensor): 993 # pylint: disable=protected-access 994 partitions.append(rt_values._row_partition) 995 rt_values = rt_values.values 996 return tuple(partitions) 997 998 @property 999 def row_splits(self): 1000 """The row-split indices for this ragged tensor's `values`. 1001 1002 `rt.row_splits` specifies where the values for each row begin and end in 1003 `rt.values`. In particular, the values for row `rt[i]` are stored in 1004 the slice `rt.values[rt.row_splits[i]:rt.row_splits[i+1]]`. 1005 1006 Returns: 1007 A 1-D integer `Tensor` with shape `[self.nrows+1]`. 1008 The returned tensor is non-empty, and is sorted in ascending order. 1009 `self.row_splits[0]` is zero, and `self.row_splits[-1]` is equal to 1010 `self.values.shape[0]`. 1011 1012 #### Example: 1013 1014 >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []]) 1015 >>> print(rt.row_splits) # indices of row splits in rt.values 1016 tf.Tensor([0 4 4 7 8 8], shape=(6,), dtype=int64) 1017 1018 """ 1019 return self._row_partition.row_splits() 1020 1021 @property 1022 def uniform_row_length(self): 1023 """The length of each row in this ragged tensor, or None if rows are ragged. 1024 1025 >>> rt1 = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]]) 1026 >>> print(rt1.uniform_row_length) # rows are ragged. 1027 None 1028 1029 >>> rt2 = tf.RaggedTensor.from_uniform_row_length( 1030 ... values=rt1, uniform_row_length=2) 1031 >>> print(rt2) 1032 <tf.RaggedTensor [[[1, 2, 3], [4]], [[5, 6], [7, 8, 9, 10]]]> 1033 >>> print(rt2.uniform_row_length) # rows are not ragged (all have size 2). 1034 tf.Tensor(2, shape=(), dtype=int64) 1035 1036 A RaggedTensor's rows are only considered to be uniform (i.e. non-ragged) 1037 if it can be determined statically (at graph construction time) that the 1038 rows all have the same length. 1039 1040 Returns: 1041 A scalar integer `Tensor`, specifying the length of every row in this 1042 ragged tensor (for ragged tensors whose rows are uniform); or `None` 1043 (for ragged tensors whose rows are ragged). 1044 """ 1045 return self._row_partition.uniform_row_length() 1046 1047 @property 1048 def flat_values(self): 1049 """The innermost `values` tensor for this ragged tensor. 1050 1051 Concretely, if `rt.values` is a `Tensor`, then `rt.flat_values` is 1052 `rt.values`; otherwise, `rt.flat_values` is `rt.values.flat_values`. 1053 1054 Conceptually, `flat_values` is the tensor formed by flattening the 1055 outermost dimension and all of the ragged dimensions into a single 1056 dimension. 1057 1058 `rt.flat_values.shape = [nvals] + rt.shape[rt.ragged_rank + 1:]` 1059 (where `nvals` is the number of items in the flattened dimensions). 1060 1061 Returns: 1062 A `Tensor`. 1063 1064 #### Example: 1065 1066 >>> rt = tf.ragged.constant([[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]]) 1067 >>> print(rt.flat_values) 1068 tf.Tensor([3 1 4 1 5 9 2 6], shape=(8,), dtype=int32) 1069 1070 """ 1071 rt_values = self.values 1072 while isinstance(rt_values, RaggedTensor): 1073 rt_values = rt_values.values 1074 return rt_values 1075 1076 @property 1077 def nested_row_splits(self): 1078 """A tuple containing the row_splits for all ragged dimensions. 1079 1080 `rt.nested_row_splits` is a tuple containing the `row_splits` tensors for 1081 all ragged dimensions in `rt`, ordered from outermost to innermost. In 1082 particular, `rt.nested_row_splits = (rt.row_splits,) + value_splits` where: 1083 1084 * `value_splits = ()` if `rt.values` is a `Tensor`. 1085 * `value_splits = rt.values.nested_row_splits` otherwise. 1086 1087 Returns: 1088 A `tuple` of 1-D integer `Tensor`s. 1089 1090 #### Example: 1091 1092 >>> rt = tf.ragged.constant( 1093 ... [[[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]]]) 1094 >>> for i, splits in enumerate(rt.nested_row_splits): 1095 ... print('Splits for dimension %d: %s' % (i+1, splits.numpy())) 1096 Splits for dimension 1: [0 3] 1097 Splits for dimension 2: [0 3 3 5] 1098 Splits for dimension 3: [0 4 4 7 8 8] 1099 1100 """ 1101 rt_nested_splits = [self.row_splits] 1102 rt_values = self.values 1103 while isinstance(rt_values, RaggedTensor): 1104 rt_nested_splits.append(rt_values.row_splits) 1105 rt_values = rt_values.values 1106 return tuple(rt_nested_splits) 1107 1108 def value_rowids(self, name=None): 1109 """Returns the row indices for the `values` in this ragged tensor. 1110 1111 `rt.value_rowids()` corresponds one-to-one with the outermost dimension of 1112 `rt.values`, and specifies the row containing each value. In particular, 1113 the row `rt[row]` consists of the values `rt.values[j]` where 1114 `rt.value_rowids()[j] == row`. 1115 1116 Args: 1117 name: A name prefix for the returned tensor (optional). 1118 1119 Returns: 1120 A 1-D integer `Tensor` with shape `self.values.shape[:1]`. 1121 The returned tensor is nonnegative, and is sorted in ascending order. 1122 1123 #### Example: 1124 1125 >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []]) 1126 >>> print(rt.values) 1127 tf.Tensor([3 1 4 1 5 9 2 6], shape=(8,), dtype=int32) 1128 >>> print(rt.value_rowids()) # corresponds 1:1 with rt.values 1129 tf.Tensor([0 0 0 0 2 2 2 3], shape=(8,), dtype=int64) 1130 1131 """ 1132 with ops.name_scope(name, "RaggedValueRowIds", [self]): 1133 return self._row_partition.value_rowids() 1134 1135 def nested_value_rowids(self, name=None): 1136 """Returns a tuple containing the value_rowids for all ragged dimensions. 1137 1138 `rt.nested_value_rowids` is a tuple containing the `value_rowids` tensors 1139 for 1140 all ragged dimensions in `rt`, ordered from outermost to innermost. In 1141 particular, `rt.nested_value_rowids = (rt.value_rowids(),) + value_ids` 1142 where: 1143 1144 * `value_ids = ()` if `rt.values` is a `Tensor`. 1145 * `value_ids = rt.values.nested_value_rowids` otherwise. 1146 1147 Args: 1148 name: A name prefix for the returned tensors (optional). 1149 1150 Returns: 1151 A `tuple` of 1-D integer `Tensor`s. 1152 1153 #### Example: 1154 1155 >>> rt = tf.ragged.constant( 1156 ... [[[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]]]) 1157 >>> for i, ids in enumerate(rt.nested_value_rowids()): 1158 ... print('row ids for dimension %d: %s' % (i+1, ids.numpy())) 1159 row ids for dimension 1: [0 0 0] 1160 row ids for dimension 2: [0 0 0 2 2] 1161 row ids for dimension 3: [0 0 0 0 2 2 2 3] 1162 1163 """ 1164 with ops.name_scope(name, "RaggedNestedValueRowIds", [self]): 1165 rt_nested_ids = [self.value_rowids()] 1166 rt_values = self.values 1167 while isinstance(rt_values, RaggedTensor): 1168 rt_nested_ids.append(rt_values.value_rowids()) 1169 rt_values = rt_values.values 1170 return tuple(rt_nested_ids) 1171 1172 def nrows(self, out_type=None, name=None): 1173 """Returns the number of rows in this ragged tensor. 1174 1175 I.e., the size of the outermost dimension of the tensor. 1176 1177 Args: 1178 out_type: `dtype` for the returned tensor. Defaults to 1179 `self.row_splits.dtype`. 1180 name: A name prefix for the returned tensor (optional). 1181 1182 Returns: 1183 A scalar `Tensor` with dtype `out_type`. 1184 1185 #### Example: 1186 1187 >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []]) 1188 >>> print(rt.nrows()) # rt has 5 rows. 1189 tf.Tensor(5, shape=(), dtype=int64) 1190 1191 """ 1192 with ops.name_scope(name, "RaggedNRows", [self]): 1193 if out_type is None: 1194 return self._row_partition.nrows() 1195 else: 1196 return math_ops.cast(self._row_partition.nrows(), dtype=out_type) 1197 1198 def row_starts(self, name=None): 1199 """Returns the start indices for rows in this ragged tensor. 1200 1201 These indices specify where the values for each row begin in 1202 `self.values`. `rt.row_starts()` is equal to `rt.row_splits[:-1]`. 1203 1204 Args: 1205 name: A name prefix for the returned tensor (optional). 1206 1207 Returns: 1208 A 1-D integer Tensor with shape `[nrows]`. 1209 The returned tensor is nonnegative, and is sorted in ascending order. 1210 1211 #### Example: 1212 1213 >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []]) 1214 >>> print(rt.values) 1215 tf.Tensor([3 1 4 1 5 9 2 6], shape=(8,), dtype=int32) 1216 >>> print(rt.row_starts()) # indices of row starts in rt.values 1217 tf.Tensor([0 4 4 7 8], shape=(5,), dtype=int64) 1218 1219 """ 1220 with ops.name_scope(name, "RaggedRowStarts", [self]): 1221 return self._row_partition.row_starts() 1222 1223 def row_limits(self, name=None): 1224 """Returns the limit indices for rows in this ragged tensor. 1225 1226 These indices specify where the values for each row end in 1227 `self.values`. `rt.row_limits(self)` is equal to `rt.row_splits[:-1]`. 1228 1229 Args: 1230 name: A name prefix for the returned tensor (optional). 1231 1232 Returns: 1233 A 1-D integer Tensor with shape `[nrows]`. 1234 The returned tensor is nonnegative, and is sorted in ascending order. 1235 1236 #### Example: 1237 1238 >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []]) 1239 >>> print(rt.values) 1240 tf.Tensor([3 1 4 1 5 9 2 6], shape=(8,), dtype=int32) 1241 >>> print(rt.row_limits()) # indices of row limits in rt.values 1242 tf.Tensor([4 4 7 8 8], shape=(5,), dtype=int64) 1243 1244 """ 1245 with ops.name_scope(name, "RaggedRowLimits", [self]): 1246 return self._row_partition.row_limits() 1247 1248 def row_lengths(self, axis=1, name=None): 1249 """Returns the lengths of the rows in this ragged tensor. 1250 1251 `rt.row_lengths()[i]` indicates the number of values in the 1252 `i`th row of `rt`. 1253 1254 Args: 1255 axis: An integer constant indicating the axis whose row lengths should be 1256 returned. 1257 name: A name prefix for the returned tensor (optional). 1258 1259 Returns: 1260 A potentially ragged integer Tensor with shape `self.shape[:axis]`. 1261 1262 Raises: 1263 ValueError: If `axis` is out of bounds. 1264 1265 #### Example: 1266 1267 >>> rt = tf.ragged.constant( 1268 ... [[[3, 1, 4], [1]], [], [[5, 9], [2]], [[6]], []]) 1269 >>> print(rt.row_lengths()) # lengths of rows in rt 1270 tf.Tensor([2 0 2 1 0], shape=(5,), dtype=int64) 1271 >>> print(rt.row_lengths(axis=2)) # lengths of axis=2 rows. 1272 <tf.RaggedTensor [[3, 1], [], [2, 1], [1], []]> 1273 1274 """ 1275 if axis == 0: 1276 return self._row_partition.nrows() 1277 1278 if axis == 1: 1279 return self._row_partition.row_lengths() 1280 1281 with ops.name_scope(name, "RaggedRowLengths", [self]): 1282 axis = array_ops.get_positive_axis( 1283 axis, self.shape.rank, ndims_name="rank(self)") 1284 if axis == 0: 1285 return self.nrows() 1286 elif axis == 1: 1287 splits = self.row_splits 1288 return splits[1:] - splits[:-1] 1289 elif isinstance(self.values, RaggedTensor): 1290 return self.with_values(self.values.row_lengths(axis - 1)) 1291 else: 1292 shape = array_ops.shape(self.values, out_type=self._row_partition.dtype) 1293 return self.with_values( 1294 array_ops.ones(shape[:axis - 1], self._row_partition.dtype) * 1295 shape[axis - 1]) 1296 1297 def nested_row_lengths(self, name=None): 1298 """Returns a tuple containing the row_lengths for all ragged dimensions. 1299 1300 `rt.nested_row_lengths()` is a tuple containing the `row_lengths` tensors 1301 for all ragged dimensions in `rt`, ordered from outermost to innermost. 1302 1303 Args: 1304 name: A name prefix for the returned tensors (optional). 1305 1306 Returns: 1307 A `tuple` of 1-D integer `Tensors`. The length of the tuple is equal to 1308 `self.ragged_rank`. 1309 """ 1310 with ops.name_scope(name, "RaggedNestedRowLengths", [self]): 1311 rt_nested_row_lengths = [] 1312 rt = self 1313 while isinstance(rt, RaggedTensor): 1314 rt_nested_row_lengths.append(rt.row_lengths()) 1315 rt = rt.values 1316 return tuple(rt_nested_row_lengths) 1317 1318 def bounding_shape(self, axis=None, name=None, out_type=None): 1319 """Returns the tight bounding box shape for this `RaggedTensor`. 1320 1321 Args: 1322 axis: An integer scalar or vector indicating which axes to return the 1323 bounding box for. If not specified, then the full bounding box is 1324 returned. 1325 name: A name prefix for the returned tensor (optional). 1326 out_type: `dtype` for the returned tensor. Defaults to 1327 `self.row_splits.dtype`. 1328 1329 Returns: 1330 An integer `Tensor` (`dtype=self.row_splits.dtype`). If `axis` is not 1331 specified, then `output` is a vector with 1332 `output.shape=[self.shape.ndims]`. If `axis` is a scalar, then the 1333 `output` is a scalar. If `axis` is a vector, then `output` is a vector, 1334 where `output[i]` is the bounding size for dimension `axis[i]`. 1335 1336 #### Example: 1337 1338 >>> rt = tf.ragged.constant([[1, 2, 3, 4], [5], [], [6, 7, 8, 9], [10]]) 1339 >>> rt.bounding_shape().numpy() 1340 array([5, 4]) 1341 1342 """ 1343 if out_type is None: 1344 out_type = self._row_partition.dtype 1345 else: 1346 out_type = dtypes.as_dtype(out_type) 1347 with ops.name_scope(name, "RaggedBoundingBox", [self, axis]): 1348 nested_splits = self.nested_row_splits 1349 rt_flat_values = self.flat_values 1350 1351 # Optimized special cases for when axis=0 or axis=1: 1352 if isinstance(axis, int): 1353 if axis == 0: 1354 return array_ops.shape(nested_splits[0], out_type=out_type)[0] - 1 1355 elif axis == 1: 1356 result = math_ops.maximum(math_ops.reduce_max(self.row_lengths()), 0) 1357 if out_type != self._row_partition.dtype: 1358 result = math_ops.cast(result, out_type) 1359 return result 1360 1361 splits_shape = array_ops.shape(self.row_splits, out_type=out_type) 1362 flat_values_shape = array_ops.shape(rt_flat_values, out_type=out_type) 1363 1364 ragged_dimensions = [splits_shape[0] - 1] + [ 1365 math_ops.maximum(math_ops.reduce_max(splits[1:] - splits[:-1]), 0) 1366 for splits in nested_splits 1367 ] 1368 inner_dimensions = flat_values_shape[1:] 1369 1370 if out_type != self._row_partition.dtype: 1371 ragged_dimensions = [ 1372 math_ops.cast(d, out_type) for d in ragged_dimensions 1373 ] 1374 bbox = array_ops.concat( 1375 [array_ops.stack(ragged_dimensions), inner_dimensions], axis=0) 1376 return bbox if axis is None else array_ops.gather(bbox, axis) 1377 1378 #============================================================================= 1379 # Transformation 1380 #============================================================================= 1381 1382 def with_values(self, new_values): 1383 """Returns a copy of `self` with `values` replaced by `new_value`. 1384 1385 Preserves cached row-partitioning tensors such as `self.cached_nrows` and 1386 `self.cached_value_rowids` if they have values. 1387 1388 Args: 1389 new_values: Potentially ragged tensor to use as the `values` for the 1390 returned `RaggedTensor`. Must have `rank > 0`, and must have the same 1391 number of rows as `self.values`. 1392 1393 Returns: 1394 A `RaggedTensor`. `result.rank = 1 + new_values.rank`. 1395 `result.ragged_rank = 1 + new_values.ragged_rank` 1396 """ 1397 new_values = _convert_to_ragged_tensor_values(new_values) 1398 new_values.shape.with_rank_at_least(1) 1399 self.values.shape[:1].assert_is_compatible_with(new_values.shape[:1]) 1400 if (isinstance(new_values, RaggedTensor) and 1401 self._row_partition.dtype != new_values.row_splits.dtype): 1402 if not ragged_config.auto_cast_partition_dtype(): 1403 raise ValueError("self and new_values have mismatched row_splits " 1404 "dtypes; use RaggedTensor.with_row_splits_dtype() to " 1405 "convert them to compatible dtypes.") 1406 new_values = new_values.with_row_splits_dtype(dtypes.int64) 1407 return self.with_row_splits_dtype(dtypes.int64).with_values(new_values) 1408 return RaggedTensor( 1409 values=new_values, row_partition=self._row_partition, internal=True) 1410 1411 def with_flat_values(self, new_values): 1412 """Returns a copy of `self` with `flat_values` replaced by `new_value`. 1413 1414 Preserves cached row-partitioning tensors such as `self.cached_nrows` and 1415 `self.cached_value_rowids` if they have values. 1416 1417 Args: 1418 new_values: Potentially ragged tensor that should replace 1419 `self.flat_values`. Must have `rank > 0`, and must have the same number 1420 of rows as `self.flat_values`. 1421 1422 Returns: 1423 A `RaggedTensor`. 1424 `result.rank = self.ragged_rank + new_values.rank`. 1425 `result.ragged_rank = self.ragged_rank + new_values.ragged_rank`. 1426 """ 1427 if isinstance(self._values, RaggedTensor): 1428 return self.with_values(self.values.with_flat_values(new_values)) 1429 else: 1430 new_values = _convert_to_ragged_tensor_values(new_values) 1431 return self.with_values(new_values) 1432 1433 def with_row_splits_dtype(self, dtype): 1434 """Returns a copy of this RaggedTensor with the given `row_splits` dtype. 1435 1436 For RaggedTensors with multiple ragged dimensions, the `row_splits` for all 1437 nested `RaggedTensor` objects are cast to the given dtype. 1438 1439 Args: 1440 dtype: The dtype for `row_splits`. One of `tf.int32` or `tf.int64`. 1441 1442 Returns: 1443 A copy of this RaggedTensor, with the `row_splits` cast to the given 1444 type. 1445 """ 1446 dtype = dtypes.as_dtype(dtype) 1447 if dtype not in (dtypes.int32, dtypes.int64): 1448 raise ValueError(f"Argument `row_splits` dtype must be int32 or int64. " 1449 f"Received {dtype}.") 1450 if self._row_partition.dtype == dtype: 1451 return self 1452 current_values = self._values 1453 if isinstance(current_values, RaggedTensor): 1454 return RaggedTensor( 1455 values=current_values.with_row_splits_dtype(dtype), 1456 row_partition=self._row_partition.with_dtype(dtype), 1457 internal=True) 1458 else: 1459 return RaggedTensor( 1460 values=current_values, 1461 row_partition=self._row_partition.with_dtype(dtype), 1462 internal=True) 1463 1464 def merge_dims(self, outer_axis, inner_axis): 1465 """Merges outer_axis...inner_axis into a single dimension. 1466 1467 Returns a copy of this RaggedTensor with the specified range of dimensions 1468 flattened into a single dimension, with elements in row-major order. 1469 1470 #### Examples: 1471 1472 >>> rt = tf.ragged.constant([[[1, 2], [3]], [[4, 5, 6]]]) 1473 >>> print(rt.merge_dims(0, 1)) 1474 <tf.RaggedTensor [[1, 2], [3], [4, 5, 6]]> 1475 >>> print(rt.merge_dims(1, 2)) 1476 <tf.RaggedTensor [[1, 2, 3], [4, 5, 6]]> 1477 >>> print(rt.merge_dims(0, 2)) 1478 tf.Tensor([1 2 3 4 5 6], shape=(6,), dtype=int32) 1479 1480 To mimic the behavior of `np.flatten` (which flattens all dimensions), use 1481 `rt.merge_dims(0, -1). To mimic the behavior of `tf.layers.Flatten` (which 1482 flattens all dimensions except the outermost batch dimension), use 1483 `rt.merge_dims(1, -1)`. 1484 1485 Args: 1486 outer_axis: `int`: The first dimension in the range of dimensions to 1487 merge. May be negative if `self.shape.rank` is statically known. 1488 inner_axis: `int`: The last dimension in the range of dimensions to merge. 1489 May be negative if `self.shape.rank` is statically known. 1490 1491 Returns: 1492 A copy of this tensor, with the specified dimensions merged into a 1493 single dimension. The shape of the returned tensor will be 1494 `self.shape[:outer_axis] + [N] + self.shape[inner_axis + 1:]`, where `N` 1495 is the total number of slices in the merged dimensions. 1496 """ 1497 outer_axis = array_ops.get_positive_axis( 1498 outer_axis, 1499 self.shape.rank, 1500 axis_name="outer_axis", 1501 ndims_name="rank(self)") 1502 inner_axis = array_ops.get_positive_axis( 1503 inner_axis, 1504 self.shape.rank, 1505 axis_name="inner_axis", 1506 ndims_name="rank(self)") 1507 if not outer_axis <= inner_axis: 1508 raise ValueError(f"Expected outer_axis ({outer_axis}) to be less than or " 1509 f"equal to inner_axis ({inner_axis}).") 1510 return merge_dims(self, outer_axis, inner_axis) 1511 1512 def _set_shape(self, shape): 1513 """Updates the static shape of `self` to be `shape`. 1514 1515 * If a dimension of `shape` has known rank, and is encoded via 1516 partitioning, then this will update the corresponding partition to 1517 define `_uniform_row_length` and `nrows`. 1518 * If a dimension of `shape` has a known rank, and is encoded as one 1519 of the `flat_values` dimensions, then `flat_values.set_shape()` will 1520 be used to update its shape. 1521 1522 Warning: Using this method to assert an incorrect shape for a RaggedTensor 1523 (i.e., one that's not consistent with its actual shape) can cause 1524 segmentation faults and very difficult-to-diagnose behavior. Only use this 1525 method if you are certain that the shape is correct. 1526 1527 Args: 1528 shape: `tf.TensorShape` specifying the shape for this `RaggedTensor`. 1529 """ 1530 # TODO(edloper): Refactor this to not directly access private members 1531 # of RowPartition. 1532 # pylint: disable=protected-access 1533 1534 shape = tensor_shape.as_shape(shape) 1535 if shape.rank is None: 1536 return # Nothing to do. 1537 1538 shape = shape.as_list() 1539 1540 # Outermost dimension 1541 if shape[0] is not None: 1542 self._row_partition._row_splits.set_shape(shape[0] + 1) 1543 1544 # Partitioned dimensions 1545 dtype = self._row_partition.dtype 1546 for i, partition in enumerate(self._nested_row_partitions): 1547 size = shape[i + 1] 1548 if size is not None: 1549 if partition._uniform_row_length is not None: 1550 old_row_length = tensor_util.constant_value( 1551 partition._uniform_row_length) 1552 if old_row_length is not None: 1553 if size == old_row_length: 1554 continue # already have shape info for this axis. 1555 else: 1556 raise ValueError(f"Inconsistent size for axis {i + 1}: " 1557 f"{old_row_length} vs. {size}.") 1558 partition._uniform_row_length = ops.convert_to_tensor(size, dtype) 1559 if partition._nrows is None: 1560 partition._nrows = array_ops.size( 1561 partition._row_splits, out_type=dtype) - 1 1562 1563 # self.flat_values could be a CompositeTensor and doesn't have set_shape. 1564 if hasattr(self.flat_values, "set_shape"): 1565 # Inner dimensions 1566 flat_shape = tensor_shape.as_shape([None] + shape[self.ragged_rank + 1:]) 1567 self.flat_values.set_shape(flat_shape) 1568 1569 #============================================================================= 1570 # Tensor Type Conversions 1571 #============================================================================= 1572 1573 @classmethod 1574 @dispatch.add_dispatch_support 1575 def from_tensor(cls, 1576 tensor, 1577 lengths=None, 1578 padding=None, 1579 ragged_rank=1, 1580 name=None, 1581 row_splits_dtype=dtypes.int64): 1582 """Converts a `tf.Tensor` into a `RaggedTensor`. 1583 1584 The set of absent/default values may be specified using a vector of lengths 1585 or a padding value (but not both). If `lengths` is specified, then the 1586 output tensor will satisfy `output[row] = tensor[row][:lengths[row]]`. If 1587 'lengths' is a list of lists or tuple of lists, those lists will be used 1588 as nested row lengths. If `padding` is specified, then any row *suffix* 1589 consisting entirely of `padding` will be excluded from the returned 1590 `RaggedTensor`. If neither `lengths` nor `padding` is specified, then the 1591 returned `RaggedTensor` will have no absent/default values. 1592 1593 Examples: 1594 1595 >>> dt = tf.constant([[5, 7, 0], [0, 3, 0], [6, 0, 0]]) 1596 >>> tf.RaggedTensor.from_tensor(dt) 1597 <tf.RaggedTensor [[5, 7, 0], [0, 3, 0], [6, 0, 0]]> 1598 >>> tf.RaggedTensor.from_tensor(dt, lengths=[1, 0, 3]) 1599 <tf.RaggedTensor [[5], [], [6, 0, 0]]> 1600 1601 >>> tf.RaggedTensor.from_tensor(dt, padding=0) 1602 <tf.RaggedTensor [[5, 7], [0, 3], [6]]> 1603 1604 >>> dt = tf.constant([[[5, 0], [7, 0], [0, 0]], 1605 ... [[0, 0], [3, 0], [0, 0]], 1606 ... [[6, 0], [0, 0], [0, 0]]]) 1607 >>> tf.RaggedTensor.from_tensor(dt, lengths=([2, 0, 3], [1, 1, 2, 0, 1])) 1608 <tf.RaggedTensor [[[5], [7]], [], [[6, 0], [], [0]]]> 1609 1610 Args: 1611 tensor: The `Tensor` to convert. Must have rank `ragged_rank + 1` or 1612 higher. 1613 lengths: An optional set of row lengths, specified using a 1-D integer 1614 `Tensor` whose length is equal to `tensor.shape[0]` (the number of rows 1615 in `tensor`). If specified, then `output[row]` will contain 1616 `tensor[row][:lengths[row]]`. Negative lengths are treated as zero. You 1617 may optionally pass a list or tuple of lengths to this argument, which 1618 will be used as nested row lengths to construct a ragged tensor with 1619 multiple ragged dimensions. 1620 padding: An optional padding value. If specified, then any row suffix 1621 consisting entirely of `padding` will be excluded from the returned 1622 RaggedTensor. `padding` is a `Tensor` with the same dtype as `tensor` 1623 and with `shape=tensor.shape[ragged_rank + 1:]`. 1624 ragged_rank: Integer specifying the ragged rank for the returned 1625 `RaggedTensor`. Must be greater than zero. 1626 name: A name prefix for the returned tensors (optional). 1627 row_splits_dtype: `dtype` for the returned `RaggedTensor`'s `row_splits` 1628 tensor. One of `tf.int32` or `tf.int64`. 1629 1630 Returns: 1631 A `RaggedTensor` with the specified `ragged_rank`. The shape of the 1632 returned ragged tensor is compatible with the shape of `tensor`. 1633 1634 Raises: 1635 ValueError: If both `lengths` and `padding` are specified. 1636 ValueError: If the rank of `tensor` is 0 or 1. 1637 """ 1638 row_splits_dtype = dtypes.as_dtype(row_splits_dtype) 1639 if lengths is not None and padding is not None: 1640 raise ValueError("Specify argument `lengths` or `padding`, but not both.") 1641 if not isinstance(ragged_rank, int): 1642 raise TypeError(f"Argument `ragged_rank` must be an int. " 1643 f"Received {ragged_rank}.") 1644 if ragged_rank <= 0: 1645 raise ValueError(f"Argument `ragged_rank` must be greater than 0. " 1646 f"Received {ragged_rank}.") 1647 1648 with ops.name_scope(name, "RaggedFromTensor", [tensor, lengths, padding]): 1649 tensor = ops.convert_to_tensor(tensor, name="tensor") 1650 if tensor.shape.rank is not None and tensor.shape.rank < 2: 1651 raise ValueError(f"The rank of a RaggedTensor must be greater than 1, " 1652 f"i.e., a list of scalars won't have ragged " 1653 f"dimensions. Received argument `tensor` with rank " 1654 f"{tensor.shape.rank}.") 1655 tensor.shape.with_rank_at_least(ragged_rank + 1) 1656 input_shape = array_ops.shape(tensor, out_type=row_splits_dtype) 1657 ncols = input_shape[1] 1658 1659 # Handle nested row lengths. 1660 if (lengths is not None and isinstance(lengths, (list, tuple)) and 1661 len(lengths) and not isinstance(lengths[0], (int, float))): 1662 if ragged_rank not in (1, len(lengths)): 1663 # Note: we accept `ragged_rank=1` here because it's the default value; 1664 # i.e., if the user passes in a tuple of lengths, but doesn't specify 1665 # ragged_rank, then we should use that tuple to determine ragged_rank. 1666 # We only want to complain if they pass in an explicit ragged_rank 1667 # that doesn't match len(lengths). 1668 raise ValueError(f"If Argument `lengths` is a tuple of row_lengths, " 1669 f"argument `ragged_rank` must be " 1670 f"len(lengths): {len(lengths)}. Received " 1671 f"ragged_rank: {ragged_rank}.") 1672 # Rather than reconstructing the tensor mask directly, we can 1673 # recreate it as a boolean RaggedTensor, then densify that and use 1674 # that as the mask to clear out the unused data in the passed tensor. 1675 tensor.shape.with_rank_at_least(len(lengths) + 1) 1676 num_tokens = math_ops.reduce_sum(lengths[-1]) 1677 ones_mask = array_ops.ones([num_tokens], dtype=dtypes.bool) 1678 ragged_mask = cls.from_nested_row_lengths( 1679 ones_mask, lengths, validate=False) 1680 dense_ragged_mask = ragged_mask.to_tensor(default_value=False) 1681 masked_data = array_ops.boolean_mask(tensor, dense_ragged_mask) 1682 return cls.from_nested_row_lengths(masked_data, lengths, validate=False) 1683 1684 # Handle ragged_rank>1 via recursion: 1685 # If the output should have multiple ragged dimensions, then first 1686 # flatten the tensor to eliminate all but the last ragged dimension, 1687 # and recursively convert that flattened tensor. Then add on the splits 1688 # for the dimensions that we flattened out. 1689 if ragged_rank > 1: 1690 if tensor.shape.is_fully_defined(): 1691 input_shape = tensor.shape.as_list() 1692 # The total number of elements in each dimension. E.g., if 1693 # input_shape=[3, 4, 5, 6], then dim[2] has 3*4*5 elements in total. 1694 dim_size = np.cumprod(input_shape) 1695 new_shape = [dim_size[ragged_rank - 1]] + input_shape[ragged_rank:] 1696 else: 1697 dim_size = math_ops.cumprod(input_shape) 1698 new_shape = array_ops.concat( 1699 [[dim_size[ragged_rank - 1]], input_shape[ragged_rank:]], axis=0) 1700 flattened = array_ops.reshape(tensor, new_shape) 1701 result = cls.from_tensor( 1702 flattened, lengths, padding, row_splits_dtype=row_splits_dtype) 1703 1704 for axis in range(ragged_rank - 1, 0, -1): 1705 dim_len = tensor_shape.dimension_at_index(tensor.shape, axis).value 1706 if dim_len is None: 1707 dim_len = input_shape[axis] 1708 else: 1709 dim_len = constant_op.constant(dim_len, row_splits_dtype) 1710 result = RaggedTensor.from_uniform_row_length( 1711 values=result, 1712 uniform_row_length=dim_len, 1713 nrows=dim_size[axis - 1], 1714 validate=False) 1715 return result 1716 1717 # If padding was specified, then use it to find row lengths. 1718 if padding is not None: 1719 padding = ops.convert_to_tensor( 1720 padding, name="padding", dtype=tensor.dtype) 1721 padding.shape.assert_is_compatible_with(tensor.shape[2:]) 1722 1723 # Find places where the padding is equal to the tensor. (This will 1724 # broadcast `padding` across the outermost 2 dimensions of `tensor`, 1725 # so `has_default_value.shape = tensor.shape`.) 1726 has_default_value = math_ops.equal(padding, tensor) 1727 1728 # If the padding isn't a scalar, then require that all values in the 1729 # padding match each item in the tensor. After this block of code, 1730 # `has_default.shape = tensor.shape[:2]`. (Unfortunately, we can't just 1731 # use reduce_all for both cases, becaue when you pass an empty `axis` 1732 # list to reduce_all, it reduces all axes; but we want it to reduce no 1733 # axes -- i.e., to be a no-op.) 1734 tensor_rank = array_ops.rank(tensor) 1735 reduce_axis = math_ops.range(2, tensor_rank) 1736 has_default = control_flow_ops.cond( 1737 tensor_rank > 2, 1738 lambda: math_ops.reduce_all(has_default_value, axis=reduce_axis), 1739 lambda: has_default_value) 1740 has_default.set_shape(tensor_shape.TensorShape([None, None])) 1741 has_default.set_shape(tensor.shape[:2]) 1742 1743 # Use has_default to find the length of each row: for each 1744 # non-default item in a row, calculate the length that the row needs to 1745 # have to include that item; and then take the max of those values 1746 # (across each row). 1747 has_nondefault = math_ops.logical_not(has_default) 1748 has_nondefault = math_ops.cast(has_nondefault, row_splits_dtype) 1749 length_for_nondefault_value = ( 1750 has_nondefault * 1751 array_ops.expand_dims(math_ops.range(1, ncols + 1), 0)) 1752 lengths = math_ops.reduce_max(length_for_nondefault_value, axis=1) 1753 1754 if lengths is not None: 1755 # If we have lengths (either directly supplied, or computed from 1756 # paddings), then use those to construct splits; and then use masking 1757 # to get the corresponding values. 1758 lengths = ragged_util.convert_to_int_tensor(lengths, "lengths", 1759 row_splits_dtype) 1760 lengths.shape.assert_has_rank(1) 1761 lengths = math_ops.minimum(lengths, ncols) 1762 lengths = math_ops.maximum(lengths, 0) 1763 limits = math_ops.cumsum(lengths) 1764 splits = array_ops.concat( 1765 [array_ops.zeros([1], row_splits_dtype), limits], axis=0) 1766 mask = array_ops.sequence_mask(lengths, maxlen=ncols) 1767 values = array_ops.boolean_mask(tensor, mask) 1768 return cls.from_row_splits(values, splits, validate=False) 1769 1770 # If neither padding nor lengths were specified, then create a splits 1771 # vector that contains no default values, and reshape the input tensor 1772 # to form the values for the RaggedTensor. 1773 values_shape = array_ops.concat( 1774 [[input_shape[0] * input_shape[1]], input_shape[2:]], axis=0) 1775 values = array_ops.reshape(tensor, values_shape) 1776 const_nrows = tensor_shape.dimension_at_index(tensor.shape, 0).value 1777 const_ncols = tensor_shape.dimension_at_index(tensor.shape, 1).value 1778 if const_nrows is not None: 1779 nrows = constant_op.constant(const_nrows, row_splits_dtype) 1780 else: 1781 nrows = input_shape[0] 1782 if const_ncols is not None: 1783 ncols = constant_op.constant(const_ncols, row_splits_dtype) 1784 else: 1785 ncols = input_shape[1] 1786 return RaggedTensor.from_uniform_row_length( 1787 values=values, uniform_row_length=ncols, nrows=nrows, validate=False) 1788 1789 def to_tensor(self, default_value=None, name=None, shape=None): 1790 """Converts this `RaggedTensor` into a `tf.Tensor`. 1791 1792 If `shape` is specified, then the result is padded and/or truncated to 1793 the specified shape. 1794 1795 Examples: 1796 1797 >>> rt = tf.ragged.constant([[9, 8, 7], [], [6, 5], [4]]) 1798 >>> print(rt.to_tensor()) 1799 tf.Tensor( 1800 [[9 8 7] [0 0 0] [6 5 0] [4 0 0]], shape=(4, 3), dtype=int32) 1801 >>> print(rt.to_tensor(shape=[5, 2])) 1802 tf.Tensor( 1803 [[9 8] [0 0] [6 5] [4 0] [0 0]], shape=(5, 2), dtype=int32) 1804 1805 Args: 1806 default_value: Value to set for indices not specified in `self`. Defaults 1807 to zero. `default_value` must be broadcastable to 1808 `self.shape[self.ragged_rank + 1:]`. 1809 name: A name prefix for the returned tensors (optional). 1810 shape: The shape of the resulting dense tensor. In particular, 1811 `result.shape[i]` is `shape[i]` (if `shape[i]` is not None), or 1812 `self.bounding_shape(i)` (otherwise).`shape.rank` must be `None` or 1813 equal to `self.rank`. 1814 1815 Returns: 1816 A `Tensor` with shape `ragged.bounding_shape(self)` and the 1817 values specified by the non-empty values in `self`. Empty values are 1818 assigned `default_value`. 1819 """ 1820 with ops.name_scope(name, "RaggedToTensor", [self, default_value, shape]): 1821 if default_value is not None: 1822 default_value = ops.convert_to_tensor( 1823 default_value, name="default_value", dtype=self.dtype) 1824 type_tensor_pairs = _get_row_partition_type_tensor_pairs(self) 1825 row_partition_types = [x[0] for x in type_tensor_pairs] 1826 row_partition_tensors = [x[1] for x in type_tensor_pairs] 1827 if default_value is None: 1828 default_value = array_ops.zeros((), self.dtype) 1829 1830 if (isinstance(shape, (list, tuple)) and 1831 any(isinstance(v, ops.Tensor) for v in shape) and 1832 all(isinstance(v, (int, ops.Tensor)) for v in shape)): 1833 shape = array_ops.stack(shape) 1834 1835 shape_tensor = _shape_as_tensor(shape, row_partition_tensors[0].dtype) 1836 tensor = gen_ragged_conversion_ops.ragged_tensor_to_tensor( 1837 shape=shape_tensor, 1838 values=self.flat_values, 1839 default_value=default_value, 1840 row_partition_types=row_partition_types, 1841 row_partition_tensors=row_partition_tensors) 1842 1843 ragged_shape = self.shape 1844 1845 if ragged_shape.rank is not None and not isinstance(shape, ops.Tensor): 1846 # Merged self.shape and shape, favoring the second one as it takes 1847 # into account potential padding added to the output. 1848 shape = tensor_shape.as_shape(shape) 1849 if shape.rank is None: 1850 output_shape = ragged_shape 1851 else: 1852 # At this point we can assume that hshape.rank == ragged_shape.rank 1853 # because otherwise it would have failed earlier. 1854 output_shape = [ 1855 s1 if s1 is not None else s2 1856 for (s1, s2) in zip(shape.as_list(), ragged_shape.as_list()) 1857 ] 1858 tensor.set_shape(output_shape) 1859 1860 return tensor 1861 1862 @classmethod 1863 @dispatch.add_dispatch_support 1864 def from_sparse(cls, st_input, name=None, row_splits_dtype=dtypes.int64): 1865 """Converts a 2D `tf.sparse.SparseTensor` to a `RaggedTensor`. 1866 1867 Each row of the `output` `RaggedTensor` will contain the explicit values 1868 from the same row in `st_input`. `st_input` must be ragged-right. If not 1869 it is not ragged-right, then an error will be generated. 1870 1871 Example: 1872 1873 >>> indices = [[0, 0], [0, 1], [0, 2], [1, 0], [3, 0]] 1874 >>> st = tf.sparse.SparseTensor(indices=indices, 1875 ... values=[1, 2, 3, 4, 5], 1876 ... dense_shape=[4, 3]) 1877 >>> tf.RaggedTensor.from_sparse(st).to_list() 1878 [[1, 2, 3], [4], [], [5]] 1879 1880 Currently, only two-dimensional `SparseTensors` are supported. 1881 1882 Args: 1883 st_input: The sparse tensor to convert. Must have rank 2. 1884 name: A name prefix for the returned tensors (optional). 1885 row_splits_dtype: `dtype` for the returned `RaggedTensor`'s `row_splits` 1886 tensor. One of `tf.int32` or `tf.int64`. 1887 1888 Returns: 1889 A `RaggedTensor` with the same values as `st_input`. 1890 `output.ragged_rank = rank(st_input) - 1`. 1891 `output.shape = [st_input.dense_shape[0], None]`. 1892 Raises: 1893 ValueError: If the number of dimensions in `st_input` is not known 1894 statically, or is not two. 1895 """ 1896 row_splits_dtype = dtypes.as_dtype(row_splits_dtype) 1897 if not sparse_tensor.is_sparse(st_input): 1898 raise TypeError(f"Argument `st_input` must be of type SparseTensor, but " 1899 f"is of type {type(st_input).__name__}.") 1900 with ops.name_scope(name, "RaggedFromSparse", [st_input]): 1901 st_input = sparse_tensor.convert_to_tensor_or_sparse_tensor( 1902 st_input, name="st_input") 1903 1904 if st_input.dense_shape.shape.ndims is None: 1905 static_rank_from_dense_shape = None 1906 else: 1907 static_rank_from_dense_shape = st_input.dense_shape.shape.dims[0].value 1908 1909 if st_input.indices.shape.ndims is None: 1910 static_rank_from_indices = None 1911 else: 1912 static_rank_from_indices = st_input.indices.shape.dims[1].value 1913 1914 if static_rank_from_dense_shape != 2 and static_rank_from_indices != 2: 1915 raise ValueError("rank(st_input) must be 2.") 1916 1917 with ops.control_dependencies( 1918 _assert_sparse_indices_are_ragged_right(st_input.indices)): 1919 # Treat sparse row indices as segment ids to generate a splits tensor 1920 # thta we can pair with the sparse tensor values. (Ignore sparse column 1921 # indices.) 1922 segment_ids = math_ops.cast(st_input.indices[:, 0], row_splits_dtype) 1923 num_segments = math_ops.cast(st_input.dense_shape[0], row_splits_dtype) 1924 return cls.from_value_rowids( 1925 st_input.values, segment_ids, num_segments, validate=False) 1926 1927 def to_sparse(self, name=None): 1928 """Converts this `RaggedTensor` into a `tf.sparse.SparseTensor`. 1929 1930 Example: 1931 1932 >>> rt = tf.ragged.constant([[1, 2, 3], [4], [], [5, 6]]) 1933 >>> print(rt.to_sparse()) 1934 SparseTensor(indices=tf.Tensor( 1935 [[0 0] [0 1] [0 2] [1 0] [3 0] [3 1]], 1936 shape=(6, 2), dtype=int64), 1937 values=tf.Tensor([1 2 3 4 5 6], shape=(6,), dtype=int32), 1938 dense_shape=tf.Tensor([4 3], shape=(2,), dtype=int64)) 1939 1940 Args: 1941 name: A name prefix for the returned tensors (optional). 1942 1943 Returns: 1944 A SparseTensor with the same values as `self`. 1945 """ 1946 with ops.name_scope(name, "RaggedToSparse", [self]): 1947 result = gen_ragged_conversion_ops.ragged_tensor_to_sparse( 1948 self.nested_row_splits, self.flat_values, name=name) 1949 return sparse_tensor.SparseTensor(result.sparse_indices, 1950 result.sparse_values, 1951 result.sparse_dense_shape) 1952 1953 @classmethod 1954 def _from_variant(cls, 1955 variant, 1956 dtype, 1957 output_ragged_rank, 1958 input_ragged_rank=None, 1959 row_splits_dtype=dtypes.int64, 1960 name=None): 1961 """Converts a `variant` Tensor into a `RaggedTensor`. 1962 1963 The input `variant` could be a scalar, meaning it encodes a single 1964 `RaggedTensor` with ragged_rank `output_ragged_rank`. Alternatively it could 1965 have an arbitrary rank, in which case each element is decoded into a 1966 `RaggedTensor` with ragged_rank `input_ragged_rank` and these are then 1967 stacked according to the input shape to output a single `RaggedTensor` 1968 with ragged_rank `output_ragged_rank`. If `input_ragged_rank` is not 1969 provided, it is inferred dynamically as `output_ragged_rank` - 1970 `rank(variant)`. If `input_ragged_rank` is provided, the following must be 1971 true: `output_ragged_rank` = `input_ragged_rank` + `rank(variant)`. 1972 1973 Example: 1974 1975 >>> rt = tf.ragged.constant([[0], [1, 2]]) 1976 >>> et = rt._to_variant() 1977 >>> stacked_et = tf.stack([et, et]) 1978 >>> tf.RaggedTensor._from_variant( # scalar input. 1979 ... et, dtype=tf.int32, output_ragged_rank=1).to_list() 1980 [[0], [1, 2]] 1981 >>> tf.RaggedTensor._from_variant( # batched input. 1982 ... stacked_et, dtype=tf.int32, output_ragged_rank=2).to_list() 1983 [[[0], [1, 2]], [[0], [1, 2]]] 1984 1985 Args: 1986 variant: A `variant` Tensor representing an encoded (possibly 1987 nested-batched) `RaggedTensor`. 1988 dtype: The dtype of the encoded `RaggedTensor`. 1989 output_ragged_rank: The expected ragged rank of the output `RaggedTensor`. 1990 input_ragged_rank: The ragged rank of each encoded `RaggedTensor`. This is 1991 optional and inferred dynamically if not provided. 1992 row_splits_dtype: `dtype` for the RaggedTensor's `row_splits` tensor. One 1993 of `tf.int32` or `tf.int64`. 1994 name: A name prefix for the returned tensors (optional). 1995 1996 Returns: 1997 A `RaggedTensor` of dtype `dtype` and ragged rank `output_ragged_rank`. 1998 1999 Raises: 2000 ValueError: If the input rank is known, `input_ragged_rank` is provided 2001 and `output_ragged_rank` = `input_ragged_rank` + `rank(variant)` does 2002 not hold. 2003 """ 2004 variant = ops.convert_to_tensor( 2005 variant, name="variant", dtype=dtypes.variant) 2006 if (variant.shape.ndims is not None and input_ragged_rank is not None and 2007 output_ragged_rank != input_ragged_rank + variant.shape.ndims): 2008 raise ValueError( 2009 f"Argument `output_ragged_rank` ({output_ragged_rank}) must be equal " 2010 f"to `input_ragged_rank` + `variant.shape.ndims` " 2011 f"({input_ragged_rank} + {variant.shape.ndims}).") 2012 input_ragged_rank = -1 if input_ragged_rank is None else input_ragged_rank 2013 with ops.name_scope( 2014 name, "RaggedFromVariant", 2015 [variant, dtype, input_ragged_rank, output_ragged_rank]): 2016 result = gen_ragged_conversion_ops.ragged_tensor_from_variant( 2017 variant, input_ragged_rank, max(output_ragged_rank, 0), dtype, 2018 row_splits_dtype, name) 2019 return cls.from_nested_row_splits( 2020 result.output_dense_values, 2021 result.output_nested_splits, 2022 validate=False) 2023 2024 def _to_variant(self, batched_input=False, name=None): 2025 """Converts this `RaggedTensor` into a `variant` Tensor. 2026 2027 If `batched_input` is `True`, then the `RaggedTensor` is unbatched along the 2028 zero-th dimension, each component `RaggedTensor` is encoded into a scalar 2029 `variant` Tensor, and these are stacked to return a 1-D `variant` Tensor. 2030 If `batched_input` is `False`, then the `RaggedTensor` is encoded as is and 2031 a scalar `variant` Tensor is returned. 2032 2033 Example: 2034 >>> rt = tf.ragged.constant([[[0]], [[1]], [[2]]]) 2035 >>> rt._to_variant().shape.as_list() 2036 [] 2037 >>> rt._to_variant(batched_input=True).shape.as_list() 2038 [3] 2039 2040 Args: 2041 batched_input: If `True`, the `RaggedTensor` is unbatched and converted to 2042 a `variant` vector. Set to `False` by default. 2043 name: A name prefix for the returned tensors (optional). 2044 2045 Returns: 2046 A `variant` Tensor that encodes this `RaggedTensor`. 2047 """ 2048 with ops.name_scope(name, "RaggedToVariant", [self, batched_input]): 2049 return gen_ragged_conversion_ops.ragged_tensor_to_variant( 2050 self.nested_row_splits, self.flat_values, batched_input, name) 2051 2052 #============================================================================= 2053 # String Encoding 2054 #============================================================================= 2055 def __repr__(self): 2056 if self._is_eager(): 2057 # The np.array2string in _formatter provides a separator argument, but 2058 # doesn't handle recursive calls correctly. The np.printoptions handles 2059 # recursive calls correctly, but doesn't provide a separator argument. 2060 # Combines them together to print elements separated by comma, while 2061 # avoiding the redundant array prefixes and dtypes. For example, 2062 # the value of tf.ragged.constant([[1, 2], [3, 4]]) will look like 2063 # 2064 # [[1, 2], 2065 # [3, 4]] 2066 with np.printoptions(formatter={"all": _formatter}): 2067 value_text = _formatter(self.numpy()) 2068 return f"<tf.RaggedTensor {value_text}>" 2069 else: 2070 return "tf.RaggedTensor(values=%s, row_splits=%s)" % (self.values, 2071 self.row_splits) 2072 2073 #============================================================================= 2074 # Eager Execution Mode 2075 #============================================================================= 2076 2077 def numpy(self): 2078 """Returns a numpy `array` with the values for this `RaggedTensor`. 2079 2080 Requires that this `RaggedTensor` was constructed in eager execution mode. 2081 2082 Ragged dimensions are encoded using numpy `arrays` with `dtype=object` and 2083 `rank=1`, where each element is a single row. 2084 2085 #### Examples 2086 2087 In the following example, the value returned by `RaggedTensor.numpy()` 2088 contains three numpy `array` objects: one for each row (with `rank=1` and 2089 `dtype=int64`), and one to combine them (with `rank=1` and `dtype=object`): 2090 2091 >>> tf.ragged.constant([[1, 2, 3], [4, 5]], dtype=tf.int64).numpy() 2092 array([array([1, 2, 3]), array([4, 5])], dtype=object) 2093 2094 Uniform dimensions are encoded using multidimensional numpy `array`s. In 2095 the following example, the value returned by `RaggedTensor.numpy()` contains 2096 a single numpy `array` object, with `rank=2` and `dtype=int64`: 2097 2098 >>> tf.ragged.constant([[1, 2, 3], [4, 5, 6]], dtype=tf.int64).numpy() 2099 array([[1, 2, 3], [4, 5, 6]]) 2100 2101 Returns: 2102 A numpy `array`. 2103 """ 2104 if not self._is_eager(): 2105 raise ValueError("RaggedTensor.numpy() is only supported in eager mode.") 2106 values = self.values.numpy() 2107 splits = self.row_splits.numpy() 2108 rows = [values[splits[i]:splits[i + 1]] for i in range(len(splits) - 1)] 2109 if not rows: 2110 return np.zeros((0, 0) + values.shape[1:], dtype=values.dtype) 2111 # Note: if `rows` have ragged lengths, then they will be stored in a 2112 # np.ndarray with dtype=object and rank=1. If they have uniform lengths, 2113 # they will be combined into a single np.ndarray with dtype=row.dtype and 2114 # rank=row.rank+1. 2115 # 2116 # Manually set dtype as numpy now complains when given ragged rows. 2117 has_variable_length_rows = any(len(row) != len(rows[0]) for row in rows) 2118 dtype = np.object_ if has_variable_length_rows else None 2119 return np.array(rows, dtype=dtype) 2120 2121 def to_list(self): 2122 """Returns a nested Python `list` with the values for this `RaggedTensor`. 2123 2124 Requires that `rt` was constructed in eager execution mode. 2125 2126 Returns: 2127 A nested Python `list`. 2128 """ 2129 if not isinstance(self.row_splits, ops.EagerTensor): 2130 raise ValueError("to_list can only be used in eager mode.") 2131 row_splits = self.row_splits.numpy().tolist() 2132 values = self.values 2133 2134 if isinstance(values, RaggedTensor): 2135 return [ 2136 values[row_splits[i]:row_splits[i + 1]].to_list() 2137 for i in range(len(row_splits) - 1) 2138 ] 2139 else: 2140 # Convert values to a Python list. 2141 if hasattr(values, "numpy"): 2142 values_as_list = values.numpy().tolist() 2143 elif hasattr(values, "to_list"): 2144 values_as_list = values.to_list() 2145 else: 2146 raise ValueError("values must be convertible to a list") 2147 2148 return [ 2149 values_as_list[row_splits[i]:row_splits[i + 1]] 2150 for i in range(len(row_splits) - 1) 2151 ] 2152 2153 def _eager_value(self): 2154 """Returns a RaggedTensorValue for self. Requires self._is_eager()=true.""" 2155 value = self.flat_values.numpy() 2156 for row_splits in reversed(self.nested_row_splits): 2157 value = ragged_tensor_value.RaggedTensorValue(value, row_splits.numpy()) 2158 return value 2159 2160 def _is_eager(self): 2161 """Returns True if values & row_splits Tensors are all `EagerTensor`s.""" 2162 rt = self 2163 while isinstance(rt, RaggedTensor): 2164 if not isinstance(rt.row_splits, ops.EagerTensor): 2165 return False 2166 rt = rt.values 2167 return isinstance(rt, ops.EagerTensor) 2168 2169 #============================================================================= 2170 # Operators 2171 #============================================================================= 2172 # To avoid circular dependencies, we define stub methods for operators here, 2173 # and then override them when the ragged_operators module is imported. 2174 2175 def _overloaded_operator(name): # pylint: disable=no-self-argument 2176 2177 def stub(*args, **kwargs): 2178 del args, kwargs 2179 raise ValueError( 2180 f"You must import 'tensorflow.python.ops.ragged.ragged_ops' " 2181 f"before using RaggedTensor.{name}.") 2182 2183 return stub 2184 2185 __getitem__ = _overloaded_operator("__getitem__") 2186 __ge__ = _overloaded_operator("__ge__") 2187 __gt__ = _overloaded_operator("__gt__") 2188 __le__ = _overloaded_operator("__le__") 2189 __lt__ = _overloaded_operator("__lt__") 2190 __and__ = _overloaded_operator("__and__") 2191 __rand__ = _overloaded_operator("__rand__") 2192 __invert__ = _overloaded_operator("__invert__") 2193 __ror__ = _overloaded_operator("__ror__") 2194 __or__ = _overloaded_operator("__or__") 2195 __xor__ = _overloaded_operator("__xor__") 2196 __rxor__ = _overloaded_operator("__rxor__") 2197 __abs__ = _overloaded_operator("__abs__") 2198 __add__ = _overloaded_operator("__add__") 2199 __radd__ = _overloaded_operator("__radd__") 2200 __div__ = _overloaded_operator("__div__") 2201 __rdiv__ = _overloaded_operator("__rdiv__") 2202 __floordiv__ = _overloaded_operator("__floordiv__") 2203 __rfloordiv__ = _overloaded_operator("__rfloordiv__") 2204 __mod__ = _overloaded_operator("__mod__") 2205 __rmod__ = _overloaded_operator("__rmod__") 2206 __mul__ = _overloaded_operator("__mul__") 2207 __rmul__ = _overloaded_operator("__rmul__") 2208 __neg__ = _overloaded_operator("__neg__") 2209 __pow__ = _overloaded_operator("__pow__") 2210 __rpow__ = _overloaded_operator("__rpow__") 2211 __sub__ = _overloaded_operator("__sub__") 2212 __rsub__ = _overloaded_operator("__rsub__") 2213 __truediv__ = _overloaded_operator("__truediv__") 2214 __rtruediv__ = _overloaded_operator("__rtruediv__") 2215 del _overloaded_operator 2216 2217 #============================================================================= 2218 # Name Scope 2219 #============================================================================= 2220 2221 # This private function is used by ops.name_scope to ensure that all of the 2222 # input tensors for the scope belong to the same graph. Defining this means 2223 # that you may include `RaggedTensor` objects in the name_scope `values` 2224 # list. 2225 def _as_graph_element(self): 2226 """Convert `self` to a graph element.""" 2227 values = self.values 2228 while isinstance(values, RaggedTensor): 2229 values = values.values 2230 return values 2231 2232 #============================================================================= 2233 # Composite Tensor 2234 #============================================================================= 2235 2236 @property 2237 def _type_spec(self): 2238 return RaggedTensorSpec.from_value(self) 2239 2240 def _shape_invariant_to_type_spec(self, shape): 2241 return RaggedTensorSpec(shape, self.dtype, self.ragged_rank, 2242 self.row_splits.dtype) 2243 2244 def consumers(self): 2245 return self._consumers() 2246 2247 __composite_gradient__ = ( 2248 composite_tensor_gradient.WithValuesCompositeTensorGradient()) 2249 2250 2251def is_ragged(value): 2252 """Returns true if `value` is a ragged tensor or ragged tensor value.""" 2253 return isinstance(value, 2254 (RaggedTensor, ragged_tensor_value.RaggedTensorValue)) 2255 2256 2257def match_row_splits_dtypes(*tensors, **kwargs): 2258 """Return a copy of `tensors` with row_splits all having the same dtype. 2259 2260 Args: 2261 *tensors: A list of Tensors or RaggedTensors. 2262 **kwargs: If 'return_dtype=True', then return a tuple (dtype, tensors), 2263 where `dtype` is the data type used by row-splits, and `tensors` is the 2264 converted list of `Tensors` and `RaggedTensors`. 2265 2266 Returns: 2267 The converted list of `Tensors` and `RaggedTensors`. 2268 """ 2269 return_dtype = kwargs.pop("return_dtype", False) 2270 if kwargs: 2271 raise ValueError(f"Unexpected keyword args {kwargs}.") 2272 2273 has_int32 = False 2274 has_int64 = False 2275 for tensor in tensors: 2276 if isinstance(tensor, RaggedTensor): 2277 if tensor.row_splits.dtype == dtypes.int32: 2278 has_int32 = True 2279 else: 2280 has_int64 = True 2281 2282 if has_int32 and has_int64: 2283 if not ragged_config.auto_cast_partition_dtype(): 2284 raise ValueError("Input RaggedTensors have mismatched row_splits dtypes; " 2285 "use RaggedTensor.with_row_splits_dtype() to convert " 2286 "them to compatible dtypes.") 2287 dtype = dtypes.int64 2288 tensors = tuple( 2289 t.with_row_splits_dtype(dtypes.int64) if isinstance(t, RaggedTensor 2290 ) else t 2291 for t in tensors) 2292 2293 elif has_int32: 2294 dtype = dtypes.int32 2295 else: 2296 dtype = dtypes.int64 2297 2298 if return_dtype: 2299 return (dtype, tensors) 2300 else: 2301 return tensors 2302 2303 2304#=============================================================================== 2305# RaggedTensorSpec 2306#=============================================================================== 2307@tf_export("RaggedTensorSpec") 2308@type_spec.register("tf.RaggedTensorSpec") 2309class RaggedTensorSpec(type_spec.BatchableTypeSpec): 2310 """Type specification for a `tf.RaggedTensor`.""" 2311 2312 __slots__ = [ 2313 "_shape", "_dtype", "_ragged_rank", "_row_splits_dtype", 2314 "_flat_values_spec" 2315 ] 2316 2317 @property 2318 def dtype(self): 2319 """The `tf.dtypes.DType` specified by this type for the RaggedTensor. 2320 2321 Examples: 2322 2323 >>> rt = tf.ragged.constant([["a"], ["b", "c"]], dtype=tf.string) 2324 >>> tf.type_spec_from_value(rt).dtype 2325 tf.string 2326 2327 Returns: 2328 A `tf.dtypes.DType` of the values in the RaggedTensor. 2329 """ 2330 return self._dtype 2331 2332 @property 2333 def shape(self): 2334 """The statically known shape of the RaggedTensor. 2335 2336 Examples: 2337 2338 >>> rt = tf.ragged.constant([[0], [1, 2]]) 2339 >>> tf.type_spec_from_value(rt).shape 2340 TensorShape([2, None]) 2341 2342 >>> rt = tf.ragged.constant([[[0, 1]], [[1, 2], [3, 4]]], ragged_rank=1) 2343 >>> tf.type_spec_from_value(rt).shape 2344 TensorShape([2, None, 2]) 2345 2346 Returns: 2347 A `tf.TensorShape` containing the statically known shape of the 2348 RaggedTensor. Ragged dimensions have a size of `None`. 2349 """ 2350 return self._shape 2351 2352 @property 2353 def ragged_rank(self): 2354 """The number of times the RaggedTensor's flat_values is partitioned. 2355 2356 Defaults to `shape.ndims - 1`. 2357 2358 Examples: 2359 2360 >>> values = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]]) 2361 >>> tf.type_spec_from_value(values).ragged_rank 2362 1 2363 2364 >>> rt1 = tf.RaggedTensor.from_uniform_row_length(values, 2) 2365 >>> tf.type_spec_from_value(rt1).ragged_rank 2366 2 2367 2368 Returns: 2369 A Python `int` indicating the number of times the underlying `flat_values` 2370 Tensor has been partitioned to add a new dimension. 2371 I.e., `tf.rank(rt) = tf.rank(rt.flat_values) + rt.ragged_rank`. 2372 """ 2373 return self._ragged_rank 2374 2375 @property 2376 def row_splits_dtype(self): 2377 """The `tf.dtypes.DType` of the RaggedTensor's `row_splits`. 2378 2379 Examples: 2380 2381 >>> rt = tf.ragged.constant([[1, 2, 3], [4]], row_splits_dtype=tf.int64) 2382 >>> tf.type_spec_from_value(rt).row_splits_dtype 2383 tf.int64 2384 2385 Returns: 2386 A `tf.dtypes.DType` for the RaggedTensor's `row_splits` tensor. One 2387 of `tf.int32` or `tf.int64`. 2388 """ 2389 return self._row_splits_dtype 2390 2391 @property 2392 def flat_values_spec(self): 2393 """The `TypeSpec` of the flat_values of RaggedTensor. 2394 2395 Returns: 2396 - The TypeSpec of flat_values. 2397 - None when the flat_values is a Tensor. 2398 """ 2399 return self._flat_values_spec 2400 2401 @property 2402 def value_type(self): 2403 return RaggedTensor if self._ragged_rank > 0 else ops.Tensor 2404 2405 def __init__(self, 2406 shape=None, 2407 dtype=dtypes.float32, 2408 ragged_rank=None, 2409 row_splits_dtype=dtypes.int64, 2410 flat_values_spec=None): 2411 """Constructs a type specification for a `tf.RaggedTensor`. 2412 2413 Args: 2414 shape: The shape of the RaggedTensor, or `None` to allow any shape. If a 2415 shape is specified, then all ragged dimensions must have size `None`. 2416 dtype: `tf.DType` of values in the RaggedTensor. 2417 ragged_rank: Python integer, the number of times the RaggedTensor's 2418 flat_values is partitioned. Defaults to `shape.ndims - 1`. 2419 row_splits_dtype: `dtype` for the RaggedTensor's `row_splits` tensor. One 2420 of `tf.int32` or `tf.int64`. 2421 flat_values_spec: TypeSpec for flat_value of the RaggedTensor. It shall be 2422 provided when the flat_values is a CompositeTensor rather then Tensor. 2423 If both `dtype` and `flat_values_spec` and are provided, `dtype` must 2424 be the same as `flat_values_spec.dtype`. (experimental) 2425 """ 2426 self._shape = tensor_shape.as_shape(shape) 2427 self._row_splits_dtype = dtypes.as_dtype(row_splits_dtype) 2428 if flat_values_spec is not None: 2429 if dtype is None: 2430 dtype = flat_values_spec.dtype 2431 elif dtype != flat_values_spec.dtype: 2432 raise ValueError("dtype must be the same as flat_values_spec.dtype") 2433 elif dtype is None: 2434 raise ValueError( 2435 "At least one of dtype or flat_values_spec must be provided") 2436 self._dtype = dtypes.as_dtype(dtype) 2437 self._flat_values_spec = flat_values_spec 2438 2439 rank = self._shape.ndims 2440 if ragged_rank is None: 2441 if rank is None: 2442 raise ValueError("Must specify ragged_rank or " 2443 "a shape with a known rank.") 2444 ragged_rank = rank - 1 2445 self._ragged_rank = ragged_rank 2446 if not isinstance(self._ragged_rank, int): 2447 raise TypeError(f"Argument `ragged_rank` must be an int. " 2448 f"Received {ragged_rank}.") 2449 2450 if rank is not None: 2451 if ragged_rank >= rank: 2452 raise ValueError(f"Argument `ragged_rank` ({ragged_rank}) must be less " 2453 f"than rank ({rank}).") 2454 2455 def is_compatible_with(self, spec_or_value): 2456 # RaggedTensor with ragged_rank 0 can be compatible with raw flat_values. 2457 if self._ragged_rank == 0: 2458 if self._flat_values_spec is None: 2459 if isinstance(spec_or_value, (ops.Tensor, tensor_spec.TensorSpec)): 2460 return tensor_spec.TensorSpec( 2461 self._shape, self._dtype).is_compatible_with(spec_or_value) 2462 elif not isinstance(spec_or_value, (RaggedTensor, RaggedTensorSpec)): 2463 return self._flat_values_spec.is_compatible_with(spec_or_value) 2464 return super(RaggedTensorSpec, self).is_compatible_with(spec_or_value) 2465 2466 def _serialize(self): 2467 if self._flat_values_spec is None: 2468 return (self._shape, self._dtype, self._ragged_rank, 2469 self._row_splits_dtype) 2470 else: 2471 return (self._shape, self._dtype, self._ragged_rank, 2472 self._row_splits_dtype, self._flat_values_spec) 2473 2474 @property 2475 def _component_specs(self): 2476 if self._ragged_rank <= 0: 2477 if self._flat_values_spec is not None: 2478 return [self._flat_values_spec] 2479 else: 2480 return [tensor_spec.TensorSpec(self._shape, self._dtype)] 2481 2482 flat_values_spec = self._flat_values_spec 2483 if flat_values_spec is None: 2484 flat_values_shape = tensor_shape.TensorShape([None]).concatenate( 2485 self._shape[self._ragged_rank + 1:]) 2486 flat_values_spec = tensor_spec.TensorSpec(flat_values_shape, self._dtype) 2487 outer_dim = tensor_shape.dimension_at_index(self._shape, 0) 2488 outer_splits_shape = [None if outer_dim is None else outer_dim + 1] 2489 inner_splits_spec = tensor_spec.TensorSpec([None], self._row_splits_dtype) 2490 2491 specs = ([ 2492 flat_values_spec, 2493 tensor_spec.TensorSpec(outer_splits_shape, self._row_splits_dtype) 2494 ] + [inner_splits_spec for _ in range(self._ragged_rank - 1)]) 2495 return specs 2496 2497 def _to_components(self, value): 2498 if is_ragged(value): 2499 return [value.flat_values] + list(value.nested_row_splits) 2500 else: 2501 return [value] 2502 2503 def _from_components(self, tensor_list): 2504 result = tensor_list[0] 2505 if (all(isinstance(t, np.ndarray) for t in tensor_list) and 2506 not tf2.enabled()): 2507 for row_splits in reversed(tensor_list[1:]): 2508 result = ragged_tensor_value.RaggedTensorValue(result, row_splits) 2509 else: 2510 if isinstance(tensor_list[0], np.ndarray): 2511 tensor_list = [ops.convert_to_tensor(t) for t in tensor_list] 2512 result = tensor_list[0] 2513 for row_splits in reversed(tensor_list[1:]): 2514 result = RaggedTensor( 2515 result, 2516 RowPartition.from_row_splits(row_splits, validate=False), 2517 internal=True) 2518 if self._shape.ndims is not None: 2519 if isinstance(result, RaggedTensor): 2520 result._set_shape(self._shape) # pylint: disable=protected-access 2521 # TODO(xjun): MaskedTensor doesn't implement set_shape. 2522 if self.flat_values_spec is not None and hasattr(result.flat_values, 2523 "set_shape"): 2524 result.flat_values.set_shape(self.flat_values_spec.shape) 2525 elif isinstance(result, ops.Tensor): 2526 result.set_shape(self._shape) 2527 return result 2528 2529 # The RaggedTensorSpec tensor_list encoding uses to/from_variant ops 2530 # to (un)box the component tensors in a way that allows for batching & 2531 # unbatching. 2532 @property 2533 def _flat_tensor_specs(self): 2534 # NOTE(mishragaurav): The default flat shape of a boxed `RaggedTensor` is 2535 # `[]` (scalar), but a `RaggedTensorSpec` can also represent a batch of 2536 # boxed `RaggedTensor` objects with shape `(...)` (and batches of batches, 2537 # etc.), so the flat shape must be unknown. 2538 return [tensor_spec.TensorSpec(None, dtypes.variant)] 2539 2540 def _to_tensor_list(self, value): 2541 # TODO(edloper): Update gen_ragged_conversion_ops that convert to and 2542 # from variant to include all of the row-partitioning tensors. 2543 if self._flat_values_spec is not None: 2544 raise ValueError("Customized value_type is not supported.") 2545 if isinstance(value, RaggedTensor): 2546 if value.ragged_rank != self._ragged_rank: 2547 raise ValueError( 2548 f"Ragged rank of value {value.ragged_rank} does not match " 2549 f"ragged rank of type {self._ragged_rank}.") 2550 # pylint: disable=protected-access 2551 return [value._to_variant(batched_input=False)] 2552 else: 2553 if self._ragged_rank > 0: 2554 raise ValueError( 2555 f"Expected a RaggedTensor if ragged rank={self._ragged_rank}" 2556 f" but got {type(value).__name__}." 2557 ) 2558 return [ 2559 gen_ragged_conversion_ops.ragged_tensor_to_variant( 2560 (), value, batched_input=False) 2561 ] 2562 2563 def _to_batched_tensor_list(self, value): 2564 if self._flat_values_spec is not None: 2565 raise ValueError("Customized value_type is not supported.") 2566 if isinstance(value, RaggedTensor): 2567 if value.ragged_rank != self._ragged_rank: 2568 raise ValueError( 2569 f"Ragged rank of value {value.ragged_rank} does not match " 2570 f"ragged rank of type {self._ragged_rank}.") 2571 # pylint: disable=protected-access 2572 return [value._to_variant(batched_input=True)] 2573 else: 2574 if self._ragged_rank > 0: 2575 raise ValueError( 2576 f"Expected a RaggedTensor if ragged rank={self._ragged_rank}" 2577 f" but got {type(value).__name__}." 2578 ) 2579 return [ 2580 gen_ragged_conversion_ops.ragged_tensor_to_variant( 2581 rt_nested_splits=(), rt_dense_values=value, batched_input=True) 2582 ] 2583 2584 def _from_compatible_tensor_list(self, tensor_list): 2585 if self._flat_values_spec is not None: 2586 raise ValueError("Customized value_type is not supported.") 2587 result = RaggedTensor._from_variant( # pylint: disable=protected-access 2588 tensor_list[0], 2589 dtype=self._dtype, 2590 row_splits_dtype=self._row_splits_dtype, 2591 output_ragged_rank=self._ragged_rank) 2592 if self._shape.ndims is not None: 2593 if isinstance(result, RaggedTensor): 2594 result._set_shape(self._shape) # pylint: disable=protected-access 2595 # TODO(xjun): MaskedTensor doesn't implement set_shape. 2596 if self.flat_values_spec is not None and hasattr(self.flat_values, 2597 "set_shape"): 2598 result.flat_values.set_shape(self.flat_values_spec.shape) 2599 else: 2600 result.set_shape(self._shape) 2601 return result 2602 2603 def _batch(self, batch_size): 2604 if self._flat_values_spec is not None: 2605 raise ValueError("Customized value_type is not supported.") 2606 return RaggedTensorSpec( 2607 tensor_shape.TensorShape([batch_size]).concatenate(self._shape), 2608 self._dtype, self._ragged_rank + 1, self._row_splits_dtype) 2609 2610 def _unbatch(self): 2611 if self._flat_values_spec is not None: 2612 raise ValueError("Customized value_type is not supported.") 2613 # Note: Negative ragged_rank is allowed here because the dataset could be 2614 # subsequently batched again. If ragged_rank > 1, assume row_splits_dtype is 2615 # consistent. Errors are handled in 2616 # RaggedTensorSpec._from_compatible_tensor_list() 2617 return RaggedTensorSpec(self._shape[1:], self._dtype, self._ragged_rank - 1, 2618 self._row_splits_dtype) 2619 2620 def _to_legacy_output_types(self): 2621 return self._dtype 2622 2623 def _to_legacy_output_shapes(self): 2624 return self._shape 2625 2626 def _to_legacy_output_classes(self): 2627 return self 2628 2629 @classmethod 2630 def from_value(cls, value): 2631 if (isinstance(value, ragged_tensor_value.RaggedTensorValue) or 2632 isinstance(value.flat_values, ops.Tensor)): 2633 return cls( 2634 shape=value.shape, 2635 dtype=value.values.dtype, 2636 ragged_rank=value.ragged_rank, 2637 row_splits_dtype=value.row_splits.dtype) 2638 else: 2639 flat_values_spec = type_spec.type_spec_from_value(value.flat_values) 2640 # Relax shape[0] to None, as it is connected to dynamic ragged shapes. 2641 flat_values_spec = flat_values_spec._unbatch()._batch(None) # pylint: disable=protected-access 2642 return cls( 2643 shape=value.shape, 2644 dtype=value.values.dtype, 2645 ragged_rank=value.ragged_rank, 2646 row_splits_dtype=value.row_splits.dtype, 2647 flat_values_spec=flat_values_spec) 2648 2649 2650type_spec.register_type_spec_from_value_converter( 2651 ragged_tensor_value.RaggedTensorValue, RaggedTensorSpec.from_value) 2652 2653 2654#=============================================================================== 2655# Convert value -> tensor 2656#=============================================================================== 2657def convert_to_tensor_or_ragged_tensor(value, 2658 dtype=None, 2659 preferred_dtype=None, 2660 name=None): 2661 """Converts value to a `RaggedTensor` or `Tensor`. 2662 2663 * If `value` is a `RaggedTensor`, then return it as-is. 2664 * If `value` is a `RaggedTensorValue`, return a corresponding constant 2665 `RaggedTensor`. 2666 * Otherwise, use `convert_to_tensor` to convert `value` to a `Tensor`. 2667 2668 Args: 2669 value: A `RaggedTensor`, a `RaggedTensorValue`, or an object whose type has 2670 a registered `Tensor` conversion function. 2671 dtype: Optional element type for the returned tensor. If missing the type 2672 is inferred from the type of `value`. 2673 preferred_dtype: Optional element type for the returned tensor, used when 2674 dtype is None. This argument has no effect if `value` is already a 2675 tensor, or when conversion is not possible. 2676 name: Optional name to use if a new `Tensor` is created. 2677 2678 Returns: 2679 A `Tensor` or `RaggedTensor`. 2680 """ 2681 if isinstance(value, RaggedTensor): 2682 if dtype and not dtype.is_compatible_with(value.dtype): 2683 raise ValueError(f"Tensor conversion requested dtype {dtype.name} for " 2684 f"RaggedTensor with dtype {value.dtype.name}: {value}.") 2685 return value 2686 elif isinstance(value, ragged_tensor_value.RaggedTensorValue): 2687 with ops.name_scope(name, "ConvertToTensorOrRaggedTensor", []): 2688 flat_values = ops.convert_to_tensor( 2689 value=value.flat_values, 2690 dtype=dtype, 2691 dtype_hint=preferred_dtype, 2692 name="flat_values") 2693 return RaggedTensor.from_nested_row_splits( 2694 flat_values, value.nested_row_splits, validate=False) 2695 else: 2696 return ops.convert_to_tensor_v2_with_dispatch( 2697 value=value, dtype=dtype, dtype_hint=preferred_dtype, name=name) 2698 2699 2700def _convert_to_ragged_tensor_values(value): 2701 """Converts value to supported RaggedTensor value. 2702 2703 * If `value` is an object of supported value type, then return it as-is. 2704 * Otherwise convert it to Tensor or RaggedTensor. 2705 2706 Args: 2707 value: An object of `Tensor`, `RaggedTensor` or registerred RaggedTensor 2708 value types, or an object whose type has a registered `Tensor` conversion 2709 function. 2710 2711 Returns: 2712 An object of `Tensor`, `RaggedTensor` or registerred RaggedTensor 2713 value types 2714 """ 2715 if _is_supported_ragged_values_type(value): 2716 return value 2717 else: 2718 return convert_to_tensor_or_ragged_tensor(value, name="values") 2719 2720 2721#=============================================================================== 2722# Register RaggedTensor for use with session.run. 2723#=============================================================================== 2724def _ragged_tensor_value_from_components(components): 2725 components = list(components) 2726 value = components.pop() 2727 while components: 2728 value = ragged_tensor_value.RaggedTensorValue(value, components.pop()) 2729 return value 2730 2731 2732def _ragged_tensor_session_fetch(rt): 2733 components = rt.nested_row_splits + (rt.flat_values,) 2734 return (components, _ragged_tensor_value_from_components) 2735 2736 2737def _ragged_tensor_session_feed(feed_key, feed_val): 2738 key_components = feed_key.nested_row_splits + (feed_key.flat_values,) 2739 val_components = feed_val.nested_row_splits + (feed_val.flat_values,) 2740 return zip(key_components, val_components) 2741 2742 2743def _ragged_tensor_session_feed_for_partial_run(feed_key): 2744 return feed_key.nested_row_splits + (feed_key.flat_values,) 2745 2746 2747session.register_session_run_conversion_functions( 2748 RaggedTensor, _ragged_tensor_session_fetch, _ragged_tensor_session_feed, 2749 _ragged_tensor_session_feed_for_partial_run) 2750 2751 2752#=============================================================================== 2753# RaggedTensorType 2754#=============================================================================== 2755class RaggedTensorType: 2756 """Encoding of a static type for a `RaggedTensor`. 2757 2758 Use this type to express/declare that an output must have the type of 2759 `RaggedTensor`. 2760 """ 2761 2762 def __init__(self, dtype, ragged_rank, row_splits_dtype=dtypes.int64): 2763 """Initializes a RaggedTensorType object. 2764 2765 Args: 2766 dtype: data type of the `RaggedTensor`'s inner values. 2767 ragged_rank: ragged_rank of the declared `RaggedTensor`. 2768 row_splits_dtype: data type for the `RaggedTensor`'s row splits. 2769 One of: `tf.int32` or `tf.int64`. 2770 """ 2771 row_splits_dtype = dtypes.as_dtype(row_splits_dtype) 2772 self._dtype = dtype 2773 self._ragged_rank = ragged_rank 2774 self._row_splits_dtype = row_splits_dtype 2775 2776 dtype = property(lambda self: self._dtype) 2777 ragged_rank = property(lambda self: self._ragged_rank) 2778 row_splits_dtype = property(lambda self: self._row_splits_dtype) 2779 2780 def __repr__(self): 2781 return "RaggedTensorType(%r, %r, %r)" % (self.dtype, self.ragged_rank, 2782 self.row_splits_dtype) 2783 2784 2785#=============================================================================== 2786# Helper Functions 2787#=============================================================================== 2788def _assert_sparse_indices_are_ragged_right(indices): 2789 """Checks that the given SparseTensor.indices tensor is ragged-right. 2790 2791 Example: `indices = [[0, 0], [0, 1], [2, 0], [3, 1]]` is not ragged right 2792 because the entry `[3, 1]` skips a cell. 2793 2794 Args: 2795 indices: The SparseTensor indices to check. 2796 2797 Returns: 2798 A list of control dependency op tensors. 2799 """ 2800 index_prefix = indices[:, :-1] 2801 index_suffix = indices[:, -1] 2802 2803 # Check whether each index is starting a new row in the innermost dimension 2804 # (prefix[i] != prefix[i-1]) or continuing a row (prefix[i] == prefix[i-1]). 2805 # (Note: this skips the first index; we will check that separately below.) 2806 index_prefix_changed = math_ops.reduce_any( 2807 math_ops.not_equal(index_prefix[1:], index_prefix[:-1]), axis=1) 2808 2809 # Check two cases: 2810 # * For indices that start a new row: index_suffix[i] must be zero. 2811 # * For indices that continue a row: index_suffix[i] must be equal to 2812 # index_suffix[i-1]+1. 2813 index_ok = array_ops.where( 2814 index_prefix_changed, math_ops.equal(index_suffix[1:], 0), 2815 math_ops.equal(index_suffix[1:], index_suffix[:-1] + 1)) 2816 2817 # Also check that the very first index didn't skip any cells. The first 2818 # index starts a new row (by definition), so its suffix should be zero. 2819 sparse_indices_are_ragged_right = math_ops.logical_and( 2820 math_ops.reduce_all(math_ops.equal(index_suffix[:1], 0)), 2821 math_ops.reduce_all(index_ok)) 2822 2823 message = [ 2824 "SparseTensor is not right-ragged", "SparseTensor.indices =", indices 2825 ] 2826 return [control_flow_ops.Assert(sparse_indices_are_ragged_right, message)] 2827 2828 2829@ops.RegisterGradient("RaggedTensorToSparse") 2830def _ragged_tensor_to_sparse_gradient(op, unused_sparse_indices_grad, 2831 sparse_values_grad, 2832 unused_sparse_shape_grad): 2833 """Gradient for RaggedTensorToSparse.""" 2834 op_inputs_nested_row_splits = op.inputs[:-1] 2835 op_inputs_flat_values = op.inputs[-1] 2836 2837 # No gradient for the RaggedTensor's nested_row_splits. 2838 nested_row_splits_gradient = [None] * len(op_inputs_nested_row_splits) 2839 2840 # Gradient for the RaggedTensor's flat_values is formed by reshaping 2841 # the gradient for the SparseTensor's values. 2842 flat_values_shape = array_ops.shape(op_inputs_flat_values) 2843 flat_values_gradient = array_ops.reshape(sparse_values_grad, 2844 flat_values_shape) 2845 2846 return nested_row_splits_gradient + [flat_values_gradient] 2847 2848 2849def _assert_monotonic_increasing(tensor, message=None): 2850 return check_ops.assert_non_negative( 2851 tensor[1:] - tensor[:-1], message=message) 2852 2853 2854def _assert_zero(tensor, message=None): 2855 return check_ops.assert_equal( 2856 tensor, constant_op.constant(0, dtype=tensor.dtype), message=message) 2857 2858 2859def _nrows(tensor, out_type=dtypes.int32): 2860 if isinstance(tensor, RaggedTensor): 2861 return tensor.nrows(out_type=out_type) 2862 else: 2863 return array_ops.shape(tensor, out_type=out_type)[0] 2864 2865 2866def merge_dims(value, outer_axis, inner_axis): 2867 """Merges value[outer_axis...inner_axis] into a single dimension. 2868 2869 See `RaggedTensor.merge_dims()` for more details. This helper differs from 2870 `RaggedTensor.merge_dims()` in that `value` may be a dense or ragged tensor. 2871 2872 Args: 2873 value: A `RaggedTensor` or `Tensor` 2874 outer_axis: `int` 2875 inner_axis: `int` 2876 2877 Returns: 2878 A flattened `RaggedTensor` or `Tensor`. 2879 """ 2880 if outer_axis == inner_axis: 2881 return value 2882 2883 # Flatten outer dimensions of a RaggedTensor by just taking its values. 2884 while outer_axis == 0 and isinstance(value, RaggedTensor): 2885 value = value.values 2886 inner_axis -= 1 2887 if inner_axis == 0: 2888 return value 2889 2890 # Flatten non-Ragged tensors using tf.reshape(). 2891 if not isinstance(value, RaggedTensor): 2892 if value.shape.is_fully_defined(): 2893 old_shape = value.shape.as_list() 2894 new_shape = old_shape[:outer_axis] + [-1] + old_shape[inner_axis + 1:] 2895 else: 2896 old_shape = array_ops.shape(value) 2897 new_shape = array_ops.concat( 2898 [old_shape[:outer_axis], [-1], old_shape[inner_axis + 1:]], axis=0) 2899 return array_ops.reshape(value, new_shape) 2900 2901 # Handle outer_axis>1 via recursion. 2902 if outer_axis > 1: 2903 return value.with_values( 2904 merge_dims(value.values, outer_axis - 1, inner_axis - 1)) 2905 2906 # At this point, we know outer_axis == 1, and value is a RaggedTensor. 2907 # So we need to flatten the values and build a corresponding splits tensor. 2908 new_values = value.values 2909 new_splits = value.row_splits 2910 for axis in range(outer_axis, inner_axis): 2911 if isinstance(new_values, RaggedTensor): 2912 # Flatten a single ragged dimension. 2913 new_splits = array_ops.gather(new_values.row_splits, new_splits) 2914 new_values = new_values.values 2915 else: 2916 # Flatten all remaining dense dimensions. 2917 shape_split = inner_axis - axis + 1 2918 if new_values.shape.is_fully_defined(): 2919 old_shape = new_values.shape.as_list() 2920 new_shape = [-1] + old_shape[shape_split:] 2921 flat_size = _prod(old_shape[1:shape_split]) 2922 else: 2923 old_shape = array_ops.shape(new_values) 2924 new_shape = array_ops.concat([[-1], old_shape[shape_split:]], axis=0) 2925 flat_size = math_ops.cast( 2926 math_ops.reduce_prod(old_shape[1:shape_split]), new_splits.dtype) 2927 new_values = array_ops.reshape(new_values, new_shape) 2928 new_splits = new_splits * flat_size 2929 break 2930 return RaggedTensor.from_row_splits(new_values, new_splits) 2931 2932 2933def _prod(lst): 2934 """Returns the product of the numbers in a list.""" 2935 return functools.reduce(operator.mul, lst, 1) 2936 2937 2938def _get_row_partition_type_tensor_pairs_tail(partition): 2939 """Gets a row partition type tensor pair for the tail. 2940 2941 If value_rowid is defined, then it is used. Otherwise, row_splits 2942 are used. 2943 2944 Args: 2945 partition: a RowPartition. 2946 2947 Returns: 2948 A list of (row_partition_type, row_partition_tensor) pairs. 2949 """ 2950 if partition._has_precomputed_value_rowids(): # pylint: disable=protected-access 2951 return ("VALUE_ROWIDS", partition.value_rowids()) 2952 else: 2953 return ("ROW_SPLITS", partition.row_splits()) 2954 2955 2956def _get_row_partition_type_tensor_pairs(rt_input): 2957 """Gets a list of the row partitions for rt_input. 2958 2959 If value_rowids are defined, then they are used. Otherwise, row_splits 2960 are used. If the outermost level has value_rowids defind, then nrows is 2961 also added. 2962 2963 Args: 2964 rt_input: a ragged tensor. 2965 2966 Returns: 2967 A list of (row_partition_type, row_partition_tensor) pairs. 2968 """ 2969 partitions = rt_input._nested_row_partitions # pylint: disable=protected-access 2970 tail = [_get_row_partition_type_tensor_pairs_tail(x) for x in partitions[1:]] 2971 2972 if partitions[0]._value_rowids is not None: # pylint: disable=protected-access 2973 return [("FIRST_DIM_SIZE", partitions[0].nrows()), 2974 ("VALUE_ROWIDS", partitions[0].value_rowids())] + tail 2975 else: 2976 return [("ROW_SPLITS", partitions[0].row_splits())] + tail 2977 2978 2979def _shape_as_tensor(shape, dtype): 2980 """Takes shape and coerces it to a shape as a tensor. 2981 2982 If the object is already a tensor, simply passes it on (result is guaranteed 2983 to be int64 or int32, but not necessarily dtype). 2984 If not, creates a tensor of type dtype. 2985 2986 Result is either a scalar equal to -1 if the shape is unknown_rank. 2987 Otherwise, it is a vector, where unknown dimensions are represented with a 2988 value of -1. 2989 2990 In C++, see TensorShapeFromTensor for parsing shapes in kernels, and 2991 InferenceContext::MakeShapeFromShapeTensorTreatScalarAsUnknownShape, for 2992 use in the shape inference function. 2993 2994 Args: 2995 shape: input to coerce from TensorShape, Tensor, None, List[Optional[Int]], 2996 Tuple[Optional[Int]]. 2997 dtype: tf.int64 or tf.int32 2998 2999 Returns: 3000 a scalar or vector tensor of dtype tf.int32 or tf.int64. 3001 """ 3002 if dtype != dtypes.int64 and dtype != dtypes.int32: 3003 raise ValueError(f"Expected int64 or int32 for dtype: got {dtype}.") 3004 3005 if isinstance(shape, ops.Tensor): 3006 if shape.dtype != dtypes.int64 and shape.dtype != dtypes.int32: 3007 return math_ops.cast(shape, dtype) 3008 return shape 3009 shape = tensor_shape.as_shape(shape) 3010 if not shape: 3011 # Imply rank is unknown using a -1 scalar. 3012 return constant_op.constant(-1, dtype=dtype) 3013 shape = [(-1 if x is None else x) for x in shape.as_list()] 3014 # At this point, shape is List[Int]. 3015 return constant_op.constant(shape, dtype=dtype) 3016 3017 3018def _nvals_uniform_row_length(values, uniform_row_length): 3019 """Get the number of values for uniform row length constructor.""" 3020 const_nvals = tensor_shape.dimension_at_index(values.shape, 0).value 3021 if const_nvals is not None: 3022 nvals = constant_op.constant(const_nvals, uniform_row_length.dtype) 3023 elif isinstance(values, RaggedTensor): 3024 nvals = values.nrows(out_type=uniform_row_length.dtype) 3025 else: 3026 nvals = array_ops.shape(values, out_type=uniform_row_length.dtype)[0] 3027 return nvals 3028 3029 3030def _get_optional_partition_dtype(values): 3031 """Returns the partition dtype, or None if None exists.""" 3032 if isinstance(values, RaggedTensor): 3033 # pylint: disable=protected-access 3034 return values._row_partition.dtype 3035 return None 3036 3037 3038_SUPPORTED_RAGGED_VALUE_TYPES = (ops.Tensor, RaggedTensor) 3039 3040 3041# TODO(edloper): Consider whether we should change the registry to be on 3042# TypeSpecs rather than ValueTypes. 3043def _add_supported_value_type(cls): 3044 """Register the `cls` as supported value type of RaggedTenosr. 3045 3046 The cls must be a subclass of CompositeTensor, and must support: 3047 - Spec: 3048 The Spec must be a `BatchableTypeSpec` 3049 - Properties: 3050 - x.shape 3051 - x.dtype 3052 - Methods: 3053 - x.__getitem__(idx) (method: returns a supported value type) 3054 - x.set_shape(shape) 3055 - Ops: 3056 - tf.shape(x) -- tf.shape(x)[0] must be a tf.Tensor. 3057 - tf.tile(x) 3058 - assert_rank_at_least(x) 3059 - tf.ones_like(x) 3060 - tf.gather(params=x, indices=Tensor) 3061 - tf.add(x, y) 3062 - tf.boolean_mask(x, ...) 3063 - @TODO(edloper): Complete this list 3064 3065 Note: the following RaggedTensor, RaggedTensorSpec methods & ops are not 3066 currently supported unless `rt.values` is a RaggedTensor or a tf.Tensor: 3067 - rt.to_tensor() 3068 - rt.to_sparse_tensor() 3069 - rt._to_variant() 3070 - rt._from_variant() 3071 - tf.ragged.cross([rt]) 3072 - tf.gather(params=x, indices=rt) # rt used for indices 3073 - RaggedTensorSpec methods: 3074 - _batch 3075 - _unbatch 3076 - _to_tensor_list 3077 - _to_batched_tensor_list 3078 - _from_compatible_tensor_list 3079 3080 Args: 3081 cls: The type to be added to supported value types. 3082 """ 3083 if not issubclass(cls, composite_tensor.CompositeTensor): 3084 raise ValueError(f"cls ({cls}) must be a subclass of CompositeTensor.") 3085 if not hasattr(cls, "shape"): 3086 raise ValueError("cls must support the `shape` property.") 3087 if not hasattr(cls, "dtype"): 3088 raise ValueError("cls must support the `dtype` property.") 3089 global _SUPPORTED_RAGGED_VALUE_TYPES 3090 _SUPPORTED_RAGGED_VALUE_TYPES += (cls,) 3091 3092 3093def _is_supported_ragged_values_type(value): 3094 return isinstance(value, _SUPPORTED_RAGGED_VALUE_TYPES) 3095 3096 3097def _assert_is_supported_ragged_values_type(value): 3098 if not _is_supported_ragged_values_type(value): 3099 ok_types = ", ".join(cls.__name__ for cls in _SUPPORTED_RAGGED_VALUE_TYPES) 3100 raise TypeError(f"type(values) must be one of: {ok_types}, got {value}.") 3101 3102 3103def _formatter(x): 3104 """Separate Numpy array elements with comma.""" 3105 if isinstance(x, np.ndarray): 3106 if x.size != 0: 3107 return np.array2string(x, separator=", ") 3108 else: 3109 # When x.size==0, np.array2string always returns `[]`. This isn't always 3110 # what we want. E.g., if `x.shape=[0, 3]`, then we want `[[], [], []]`. 3111 return repr(x.tolist()) 3112 else: 3113 return str(x) 3114 3115# Type annotation indicating that a value is ragged. Includes RaggedTensor 3116# as well as the (deprecated) RaggedTensorValue class from TF 1.x. 3117Ragged = typing.Union[RaggedTensor, ragged_tensor_value.RaggedTensorValue] 3118 3119# Type annotation indicating that a value is a ragged tensor, a dense tensor, 3120# or a value that can be converted to a tensor (e.g. np.array). 3121# TODO(edloper): Add Variable to TensorLike, and remove it from here. 3122RaggedOrDense = typing.Union[Ragged, core_types.TensorLike] 3123