1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Classes for storing ragged tensors and their values.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.client import session 22from tensorflow.python.framework import composite_tensor 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import sparse_tensor 27from tensorflow.python.framework import tensor_shape 28from tensorflow.python.framework import tensor_util 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import control_flow_ops 31from tensorflow.python.ops import gen_ragged_conversion_ops 32from tensorflow.python.ops import math_ops 33from tensorflow.python.ops.ragged import ragged_tensor_value 34from tensorflow.python.ops.ragged import ragged_util 35from tensorflow.python.ops.ragged import segment_id_ops 36from tensorflow.python.util.tf_export import tf_export 37 38# pylint: disable=protected-access 39_eval_using_default_session = ops._eval_using_default_session 40 41# pylint: enable=protected-access 42 43#=============================================================================== 44# RaggedTensor 45#=============================================================================== 46 47 48@tf_export("RaggedTensor") 49class RaggedTensor(composite_tensor.CompositeTensor): 50 """Represents a ragged tensor. 51 52 A `RaggedTensor` is a tensor with one or more *ragged dimensions*, which are 53 dimensions whose slices may have different lengths. For example, the inner 54 (column) dimension of `rt=[[3, 1, 4, 1], [], [5, 9, 2], [6], []]` is ragged, 55 since the column slices (`rt[0, :]`, ..., `rt[4, :]`) have different lengths. 56 Dimensions whose slices all have the same length are called *uniform 57 dimensions*. The outermost dimension of a `RaggedTensor` is always uniform, 58 since it consists of a single slice (and so there is no possibility for 59 differing slice lengths). 60 61 The total number of dimensions in a `RaggedTensor` is called its *rank*, 62 and the number of ragged dimensions in a `RaggedTensor` is called its 63 *ragged-rank*. A `RaggedTensor`'s ragged-rank is fixed at graph creation 64 time: it can't depend on the runtime values of `Tensor`s, and can't vary 65 dynamically for different session runs. 66 67 ### Potentially Ragged Tensors 68 69 Many ops support both `Tensor`s and `RaggedTensor`s. The term "potentially 70 ragged tensor" may be used to refer to a tensor that might be either a 71 `Tensor` or a `RaggedTensor`. The ragged-rank of a `Tensor` is zero. 72 73 ### Documenting RaggedTensor Shapes 74 75 When documenting the shape of a RaggedTensor, ragged dimensions can be 76 indicated by enclosing them in parentheses. For example, the shape of 77 a 3-D `RaggedTensor` that stores the fixed-size word embedding for each 78 word in a sentence, for each sentence in a batch, could be written as 79 `[num_sentences, (num_words), embedding_size]`. The parentheses around 80 `(num_words)` indicate that dimension is ragged, and that the length 81 of each element list in that dimension may vary for each item. 82 83 ### Component Tensors 84 85 Internally, a `RaggedTensor` consists of a concatenated list of values that 86 are partitioned into variable-length rows. In particular, each `RaggedTensor` 87 consists of: 88 89 * A `values` tensor, which concatenates the variable-length rows into a 90 flattened list. For example, the `values` tensor for 91 `[[3, 1, 4, 1], [], [5, 9, 2], [6], []]` is `[3, 1, 4, 1, 5, 9, 2, 6]`. 92 93 * A `row_splits` vector, which indicates how those flattened values are 94 divided into rows. In particular, the values for row `rt[i]` are stored 95 in the slice `rt.values[rt.row_splits[i]:rt.row_splits[i+1]]`. 96 97 Example: 98 99 ```python 100 >>> print(tf.RaggedTensor.from_row_splits( 101 ... values=[3, 1, 4, 1, 5, 9, 2, 6], 102 ... row_splits=[0, 4, 4, 7, 8, 8])) 103 <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]> 104 ``` 105 106 ### Alternative Row-Partitioning Schemes 107 108 In addition to `row_splits`, ragged tensors provide support for four other 109 row-partitioning schemes: 110 111 * `row_lengths`: a vector with shape `[nrows]`, which specifies the length 112 of each row. 113 114 * `value_rowids` and `nrows`: `value_rowids` is a vector with shape 115 `[nvals]`, corresponding one-to-one with `values`, which specifies 116 each value's row index. In particular, the row `rt[row]` consists of the 117 values `rt.values[j]` where `value_rowids[j]==row`. `nrows` is an 118 int64 scalar that specifies the number of rows in the `RaggedTensor`. 119 (`nrows` is used to indicate trailing empty rows.) 120 121 * `row_starts`: a vector with shape `[nrows]`, which specifies the start 122 offset of each row. Equivalent to `row_splits[:-1]`. 123 124 * `row_limits`: a vector with shape `[nrows]`, which specifies the stop 125 offset of each row. Equivalent to `row_splits[1:]`. 126 127 Example: The following ragged tensors are equivalent, and all represent the 128 nested list `[[3, 1, 4, 1], [], [5, 9, 2], [6], []]`. 129 130 ```python 131 >>> values = [3, 1, 4, 1, 5, 9, 2, 6] 132 >>> rt1 = RaggedTensor.from_row_splits(values, row_splits=[0, 4, 4, 7, 8, 8]) 133 >>> rt2 = RaggedTensor.from_row_lengths(values, row_lengths=[4, 0, 3, 1, 0]) 134 >>> rt3 = RaggedTensor.from_value_rowids( 135 ... values, value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], nrows=5) 136 >>> rt4 = RaggedTensor.from_row_starts(values, row_starts=[0, 4, 4, 7, 8]) 137 >>> rt5 = RaggedTensor.from_row_limits(values, row_limits=[4, 4, 7, 8, 8]) 138 ``` 139 140 ### Multiple Ragged Dimensions 141 142 `RaggedTensor`s with multiple ragged dimensions can be defined by using 143 a nested `RaggedTensor` for the `values` tensor. Each nested `RaggedTensor` 144 adds a single ragged dimension. 145 146 ```python 147 >>> inner_rt = RaggedTensor.from_row_splits( # =rt1 from above 148 ... values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8]) 149 >>> outer_rt = RaggedTensor.from_row_splits( 150 ... values=inner_rt, row_splits=[0, 3, 3, 5]) 151 >>> print outer_rt.to_list() 152 [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]] 153 >>> print outer_rt.ragged_rank 154 2 155 ``` 156 157 The factory function `RaggedTensor.from_nested_row_splits` may be used to 158 construct a `RaggedTensor` with multiple ragged dimensions directly, by 159 providing a list of `row_splits` tensors: 160 161 ```python 162 >>> RaggedTensor.from_nested_row_splits( 163 ... flat_values=[3, 1, 4, 1, 5, 9, 2, 6], 164 ... nested_row_splits=([0, 3, 3, 5], [0, 4, 4, 7, 8, 8])).to_list() 165 [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]] 166 ``` 167 168 ### Uniform Inner Dimensions 169 170 `RaggedTensor`s with uniform inner dimensions can be defined 171 by using a multidimensional `Tensor` for `values`. 172 173 ```python 174 >>> rt = RaggedTensor.from_row_splits(values=tf.ones([5, 3]), 175 .. row_splits=[0, 2, 5]) 176 >>> print rt.to_list() 177 [[[1, 1, 1], [1, 1, 1]], 178 [[1, 1, 1], [1, 1, 1], [1, 1, 1]]] 179 >>> print rt.shape 180 (2, ?, 3) 181 ``` 182 183 ### RaggedTensor Shape Restrictions 184 185 The shape of a RaggedTensor is currently restricted to have the following 186 form: 187 188 * A single uniform dimension 189 * Followed by one or more ragged dimensions 190 * Followed by zero or more uniform dimensions. 191 192 This restriction follows from the fact that each nested `RaggedTensor` 193 replaces the uniform outermost dimension of its `values` with a uniform 194 dimension followed by a ragged dimension. 195 """ 196 197 #============================================================================= 198 # Constructor (private) 199 #============================================================================= 200 def __init__(self, 201 values, 202 row_splits, 203 cached_row_lengths=None, 204 cached_value_rowids=None, 205 cached_nrows=None, 206 internal=False): 207 """Creates a `RaggedTensor` with a specified partitioning for `values`. 208 209 This constructor is private -- please use one of the following ops to 210 build `RaggedTensor`s: 211 212 * `tf.RaggedTensor.from_row_lengths` 213 * `tf.RaggedTensor.from_value_rowids` 214 * `tf.RaggedTensor.from_row_splits` 215 * `tf.RaggedTensor.from_row_starts` 216 * `tf.RaggedTensor.from_row_limits` 217 * `tf.RaggedTensor.from_nested_row_splits` 218 * `tf.RaggedTensor.from_nested_row_lengths` 219 * `tf.RaggedTensor.from_nested_value_rowids` 220 221 Args: 222 values: A potentially ragged tensor of any dtype and shape `[nvals, ...]`. 223 row_splits: A 1-D int64 tensor with shape `[nrows+1]`. 224 cached_row_lengths: A 1-D int64 tensor with shape `[nrows]` 225 cached_value_rowids: A 1-D int64 tensor with shape `[nvals]`. 226 cached_nrows: A 1-D int64 scalar tensor. 227 internal: True if the constructor is being called by one of the factory 228 methods. If false, an exception will be raised. 229 230 Raises: 231 TypeError: If a row partitioning tensor has an inappropriate dtype. 232 TypeError: If exactly one row partitioning argument was not specified. 233 ValueError: If a row partitioning tensor has an inappropriate shape. 234 ValueError: If multiple partitioning arguments are specified. 235 ValueError: If nrows is specified but value_rowids is not None. 236 """ 237 if not internal: 238 raise ValueError("RaggedTensor constructor is private; please use one " 239 "of the factory methods instead (e.g., " 240 "RaggedTensor.from_row_lengths())") 241 242 # Validate the arguments. 243 if not isinstance(values, (RaggedTensor, ops.Tensor)): 244 raise TypeError("values must be a Tensor or RaggedTensor.") 245 if not isinstance(row_splits, ops.Tensor): 246 raise TypeError("Row-partitioning argument must be a Tensor.") 247 values.shape.with_rank_at_least(1) 248 row_splits.shape.assert_has_rank(1) 249 row_splits.set_shape([None]) 250 251 self._values = values 252 self._row_splits = row_splits 253 254 # Store any cached tensors. These are used to avoid unnecessary 255 # round-trip conversions when a RaggedTensor is constructed from 256 # lengths or rowids, and we later want those lengths/rowids back. 257 for tensor in [cached_row_lengths, cached_value_rowids, cached_nrows]: 258 if tensor is not None and not isinstance(tensor, ops.Tensor): 259 raise TypeError("Cached value must be a Tensor or None.") 260 self._cached_row_lengths = cached_row_lengths 261 self._cached_value_rowids = cached_value_rowids 262 self._cached_nrows = cached_nrows 263 264 #============================================================================= 265 # Factory Methods 266 #============================================================================= 267 268 @classmethod 269 def from_value_rowids(cls, values, value_rowids, nrows=None, name=None): 270 """Creates a `RaggedTensor` with rows partitioned by `value_rowids`. 271 272 The returned `RaggedTensor` corresponds with the python list defined by: 273 274 ```python 275 result = [[values[i] for i in range(len(values)) if value_rowids[i] == row] 276 for row in range(nrows)] 277 ``` 278 279 Warning: currently, this needs to cast value_rowids to int64 before 280 converting, since `tf.bincount` only supports `int32`. 281 282 Args: 283 values: A potentially ragged tensor with shape `[nvals, ...]`. 284 value_rowids: A 1-D int64 tensor with shape `[nvals]`, which corresponds 285 one-to-one with `values`, and specifies each value's row index. Must be 286 nonnegative, and must be sorted in ascending order. 287 nrows: An int64 scalar specifying the number of rows. This should be 288 specified if the `RaggedTensor` may containing empty training rows. Must 289 be greater than `value_rowids[-1]` (or zero if `value_rowids` is empty). 290 Defaults to `value_rowids[-1]` (or zero if `value_rowids` is empty). 291 name: A name prefix for the RaggedTensor (optional). 292 293 Returns: 294 A `RaggedTensor`. `result.rank = values.rank + 1`. 295 `result.ragged_rank = values.ragged_rank + 1`. 296 297 Raises: 298 ValueError: If `nrows` is incompatible with `value_rowids`. 299 300 #### Example: 301 ```python 302 >>> print(tf.RaggedTensor.from_value_rowids( 303 ... values=[3, 1, 4, 1, 5, 9, 2, 6], 304 ... value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], 305 ... nrows=5)) 306 <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]> 307 ``` 308 """ 309 with ops.name_scope(name, "RaggedFromValueRowIds", 310 [values, value_rowids, nrows]): 311 values = convert_to_tensor_or_ragged_tensor(values, name="values") 312 value_rowids = ops.convert_to_tensor( 313 value_rowids, dtypes.int64, name="value_rowids") 314 if nrows is None: 315 const_rowids = tensor_util.constant_value(value_rowids) 316 if const_rowids is None: 317 nrows = array_ops.concat([value_rowids[-1:], [-1]], axis=0)[0] + 1 318 const_nrows = None 319 else: 320 const_nrows = const_rowids[-1] + 1 if const_rowids.size > 0 else 0 321 nrows = ops.convert_to_tensor(const_nrows, dtypes.int64, name="nrows") 322 else: 323 nrows = ops.convert_to_tensor(nrows, dtypes.int64, "nrows") 324 const_nrows = tensor_util.constant_value(nrows) 325 if const_nrows is not None: 326 if const_nrows < 0: 327 raise ValueError("Expected nrows >= 0; got %d" % const_nrows) 328 const_rowids = tensor_util.constant_value(value_rowids) 329 if const_rowids is not None and const_rowids.size > 0: 330 if not const_nrows >= const_rowids[-1] + 1: 331 raise ValueError( 332 "Expected nrows >= value_rowids[-1] + 1; got nrows=%d, " 333 "value_rowids[-1]=%d" % (const_nrows, const_rowids[-1])) 334 335 value_rowids.shape.assert_has_rank(1) 336 nrows.shape.assert_has_rank(0) 337 values.shape[:1].assert_is_compatible_with(value_rowids.shape) 338 339 # Convert value_rowids & nrows to row_splits. 340 # Note: we don't use segment_ids_to_row_splits() here because we want 341 # to save the intermediate value `row_lengths`, so we can cache it. 342 # TODO(b/116708836) Upgrade bincount to accept int64 so we can skip the 343 # cast (Remove the warning in the docstring when we do.) 344 value_rowids_int32 = math_ops.cast(value_rowids, dtypes.int32) 345 nrows_int32 = math_ops.cast(nrows, dtypes.int32) 346 row_lengths = math_ops.bincount( 347 value_rowids_int32, 348 minlength=nrows_int32, 349 maxlength=nrows_int32, 350 dtype=dtypes.int64) 351 row_splits = array_ops.concat([[0], math_ops.cumsum(row_lengths)], axis=0) 352 if const_nrows is not None: 353 row_lengths.set_shape([const_nrows]) 354 row_splits.set_shape([const_nrows + 1]) 355 356 return cls( 357 values, 358 row_splits, 359 cached_row_lengths=row_lengths, 360 cached_value_rowids=value_rowids, 361 cached_nrows=nrows, 362 internal=True) 363 364 @classmethod 365 def from_row_splits(cls, values, row_splits, name=None): 366 """Creates a `RaggedTensor` with rows partitioned by `row_splits`. 367 368 The returned `RaggedTensor` corresponds with the python list defined by: 369 370 ```python 371 result = [values[row_splits[i]:row_splits[i + 1]] 372 for i in range(len(row_splits) - 1)] 373 ``` 374 375 Args: 376 values: A potentially ragged tensor with shape `[nvals, ...]`. 377 row_splits: A 1-D int64 tensor with shape `[nrows+1]`. Must not be empty, 378 and must be sorted in ascending order. `row_splits[0]` must be zero and 379 `row_splits[-1]` must be `nvals`. 380 name: A name prefix for the RaggedTensor (optional). 381 382 Returns: 383 A `RaggedTensor`. `result.rank = values.rank + 1`. 384 `result.ragged_rank = values.ragged_rank + 1`. 385 386 Raises: 387 ValueError: If `row_splits` is an empty list. 388 389 #### Example: 390 ```python 391 >>> print(tf.RaggedTensor.from_row_splits( 392 ... values=[3, 1, 4, 1, 5, 9, 2, 6], 393 ... row_splits=[0, 4, 4, 7, 8, 8])) 394 <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]> 395 ``` 396 """ 397 if isinstance(row_splits, (list, tuple)) and not row_splits: 398 raise ValueError("row_splits tensor may not be empty.") 399 with ops.name_scope(name, "RaggedFromRowSplits", [values, row_splits]): 400 values = convert_to_tensor_or_ragged_tensor(values, name="values") 401 row_splits = ops.convert_to_tensor(row_splits, dtypes.int64, "row_splits") 402 row_splits.shape.assert_has_rank(1) 403 return cls(values=values, row_splits=row_splits, internal=True) 404 405 @classmethod 406 def from_row_lengths(cls, values, row_lengths, name=None): 407 """Creates a `RaggedTensor` with rows partitioned by `row_lengths`. 408 409 The returned `RaggedTensor` corresponds with the python list defined by: 410 411 ```python 412 result = [[values.pop(0) for i in range(length)] 413 for length in row_lengths] 414 ``` 415 416 Args: 417 values: A potentially ragged tensor with shape `[nvals, ...]`. 418 row_lengths: A 1-D int64 tensor with shape `[nrows]`. Must be 419 nonnegative. `sum(row_lengths)` must be `nvals`. 420 name: A name prefix for the RaggedTensor (optional). 421 422 Returns: 423 A `RaggedTensor`. `result.rank = values.rank + 1`. 424 `result.ragged_rank = values.ragged_rank + 1`. 425 426 #### Example: 427 ```python 428 >>> print(tf.RaggedTensor.from_row_lengths( 429 ... values=[3, 1, 4, 1, 5, 9, 2, 6], 430 ... row_lengths=[4, 0, 3, 1, 0])) 431 <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []])> 432 ``` 433 """ 434 with ops.name_scope(name, "RaggedFromRowLengths", [values, row_lengths]): 435 values = convert_to_tensor_or_ragged_tensor(values, name="values") 436 row_lengths = ops.convert_to_tensor(row_lengths, dtypes.int64, 437 "row_lengths") 438 row_lengths.shape.assert_has_rank(1) 439 row_limits = math_ops.cumsum(row_lengths) 440 row_splits = array_ops.concat([[0], row_limits], axis=0) 441 return cls( 442 values=values, 443 row_splits=row_splits, 444 cached_row_lengths=row_lengths, 445 internal=True) 446 447 @classmethod 448 def from_row_starts(cls, values, row_starts, name=None): 449 """Creates a `RaggedTensor` with rows partitioned by `row_starts`. 450 451 Equivalent to: `from_row_splits(values, concat([row_starts, nvals]))`. 452 453 Args: 454 values: A potentially ragged tensor with shape `[nvals, ...]`. 455 row_starts: A 1-D int64 tensor with shape `[nrows]`. Must be nonnegative 456 and sorted in ascending order. If `nrows>0`, then `row_starts[0]` must 457 be zero. 458 name: A name prefix for the RaggedTensor (optional). 459 460 Returns: 461 A `RaggedTensor`. `result.rank = values.rank + 1`. 462 `result.ragged_rank = values.ragged_rank + 1`. 463 464 #### Example: 465 ```python 466 >>> print(tf.RaggedTensor.from_row_starts( 467 ... values=[3, 1, 4, 1, 5, 9, 2, 6], 468 ... row_starts=[0, 4, 4, 7, 8])) 469 <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]> 470 ``` 471 """ 472 with ops.name_scope(name, "RaggedFromRowStarts", [values, row_starts]): 473 values = convert_to_tensor_or_ragged_tensor(values, name="values") 474 row_starts = ops.convert_to_tensor(row_starts, dtypes.int64, "row_starts") 475 row_starts.shape.assert_has_rank(1) 476 nvals = array_ops.shape(values, out_type=dtypes.int64)[:1] 477 row_splits = array_ops.concat([row_starts, nvals], axis=0) 478 return cls(values=values, row_splits=row_splits, internal=True) 479 480 @classmethod 481 def from_row_limits(cls, values, row_limits, name=None): 482 """Creates a `RaggedTensor` with rows partitioned by `row_limits`. 483 484 Equivalent to: `from_row_splits(values, concat([0, row_limits]))`. 485 486 Args: 487 values: A potentially ragged tensor with shape `[nvals, ...]`. 488 row_limits: A 1-D int64 tensor with shape `[nrows]`. Must be sorted in 489 ascending order. If `nrows>0`, then `row_limits[-1]` must be `nvals`. 490 name: A name prefix for the RaggedTensor (optional). 491 492 Returns: 493 A `RaggedTensor`. `result.rank = values.rank + 1`. 494 `result.ragged_rank = values.ragged_rank + 1`. 495 496 #### Example: 497 ```python 498 >>> print(tf.RaggedTensor.from_row_limits( 499 ... values=[3, 1, 4, 1, 5, 9, 2, 6], 500 ... row_limits=[4, 4, 7, 8, 8])) 501 <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]> 502 ``` 503 """ 504 with ops.name_scope(name, "RaggedFromRowLimits", [values, row_limits]): 505 values = convert_to_tensor_or_ragged_tensor(values, name="values") 506 row_limits = ops.convert_to_tensor(row_limits, dtypes.int64, "row_limits") 507 row_limits.shape.assert_has_rank(1) 508 zero = array_ops.zeros([1], dtypes.int64) 509 row_splits = array_ops.concat([zero, row_limits], axis=0) 510 return cls(values=values, row_splits=row_splits, internal=True) 511 512 @classmethod 513 def from_nested_value_rowids(cls, 514 flat_values, 515 nested_value_rowids, 516 nested_nrows=None, 517 name=None): 518 """Creates a `RaggedTensor` from a nested list of `value_rowids` tensors. 519 520 Equivalent to: 521 522 ```python 523 result = flat_values 524 for (rowids, nrows) in reversed(zip(nested_value_rowids, nested_nrows)): 525 result = from_value_rowids(result, rowids, nrows) 526 ``` 527 528 Args: 529 flat_values: A potentially ragged tensor. 530 nested_value_rowids: A list of 1-D int64 tensors. The `i`th tensor is 531 used as the `value_rowids` for the `i`th ragged dimension. 532 nested_nrows: A list of int64 scalars. The `i`th scalar is used as the 533 `nrows` for the `i`th ragged dimension. 534 name: A name prefix for the RaggedTensor (optional). 535 536 Returns: 537 A `RaggedTensor` (or `flat_values` if `nested_value_rowids` is empty). 538 539 Raises: 540 ValueError: If `len(nested_values_rowids) != len(nested_nrows)`. 541 """ 542 if isinstance(nested_value_rowids, ops.Tensor): 543 raise TypeError("nested_value_rowids must be a list of Tensors") 544 if nested_nrows is None: 545 nested_nrows = [None] * len(nested_value_rowids) 546 else: 547 if isinstance(nested_nrows, ops.Tensor): 548 raise TypeError("nested_nrows must be a list of Tensors") 549 if len(nested_nrows) != len(nested_value_rowids): 550 raise ValueError("nested_nrows must have the same length as " 551 "nested_value_rowids") 552 553 with ops.name_scope( 554 name, "RaggedFromNestedValueRowIds", 555 [flat_values] + list(nested_value_rowids) + list(nested_nrows)): 556 result = flat_values 557 for value_rowids, nrows in reversed( 558 list(zip(nested_value_rowids, nested_nrows))): 559 result = cls.from_value_rowids(result, value_rowids, nrows) 560 return result 561 562 @classmethod 563 def from_nested_row_splits(cls, flat_values, nested_row_splits, name=None): 564 """Creates a `RaggedTensor` from a nested list of `row_splits` tensors. 565 566 Equivalent to: 567 568 ```python 569 result = flat_values 570 for row_splits in reversed(nested_row_splits): 571 result = from_row_splits(result, row_splits) 572 ``` 573 574 Args: 575 flat_values: A potentially ragged tensor. 576 nested_row_splits: A list of 1-D int64 tensors. The `i`th tensor is used 577 as the `row_splits` for the `i`th ragged dimension. 578 name: A name prefix for the RaggedTensor (optional). 579 580 Returns: 581 A `RaggedTensor` (or `flat_values` if `nested_row_splits` is empty). 582 """ 583 if isinstance(nested_row_splits, ops.Tensor): 584 raise TypeError("nested_row_splits must be a list of Tensors") 585 with ops.name_scope(name, "RaggedFromNestedRowSplits", 586 [flat_values] + list(nested_row_splits)): 587 result = flat_values 588 for splits in reversed(nested_row_splits): 589 result = cls.from_row_splits(result, splits) 590 return result 591 592 @classmethod 593 def from_nested_row_lengths(cls, flat_values, nested_row_lengths, name=None): 594 """Creates a `RaggedTensor` from a nested list of `row_lengths` tensors. 595 596 Equivalent to: 597 598 ```python 599 result = flat_values 600 for row_lengths in reversed(nested_row_lengths): 601 result = from_row_lengths(result, row_lengths) 602 ``` 603 604 Args: 605 flat_values: A potentially ragged tensor. 606 nested_row_lengths: A list of 1-D int64 tensors. The `i`th tensor is used 607 as the `row_lengths` for the `i`th ragged dimension. 608 name: A name prefix for the RaggedTensor (optional). 609 610 Returns: 611 A `RaggedTensor` (or `flat_values` if `nested_row_lengths` is empty). 612 """ 613 if isinstance(nested_row_lengths, ops.Tensor): 614 raise TypeError("nested_row_lengths must be a list of Tensors") 615 with ops.name_scope(name, "RaggedFromNestedRowlengths", 616 [flat_values] + list(nested_row_lengths)): 617 result = flat_values 618 for lengths in reversed(nested_row_lengths): 619 result = cls.from_row_lengths(result, lengths) 620 return result 621 622 #============================================================================= 623 # Accessors 624 #============================================================================= 625 626 @property 627 def dtype(self): 628 """The `DType` of values in this tensor.""" 629 return self._values.dtype 630 631 @property 632 def shape(self): 633 """The statically known shape of this ragged tensor. 634 635 Returns: 636 A `TensorShape` containing the statically known shape of this ragged 637 tensor. Ragged dimensions have a size of `None`. 638 639 Examples: 640 641 ```python 642 >>> ragged.constant([[0], [1, 2]]).shape 643 TensorShape([Dimension(2), Dimension(None)]) 644 645 >>> ragged.constant([[[0, 1]], [[1, 2], [3, 4]]], ragged_rank=1).shape 646 TensorShape([Dimension(2), Dimension(None), Dimension(2) 647 ``` 648 """ 649 nrows = tensor_shape.dimension_at_index(self._row_splits.shape, 0) - 1 650 651 values_shape = self._values.shape 652 value_shape = values_shape[1:] 653 return tensor_shape.TensorShape([nrows, None]).concatenate(value_shape) 654 655 @property 656 def ragged_rank(self): 657 """The number of ragged dimensions in this ragged tensor. 658 659 Returns: 660 A Python `int` indicating the number of ragged dimensions in this ragged 661 tensor. The outermost dimension is not considered ragged. 662 """ 663 values_is_ragged = isinstance(self._values, RaggedTensor) 664 return self._values.ragged_rank + 1 if values_is_ragged else 1 665 666 @property 667 def values(self): 668 """The concatenated rows for this ragged tensor. 669 670 `rt.values` is a potentially ragged tensor formed by flattening the two 671 outermost dimensions of `rt` into a single dimension. 672 673 `rt.values.shape = [nvals] + rt.shape[2:]` (where `nvals` is the 674 number of items in the outer two dimensions of `rt`). 675 676 `rt.ragged_rank = self.ragged_rank - 1` 677 678 Returns: 679 A potentially ragged tensor. 680 681 #### Example: 682 ```python 683 >>> rt = ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []]) 684 >>> print rt.values 685 tf.Tensor([3, 1, 4, 1, 5, 9, 2, 6]) 686 ``` 687 """ 688 return self._values 689 690 @property 691 def row_splits(self): 692 """The row-split indices for this ragged tensor's `values`. 693 694 `rt.row_splits` specifies where the values for each row begin and end in 695 `rt.values`. In particular, the values for row `rt[i]` are stored in 696 the slice `rt.values[rt.row_splits[i]:rt.row_splits[i+1]]`. 697 698 Returns: 699 A 1-D `int64` `Tensor` with shape `[self.nrows+1]`. 700 The returned tensor is non-empty, and is sorted in ascending order. 701 `self.row_splits[0]` is zero, and `self.row_splits[-1]` is equal to 702 `self.values.shape[0]`. 703 704 #### Example: 705 ```python 706 >>> rt = ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []]) 707 >>> print rt.row_splits # indices of row splits in rt.values 708 tf.Tensor([0, 4, 4, 7, 8, 8]) 709 ``` 710 """ 711 return self._row_splits 712 713 @property 714 def flat_values(self): 715 """The innermost `values` tensor for this ragged tensor. 716 717 Concretely, if `rt.values` is a `Tensor`, then `rt.flat_values` is 718 `rt.values`; otherwise, `rt.flat_values` is `rt.values.flat_values`. 719 720 Conceptually, `flat_values` is the tensor formed by flattening the 721 outermost dimension and all of the ragged dimensions into a single 722 dimension. 723 724 `rt.flat_values.shape = [nvals] + rt.shape[rt.ragged_rank + 1:]` 725 (where `nvals` is the number of items in the flattened dimensions). 726 727 Returns: 728 A `Tensor`. 729 730 #### Example: 731 732 ```python 733 >>> rt = ragged.constant([[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]]) 734 >>> print rt.flat_values() 735 tf.Tensor([3, 1, 4, 1, 5, 9, 2, 6]) 736 ``` 737 """ 738 rt_values = self.values 739 while isinstance(rt_values, RaggedTensor): 740 rt_values = rt_values.values 741 return rt_values 742 743 @property 744 def nested_row_splits(self): 745 """A tuple containing the row_splits for all ragged dimensions. 746 747 `rt.nested_row_splits` is a tuple containing the `row_splits` tensors for 748 all ragged dimensions in `rt`, ordered from outermost to innermost. In 749 particular, `rt.nested_row_splits = (rt.row_splits,) + value_splits` where: 750 751 * `value_splits = ()` if `rt.values` is a `Tensor`. 752 * `value_splits = rt.values.nested_row_splits` otherwise. 753 754 Returns: 755 A `tuple` of 1-D `int64` `Tensor`s. 756 757 #### Example: 758 759 ```python 760 >>> rt = ragged.constant([[[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]]]) 761 >>> for i, splits in enumerate(rt.nested_row_splits()): 762 ... print('Splits for dimension %d: %s' % (i+1, splits)) 763 Splits for dimension 1: [0, 1] 764 Splits for dimension 2: [0, 3, 3, 5] 765 Splits for dimension 3: [0, 4, 4, 7, 8, 8] 766 ``` 767 768 """ 769 rt_nested_splits = [self.row_splits] 770 rt_values = self.values 771 while isinstance(rt_values, RaggedTensor): 772 rt_nested_splits.append(rt_values.row_splits) 773 rt_values = rt_values.values 774 return tuple(rt_nested_splits) 775 776 def value_rowids(self, name=None): 777 """Returns the row indices for the `values` in this ragged tensor. 778 779 `rt.value_rowids()` corresponds one-to-one with the outermost dimension of 780 `rt.values`, and specifies the row containing each value. In particular, 781 the row `rt[row]` consists of the values `rt.values[j]` where 782 `rt.value_rowids()[j] == row`. 783 784 Args: 785 name: A name prefix for the returned tensor (optional). 786 787 Returns: 788 A 1-D `int64` `Tensor` with shape `self.values.shape[:1]`. 789 The returned tensor is nonnegative, and is sorted in ascending order. 790 791 #### Example: 792 ```python 793 >>> rt = ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []]) 794 >>> rt.values 795 tf.Tensor([3, 1, 4, 1, 5, 9, 2, 6]) 796 >>> rt.value_rowids() 797 tf.Tensor([0, 0, 0, 0, 2, 2, 2, 3]) # corresponds 1:1 with rt.values 798 ``` 799 """ 800 if self._cached_value_rowids is not None: 801 return self._cached_value_rowids 802 803 with ops.name_scope(name, "RaggedValueRowIds", [self]): 804 return segment_id_ops.row_splits_to_segment_ids(self.row_splits) 805 806 def nrows(self, out_type=dtypes.int64, name=None): 807 """Returns the number of rows in this ragged tensor. 808 809 I.e., the size of the outermost dimension of the tensor. 810 811 Args: 812 out_type: `dtype` for the returned tensor. 813 name: A name prefix for the returned tensor (optional). 814 815 Returns: 816 A scalar `Tensor` with dtype `out_type`. 817 818 #### Example: 819 ```python 820 >>> rt = ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []]) 821 >>> rt.nrows() # rt has 5 rows. 822 5 823 ``` 824 """ 825 if self._cached_nrows is not None: 826 return self._cached_nrows 827 828 with ops.name_scope(name, "RaggedNRows", [self]): 829 return array_ops.shape(self.row_splits, out_type=out_type)[0] - 1 830 831 def row_starts(self, name=None): 832 """Returns the start indices for rows in this ragged tensor. 833 834 These indices specify where the values for each row begin in 835 `self.values`. `rt.row_starts()` is equal to `rt.row_splits[:-1]`. 836 837 Args: 838 name: A name prefix for the returned tensor (optional). 839 840 Returns: 841 A 1-D Tensor of int64 with shape `[nrows]`. 842 The returned tensor is nonnegative, and is sorted in ascending order. 843 844 #### Example: 845 ```python 846 >>> rt = ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []]) 847 >>> rt.values 848 tf.Tensor([3, 1, 4, 1, 5, 9, 2, 6]) 849 >>> rt.row_starts() # indices of row starts in rt.values 850 tf.Tensor([0, 4, 4, 7, 8]) 851 ``` 852 """ 853 with ops.name_scope(name, "RaggedRowStarts", [self]): 854 return self.row_splits[:-1] 855 856 def row_limits(self, name=None): 857 """Returns the limit indices for rows in this ragged tensor. 858 859 These indices specify where the values for each row end in 860 `self.values`. `rt.row_limits(self)` is equal to `rt.row_splits[:-1]`. 861 862 Args: 863 name: A name prefix for the returned tensor (optional). 864 865 Returns: 866 A 1-D Tensor of int64 with shape `[nrows]`. 867 The returned tensor is nonnegative, and is sorted in ascending order. 868 869 #### Example: 870 ```python 871 >>> rt = ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []]) 872 >>> rt.values 873 tf.Tensor([3, 1, 4, 1, 5, 9, 2, 6]) 874 >>> rt.row_limits() # indices of row limits in rt.values 875 tf.Tensor([4, 4, 7, 8, 8]) 876 ``` 877 """ 878 with ops.name_scope(name, "RaggedRowLimits", [self]): 879 return self.row_splits[1:] 880 881 def row_lengths(self, axis=1, name=None): 882 """Returns the lengths of the rows in this ragged tensor. 883 884 `rt.row_lengths()[i]` indicates the number of values in the 885 `i`th row of `rt`. 886 887 Args: 888 axis: An integer constant indicating the axis whose row lengths should be 889 returned. 890 name: A name prefix for the returned tensor (optional). 891 892 Returns: 893 A potentially ragged Tensor of int64 with shape `self.shape[:axis]`. 894 895 Raises: 896 ValueError: If `axis` is out of bounds. 897 898 #### Example: 899 ```python 900 >>> rt = ragged.constant([[[3, 1, 4], [1]], [], [[5, 9], [2]], [[6]], []]) 901 >>> rt.row_lengths(rt) # lengths of rows in rt 902 tf.Tensor([2, 0, 2, 1, 0]) 903 >>> rt.row_lengths(axis=2) # lengths of axis=2 rows. 904 <tf.RaggedTensor [[3, 1], [], [2, 1], [1], []]> 905 ``` 906 """ 907 if self._cached_row_lengths is not None: 908 return self._cached_row_lengths 909 910 with ops.name_scope(name, "RaggedRowLengths", [self]): 911 axis = ragged_util.get_positive_axis(axis, self.shape.ndims) 912 if axis == 0: 913 return self.nrows() 914 elif axis == 1: 915 splits = self.row_splits 916 return splits[1:] - splits[:-1] 917 elif isinstance(self.values, RaggedTensor): 918 return self.with_values(self.values.row_lengths(axis - 1)) 919 else: 920 shape = array_ops.shape(self.values, out_type=dtypes.int64) 921 return self.with_values( 922 array_ops.ones(shape[:axis - 1], dtypes.int64) * shape[axis - 1]) 923 924 def nested_row_lengths(self, name=None): 925 """Returns a tuple containing the row_lengths for all ragged dimensions. 926 927 `rtnested_row_lengths()` is a tuple containing the `row_lengths` tensors for 928 all ragged dimensions in `rt`, ordered from outermost to innermost. 929 930 Args: 931 name: A name prefix for the returned tensors (optional). 932 933 Returns: 934 A `tuple` of 1-D `int64` `Tensors`. The length of the tuple is equal to 935 `self.ragged_rank`. 936 """ 937 with ops.name_scope(name, "RaggedNestedRowLengths", [self]): 938 rt_nested_row_lengths = [] 939 rt = self 940 while isinstance(rt, RaggedTensor): 941 rt_nested_row_lengths.append(rt.row_lengths()) 942 rt = rt.values 943 return tuple(rt_nested_row_lengths) 944 945 def bounding_shape(self, axis=None, name=None): 946 """Returns the tight bounding box shape for this `RaggedTensor`. 947 948 Args: 949 axis: An integer scalar or vector indicating which axes to return the 950 bounding box for. If not specified, then the full bounding box is 951 returned. 952 name: A name prefix for the returned tensor (optional). 953 954 Returns: 955 An int64 `Tensor`. If `axis` is not specified, then `output` 956 is a vector with `output.shape=[self.shape.ndims]`. If `axis` is a 957 scalar, then the `output` is a scalar. If `axis` is a vector, then 958 `output` is a vector, where `output[i]` is the bounding size for 959 dimension `axis[i]`. 960 961 #### Example: 962 ```python 963 >>> rt = ragged.constant([[1, 2, 3, 4], [5], [], [6, 7, 8, 9], [10]]) 964 >>> rt.bounding_shape() 965 [5, 4] 966 ``` 967 """ 968 with ops.name_scope(name, "RaggedBoundingBox", [self, axis]): 969 nested_splits = self.nested_row_splits 970 rt_flat_values = self.flat_values 971 972 # Optimized special cases for when axis=0 or axis=1: 973 if isinstance(axis, int): 974 if axis == 0: 975 return array_ops.shape(nested_splits[0], out_type=dtypes.int64)[0] - 1 976 elif axis == 1: 977 return math_ops.maximum(math_ops.reduce_max(self.row_lengths()), 0) 978 979 splits_shape = array_ops.shape(self.row_splits, out_type=dtypes.int64) 980 flat_values_shape = array_ops.shape(rt_flat_values, out_type=dtypes.int64) 981 982 ragged_dimensions = array_ops.stack([splits_shape[0] - 1] + [ 983 math_ops.maximum(math_ops.reduce_max(splits[1:] - splits[:-1]), 0) 984 for splits in nested_splits 985 ]) 986 inner_dimensions = flat_values_shape[1:] 987 988 bbox = array_ops.concat([ragged_dimensions, inner_dimensions], axis=0) 989 return bbox if axis is None else array_ops.gather(bbox, axis) 990 991 #============================================================================= 992 # Transformation 993 #============================================================================= 994 995 def with_values(self, new_values): 996 """Returns a copy of `self` with `values` replaced by `new_value`. 997 998 Preserves cached row-partitioning tensors such as `self.cached_nrows` and 999 `self.cached_value_rowids` if they have values. 1000 1001 Args: 1002 new_values: Potentially ragged tensor to use as the `values` for the 1003 returned `RaggedTensor`. Must have `rank > 0`, and must have the same 1004 number of rows as `self.values`. 1005 1006 Returns: 1007 A `RaggedTensor`. `result.rank = 1 + new_values.rank`. 1008 `result.ragged_rank = 1 + new_values.ragged_rank` 1009 """ 1010 new_values.shape.with_rank_at_least(1) 1011 self.values.shape[:1].assert_is_compatible_with(new_values.shape[:1]) 1012 return RaggedTensor( 1013 new_values, 1014 self._row_splits, 1015 self._cached_row_lengths, 1016 self._cached_value_rowids, 1017 self._cached_nrows, 1018 internal=True) 1019 1020 def with_flat_values(self, new_values): 1021 """Returns a copy of `self` with `flat_values` replaced by `new_value`. 1022 1023 Preserves cached row-partitioning tensors such as `self.cached_nrows` and 1024 `self.cached_value_rowids` if they have values. 1025 1026 Args: 1027 new_values: Potentially ragged tensor that should replace 1028 `self.flat_values`. Must have `rank > 0`, and must have the same 1029 number of rows as `self.flat_values`. 1030 1031 Returns: 1032 A `RaggedTensor`. 1033 `result.rank = self.ragged_rank + new_values.rank`. 1034 `result.ragged_rank = self.ragged_rank + new_values.ragged_rank`. 1035 """ 1036 if isinstance(self._values, ops.Tensor): 1037 return self.with_values(new_values) 1038 else: 1039 return self.with_values(self.values.with_flat_values(new_values)) 1040 1041 #============================================================================= 1042 # Tensor Type Conversions 1043 #============================================================================= 1044 1045 @classmethod 1046 def from_tensor(cls, 1047 tensor, 1048 lengths=None, 1049 padding=None, 1050 ragged_rank=1, 1051 name=None): 1052 """Converts a `tf.Tensor` into a `RaggedTensor`. 1053 1054 The set of absent/default values may be specified using a vector of lengths 1055 or a padding value (but not both). If `lengths` is specified, then the 1056 output tensor will satisfy `output[row] = tensor[row][:lengths[row]]`. If 1057 'lengths' is a list of lists or tuple of lists, those lists will be used 1058 as nested row lengths. If `padding` is specified, then any row *suffix* 1059 consisting entirely of `padding` will be excluded from the returned 1060 `RaggedTensor`. If neither `lengths` nor `padding` is specified, then the 1061 returned `RaggedTensor` will have no absent/default values. 1062 1063 Examples: 1064 1065 ```python 1066 >>> dt = tf.constant([[5, 7, 0], [0, 3, 0], [6, 0, 0]]) 1067 >>> tf.RaggedTensor.from_tensor(dt) 1068 <tf.RaggedTensor [[5, 7, 0], [0, 3, 0], [6, 0, 0]]> 1069 >>> tf.RaggedTensor.from_tensor(dt, lengths=[1, 0, 3]) 1070 <tf.RaggedTensor [[5], [], [6, 0, 0]]> 1071 1072 >>> tf.RaggedTensor.from_tensor(dt, padding=0) 1073 <tf.RaggedTensor [[5, 7], [0, 3], [6]]> 1074 1075 >>> dt = tf.constant([[[5, 0], [7, 0], [0, 0]], 1076 [[0, 0], [3, 0], [0, 0]], 1077 [[6, 0], [0, 0], [0, 0]]]) 1078 >>> tf.RaggedTensor.from_tensor(dt, lengths=([2, 0, 3], [1, 1, 2, 0, 1])) 1079 <tf.RaggedTensor [[[5], [7]], [], [[6, 0], [], [0]]]> 1080 ``` 1081 1082 Args: 1083 tensor: The `Tensor` to convert. Must have rank `ragged_rank + 1` or 1084 higher. 1085 lengths: An optional set of row lengths, specified using a 1-D integer 1086 `Tensor` whose length is equal to `tensor.shape[0]` (the number of rows 1087 in `tensor`). If specified, then `output[row]` will contain 1088 `tensor[row][:lengths[row]]`. Negative lengths are treated as zero. You 1089 may optionally pass a list or tuple of lengths to this argument, which 1090 will be used as nested row lengths to construct a ragged tensor with 1091 multiple ragged dimensions. 1092 padding: An optional padding value. If specified, then any row suffix 1093 consisting entirely of `padding` will be excluded from the returned 1094 RaggedTensor. `padding` is a `Tensor` with the same dtype as `tensor` 1095 and with `shape=tensor.shape[ragged_rank + 1:]`. 1096 ragged_rank: Integer specifying the ragged rank for the returned 1097 `RaggedTensor`. Must be greater than zero. 1098 name: A name prefix for the returned tensors (optional). 1099 1100 Returns: 1101 A `RaggedTensor` with the specified `ragged_rank`. The shape of the 1102 returned ragged tensor is compatible with the shape of `tensor`. 1103 Raises: 1104 ValueError: If both `lengths` and `padding` are specified. 1105 """ 1106 if lengths is not None and padding is not None: 1107 raise ValueError("Specify lengths or padding, but not both") 1108 if not isinstance(ragged_rank, int): 1109 raise TypeError("ragged_rank expected int, got %r" % ragged_rank) 1110 if ragged_rank <= 0: 1111 raise ValueError( 1112 "ragged_rank must be greater than 0; got %s" % ragged_rank) 1113 1114 with ops.name_scope(name, "RaggedFromTensor", [tensor, lengths, padding]): 1115 tensor = ops.convert_to_tensor(tensor, name="tensor") 1116 tensor.shape.with_rank_at_least(ragged_rank + 1) 1117 input_shape = array_ops.shape(tensor, out_type=dtypes.int64) 1118 ncols = input_shape[1] 1119 1120 # Handle ragged_rank>1 via recursion: 1121 # If the output should have multiple ragged dimensions, then first 1122 # flatten the tensor to eliminate all but the last ragged dimension, 1123 # and recursively convert that flattened tensor. Then add on the splits 1124 # for the dimensions that we flattened out. 1125 if ragged_rank > 1: 1126 # Flatten `tensor` to eliminate all but the last ragged dimension. 1127 new_shape = array_ops.concat([ 1128 constant_op.constant([-1], dtypes.int64), input_shape[ragged_rank:] 1129 ], 1130 axis=0) 1131 flattened = array_ops.reshape(tensor, new_shape) 1132 # Recursively convert the flattened tensor. 1133 values = cls.from_tensor(flattened, lengths, padding) 1134 # The total number of elements in each dimension. E.g., if 1135 # input_shape=[3, 4, 5, 6], then dim[2] has 3*4*5 elements in total. 1136 dim_size = math_ops.cumprod(input_shape) 1137 # Construct splits tensors for the dimensions that were flattened. 1138 new_splits = [ 1139 math_ops.range(0, dim_size[dim - 1] + 1) * input_shape[dim] 1140 for dim in range(1, ragged_rank) 1141 ] 1142 return cls.from_nested_row_splits(values, new_splits) 1143 1144 # If padding was specified, then use it to find row lengths. 1145 if padding is not None: 1146 padding = ops.convert_to_tensor( 1147 padding, name="padding", dtype=tensor.dtype) 1148 padding.shape.assert_is_compatible_with(tensor.shape[2:]) 1149 1150 # Find places where the padding is equal to the tensor. (This will 1151 # broadcast `padding` across the outermost 2 dimensions of `tensor`, 1152 # so `has_default_value.shape = tensor.shape`.) 1153 has_default_value = math_ops.equal(padding, tensor) 1154 1155 # If the padding isn't a scalar, then require that all values in the 1156 # padding match each item in the tensor. After this block of code, 1157 # `has_default.shape = tensor.shape[:2]`. (Unfortunately, we can't just 1158 # use reduce_all for both cases, becaue when you pass an empty `axis` 1159 # list to reduce_all, it reduces all axes; but we want it to reduce no 1160 # axes -- i.e., to be a no-op.) 1161 tensor_rank = array_ops.rank(tensor) 1162 reduce_axis = math_ops.range(2, tensor_rank) 1163 has_default = control_flow_ops.cond( 1164 tensor_rank > 2, 1165 lambda: math_ops.reduce_all(has_default_value, axis=reduce_axis), 1166 lambda: has_default_value) 1167 has_default.set_shape(tensor_shape.TensorShape([None, None])) 1168 has_default.set_shape(tensor.shape[:2]) 1169 1170 # Use has_default it to find the length of each row: for each 1171 # non-default item in a row, calculate the length that the row needs to 1172 # have to include that item; and then take the max of those values 1173 # (across each row). 1174 has_nondefault = math_ops.logical_not(has_default) 1175 has_nondefault = math_ops.cast(has_nondefault, dtypes.int64) 1176 length_for_nondefault_value = ( 1177 has_nondefault * array_ops.expand_dims( 1178 math_ops.range(1, ncols + 1), 0)) 1179 lengths = math_ops.reduce_max(length_for_nondefault_value, axis=1) 1180 1181 if lengths is not None: 1182 if isinstance(lengths, 1183 (list, tuple)) and len(lengths) and not isinstance( 1184 lengths[0], (int, float)): 1185 # In this case, we've been given nested row lengths. Rather than 1186 # reconstructing the tensor mask directly, we can recreate it as 1187 # a boolean RaggedTensor, then densify that and use that as the 1188 # mask to clear out the unused data in the passed tensor. 1189 tensor.shape.with_rank_at_least(len(lengths) + 1) 1190 num_tokens = math_ops.reduce_sum(lengths[-1]) 1191 ones_mask = array_ops.ones([num_tokens], dtype=dtypes.bool) 1192 ragged_mask = cls.from_nested_row_lengths(ones_mask, lengths) 1193 dense_ragged_mask = ragged_mask.to_tensor(default_value=False) 1194 masked_data = array_ops.boolean_mask(tensor, dense_ragged_mask) 1195 return cls.from_nested_row_lengths(masked_data, lengths) 1196 else: 1197 # If we have lengths (either directly supplied, or computed from 1198 # paddings), then use those to construct splits; and then use masking 1199 # to get the corresponding values. 1200 lengths = ragged_util.convert_to_int_tensor(lengths, "lengths", 1201 dtypes.int64) 1202 lengths.shape.assert_has_rank(1) 1203 lengths = math_ops.minimum(lengths, ncols) 1204 lengths = math_ops.maximum(lengths, 0) 1205 limits = math_ops.cumsum(lengths) 1206 splits = array_ops.concat( 1207 [array_ops.zeros([1], dtypes.int64), limits], axis=0) 1208 mask = array_ops.sequence_mask(lengths, maxlen=ncols) 1209 values = array_ops.boolean_mask(tensor, mask) 1210 return cls.from_row_splits(values, splits) 1211 1212 # If neither padding nor lengths were specified, then create a splits 1213 # vector that contains no default values, and reshape the input tensor 1214 # to form the values for the RaggedTensor. 1215 nrows = input_shape[0] 1216 nvals = nrows * ncols 1217 splits = math_ops.range(nrows + 1) * ncols 1218 values_shape = array_ops.concat([[nvals], input_shape[2:]], axis=0) 1219 values = array_ops.reshape(tensor, values_shape) 1220 return cls.from_row_splits(values, splits) 1221 1222 def to_tensor(self, default_value=None, name=None): 1223 """Converts this `RaggedTensor` into a `tf.Tensor`. 1224 1225 Example: 1226 1227 ```python 1228 >>> rt = ragged.constant([[9, 8, 7], [], [6, 5], [4]]) 1229 >>> print rt.to_tensor() 1230 [[9 8 7] 1231 [0 0 0] 1232 [6 5 0] 1233 [4 0 0]] 1234 ``` 1235 1236 Args: 1237 default_value: Value to set for indices not specified in `self`. Defaults 1238 to zero. `default_value` must be broadcastable to 1239 `self.shape[self.ragged_rank + 1:]`. 1240 name: A name prefix for the returned tensors (optional). 1241 1242 Returns: 1243 A `Tensor` with shape `ragged.bounding_shape(self)` and the 1244 values specified by the non-empty values in `self`. Empty values are 1245 assigned `default_value`. 1246 """ 1247 with ops.name_scope(name, "RaggedToTensor", [self, default_value]): 1248 if default_value is not None: 1249 default_value = ops.convert_to_tensor( 1250 default_value, name="default_value", dtype=self.dtype) 1251 1252 # If ragged_rank > 1, then recursively convert the ragged values into a 1253 # `Tensor` before we proceed. 1254 values = self.values 1255 if is_ragged(values): 1256 values = values.to_tensor(default_value) 1257 1258 # Tile the default value, if necessary. 1259 if default_value is not None: 1260 if values.shape.ndims is not None: 1261 default_value.shape.with_rank_at_most(values.shape.ndims - 1) 1262 if (values.shape.ndims is None or default_value.shape.ndims is None or 1263 values.shape.ndims != default_value.shape.ndims + 1): 1264 value_shape = array_ops.shape(values)[1:] 1265 default_value = array_ops.broadcast_to(default_value, value_shape) 1266 default_value.shape.assert_is_compatible_with(values.shape[1:]) 1267 1268 # Get the expected dense shape ([nrows, ncols] + value_shape). 1269 rt_row_lengths = [self.row_splits[1:] - self.row_splits[:-1]] 1270 nrows = array_ops.shape(self.row_splits, out_type=dtypes.int64)[0] - 1 1271 ncols = math_ops.maximum(math_ops.reduce_max(rt_row_lengths), 0) 1272 values_shape = array_ops.shape(values, out_type=dtypes.int64) 1273 value_shape = values_shape[1:] 1274 nvals = values_shape[0] 1275 1276 # Build a default value if none was supplied. 1277 if default_value is None: 1278 default_value = array_ops.zeros(value_shape, dtype=values.dtype) 1279 default_value.shape.assert_is_compatible_with(values.shape[1:]) 1280 default_value.set_shape(values.shape[1:]) 1281 1282 # Get the row start indices, and expand to shape=[nrows, 1]. 1283 starts = array_ops.expand_dims(self.row_splits[:-1], 1) 1284 1285 # Get the row limit indices, and expand to shape=[nrows, 1]. 1286 limits = array_ops.expand_dims(self.row_splits[1:], 1) 1287 1288 # Get the column indices, and expand to shape=[1, ncols]. 1289 columns = array_ops.expand_dims(math_ops.range(0, ncols), 0) 1290 1291 # Build a list containing the values plus the default value. We will use 1292 # tf.gather to collect values from this list for the `Tensor` (using 1293 # nvals as the index for the default value). 1294 values_and_default = array_ops.concat( 1295 [values, array_ops.stack([default_value])], axis=0) 1296 1297 # Construct a matrix "indices" pointing into values_and_default. I.e., 1298 # output[r, c] = values_and_default[indices[r, c]. 1299 nondefault_index = starts + columns 1300 has_value = nondefault_index < limits 1301 default_index = array_ops.fill(array_ops.stack([nrows, ncols]), nvals) 1302 indices = array_ops.where(has_value, nondefault_index, default_index) 1303 1304 # Gather the results into a `Tensor`. 1305 return array_ops.gather(values_and_default, indices) 1306 1307 @classmethod 1308 def from_sparse(cls, st_input, name=None): 1309 """Converts a 2D `tf.SparseTensor` to a `RaggedTensor`. 1310 1311 Each row of the `output` `RaggedTensor` will contain the explicit values 1312 from the same row in `st_input`. `st_input` must be ragged-right. If not 1313 it is not ragged-right, then an error will be generated. 1314 1315 Example: 1316 1317 ```python 1318 >>> st = SparseTensor(indices=[[0, 1], [0, 2], [0, 3], [1, 0], [3, 0]], 1319 ... values=[1, 2, 3, 4, 5], 1320 ... dense_shape=[4, 3]) 1321 >>> rt.RaggedTensor.from_sparse(st).eval().tolist() 1322 [[1, 2, 3], [4], [], [5]] 1323 ``` 1324 1325 Currently, only two-dimensional `SparseTensors` are supported. 1326 1327 Args: 1328 st_input: The sparse tensor to convert. Must have rank 2. 1329 name: A name prefix for the returned tensors (optional). 1330 1331 Returns: 1332 A `RaggedTensor` with the same values as `st_input`. 1333 `output.ragged_rank = rank(st_input) - 1`. 1334 `output.shape = [st_input.dense_shape[0], None]`. 1335 Raises: 1336 ValueError: If the number of dimensions in `st_input` is not known 1337 statically, or is not two. 1338 """ 1339 if not sparse_tensor.is_sparse(st_input): 1340 raise TypeError("Expected SparseTensor, got %s" % type(st_input).__name__) 1341 with ops.name_scope(name, "RaggedFromSparse", [st_input]): 1342 st_input = sparse_tensor.convert_to_tensor_or_sparse_tensor( 1343 st_input, name="st_input") 1344 1345 if st_input.dense_shape.shape.ndims is None: 1346 static_rank_from_dense_shape = None 1347 else: 1348 static_rank_from_dense_shape = st_input.dense_shape.shape.dims[0].value 1349 1350 if st_input.indices.shape.ndims is None: 1351 static_rank_from_indices = None 1352 else: 1353 static_rank_from_indices = st_input.indices.shape.dims[1].value 1354 1355 if static_rank_from_dense_shape != 2 and static_rank_from_indices != 2: 1356 raise ValueError("rank(st_input) must be 2") 1357 1358 with ops.control_dependencies( 1359 _assert_sparse_indices_are_ragged_right(st_input.indices)): 1360 # Treat sparse row indices as segment ids to generate a splits tensor 1361 # thta we can pair with the sparse tensor values. (Ignore sparse column 1362 # indices.) 1363 segment_ids = st_input.indices[:, 0] 1364 num_segments = st_input.dense_shape[0] 1365 return cls.from_value_rowids(st_input.values, segment_ids, num_segments) 1366 1367 def to_sparse(self, name=None): 1368 """Converts this `RaggedTensor` into a `tf.SparseTensor`. 1369 1370 Example: 1371 1372 ```python 1373 >>> rt = ragged.constant([[1, 2, 3], [4], [], [5, 6]]) 1374 >>> rt.to_sparse().eval() 1375 SparseTensorValue(indices=[[0, 0], [0, 1], [0, 2], [1, 0], [3, 0], [3, 1]], 1376 values=[1, 2, 3, 4, 5, 6], 1377 dense_shape=[4, 3]) 1378 ``` 1379 1380 Args: 1381 name: A name prefix for the returned tensors (optional). 1382 1383 Returns: 1384 A SparseTensor with the same values as `self`. 1385 """ 1386 with ops.name_scope(name, "RaggedToSparse", [self]): 1387 result = gen_ragged_conversion_ops.ragged_tensor_to_sparse( 1388 self.nested_row_splits, self.flat_values, name=name) 1389 return sparse_tensor.SparseTensor(result.sparse_indices, 1390 result.sparse_values, 1391 result.sparse_dense_shape) 1392 1393 #============================================================================= 1394 # String Encoding 1395 #============================================================================= 1396 def __str__(self): 1397 if self._is_eager(): 1398 return "<tf.RaggedTensor %s>" % self.to_list() 1399 else: 1400 return self.__repr__() 1401 1402 def __repr__(self): 1403 return "tf.RaggedTensor(values=%s, row_splits=%s)" % (self._values, 1404 self._row_splits) 1405 1406 #============================================================================= 1407 # Eager Execution Mode 1408 #============================================================================= 1409 1410 def to_list(self): 1411 """Returns a nested Python `list` with the values for this `RaggedTensor`. 1412 1413 Requires that `rt` was constructed in eager execution mode. 1414 1415 Returns: 1416 A nested Python `list`. 1417 """ 1418 if self._is_eager(): 1419 return self._eager_value().to_list() 1420 else: 1421 raise ValueError("RaggedTensor.to_list() is only supported in eager " 1422 "mode; in graph mode, evaluate the RaggedTensor first " 1423 "and then use RaggedTensorValue.to_list().") 1424 1425 def _eager_value(self): 1426 """Returns a RaggedTensorValue for self. Requires self._is_eager()=true.""" 1427 value = self.flat_values.numpy() 1428 for row_splits in reversed(self.nested_row_splits): 1429 value = ragged_tensor_value.RaggedTensorValue(value, row_splits.numpy()) 1430 return value 1431 1432 def _is_eager(self): 1433 """Returns True if values & row_splits Tensors are all `EagerTensor`s.""" 1434 rt = self 1435 while isinstance(rt, RaggedTensor): 1436 if not isinstance(rt.row_splits, ops.EagerTensor): 1437 return False 1438 rt = rt.values 1439 return isinstance(rt, ops.EagerTensor) 1440 1441 #============================================================================= 1442 # Indexing & Slicing 1443 #============================================================================= 1444 def __getitem__(self, key): 1445 """Returns the specified piece of this RaggedTensor.""" 1446 # See ragged_getitem.py for the documentation and implementation of this 1447 # method. 1448 # 1449 # Note: the imports in ragged/__init__.py ensure that this method always 1450 # gets overridden before it is called. 1451 1452 #============================================================================= 1453 # Name Scope 1454 #============================================================================= 1455 1456 # This private function is used by ops.name_scope to ensure that all of the 1457 # input tensors for the scope belong to the same graph. Defining this means 1458 # that you may include `RaggedTensor` objects in the name_scope `values` 1459 # list. 1460 def _as_graph_element(self): 1461 """Convert `self` to a graph element.""" 1462 values = self.values 1463 while isinstance(values, RaggedTensor): 1464 values = values.values 1465 return values 1466 1467 #============================================================================= 1468 # Composite Tensor 1469 #============================================================================= 1470 1471 def _to_components(self): 1472 return (self.flat_values,) + self.nested_row_splits 1473 1474 @classmethod 1475 def _from_components(cls, components): 1476 return cls.from_nested_row_splits(components[0], components[1:]) 1477 1478 def _shape_invariant_to_components(self, shape=None): 1479 ragged_rank = self.ragged_rank 1480 flat_values = self.flat_values 1481 1482 if shape is None: 1483 # Default shape invariant 1484 value_shape = flat_values.shape[1:] 1485 values_shape = tensor_shape.TensorShape([None]).concatenate(value_shape) 1486 return ((values_shape, self._row_splits.shape) + 1487 tuple(tensor_shape.TensorShape([None]) 1488 for i in range(1, ragged_rank))) 1489 else: 1490 # Explicitly specified shape invariant 1491 if shape.ndims is not None and shape.ndims <= ragged_rank: 1492 raise ValueError("Shape invariant %s does not have sufficient rank " 1493 "for a RaggedTensor with %d ragged dimensions." % 1494 (shape, self.ragged_rank)) 1495 if any(tensor_shape.dimension_value(shape[dim]) is not None 1496 for dim in range(1, self.ragged_rank + 1)): 1497 raise ValueError("Shape invariant dimension size must be None for " 1498 "ragged dimenions.") 1499 nrows = tensor_shape.dimension_value(shape[0]) 1500 value_shape = shape[self.ragged_rank + 1:] 1501 values_shape = tensor_shape.TensorShape([None]).concatenate(value_shape) 1502 if nrows is None: 1503 outer_splits_shape = tensor_shape.TensorShape([None]) 1504 else: 1505 outer_splits_shape = tensor_shape.TensorShape([nrows + 1]) 1506 return ((values_shape, outer_splits_shape) + 1507 tuple(tensor_shape.TensorShape([None]) 1508 for i in range(1, ragged_rank))) 1509 1510 @property 1511 def _is_graph_tensor(self): 1512 return hasattr(self._values, 'graph') 1513 1514 1515def is_ragged(value): 1516 """Returns true if `value` is a ragged tensor or ragged tensor value.""" 1517 return isinstance(value, 1518 (RaggedTensor, ragged_tensor_value.RaggedTensorValue)) 1519 1520 1521#=============================================================================== 1522# Convert value -> tensor 1523#=============================================================================== 1524def convert_to_tensor_or_ragged_tensor(value, 1525 dtype=None, 1526 preferred_dtype=None, 1527 name=None): 1528 """Converts value to a `RaggedTensor` or `Tensor`. 1529 1530 * If `value` is a `RaggedTensor`, then return it as-is. 1531 * If `value` is a `RaggedTensorValue`, return a corresponding constant 1532 `RaggedTensor`. 1533 * Otherwise, use `convert_to_tensor` to convert `value` to a `Tensor`. 1534 1535 Args: 1536 value: A `RaggedTensor`, a `RaggedTensorValue`, or an object whose type has 1537 a registered `Tensor` conversion function. 1538 dtype: Optional element type for the returned tensor. If missing the type 1539 is inferred from the type of `value`. 1540 preferred_dtype: Optional element type for the returned tensor, used when 1541 dtype is None. This argument has no effect if `value` is already a 1542 tensor, or when conversion is not possible. 1543 name: Optional name to use if a new `Tensor` is created. 1544 1545 Returns: 1546 A `Tensor` or `RaggedTensor`. 1547 """ 1548 if isinstance(value, RaggedTensor): 1549 if dtype and not dtype.is_compatible_with(value.dtype): 1550 raise ValueError("Tensor conversion requested dtype %s for " 1551 "RaggedTensor with dtype %s: %r" % 1552 (dtype.name, value.dtype.name, value)) 1553 return value 1554 elif isinstance(value, ragged_tensor_value.RaggedTensorValue): 1555 with ops.name_scope(name, "ConvertToTensorOrRaggedTensor", []): 1556 flat_values = ops.convert_to_tensor( 1557 value=value.flat_values, 1558 dtype=dtype, 1559 preferred_dtype=preferred_dtype, 1560 name="flat_values") 1561 return RaggedTensor.from_nested_row_splits(flat_values, 1562 value.nested_row_splits) 1563 else: 1564 return ops.convert_to_tensor( 1565 value=value, dtype=dtype, preferred_dtype=preferred_dtype, name=name) 1566 1567 1568#=============================================================================== 1569# Register RaggedTensor for use with session.run. 1570#=============================================================================== 1571def _ragged_tensor_value_from_components(components): 1572 components = list(components) 1573 value = components.pop() 1574 while components: 1575 value = ragged_tensor_value.RaggedTensorValue(value, components.pop()) 1576 return value 1577 1578 1579def _ragged_tensor_session_fetch(rt): 1580 components = rt.nested_row_splits + (rt.flat_values,) 1581 return (components, _ragged_tensor_value_from_components) 1582 1583 1584def _ragged_tensor_session_feed(feed_key, feed_val): 1585 key_components = feed_key.nested_row_splits + (feed_key.flat_values,) 1586 val_components = feed_val.nested_row_splits + (feed_val.flat_values,) 1587 return zip(key_components, val_components) 1588 1589 1590def _ragged_tensor_session_feed_for_partial_run(feed_key): 1591 return feed_key.nested_row_splits + (feed_key.flat_values,) 1592 1593 1594session.register_session_run_conversion_functions( 1595 RaggedTensor, _ragged_tensor_session_fetch, _ragged_tensor_session_feed, 1596 _ragged_tensor_session_feed_for_partial_run) 1597 1598 1599#=============================================================================== 1600# RaggedTensorType 1601#=============================================================================== 1602class RaggedTensorType(object): 1603 """Encoding of a static type for a `RaggedTensor`. 1604 1605 Use this type to express/declare that an output must have the type of 1606 `RaggedTensor`. 1607 """ 1608 1609 def __init__(self, dtype, ragged_rank): 1610 """Initializes a RaggedTensorType object. 1611 1612 Args: 1613 dtype: data type of the `RaggedTensor`'s inner values. 1614 ragged_rank: ragged_rank of the declared `RaggedTensor`. 1615 """ 1616 self._dtype = dtype 1617 self._ragged_rank = ragged_rank 1618 1619 dtype = property(lambda self: self._dtype) 1620 ragged_rank = property(lambda self: self._ragged_rank) 1621 1622 1623#=============================================================================== 1624# Helper Functions 1625#=============================================================================== 1626def _assert_sparse_indices_are_ragged_right(indices): 1627 """Checks that the given SparseTensor.indices tensor is ragged-right. 1628 1629 Example: `indices = [[0, 0], [0, 1], [2, 0], [3, 1]]` is not ragged right 1630 because the entry `[3, 1]` skips a cell. 1631 1632 Args: 1633 indices: The SparseTensor indices to check. 1634 1635 Returns: 1636 A list of control dependency op tensors. 1637 """ 1638 index_prefix = indices[:, :-1] 1639 index_suffix = indices[:, -1] 1640 1641 # Check whether each index is starting a new row in the innermost dimension 1642 # (prefix[i] != prefix[i-1]) or continuing a row (prefix[i] == prefix[i-1]). 1643 # (Note: this skips the first index; we will check that separately below.) 1644 index_prefix_changed = math_ops.reduce_any( 1645 math_ops.not_equal(index_prefix[1:], index_prefix[:-1]), axis=1) 1646 1647 # Check two cases: 1648 # * For indices that start a new row: index_suffix[i] must be zero. 1649 # * For indices that continue a row: index_suffix[i] must be equal to 1650 # index_suffix[i-1]+1. 1651 index_ok = array_ops.where( 1652 index_prefix_changed, math_ops.equal(index_suffix[1:], 0), 1653 math_ops.equal(index_suffix[1:], index_suffix[:-1] + 1)) 1654 1655 # Also check that the very first index didn't skip any cells. The first 1656 # index starts a new row (by definition), so its suffix should be zero. 1657 sparse_indices_are_ragged_right = math_ops.logical_and( 1658 math_ops.reduce_all(math_ops.equal(index_suffix[:1], 0)), 1659 math_ops.reduce_all(index_ok)) 1660 1661 message = [ 1662 "SparseTensor is not right-ragged", "SparseTensor.indices =", indices 1663 ] 1664 return [control_flow_ops.Assert(sparse_indices_are_ragged_right, message)] 1665 1666 1667@ops.RegisterGradient("RaggedTensorToSparse") 1668def _ragged_tensor_to_sparse_gradient(op, unused_sparse_indices_grad, 1669 sparse_values_grad, 1670 unused_sparse_shape_grad): 1671 """Gradient for RaggedTensorToSparse.""" 1672 op_inputs_nested_row_splits = op.inputs[:-1] 1673 op_inputs_flat_values = op.inputs[-1] 1674 1675 # No gradient for the RaggedTensor's nested_row_splits. 1676 nested_row_splits_gradient = [None] * len(op_inputs_nested_row_splits) 1677 1678 # Gradient for the RaggedTensor's flat_values is formed by reshaping 1679 # the gradient for the SparseTensor's values. 1680 flat_values_shape = array_ops.shape(op_inputs_flat_values) 1681 flat_values_gradient = array_ops.reshape(sparse_values_grad, 1682 flat_values_shape) 1683 1684 return nested_row_splits_gradient + [flat_values_gradient] 1685