1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A class used to partition a sequence into contiguous subsequences ("rows"). 16""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import numpy as np 23 24from tensorflow.python.framework import composite_tensor 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import tensor_shape 29from tensorflow.python.framework import tensor_spec 30from tensorflow.python.framework import tensor_util 31from tensorflow.python.framework import type_spec 32from tensorflow.python.ops import array_ops 33from tensorflow.python.ops import check_ops 34from tensorflow.python.ops import control_flow_ops 35from tensorflow.python.ops import math_ops 36from tensorflow.python.ops.ragged import segment_id_ops 37 38 39#=============================================================================== 40# RowPartition 41#=============================================================================== 42# TODO(edloper): Consider removing row_starts and row_limits factory methods 43# and accessors from RowPartition. In particular, these two encodings are 44# "second-class citizens": we never cache them, and if you do construct a 45# RowPartition from them then it may be more expensive than you might expect 46# (because we append a value to the beginning/end to transform them into 47# splits). If we do remove them from RowPartition, then we would still keep 48# the from_row_starts and from_row_limits factory methods in RaggedTensor. 49 50 51class RowPartition(composite_tensor.CompositeTensor): 52 """Partitioning of a sequence of values into contiguous subsequences ("rows"). 53 54 A `RowPartition` describes how a sequence with `nvals` items should be 55 divided into `nrows` contiguous subsequences ("rows"). For example, a 56 `RowPartition` could be used to partition the vector `[1, 2, 3, 4, 5]` into 57 subsequences `[[1, 2], [3], [], [4, 5]]`. Note that `RowPartition` stores 58 information about how values are partitioned, but does not include the 59 partitioned values themselves. `tf.RaggedTensor` is used to pair a `values` 60 tensor with one or more `RowPartition`s, providing a complete encoding for a 61 ragged tensor (i.e. a tensor with variable-length dimensions). 62 63 `RowPartition`s may be defined using several different schemes: 64 65 * `row_lengths`: an integer vector with shape `[nrows]`, which specifies 66 the length of each row. 67 68 * `row_splits`: an integer vector with shape `[nrows+1]`, specifying the 69 "split points" between each row. 70 71 * `row_starts`: an integer vector with shape `[nrows]`, which specifies 72 the start offset for each row. Equivalent to `row_splits[:-1]`. 73 74 * `row_limits`: an integer vector with shape `[nrows]`, which specifies 75 the stop offset for each row. Equivalent to `row_splits[1:]`. 76 77 * `value_rowids` is an integer vector with shape `[nvals]`, corresponding 78 one-to-one with sequence values, which specifies the row that each value 79 belongs to. If the partition has empty trailing rows, then `nrows` 80 must also be specified. 81 82 * `uniform_row_length` is an integer scalar, specifying the length of every 83 row. This scheme may only be used if all rows have the same length. 84 85 For example, the following `RowPartition`s all represent the partitioning of 86 8 values into 5 sublists as follows: `[[*, *, *, *], [], [*, *, *], [*], []]`. 87 88 >>> p1 = RowPartition.from_row_lengths([4, 0, 3, 1, 0]) 89 >>> p2 = RowPartition.from_row_splits([0, 4, 4, 7, 8, 8]) 90 >>> p3 = RowPartition.from_row_starts([0, 4, 4, 7, 8], nvals=8) 91 >>> p4 = RowPartition.from_row_limits([4, 4, 7, 8, 8]) 92 >>> p5 = RowPartition.from_value_rowids([0, 0, 0, 0, 2, 2, 2, 3], nrows=5) 93 94 For more information about each scheme, see the documentation for the 95 its factory method. For additional examples, see the documentation on 96 `tf.RaggedTensor`. 97 98 ### Precomputed Encodings 99 100 `RowPartition` always stores at least one encoding of the partitioning, but 101 it can be configured to cache additional encodings as well. This can 102 avoid unnecessary recomputation in eager mode. (In graph mode, optimizations 103 such as common subexpression elimination will typically prevent these 104 unnecessary recomputations.) To check which encodings are precomputed, use 105 `RowPartition.has_precomputed_<encoding>`. To cache an additional 106 encoding, use `RowPartition.with_precomputed_<encoding>`. 107 """ 108 109 #============================================================================= 110 # Constructor (private) 111 #============================================================================= 112 def __init__(self, 113 row_splits, 114 row_lengths=None, 115 value_rowids=None, 116 nrows=None, 117 uniform_row_length=None, 118 internal=False): 119 """Creates a `RowPartition` from the specified encoding tensor(s). 120 121 This constructor is private -- please use one of the following ops to 122 build `RowPartition`s: 123 124 * `RowPartition.from_row_lengths` 125 * `RowPartition.from_value_rowids` 126 * `RowPartition.from_row_splits` 127 * `RowPartition.from_row_starts` 128 * `RowPartition.from_row_limits` 129 130 Args: 131 row_splits: A 1-D integer tensor with shape `[nrows+1]`. 132 row_lengths: A 1-D integer tensor with shape `[nrows]` 133 value_rowids: A 1-D integer tensor with shape `[nvals]`. 134 nrows: A 1-D integer scalar tensor. 135 uniform_row_length: A scalar tensor. 136 internal: Private key value, required to ensure that this private 137 constructor is *only* called from the factory methods. 138 139 Raises: 140 TypeError: If a row partitioning tensor has an inappropriate dtype. 141 TypeError: If exactly one row partitioning argument was not specified. 142 ValueError: If a row partitioning tensor has an inappropriate shape. 143 ValueError: If multiple partitioning arguments are specified. 144 ValueError: If nrows is specified but value_rowids is not None. 145 """ 146 if internal is not _row_partition_factory_key: 147 raise ValueError("RaggedTensor constructor is private; please use one " 148 "of the factory methods instead (e.g., " 149 "RaggedTensor.from_row_lengths())") 150 151 # Validate the arguments. 152 if not isinstance(row_splits, ops.Tensor): 153 raise TypeError("Row-partitioning argument must be a Tensor, got %r" % 154 row_splits) 155 if row_splits.dtype not in (dtypes.int32, dtypes.int64): 156 raise ValueError("Row-partitioning argument must be int32 or int64") 157 158 # Validate shapes & dtypes. 159 row_splits.shape.assert_has_rank(1) 160 row_splits.set_shape([None]) 161 self._row_splits = row_splits 162 163 # Store any cached tensors. These are used to avoid unnecessary 164 # round-trip conversions when a RaggedTensor is constructed from 165 # lengths or rowids, and we later want those lengths/rowids back. 166 for tensor in [row_lengths, value_rowids, nrows]: 167 if tensor is not None: 168 if not isinstance(tensor, ops.Tensor): 169 raise TypeError("Cached value must be a Tensor or None.") 170 elif tensor.dtype not in (dtypes.int32, dtypes.int64): 171 raise TypeError("Cached value must be int32 or int64.") 172 self._row_lengths = row_lengths 173 self._value_rowids = value_rowids 174 self._nrows = nrows 175 176 if uniform_row_length is not None: 177 if not isinstance(uniform_row_length, ops.Tensor): 178 raise TypeError("uniform_row_length must be a Tensor or None.") 179 elif uniform_row_length.dtype not in (dtypes.int32, dtypes.int64): 180 raise TypeError("uniform_row_length must be int32 or int64.") 181 self._uniform_row_length = uniform_row_length 182 183 #============================================================================= 184 # Factory Methods 185 #============================================================================= 186 187 @classmethod 188 def from_value_rowids(cls, 189 value_rowids, 190 nrows=None, 191 validate=True, 192 preferred_dtype=None): 193 """Creates a `RowPartition` with rows partitioned by `value_rowids`. 194 195 This `RowPartition` divides a sequence `values` into rows by specifying 196 which row each value should be added to: 197 198 ```python 199 partitioned_rows = [[] for _ in nrows] 200 for (value, rowid) in zip(values, value_rowids): 201 partitioned_rows[rowid].append(value) 202 `` 203 204 Args: 205 value_rowids: A 1-D integer tensor with shape `[nvals]`, which corresponds 206 one-to-one with `values`, and specifies each value's row index. Must be 207 nonnegative, and must be sorted in ascending order. 208 nrows: An integer scalar specifying the number of rows. This should be 209 specified if the `RowPartition` may containing empty training rows. Must 210 be greater than `value_rowids[-1]` (or greater than or equal to zero if 211 `value_rowids` is empty). Defaults to `value_rowids[-1]` (or zero if 212 `value_rowids` is empty). 213 validate: If true, then use assertions to check that the arguments form a 214 valid `RowPartition`. 215 preferred_dtype: The dtype to encode value_rowids if it doesn't already 216 have one. The default is tf.int64. 217 218 Returns: 219 A `RowPartition`. 220 221 Raises: 222 ValueError: If `nrows` is incompatible with `value_rowids`. 223 224 #### Example: 225 226 >>> print(RowPartition.from_value_rowids( 227 ... value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], 228 ... nrows=4)) 229 tf.RowPartition(row_splits=tf.Tensor([0 4 4 7 8], shape=(5,), dtype=int64)) 230 """ 231 # Local import bincount_ops to avoid import-cycle since bincount_ops 232 # imports ragged_tensor. 233 from tensorflow.python.ops import bincount_ops # pylint: disable=g-import-not-at-top 234 if not isinstance(validate, bool): 235 raise TypeError("validate must have type bool") 236 with ops.name_scope(None, "RowPartitionFromValueRowIds", 237 [value_rowids, nrows]): 238 value_rowids = cls._convert_row_partition(value_rowids, "value_rowids", 239 preferred_dtype) 240 if nrows is None: 241 const_rowids = tensor_util.constant_value(value_rowids) 242 if const_rowids is None: 243 nrows = array_ops.concat([value_rowids[-1:], [-1]], axis=0)[0] + 1 244 const_nrows = None 245 else: 246 const_nrows = const_rowids[-1] + 1 if const_rowids.size > 0 else 0 247 nrows = ops.convert_to_tensor( 248 const_nrows, value_rowids.dtype, name="nrows") 249 else: 250 nrows = ops.convert_to_tensor(nrows, value_rowids.dtype, "nrows") 251 const_nrows = tensor_util.constant_value(nrows) 252 if const_nrows is not None: 253 if const_nrows < 0: 254 raise ValueError("Expected nrows >= 0; got %d" % const_nrows) 255 const_rowids = tensor_util.constant_value(value_rowids) 256 if const_rowids is not None and const_rowids.size > 0: 257 if not const_nrows >= const_rowids[-1] + 1: 258 raise ValueError( 259 "Expected nrows >= value_rowids[-1] + 1; got nrows=%d, " 260 "value_rowids[-1]=%d" % (const_nrows, const_rowids[-1])) 261 262 value_rowids.shape.assert_has_rank(1) 263 nrows.shape.assert_has_rank(0) 264 265 if validate: 266 msg = ("Arguments to from_value_rowids do not form a valid " 267 "RowPartition") 268 checks = [ 269 check_ops.assert_rank(value_rowids, 1, message=msg), 270 check_ops.assert_rank(nrows, 0, message=msg), 271 check_ops.assert_non_negative(value_rowids[:1], message=msg), 272 _assert_monotonic_increasing(value_rowids, message=msg), 273 check_ops.assert_less(value_rowids[-1:], nrows, message=msg), 274 ] 275 value_rowids = control_flow_ops.with_dependencies(checks, value_rowids) 276 277 # Convert value_rowids & nrows to row_splits. 278 # Note: we don't use segment_ids_to_row_splits() here because we want 279 # to save the intermediate value `row_lengths`, so we can cache it. 280 # TODO(b/116708836) Upgrade bincount to accept int64 so we can skip the 281 # cast. 282 value_rowids_int32 = math_ops.cast(value_rowids, dtypes.int32) 283 nrows_int32 = math_ops.cast(nrows, dtypes.int32) 284 row_lengths = bincount_ops.bincount( 285 value_rowids_int32, 286 minlength=nrows_int32, 287 maxlength=nrows_int32, 288 dtype=value_rowids.dtype) 289 row_splits = array_ops.concat([[0], math_ops.cumsum(row_lengths)], axis=0) 290 if const_nrows is not None: 291 row_lengths.set_shape([const_nrows]) 292 row_splits.set_shape([const_nrows + 1]) 293 294 return cls( 295 row_splits=row_splits, 296 row_lengths=row_lengths, 297 value_rowids=value_rowids, 298 nrows=nrows, 299 internal=_row_partition_factory_key) 300 301 @classmethod 302 def from_row_splits(cls, row_splits, validate=True, preferred_dtype=None): 303 """Creates a `RowPartition` with rows partitioned by `row_splits`. 304 305 This `RowPartition` divides a sequence `values` into rows by indicating 306 where each row begins and ends: 307 308 ```python 309 partitioned_rows = [] 310 for i in range(len(row_splits) - 1): 311 row_start = row_splits[i] 312 row_end = row_splits[i + 1] 313 partitioned_rows.append(values[row_start:row_end]) 314 ``` 315 316 Args: 317 row_splits: A 1-D integer tensor with shape `[nrows+1]`. Must not be 318 empty, and must be sorted in ascending order. `row_splits[0]` must be 319 zero. 320 validate: If true, then use assertions to check that the arguments form a 321 valid `RowPartition`. 322 preferred_dtype: If row_splits has an unspecified type, use this one. If 323 preferred_dtype is None, defaults to dtypes.int64. 324 325 Returns: 326 A `RowPartition`. 327 328 Raises: 329 ValueError: If `row_splits` is an empty list. 330 """ 331 if not isinstance(validate, bool): 332 raise TypeError("validate must have type bool") 333 if isinstance(row_splits, (list, tuple)) and not row_splits: 334 raise ValueError("row_splits tensor may not be empty.") 335 if isinstance(row_splits, tensor_spec.TensorSpec): 336 return cls(row_splits=row_splits, internal=_row_partition_factory_key) 337 338 with ops.name_scope(None, "RowPartitionFromRowSplits", [row_splits]): 339 row_splits = cls._convert_row_partition(row_splits, "row_splits", 340 preferred_dtype) 341 row_splits.shape.assert_has_rank(1) 342 343 if validate: 344 msg = "Arguments to from_row_splits do not form a valid RaggedTensor:" 345 checks = [ 346 check_ops.assert_rank(row_splits, 1, message=(msg + "rank")), 347 _assert_zero(row_splits[0], message=(msg + "zero")), 348 _assert_monotonic_increasing( 349 row_splits, message=(msg + "monotonic")), 350 ] 351 row_splits = control_flow_ops.with_dependencies(checks, row_splits) 352 353 return cls(row_splits=row_splits, internal=_row_partition_factory_key) 354 355 @classmethod 356 def from_row_lengths(cls, row_lengths, validate=True, preferred_dtype=None): 357 """Creates a `RowPartition` with rows partitioned by `row_lengths`. 358 359 This `RowPartition` divides a sequence `values` into rows by indicating 360 the length of each row: 361 362 ```python 363 partitioned_rows = [[values.pop(0) for _ in range(length)] 364 for length in row_lengths] 365 ``` 366 367 Args: 368 row_lengths: A 1-D integer tensor with shape `[nrows]`. Must be 369 nonnegative. 370 validate: If true, then use assertions to check that the arguments form a 371 valid `RowPartition`. 372 preferred_dtype: If row_lengths has an unspecified type, use this one. If 373 preferred_dtype is None, defaults to dtypes.int64. 374 375 Returns: 376 A `RowPartition`. 377 """ 378 if not isinstance(validate, bool): 379 raise TypeError("validate must have type bool") 380 with ops.name_scope(None, "RowPartitionFromRowLengths", [row_lengths]): 381 row_lengths = cls._convert_row_partition(row_lengths, "row_lengths", 382 preferred_dtype) 383 row_lengths.shape.assert_has_rank(1) 384 385 if validate: 386 msg = "Arguments to from_row_lengths do not form a valid RowPartition" 387 checks = [ 388 check_ops.assert_rank(row_lengths, 1, message=msg), 389 check_ops.assert_non_negative(row_lengths, message=msg), 390 ] 391 row_lengths = control_flow_ops.with_dependencies(checks, row_lengths) 392 393 row_limits = math_ops.cumsum(row_lengths) 394 row_splits = array_ops.concat([[0], row_limits], axis=0) 395 return cls( 396 row_splits=row_splits, 397 row_lengths=row_lengths, 398 internal=_row_partition_factory_key) 399 400 @classmethod 401 def from_row_starts(cls, 402 row_starts, 403 nvals, 404 validate=True, 405 preferred_dtype=None): 406 """Creates a `RowPartition` with rows partitioned by `row_starts`. 407 408 Equivalent to: `from_row_splits(concat([row_starts, nvals], axis=0))`. 409 410 Args: 411 row_starts: A 1-D integer tensor with shape `[nrows]`. Must be 412 nonnegative and sorted in ascending order. If `nrows>0`, then 413 `row_starts[0]` must be zero. 414 nvals: A scalar tensor indicating the number of values. 415 validate: If true, then use assertions to check that the arguments form a 416 valid `RowPartition`. 417 preferred_dtype: If row_limits has an unspecified type, use this one. If 418 preferred_dtype is None, defaults to dtypes.int64. 419 420 Returns: 421 A `RowPartition`. 422 """ 423 if not isinstance(validate, bool): 424 raise TypeError("validate must have type bool") 425 with ops.name_scope(None, "RowPartitionFromRowStarts", [row_starts]): 426 row_starts = cls._convert_row_partition(row_starts, "row_starts", 427 preferred_dtype) 428 row_starts.shape.assert_has_rank(1) 429 nvals = math_ops.cast(nvals, row_starts.dtype) 430 if validate: 431 msg = "Arguments to from_row_starts do not form a valid RaggedTensor" 432 checks = [ 433 check_ops.assert_rank(row_starts, 1, message=msg), 434 _assert_zero(row_starts[:1], message=msg), 435 _assert_monotonic_increasing(row_starts, message=msg), 436 check_ops.assert_less_equal(row_starts[-1:], nvals, message=msg), 437 ] 438 row_starts = control_flow_ops.with_dependencies(checks, row_starts) 439 440 row_splits = array_ops.concat([row_starts, [nvals]], axis=0) 441 return cls(row_splits=row_splits, internal=_row_partition_factory_key) 442 443 @classmethod 444 def from_row_limits(cls, row_limits, validate=True, preferred_dtype=None): 445 """Creates a `RowPartition` with rows partitioned by `row_limits`. 446 447 Equivalent to: `from_row_splits(values, concat([0, row_limits], axis=0))`. 448 449 Args: 450 row_limits: A 1-D integer tensor with shape `[nrows]`. Must be sorted in 451 ascending order. 452 validate: If true, then use assertions to check that the arguments form a 453 valid `RowPartition`. 454 preferred_dtype: If row_limits has an unspecified type, use this one. If 455 preferred_dtype is None, defaults to dtypes.int64. 456 457 Returns: 458 A `RowPartition`. 459 """ 460 if not isinstance(validate, bool): 461 raise TypeError("validate must have type bool") 462 with ops.name_scope(None, "RowPartitionFromRowLimits", [row_limits]): 463 row_limits = cls._convert_row_partition(row_limits, "row_limits", 464 preferred_dtype) 465 row_limits.shape.assert_has_rank(1) 466 467 if validate: 468 msg = "Arguments to from_row_limits do not form a valid RaggedTensor" 469 checks = [ 470 check_ops.assert_rank(row_limits, 1, message=msg), 471 check_ops.assert_non_negative(row_limits[:1], message=msg), 472 _assert_monotonic_increasing(row_limits, message=msg), 473 ] 474 row_limits = control_flow_ops.with_dependencies(checks, row_limits) 475 476 zero = array_ops.zeros([1], row_limits.dtype) 477 row_splits = array_ops.concat([zero, row_limits], axis=0) 478 return cls(row_splits=row_splits, internal=_row_partition_factory_key) 479 480 # TODO(edloper): Make nvals optional: user must specify at least one of 481 # {nvals, nrows}, but they can pick which one to specify. 482 @classmethod 483 def from_uniform_row_length(cls, 484 uniform_row_length, 485 nvals, 486 nrows=None, 487 validate=True, 488 preferred_dtype=None): 489 """Creates a `RowPartition` with rows partitioned by `uniform_row_length`. 490 491 This `RowPartition` divides a sequence `values` into rows that all have 492 the same length: 493 494 ```python 495 partitioned_rows = [[values.pop(0) for _ in range(uniform_row_length)] 496 for _ in range(nrows)] 497 ``` 498 499 Args: 500 uniform_row_length: A scalar integer tensor. Must be nonnegative. The 501 size of the outer axis of `values` must be evenly divisible by 502 `uniform_row_length`. 503 nvals: a non-negative scalar integer tensor for the number of values. 504 nrows: The number of rows in the constructed RowPartition. If not 505 specified, then it defaults to `nvals/uniform_row_length` (or `0` if 506 `uniform_row_length==0`). `nrows` only needs to be specified if 507 `uniform_row_length` might be zero. `uniform_row_length*nrows` must be 508 `nvals`. 509 validate: If true, then use assertions to check that the arguments form a 510 valid `RowPartition`. 511 preferred_dtype: if uniform_row_length has no dtype, use this one. 512 513 Returns: 514 A `RowPartition`. 515 """ 516 if not isinstance(validate, bool): 517 raise TypeError("validate must have type bool") 518 with ops.name_scope(None, "RowPartitionFromUniformRowLength", 519 [uniform_row_length, nrows]): 520 uniform_row_length = cls._convert_row_partition(uniform_row_length, 521 "uniform_row_length", 522 preferred_dtype) 523 uniform_row_length.shape.assert_has_rank(0) 524 525 # Find nrows. 526 const_row_length = tensor_util.constant_value(uniform_row_length) 527 if nrows is None: 528 if const_row_length is None: 529 # Avoid division by zero if uniform_row_length==0 (and nvals==0). 530 rowlen_or_1 = math_ops.maximum( 531 uniform_row_length, 532 constant_op.constant(1, uniform_row_length.dtype)) 533 nrows = nvals // rowlen_or_1 534 elif const_row_length == 0: 535 nrows = 0 536 else: 537 nrows = nvals // const_row_length 538 nrows = ops.convert_to_tensor( 539 nrows, uniform_row_length.dtype, name="nrows") 540 const_nrows = tensor_util.constant_value(nrows) 541 const_nvals = tensor_util.constant_value(nvals) 542 543 # Find row_splits. 544 if const_nrows is not None and const_row_length is not None: 545 row_splits = [v * const_row_length for v in range(const_nrows + 1)] 546 row_splits = constant_op.constant(row_splits, uniform_row_length.dtype) 547 else: 548 row_splits = math_ops.range(nrows + 1) * uniform_row_length 549 550 if validate: 551 checks = [] 552 553 if (const_nrows is None or const_row_length is None or 554 const_nvals is None): 555 checks.append( 556 check_ops.assert_equal( 557 nrows * uniform_row_length, nvals, 558 ("uniform_row_length", uniform_row_length, "times nrows", 559 nrows, "must equal nvals", nvals))) 560 else: 561 if const_nrows * const_row_length != const_nvals: 562 raise ValueError( 563 "uniform_row_length=%d times nrows=%d must equal nvals=%d" % 564 (const_row_length, const_nrows, const_nvals)) 565 566 if uniform_row_length.shape.rank is None: 567 checks.append( 568 check_ops.assert_rank( 569 uniform_row_length, 570 0, 571 message="uniform_row_length must be a scalar.")) 572 573 const_row_length = tensor_util.constant_value(uniform_row_length) 574 if const_row_length is None: 575 checks.append( 576 check_ops.assert_greater_equal( 577 uniform_row_length, 578 constant_op.constant(0, uniform_row_length.dtype), 579 message="uniform_row_length must be >= 0.")) 580 else: 581 if const_row_length < 0: 582 raise ValueError("uniform_row_length must be >= 0.") 583 584 row_splits = control_flow_ops.with_dependencies(checks, row_splits) 585 586 return cls( 587 row_splits=row_splits, 588 uniform_row_length=uniform_row_length, 589 nrows=nrows, 590 internal=_row_partition_factory_key) 591 592 @classmethod 593 def _convert_row_partition(cls, partition, name, preferred_dtype): 594 """Converts `partition` to Tensors. 595 596 Args: 597 partition: A row-partitioning tensor for the `RowPartition` being 598 constructed. I.e., one of: row_splits, row_lengths, row_starts, 599 row_limits, value_rowids, uniform_row_length. 600 name: The name of the row-partitioning tensor. 601 preferred_dtype: If partition has no dtype, give it this one. If 602 no dtype is specified, use dtypes.int64. 603 604 Returns: 605 A tensor equivalent to partition. 606 607 Raises: 608 ValueError: if dtype is not int32 or int64. 609 """ 610 if preferred_dtype is None: 611 preferred_dtype = dtypes.int64 612 if isinstance(partition, np.ndarray) and partition.dtype == np.int32: 613 partition = ops.convert_to_tensor(partition, name=name) 614 else: 615 partition = ops.convert_to_tensor( 616 partition, preferred_dtype=preferred_dtype, name=name) 617 if partition.dtype not in (dtypes.int32, dtypes.int64): 618 raise ValueError("%s must have dtype int32 or int64" % name) 619 620 return partition 621 622 def with_dependencies(self, dependencies): 623 """Returns a new RowPartition equal to self with control dependencies. 624 625 Specifically, self._row_splits is gated by the given control dependencies. 626 Used to add sanity checks to the constructors. 627 628 Args: 629 dependencies: a list of tensors to use as dependencies. 630 631 Returns: 632 A new RowPartition object. 633 """ 634 new_row_splits = control_flow_ops.with_dependencies(dependencies, 635 self._row_splits) 636 return RowPartition( 637 row_splits=new_row_splits, 638 row_lengths=self._row_lengths, 639 value_rowids=self._value_rowids, 640 nrows=self._nrows, 641 uniform_row_length=self._uniform_row_length, 642 internal=_row_partition_factory_key) 643 644 #============================================================================= 645 # Accessors 646 #============================================================================= 647 648 @property 649 def dtype(self): 650 """The `DType` used to encode the row partition (either int32 or int64).""" 651 return self._row_splits.dtype 652 653 def row_splits(self): 654 """Returns the row-split indices for this row partition. 655 656 `row_splits` specifies where the values for each row begin and end. 657 In particular, the values for row `i` are stored in the slice 658 `values[row_splits[i]:row_splits[i+1]]`. 659 660 Returns: 661 A 1-D integer `Tensor` with shape `[self.nrows+1]`. 662 The returned tensor is non-empty, and is sorted in ascending order. 663 `self.row_splits()[0] == 0`. 664 `self.row_splits()[-1] == self.nvals()`. 665 """ 666 return self._row_splits 667 668 def value_rowids(self): 669 """Returns the row indices for this row partition. 670 671 `value_rowids` specifies the row index fo reach value. In particular, 672 `value_rowids[i]` is the row index for `values[i]`. 673 674 Returns: 675 A 1-D integer `Tensor` with shape `[self.nvals()]`. 676 The returned tensor is nonnegative, and is sorted in ascending order. 677 """ 678 if self._value_rowids is not None: 679 return self._value_rowids 680 return segment_id_ops.row_splits_to_segment_ids(self._row_splits) 681 682 def nvals(self, out_type=None): 683 """Returns the number of values partitioned by this `RowPartition`. 684 685 If the sequence partitioned by this `RowPartition` is a tensor, then 686 `nvals` is the size of that tensor's outermost dimension -- i.e., 687 `nvals == values.shape[0]`. 688 689 Args: 690 out_type: `dtype` for the returned tensor. Defaults to `self.dtype`. 691 692 Returns: 693 scalar integer Tensor 694 """ 695 if out_type is None: 696 return self._row_splits[-1] 697 else: 698 out_type = dtypes.as_dtype(out_type) 699 return math_ops.cast(self._row_splits[-1], dtype=out_type) 700 701 def nrows(self, out_type=None): 702 """Returns the number of rows created by this `RowPartition`. 703 704 Args: 705 out_type: `dtype` for the returned tensor. Defaults to `self.dtype`. 706 707 Returns: 708 scalar integer Tensor 709 """ 710 if out_type is None: 711 out_type = self.dtype 712 else: 713 out_type = dtypes.as_dtype(out_type) 714 if self._nrows is not None: 715 return math_ops.cast(self._nrows, out_type) 716 nsplits = tensor_shape.dimension_at_index(self._row_splits.shape, 0) 717 if nsplits.value is None: 718 return array_ops.shape(self._row_splits, out_type=out_type)[0] - 1 719 else: 720 return constant_op.constant(nsplits.value - 1, dtype=out_type) 721 722 def uniform_row_length(self): 723 """Returns the length of each row in this partition, if rows are uniform. 724 725 If all rows in this `RowPartition` have the same length, then this returns 726 that length as a scalar integer `Tensor`. Otherwise, it returns `None`. 727 728 Returns: 729 scalar Tensor with `type=self.dtype`, or `None`. 730 """ 731 return self._uniform_row_length 732 733 def row_starts(self): 734 """Returns the start indices for rows in this row partition. 735 736 These indices specify where the values for each row begin. 737 `partition.row_starts()` is equal to `partition.row_splits()[:-1]`. 738 739 Returns: 740 A 1-D integer Tensor with shape `[self.nrows()]`. 741 The returned tensor is nonnegative, and is sorted in ascending order. 742 `self.row_starts()[0] == 0`. 743 `self.row_starts()[-1] <= self.nvals()`. 744 """ 745 return self._row_splits[:-1] 746 747 def row_limits(self): 748 """Returns the limit indices for rows in this row partition. 749 750 These indices specify where the values for each row end. 751 `partition.row_limits()` is equal to `partition.row_splits()[:-1]`. 752 753 Returns: 754 A 1-D integer Tensor with shape `[self.nrows]`. 755 The returned tensor is nonnegative, and is sorted in ascending order. 756 `self.row_limits()[-1] == self.nvals()`. 757 """ 758 return self._row_splits[1:] 759 760 def row_lengths(self): 761 """Returns the lengths of rows in this `RowPartition`. 762 763 Returns: 764 A 1-D integer Tensor with shape `[self.nrows]`. 765 The returned tensor is nonnegative. 766 `tf.reduce_sum(self.row_lengths) == self.nvals()`. 767 """ 768 if self._row_lengths is not None: 769 return self._row_lengths 770 splits = self._row_splits 771 return splits[1:] - splits[:-1] 772 773 @property 774 def static_nrows(self): 775 """The number of rows in this partition, if statically known. 776 777 ```python 778 self.row_lengths().shape == [self.static_nrows] 779 self.row_starts().shape == [self.static_nrows] 780 self.row_limits().shape == [self.static_nrows] 781 self.row_splits().shape == [self.static_nrows + 1] 782 ``` 783 784 Returns: 785 The number of rows in this partition as an `int` (if statically known); 786 or `None` (otherwise). 787 """ 788 if self._row_splits is not None: 789 nrows = tensor_shape.dimension_at_index(self._row_splits.shape, 0) - 1 790 if nrows.value is not None: 791 return nrows 792 if self._row_lengths is not None: 793 nrows = tensor_shape.dimension_at_index(self._row_lengths.shape, 0) 794 if nrows.value is not None: 795 return nrows 796 if self._nrows is not None: 797 return tensor_shape.Dimension(tensor_util.constant_value(self._nrows)) 798 return None 799 800 @property 801 def static_nvals(self): 802 """The number of values in this partition, if statically known. 803 804 ```python 805 self.value_rowids().shape == [self.static_vals] 806 ``` 807 808 Returns: 809 The number of values in this partition as an `int` (if statically known); 810 or `None` (otherwise). 811 """ 812 if self._value_rowids is not None: 813 nvals = tensor_shape.dimension_at_index(self._value_rowids.shape, 0) 814 if nvals.value is not None: 815 return nvals.value 816 return None 817 818 @property 819 def static_uniform_row_length(self): 820 """The number of values in each row of this partition, if statically known. 821 822 Returns: 823 The number of values in each row of this partition as an `int` (if 824 statically known); or `None` (otherwise). 825 """ 826 if self._uniform_row_length is not None: 827 return tensor_util.constant_value(self._uniform_row_length) 828 return None 829 830 #============================================================================= 831 # Transformation 832 #============================================================================= 833 834 def with_row_splits_dtype(self, dtype): 835 """Returns a copy of this RowPartition with the given `row_splits` dtype. 836 837 For RaggedTensors with multiple ragged dimensions, the `row_splits` for all 838 nested `RaggedTensor` objects are cast to the given dtype. 839 840 Args: 841 dtype: The dtype for `row_splits`. One of `tf.int32` or `tf.int64`. 842 843 Returns: 844 A copy of this RaggedTensor, with the `row_splits` cast to the given 845 type. 846 """ 847 dtype = dtypes.as_dtype(dtype) 848 if dtype not in (dtypes.int32, dtypes.int64): 849 raise ValueError("dtype must be int32 or int64") 850 if self.dtype == dtype: 851 return self 852 853 return RowPartition( 854 row_splits=_cast_if_not_none(self._row_splits, dtype), 855 row_lengths=_cast_if_not_none(self._row_lengths, dtype), 856 value_rowids=_cast_if_not_none(self._value_rowids, dtype), 857 nrows=_cast_if_not_none(self._nrows, dtype), 858 uniform_row_length=_cast_if_not_none(self._uniform_row_length, dtype), 859 internal=_row_partition_factory_key) 860 861 #============================================================================= 862 # String Encoding 863 #============================================================================= 864 865 def __repr__(self): 866 return "tf.RowPartition(row_splits=%s)" % (self._row_splits) 867 868 #============================================================================= 869 # Precomputed Encodings 870 #============================================================================= 871 872 def has_precomputed_row_splits(self): 873 """Returns true if `row_splits` has already been computed. 874 875 If true, then `self.row_splits()` will return its value without calling 876 any TensorFlow ops. 877 """ 878 return self._row_splits is not None 879 880 def has_precomputed_row_lengths(self): 881 """Returns true if `row_lengths` has already been computed. 882 883 If true, then `self.row_lengths()` will return its value without calling 884 any TensorFlow ops. 885 """ 886 return self._row_lengths is not None 887 888 def has_precomputed_value_rowids(self): 889 """Returns true if `value_rowids` has already been computed. 890 891 If true, then `self.value_rowids()` will return its value without calling 892 any TensorFlow ops. 893 """ 894 return self._value_rowids is not None 895 896 def has_precomputed_nrows(self): 897 """Returns true if `nrows` has already been computed. 898 899 If true, then `self.nrows()` will return its value without calling 900 any TensorFlow ops. 901 """ 902 return self._nrows is not None 903 904 def with_precomputed_row_splits(self): 905 """Returns a copy of `self` with `row_splits` precomputed.""" 906 return RowPartition( 907 row_splits=self.row_splits(), 908 row_lengths=self._row_lengths, 909 value_rowids=self._value_rowids, 910 nrows=self._nrows, 911 uniform_row_length=self._uniform_row_length, 912 internal=_row_partition_factory_key) 913 914 def with_precomputed_row_lengths(self): 915 """Returns a copy of `self` with `row_lengths` precomputed.""" 916 return RowPartition( 917 row_splits=self._row_splits, 918 row_lengths=self.row_lengths(), 919 value_rowids=self._value_rowids, 920 nrows=self._nrows, 921 uniform_row_length=self._uniform_row_length, 922 internal=_row_partition_factory_key) 923 924 def with_precomputed_value_rowids(self): 925 """Returns a copy of `self` with `value_rowids` precomputed.""" 926 return RowPartition( 927 row_splits=self._row_splits, 928 row_lengths=self._row_lengths, 929 value_rowids=self.value_rowids(), 930 nrows=self._nrows, 931 uniform_row_length=self._uniform_row_length, 932 internal=_row_partition_factory_key) 933 934 def with_precomputed_nrows(self): 935 """Returns a copy of `self` with `nrows` precomputed.""" 936 return RowPartition( 937 row_splits=self._row_splits, 938 row_lengths=self._row_lengths, 939 value_rowids=self._value_rowids, 940 nrows=self.nrows(), 941 uniform_row_length=self._uniform_row_length, 942 internal=_row_partition_factory_key) 943 944 def merge_precomputed_encodings(self, other, validate=True): 945 """Returns a RowPartition that merges encodings from `self` and `other`. 946 947 Requires that `self` and `other` describe the same partition. 948 949 Args: 950 other: A `RowPartition` that encodes the same partition as `self`. 951 validate: If true, then add runtime checks to verify that `self` and 952 `other` encode the same row partition. 953 954 Returns: 955 A `RowPartition`. 956 """ 957 # pylint: disable=protected-access 958 if (self is other or # Fast path if row partitions are equal. 959 (self._row_splits is other._row_splits and 960 self._row_lengths is other._row_lengths and 961 self._value_rowids is other._value_rowids and 962 self._nrows is other._nrows and 963 self._uniform_row_length is other._uniform_row_length)): 964 return self 965 966 # Merge the component tensors. We only need to validate one encoding. 967 # We merge less-expensive encodings first (to avoid expensive validation). 968 nrows, nrows_validated = _merge_tensors(self._nrows, other._nrows, "nrows", 969 validate) 970 uniform_row_length, uniform_row_length_validated = _merge_tensors( 971 self._uniform_row_length, other._uniform_row_length, 972 "uniform_row_length", validate) 973 if uniform_row_length_validated and nrows_validated: 974 validate = False # Validation complete. 975 row_splits, row_splits_validated = _merge_tensors(self._row_splits, 976 other._row_splits, 977 "row_splits", validate) 978 if row_splits_validated: 979 validate = False # Validation complete. 980 row_lengths, row_lengths_validated = _merge_tensors(self._row_lengths, 981 other._row_lengths, 982 "row_lengths", validate) 983 if row_lengths_validated: 984 validate = False # Validation complete. 985 value_rowids, value_rowids_validated = _merge_tensors( 986 self._value_rowids, other._value_rowids, "value_rowids", validate) 987 if value_rowids_validated and nrows_validated: 988 validate = False # Validation complete. 989 # TODO(edloper): If we make the row_splits encoding optional, then there 990 # will be cases where we need to do validation at this point -- e.g. if 991 # self has only row_splits and other has only value_rowids. But for 992 # now, we are guaranteed to have done validation by this point. 993 994 # Avoid creating new RowPartition objects if we don't need to. 995 if (row_splits is self._row_splits and row_lengths is self._row_lengths and 996 value_rowids is self._value_rowids and nrows is self._nrows and 997 uniform_row_length is self._uniform_row_length): 998 return self 999 if (row_splits is other._row_splits and 1000 row_lengths is other._row_lengths and 1001 value_rowids is other._value_rowids and nrows is other._nrows and 1002 uniform_row_length is other._uniform_row_length): 1003 return other 1004 1005 return RowPartition( 1006 row_splits=row_splits, 1007 row_lengths=row_lengths, 1008 value_rowids=value_rowids, 1009 nrows=nrows, 1010 uniform_row_length=uniform_row_length, 1011 internal=_row_partition_factory_key) 1012 1013 #============================================================================= 1014 # Composite Tensor 1015 #============================================================================= 1016 1017 @property 1018 def _type_spec(self): 1019 return RowPartitionSpec.from_value(self) 1020 1021 1022#=============================================================================== 1023# RowPartitionSpec 1024#=============================================================================== 1025# TODO(edloper): Consider refactoring RowPartitionSpec to allow any combination 1026# of precomputed row-partition encodings (rather than always using row_splits). 1027 1028 1029class RowPartitionSpec(type_spec.TypeSpec): 1030 """Type specification for a `tf.RowPartition`.""" 1031 1032 __slots__ = ["_nrows", "_nvals", "_uniform_row_length", "_dtype"] 1033 1034 value_type = property(lambda self: RowPartition) 1035 1036 def __init__(self, 1037 nrows=None, 1038 nvals=None, 1039 uniform_row_length=None, 1040 dtype=dtypes.int64): 1041 """Constructs a new RowPartitionSpec. 1042 1043 Args: 1044 nrows: The number of rows in the RowPartition, or `None` if unspecified. 1045 nvals: The number of values partitioned by the RowPartition, or `None` if 1046 unspecified. 1047 uniform_row_length: The number of values in each row for this 1048 RowPartition, or `None` if rows are ragged or row length is unspecified. 1049 dtype: The data type used to encode the partition. One of `tf.int64` or 1050 `tf.int32`. 1051 """ 1052 # Wrap dimension sizes in 1D TensorShapes so the default implementations 1053 # of TypeSpec methods such as `is_compatile_with` will work. 1054 nrows = tensor_shape.TensorShape([nrows]) 1055 nvals = tensor_shape.TensorShape([nvals]) 1056 if not isinstance(uniform_row_length, tensor_shape.TensorShape): 1057 uniform_row_length = tensor_shape.TensorShape([uniform_row_length]) 1058 else: 1059 uniform_row_length = uniform_row_length.with_rank(1) 1060 1061 self._nrows = nrows 1062 self._nvals = nvals 1063 self._uniform_row_length = uniform_row_length 1064 self._dtype = dtypes.as_dtype(dtype) 1065 if self._dtype not in (dtypes.int32, dtypes.int64): 1066 raise ValueError("dtype must be tf.int32 or tf.int64") 1067 1068 # Check dimension consistency, & infer dimensions when possible. 1069 nrows = tensor_shape.dimension_value(nrows[0]) 1070 nvals = tensor_shape.dimension_value(nvals[0]) 1071 ncols = tensor_shape.dimension_value(uniform_row_length[0]) 1072 if nrows == 0: # no rows -> no values. 1073 if nvals is None: 1074 self._nvals = tensor_shape.TensorShape([0]) 1075 elif nvals != 0: 1076 raise ValueError("nvals=%s is not compatible with nrows=%s" % 1077 (nvals, nrows)) 1078 if ncols == 0: # there are no values in each row -> no values. 1079 if nvals is None: 1080 self._nvals = tensor_shape.TensorShape([0]) 1081 elif nvals != 0: 1082 raise ValueError("nvals=%s is not compatible with uniform_row_length" 1083 "=%s" % (nvals, uniform_row_length)) 1084 if ncols is not None and nvals is not None: 1085 if ncols != 0 and nvals % ncols != 0: 1086 raise ValueError("nvals=%s is not compatible with uniform_row_length" 1087 "=%s (doesn't divide evenly)" % (nvals, ncols)) 1088 if nrows is not None and nvals != ncols * nrows: 1089 raise ValueError("nvals=%s is not compatible with nrows=%s and " 1090 "uniform_row_length=%s" % (nvals, nrows, ncols)) 1091 if nrows is None and ncols != 0: 1092 self._nrows = tensor_shape.TensorShape([nvals // ncols]) 1093 if ncols is not None and nrows is not None and nvals is None: 1094 self._nvals = tensor_shape.TensorShape([ncols * nrows]) 1095 1096 def is_compatible_with(self, other): 1097 if not super(RowPartitionSpec, self).is_compatible_with(other): 1098 return False 1099 nrows = self._nrows.merge_with(other.nrows) 1100 nvals = self._nvals.merge_with(other.nvals) 1101 ncols = self._uniform_row_length.merge_with(other.uniform_row_length) 1102 return self._dimensions_compatible(nrows, nvals, ncols) 1103 1104 def _serialize(self): 1105 return (self._nrows, self._nvals, self._uniform_row_length, self._dtype) 1106 1107 @classmethod 1108 def _deserialize(cls, serialization): 1109 # Remove TensorShape wrappers from serialization. 1110 (nrows, nvals, uniform_row_length, dtype) = serialization 1111 nrows = tensor_shape.dimension_value(nrows[0]) 1112 nvals = tensor_shape.dimension_value(nvals[0]) 1113 return cls(nrows, nvals, uniform_row_length, dtype) 1114 1115 @property 1116 def nrows(self): 1117 return tensor_shape.dimension_value(self._nrows[0]) 1118 1119 @property 1120 def nvals(self): 1121 return tensor_shape.dimension_value(self._nvals[0]) 1122 1123 @property 1124 def uniform_row_length(self): 1125 return tensor_shape.dimension_value(self._uniform_row_length[0]) 1126 1127 @property 1128 def dtype(self): 1129 return self._dtype 1130 1131 @property 1132 def _component_specs(self): 1133 row_splits_shape = tensor_shape.TensorShape( 1134 [tensor_shape.dimension_at_index(self._nrows, 0) + 1]) 1135 return tensor_spec.TensorSpec(row_splits_shape, self._dtype) 1136 1137 def _to_components(self, value): 1138 return value.row_splits() 1139 1140 def _from_components(self, tensor): 1141 return RowPartition.from_row_splits(tensor, validate=False) 1142 1143 @classmethod 1144 def from_value(cls, value): 1145 if not isinstance(value, RowPartition): 1146 raise TypeError("Expected `value` to be a `RowPartition`") 1147 return cls(value.static_nrows, value.static_nvals, 1148 value.static_uniform_row_length, value.dtype) 1149 1150 def __repr__(self): 1151 return ("RowPartitionSpec(nrows=%s, nvals=%s, uniform_row_length=%s, " 1152 "dtype=%r)" % (self.nrows, self.nvals, self.uniform_row_length, 1153 self.dtype)) 1154 1155 @staticmethod 1156 def _dimensions_compatible(nrows, nvals, uniform_row_length): 1157 """Returns true if the given dimensions are compatible.""" 1158 nrows = tensor_shape.dimension_value(nrows[0]) 1159 nvals = tensor_shape.dimension_value(nvals[0]) 1160 ncols = tensor_shape.dimension_value(uniform_row_length[0]) 1161 if nrows == 0 and nvals not in (0, None): 1162 return False # can't have values if we have no rows. 1163 if ncols == 0 and nvals not in (0, None): 1164 return False # can't have values if we have no values in each row. 1165 if ncols is not None and nvals is not None: 1166 if ncols != 0 and nvals % ncols != 0: 1167 return False # rows aren't uniform. 1168 if nrows is not None and nvals != ncols * nrows: 1169 return False # inconsistent number of values. 1170 return True 1171 1172 1173#=============================================================================== 1174# Helper Functions 1175#=============================================================================== 1176 1177 1178def _assert_monotonic_increasing(tensor, message=None): 1179 return check_ops.assert_non_negative( 1180 tensor[1:] - tensor[:-1], message=message) 1181 1182 1183def _assert_zero(tensor, message=None): 1184 return check_ops.assert_equal( 1185 tensor, constant_op.constant(0, dtype=tensor.dtype), message=message) 1186 1187 1188def _cast_if_not_none(tensor, dtype): 1189 return None if tensor is None else math_ops.cast(tensor, dtype) 1190 1191 1192def _merge_tensors(t1, t2, name, validate): 1193 """Merge two optional Tensors with equal values into a single Tensor. 1194 1195 Args: 1196 t1: tf.Tensor or None 1197 t2: tf.Tensor or None 1198 name: A name for the tensors (for error messages) 1199 validate: If true, then check that `t1` is compatible with `t2` (if both are 1200 non-None). 1201 1202 Returns: 1203 A pair `(merged_value, validated)`: 1204 * `merged_value` is `t1` if it is not None; or `t2` otherwise. 1205 * `validated` is true if we validated that t1 and t2 are equal (either 1206 by adding a check, or because t1 is t2). 1207 """ 1208 if t1 is None: 1209 return t2, False 1210 elif t2 is None: 1211 return t1, False 1212 elif t1 is t2: 1213 return t1, True 1214 else: 1215 err_msg = ("RowPartition.merge_precomuted_encodings: partitions " 1216 "have incompatible %s" % name) 1217 if not t1.shape.is_compatible_with(t2.shape): 1218 raise ValueError(err_msg) 1219 if validate: 1220 checks = [check_ops.assert_equal(t1, t2, message=err_msg)] 1221 return control_flow_ops.with_dependencies(checks, t1), True 1222 else: 1223 return t1, False 1224 1225 1226_row_partition_factory_key = object() # unique private object 1227