1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A class used to partition a sequence into contiguous subsequences ("rows"). 16""" 17 18 19# TODO(edloper): Make into a ExtensionType (if possible) 20 21 22import numpy as np 23 24from tensorflow.python.framework import composite_tensor 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import tensor_shape 29from tensorflow.python.framework import tensor_spec 30from tensorflow.python.framework import tensor_util 31from tensorflow.python.framework import type_spec 32from tensorflow.python.ops import array_ops 33from tensorflow.python.ops import check_ops 34from tensorflow.python.ops import control_flow_ops 35from tensorflow.python.ops import gen_ragged_math_ops 36from tensorflow.python.ops import math_ops 37from tensorflow.python.ops.ragged import segment_id_ops 38from tensorflow.python.util.tf_export import tf_export 39 40#=============================================================================== 41# RowPartition 42#=============================================================================== 43# TODO(edloper): Consider removing row_starts and row_limits factory methods 44# and accessors from RowPartition. In particular, these two encodings are 45# "second-class citizens": we never cache them, and if you do construct a 46# RowPartition from them then it may be more expensive than you might expect 47# (because we append a value to the beginning/end to transform them into 48# splits). If we do remove them from RowPartition, then we would still keep 49# the from_row_starts and from_row_limits factory methods in RaggedTensor. 50 51 52@tf_export("experimental.RowPartition") 53class RowPartition(composite_tensor.CompositeTensor): 54 """Partitioning of a sequence of values into contiguous subsequences ("rows"). 55 56 A `RowPartition` describes how a sequence with `nvals` items should be 57 divided into `nrows` contiguous subsequences ("rows"). For example, a 58 `RowPartition` could be used to partition the vector `[1, 2, 3, 4, 5]` into 59 subsequences `[[1, 2], [3], [], [4, 5]]`. Note that `RowPartition` stores 60 information about how values are partitioned, but does not include the 61 partitioned values themselves. `tf.RaggedTensor` is used to pair a `values` 62 tensor with one or more `RowPartition`s, providing a complete encoding for a 63 ragged tensor (i.e. a tensor with variable-length dimensions). 64 65 `RowPartition`s may be defined using several different schemes: 66 67 * `row_lengths`: an integer vector with shape `[nrows]`, which specifies 68 the length of each row. 69 70 * `row_splits`: an integer vector with shape `[nrows+1]`, specifying the 71 "split points" between each row. 72 73 * `row_starts`: an integer vector with shape `[nrows]`, which specifies 74 the start offset for each row. Equivalent to `row_splits[:-1]`. 75 76 * `row_limits`: an integer vector with shape `[nrows]`, which specifies 77 the stop offset for each row. Equivalent to `row_splits[1:]`. 78 79 * `value_rowids` is an integer vector with shape `[nvals]`, corresponding 80 one-to-one with sequence values, which specifies the row that each value 81 belongs to. If the partition has empty trailing rows, then `nrows` 82 must also be specified. 83 84 * `uniform_row_length` is an integer scalar, specifying the length of every 85 row. This scheme may only be used if all rows have the same length. 86 87 For example, the following `RowPartition`s all represent the partitioning of 88 8 values into 5 sublists as follows: `[[*, *, *, *], [], [*, *, *], [*], []]`. 89 90 >>> p1 = RowPartition.from_row_lengths([4, 0, 3, 1, 0]) 91 >>> p2 = RowPartition.from_row_splits([0, 4, 4, 7, 8, 8]) 92 >>> p3 = RowPartition.from_row_starts([0, 4, 4, 7, 8], nvals=8) 93 >>> p4 = RowPartition.from_row_limits([4, 4, 7, 8, 8]) 94 >>> p5 = RowPartition.from_value_rowids([0, 0, 0, 0, 2, 2, 2, 3], nrows=5) 95 96 For more information about each scheme, see the documentation for the 97 its factory method. For additional examples, see the documentation on 98 `tf.RaggedTensor`. 99 100 ### Precomputed Encodings 101 102 `RowPartition` always stores at least one encoding of the partitioning, but 103 it can be configured to cache additional encodings as well. This can 104 avoid unnecessary recomputation in eager mode. (In graph mode, optimizations 105 such as common subexpression elimination will typically prevent these 106 unnecessary recomputations.) To check which encodings are precomputed, use 107 `RowPartition.has_precomputed_<encoding>`. To cache an additional 108 encoding, use `RowPartition.with_precomputed_<encoding>`. 109 """ 110 111 #============================================================================= 112 # Constructor (private) 113 #============================================================================= 114 def __init__(self, 115 row_splits, 116 row_lengths=None, 117 value_rowids=None, 118 nrows=None, 119 uniform_row_length=None, 120 nvals=None, 121 internal=False): 122 """Creates a `RowPartition` from the specified encoding tensor(s). 123 124 This constructor is private -- please use one of the following ops to 125 build `RowPartition`s: 126 127 * `RowPartition.from_row_lengths` 128 * `RowPartition.from_value_rowids` 129 * `RowPartition.from_row_splits` 130 * `RowPartition.from_row_starts` 131 * `RowPartition.from_row_limits` 132 * `RowPartition.from_uniform_row_length` 133 134 If row_splits is has a constant value, then all other arguments should 135 have a constant value. 136 137 Args: 138 row_splits: A 1-D integer tensor with shape `[nrows+1]`. 139 row_lengths: A 1-D integer tensor with shape `[nrows]` 140 value_rowids: A 1-D integer tensor with shape `[nvals]`. 141 nrows: A 1-D integer scalar tensor. 142 uniform_row_length: A scalar tensor. 143 nvals: A scalar tensor. 144 internal: Private key value, required to ensure that this private 145 constructor is *only* called from the factory methods. 146 147 Raises: 148 TypeError: If a row partitioning tensor has an inappropriate dtype. 149 TypeError: If exactly one row partitioning argument was not specified. 150 ValueError: If a row partitioning tensor has an inappropriate shape. 151 ValueError: If multiple partitioning arguments are specified. 152 ValueError: If nrows is specified but value_rowids is not None. 153 """ 154 if internal is not _row_partition_factory_key: 155 raise ValueError("RowPartition constructor is private; please use one " 156 "of the factory methods instead (e.g., " 157 "RowPartition.from_row_lengths())") 158 159 # Validate the arguments. 160 if not isinstance(row_splits, ops.Tensor): 161 raise TypeError("Row-partitioning argument must be a Tensor, got %r" % 162 row_splits) 163 if row_splits.dtype not in (dtypes.int32, dtypes.int64): 164 raise ValueError("Row-partitioning argument must be int32 or int64") 165 166 # Validate shapes & dtypes. 167 row_splits.shape.assert_has_rank(1) 168 row_splits.set_shape([None]) 169 self._row_splits = row_splits 170 171 # Store any cached tensors. These are used to avoid unnecessary 172 # round-trip conversions when a RowPartition is constructed from 173 # lengths or rowids, and we later want those lengths/rowids back. 174 for tensor in [row_lengths, value_rowids, nrows, uniform_row_length, nvals]: 175 if tensor is not None: 176 if not isinstance(tensor, ops.Tensor): 177 raise TypeError("Cached value must be a Tensor or None.") 178 elif tensor.dtype != row_splits.dtype: 179 raise ValueError(f"Inconsistent dtype for encoding tensors: " 180 f"{tensor} vs {row_splits}") 181 self._row_lengths = row_lengths 182 self._value_rowids = value_rowids 183 self._nrows = nrows 184 self._uniform_row_length = uniform_row_length 185 self._nvals = nvals 186 187 #============================================================================= 188 # Factory Methods 189 #============================================================================= 190 191 @classmethod 192 def from_value_rowids(cls, 193 value_rowids, 194 nrows=None, 195 validate=True, 196 dtype=None, 197 dtype_hint=None): 198 """Creates a `RowPartition` with rows partitioned by `value_rowids`. 199 200 This `RowPartition` divides a sequence `values` into rows by specifying 201 which row each value should be added to: 202 203 ```python 204 partitioned_rows = [[] for _ in nrows] 205 for (value, rowid) in zip(values, value_rowids): 206 partitioned_rows[rowid].append(value) 207 `` 208 209 Args: 210 value_rowids: A 1-D integer tensor with shape `[nvals]`, which corresponds 211 one-to-one with `values`, and specifies each value's row index. Must be 212 nonnegative, and must be sorted in ascending order. 213 nrows: An integer scalar specifying the number of rows. This should be 214 specified if the `RowPartition` may containing empty training rows. Must 215 be greater than `value_rowids[-1]` (or greater than or equal to zero if 216 `value_rowids` is empty). Defaults to `value_rowids[-1] + 1` (or zero if 217 `value_rowids` is empty). 218 validate: If true, then use assertions to check that the arguments form a 219 valid `RowPartition`. 220 dtype: Optional dtype for the RowPartition. If missing, the type 221 is inferred from the type of `value_rowids`, dtype_hint, or tf.int64. 222 dtype_hint: Optional dtype for the RowPartition, used when dtype 223 is None. In some cases, a caller may not have a dtype in mind when 224 converting to a tensor, so dtype_hint can be used as a soft preference. 225 If the conversion to `dtype_hint` is not possible, this argument has no 226 effect. 227 228 Returns: 229 A `RowPartition`. 230 231 Raises: 232 ValueError: If `nrows` is incompatible with `value_rowids`. 233 234 #### Example: 235 236 >>> print(RowPartition.from_value_rowids( 237 ... value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], 238 ... nrows=4)) 239 tf.RowPartition(row_splits=[0 4 4 7 8]) 240 """ 241 # Local import bincount_ops to avoid import-cycle since bincount_ops 242 # imports ragged_tensor. 243 from tensorflow.python.ops import bincount_ops # pylint: disable=g-import-not-at-top 244 if not isinstance(validate, bool): 245 raise TypeError("validate must have type bool") 246 with ops.name_scope(None, "RowPartitionFromValueRowIds", 247 [value_rowids, nrows]): 248 value_rowids = cls._convert_row_partition( 249 value_rowids, "value_rowids", dtype_hint=dtype_hint, dtype=dtype) 250 if nrows is None: 251 const_rowids = tensor_util.constant_value(value_rowids) 252 if const_rowids is None: 253 nrows = array_ops.concat([value_rowids[-1:], [-1]], axis=0)[0] + 1 254 const_nrows = None 255 else: 256 const_nrows = const_rowids[-1] + 1 if const_rowids.size > 0 else 0 257 nrows = ops.convert_to_tensor( 258 const_nrows, value_rowids.dtype, name="nrows") 259 else: 260 nrows = ops.convert_to_tensor(nrows, value_rowids.dtype, "nrows") 261 const_nrows = tensor_util.constant_value(nrows) 262 if const_nrows is not None: 263 if const_nrows < 0: 264 raise ValueError("Expected nrows >= 0; got %d" % const_nrows) 265 const_rowids = tensor_util.constant_value(value_rowids) 266 if const_rowids is not None and const_rowids.size > 0: 267 if not const_nrows >= const_rowids[-1] + 1: 268 raise ValueError( 269 "Expected nrows >= value_rowids[-1] + 1; got nrows=%d, " 270 "value_rowids[-1]=%d" % (const_nrows, const_rowids[-1])) 271 272 value_rowids.shape.assert_has_rank(1) 273 nrows.shape.assert_has_rank(0) 274 275 if validate: 276 msg = ("Arguments to from_value_rowids do not form a valid " 277 "RowPartition") 278 checks = [ 279 check_ops.assert_rank(value_rowids, 1, message=msg), 280 check_ops.assert_rank(nrows, 0, message=msg), 281 check_ops.assert_non_negative(value_rowids[:1], message=msg), 282 _assert_monotonic_increasing(value_rowids, message=msg), 283 check_ops.assert_less(value_rowids[-1:], nrows, message=msg), 284 ] 285 value_rowids = control_flow_ops.with_dependencies(checks, value_rowids) 286 287 # Convert value_rowids & nrows to row_splits. 288 # Note: we don't use segment_ids_to_row_splits() here because we want 289 # to save the intermediate value `row_lengths`, so we can cache it. 290 # TODO(b/116708836) Upgrade bincount to accept int64 so we can skip the 291 # cast. 292 value_rowids_int32 = math_ops.cast(value_rowids, dtypes.int32) 293 nrows_int32 = math_ops.cast(nrows, dtypes.int32) 294 row_lengths = bincount_ops.bincount( 295 value_rowids_int32, 296 minlength=nrows_int32, 297 maxlength=nrows_int32, 298 dtype=value_rowids.dtype) 299 row_splits = array_ops.concat([[0], math_ops.cumsum(row_lengths)], axis=0) 300 if const_nrows is not None: 301 row_lengths.set_shape([const_nrows]) 302 row_splits.set_shape([const_nrows + 1]) 303 304 return cls( 305 row_splits=row_splits, 306 row_lengths=row_lengths, 307 value_rowids=value_rowids, 308 nrows=nrows, 309 internal=_row_partition_factory_key) 310 311 @classmethod 312 def from_row_splits(cls, 313 row_splits, 314 validate=True, 315 dtype=None, 316 dtype_hint=None): 317 """Creates a `RowPartition` with rows partitioned by `row_splits`. 318 319 This `RowPartition` divides a sequence `values` into rows by indicating 320 where each row begins and ends: 321 322 ```python 323 partitioned_rows = [] 324 for i in range(len(row_splits) - 1): 325 row_start = row_splits[i] 326 row_end = row_splits[i + 1] 327 partitioned_rows.append(values[row_start:row_end]) 328 ``` 329 330 Args: 331 row_splits: A 1-D integer tensor with shape `[nrows+1]`. Must not be 332 empty, and must be sorted in ascending order. `row_splits[0]` must be 333 zero. 334 validate: If true, then use assertions to check that the arguments form a 335 valid `RowPartition`. 336 dtype: Optional dtype for the RowPartition. If missing, the type 337 is inferred from the type of `row_splits`, dtype_hint, or tf.int64. 338 dtype_hint: Optional dtype for the RowPartition, used when dtype 339 is None. In some cases, a caller may not have a dtype in mind when 340 converting to a tensor, so dtype_hint can be used as a soft preference. 341 If the conversion to `dtype_hint` is not possible, this argument has no 342 effect. 343 344 Returns: 345 A `RowPartition`. 346 347 Raises: 348 ValueError: If `row_splits` is an empty list. 349 """ 350 if not isinstance(validate, bool): 351 raise TypeError("validate must have type bool") 352 if isinstance(row_splits, (list, tuple)) and not row_splits: 353 raise ValueError("row_splits tensor may not be empty.") 354 if isinstance(row_splits, tensor_spec.TensorSpec): 355 return cls(row_splits=row_splits, internal=_row_partition_factory_key) 356 357 with ops.name_scope(None, "RowPartitionFromRowSplits", [row_splits]): 358 row_splits = cls._convert_row_partition( 359 row_splits, "row_splits", dtype_hint=dtype_hint, dtype=dtype) 360 row_splits.shape.assert_has_rank(1) 361 362 if validate: 363 msg = "Arguments to from_row_splits do not form a valid RaggedTensor:" 364 checks = [ 365 check_ops.assert_rank(row_splits, 1, message=(msg + "rank")), 366 _assert_zero(row_splits[0], message=(msg + "zero")), 367 _assert_monotonic_increasing( 368 row_splits, message=(msg + "monotonic")), 369 ] 370 row_splits = control_flow_ops.with_dependencies(checks, row_splits) 371 372 return cls(row_splits=row_splits, internal=_row_partition_factory_key) 373 374 @classmethod 375 def from_row_lengths(cls, 376 row_lengths, 377 validate=True, 378 dtype=None, 379 dtype_hint=None): 380 """Creates a `RowPartition` with rows partitioned by `row_lengths`. 381 382 This `RowPartition` divides a sequence `values` into rows by indicating 383 the length of each row: 384 385 ```python 386 partitioned_rows = [[values.pop(0) for _ in range(length)] 387 for length in row_lengths] 388 ``` 389 390 Args: 391 row_lengths: A 1-D integer tensor with shape `[nrows]`. Must be 392 nonnegative. 393 validate: If true, then use assertions to check that the arguments form a 394 valid `RowPartition`. 395 396 dtype: Optional dtype for the RowPartition. If missing, the type 397 is inferred from the type of `row_lengths`, dtype_hint, or tf.int64. 398 dtype_hint: Optional dtype for the RowPartition, used when dtype 399 is None. In some cases, a caller may not have a dtype in mind when 400 converting to a tensor, so dtype_hint can be used as a soft preference. 401 If the conversion to `dtype_hint` is not possible, this argument has no 402 effect. 403 404 Returns: 405 A `RowPartition`. 406 """ 407 if not isinstance(validate, bool): 408 raise TypeError("validate must have type bool") 409 with ops.name_scope(None, "RowPartitionFromRowLengths", [row_lengths]): 410 row_lengths = cls._convert_row_partition( 411 row_lengths, "row_lengths", dtype_hint=dtype_hint, dtype=dtype) 412 row_lengths.shape.assert_has_rank(1) 413 414 if validate: 415 msg = "Arguments to from_row_lengths do not form a valid RowPartition" 416 checks = [ 417 check_ops.assert_rank(row_lengths, 1, message=msg), 418 check_ops.assert_non_negative(row_lengths, message=msg), 419 ] 420 row_lengths = control_flow_ops.with_dependencies(checks, row_lengths) 421 422 row_limits = math_ops.cumsum(row_lengths) 423 row_splits = array_ops.concat([[0], row_limits], axis=0) 424 return cls( 425 row_splits=row_splits, 426 row_lengths=row_lengths, 427 internal=_row_partition_factory_key) 428 429 @classmethod 430 def from_row_starts(cls, 431 row_starts, 432 nvals, 433 validate=True, 434 dtype=None, 435 dtype_hint=None): 436 """Creates a `RowPartition` with rows partitioned by `row_starts`. 437 438 Equivalent to: `from_row_splits(concat([row_starts, nvals], axis=0))`. 439 440 Args: 441 row_starts: A 1-D integer tensor with shape `[nrows]`. Must be 442 nonnegative and sorted in ascending order. If `nrows>0`, then 443 `row_starts[0]` must be zero. 444 nvals: A scalar tensor indicating the number of values. 445 validate: If true, then use assertions to check that the arguments form a 446 valid `RowPartition`. 447 dtype: Optional dtype for the RowPartition. If missing, the type 448 is inferred from the type of `row_starts`, dtype_hint, or tf.int64. 449 dtype_hint: Optional dtype for the RowPartition, used when dtype 450 is None. In some cases, a caller may not have a dtype in mind when 451 converting to a tensor, so dtype_hint can be used as a soft preference. 452 If the conversion to `dtype_hint` is not possible, this argument has no 453 effect. 454 455 Returns: 456 A `RowPartition`. 457 """ 458 if not isinstance(validate, bool): 459 raise TypeError("validate must have type bool") 460 with ops.name_scope(None, "RowPartitionFromRowStarts", [row_starts]): 461 row_starts = cls._convert_row_partition( 462 row_starts, "row_starts", dtype_hint=dtype_hint, dtype=dtype) 463 row_starts.shape.assert_has_rank(1) 464 # TODO(martinz): nvals and row_starts could be inconsistent at call time, 465 # even though they eventually end up the same type. 466 nvals = math_ops.cast(nvals, row_starts.dtype) 467 if validate: 468 msg = "Arguments to from_row_starts do not form a valid RaggedTensor" 469 checks = [ 470 check_ops.assert_rank(row_starts, 1, message=msg), 471 _assert_zero(row_starts[:1], message=msg), 472 _assert_monotonic_increasing(row_starts, message=msg), 473 check_ops.assert_less_equal(row_starts[-1:], nvals, message=msg), 474 ] 475 row_starts = control_flow_ops.with_dependencies(checks, row_starts) 476 477 row_splits = array_ops.concat([row_starts, [nvals]], axis=0) 478 return cls(row_splits=row_splits, nvals=nvals, 479 internal=_row_partition_factory_key) 480 481 @classmethod 482 def from_row_limits(cls, 483 row_limits, 484 validate=True, 485 dtype=None, 486 dtype_hint=None): 487 """Creates a `RowPartition` with rows partitioned by `row_limits`. 488 489 Equivalent to: `from_row_splits(values, concat([0, row_limits], axis=0))`. 490 491 Args: 492 row_limits: A 1-D integer tensor with shape `[nrows]`. Must be sorted in 493 ascending order. 494 validate: If true, then use assertions to check that the arguments form a 495 valid `RowPartition`. 496 dtype: Optional dtype for the RowPartition. If missing, the type 497 is inferred from the type of `row_limits`, dtype_hint, or tf.int64. 498 dtype_hint: Optional dtype for the RowPartition, used when dtype 499 is None. In some cases, a caller may not have a dtype in mind when 500 converting to a tensor, so dtype_hint can be used as a soft preference. 501 If the conversion to `dtype_hint` is not possible, this argument has no 502 effect. 503 504 Returns: 505 A `RowPartition`. 506 """ 507 if not isinstance(validate, bool): 508 raise TypeError("validate must have type bool") 509 with ops.name_scope(None, "RowPartitionFromRowLimits", [row_limits]): 510 row_limits = cls._convert_row_partition( 511 row_limits, "row_limits", dtype_hint=dtype_hint, dtype=dtype) 512 row_limits.shape.assert_has_rank(1) 513 514 if validate: 515 msg = "Arguments to from_row_limits do not form a valid RaggedTensor" 516 checks = [ 517 check_ops.assert_rank(row_limits, 1, message=msg), 518 check_ops.assert_non_negative(row_limits[:1], message=msg), 519 _assert_monotonic_increasing(row_limits, message=msg), 520 ] 521 row_limits = control_flow_ops.with_dependencies(checks, row_limits) 522 523 zero = array_ops.zeros([1], row_limits.dtype) 524 row_splits = array_ops.concat([zero, row_limits], axis=0) 525 return cls(row_splits=row_splits, internal=_row_partition_factory_key) 526 527 @classmethod 528 def from_uniform_row_length(cls, 529 uniform_row_length, 530 nvals=None, 531 nrows=None, 532 validate=True, 533 dtype=None, 534 dtype_hint=None): 535 """Creates a `RowPartition` with rows partitioned by `uniform_row_length`. 536 537 This `RowPartition` divides a sequence `values` into rows that all have 538 the same length: 539 540 ```python 541 partitioned_rows = [[values.pop(0) for _ in range(uniform_row_length)] 542 for _ in range(nrows)] 543 ``` 544 545 Note that either or both of nvals and nrows must be specified. 546 547 Args: 548 uniform_row_length: A scalar integer tensor. Must be nonnegative. The 549 size of the outer axis of `values` must be evenly divisible by 550 `uniform_row_length`. 551 nvals: a non-negative scalar integer tensor for the number of values. 552 Must be specified if nrows is not specified. If not specified, 553 defaults to uniform_row_length*nrows 554 nrows: The number of rows in the constructed RowPartition. If not 555 specified, then it defaults to `nvals/uniform_row_length` (or `0` if 556 `uniform_row_length==0`). `nrows` only needs to be specified if 557 `uniform_row_length` might be zero. `uniform_row_length*nrows` must be 558 `nvals`. 559 validate: If true, then use assertions to check that the arguments form a 560 valid `RowPartition`. 561 dtype: Optional dtype for the RowPartition. If missing, the type 562 is inferred from the type of `uniform_row_length`, dtype_hint, 563 or tf.int64. 564 dtype_hint: Optional dtype for the RowPartition, used when dtype 565 is None. In some cases, a caller may not have a dtype in mind when 566 converting to a tensor, so dtype_hint can be used as a soft preference. 567 If the conversion to `dtype_hint` is not possible, this argument has no 568 effect. 569 570 Returns: 571 A `RowPartition`. 572 """ 573 if not isinstance(validate, bool): 574 raise TypeError("validate must have type bool") 575 if nrows is None and nvals is None: 576 raise ValueError("Either (or both) of nvals and nrows must be specified") 577 with ops.name_scope(None, "RowPartitionFromUniformRowLength", 578 [uniform_row_length, nrows]): 579 [uniform_row_length, nvals, nrows 580 ] = _convert_all_to_tensors([(uniform_row_length, "uniform_row_length"), 581 (nvals, "nvals"), (nrows, "nrows")], 582 dtype=dtype, 583 dtype_hint=dtype_hint) 584 585 uniform_row_length.shape.assert_has_rank(0) 586 587 # Find nrows. 588 const_row_length = tensor_util.constant_value(uniform_row_length) 589 if nrows is None: 590 if const_row_length is None: 591 # Avoid division by zero if uniform_row_length==0 (and nvals==0). 592 rowlen_or_1 = math_ops.maximum( 593 uniform_row_length, 594 constant_op.constant(1, uniform_row_length.dtype)) 595 nrows = nvals // rowlen_or_1 596 elif const_row_length == 0: 597 nrows = constant_op.constant(0, dtype=uniform_row_length.dtype) 598 else: 599 nrows = nvals // const_row_length 600 const_nrows = None if nrows is None else tensor_util.constant_value(nrows) 601 const_nvals = None if nvals is None else tensor_util.constant_value(nvals) 602 const_uniform_row_length = tensor_util.constant_value(uniform_row_length) 603 604 checks = [] 605 606 if const_nvals is None and const_nrows is not None and const_uniform_row_length is not None: 607 const_nvals = const_nrows * const_uniform_row_length 608 if nvals is not None and validate: 609 checks.append(check_ops.assert_equal(nvals, const_nvals)) 610 nvals = constant_op.constant(const_nvals, uniform_row_length.dtype) 611 612 if nvals is None: 613 nvals = nrows * uniform_row_length 614 615 # Find row_splits. 616 if const_nrows is not None and const_row_length is not None: 617 row_splits = [v * const_row_length for v in range(const_nrows + 1)] 618 row_splits = constant_op.constant(row_splits, uniform_row_length.dtype) 619 else: 620 row_splits = math_ops.range( 621 nrows + 1, dtype=uniform_row_length.dtype) * uniform_row_length 622 623 if validate: 624 625 if (const_nrows is None or const_row_length is None or 626 const_nvals is None): 627 checks.append( 628 check_ops.assert_equal( 629 nrows * uniform_row_length, nvals, 630 ("uniform_row_length", uniform_row_length, "times nrows", 631 nrows, "must equal nvals", nvals))) 632 else: 633 if const_nrows * const_row_length != const_nvals: 634 raise ValueError( 635 "uniform_row_length=%d times nrows=%d must equal nvals=%d" % 636 (const_row_length, const_nrows, const_nvals)) 637 638 if uniform_row_length.shape.rank is None: 639 checks.append( 640 check_ops.assert_rank( 641 uniform_row_length, 642 0, 643 message="uniform_row_length must be a scalar.")) 644 645 const_row_length = tensor_util.constant_value(uniform_row_length) 646 if const_row_length is None: 647 checks.append( 648 check_ops.assert_greater_equal( 649 uniform_row_length, 650 constant_op.constant(0, uniform_row_length.dtype), 651 message="uniform_row_length must be >= 0.")) 652 else: 653 if const_row_length < 0: 654 raise ValueError("uniform_row_length must be >= 0.") 655 656 row_splits = control_flow_ops.with_dependencies(checks, row_splits) 657 658 return cls( 659 row_splits=row_splits, 660 uniform_row_length=uniform_row_length, 661 nrows=nrows, 662 nvals=nvals, 663 internal=_row_partition_factory_key) 664 665 @classmethod 666 def _convert_row_partition(cls, partition, name, dtype=None, dtype_hint=None): 667 """Converts `partition` to Tensors. 668 669 Args: 670 partition: A row-partitioning tensor for the `RowPartition` being 671 constructed. I.e., one of: row_splits, row_lengths, row_starts, 672 row_limits, value_rowids, uniform_row_length. 673 name: The name of the row-partitioning tensor. 674 dtype: Optional dtype for the RowPartition. If missing, the type 675 is inferred from the type of `uniform_row_length`, dtype_hint, 676 or tf.int64. 677 dtype_hint: Optional dtype for the RowPartition, used when dtype 678 is None. In some cases, a caller may not have a dtype in mind when 679 converting to a tensor, so dtype_hint can be used as a soft preference. 680 If the conversion to `dtype_hint` is not possible, this argument has no 681 effect. 682 683 Returns: 684 A tensor equivalent to partition. 685 686 Raises: 687 ValueError: if dtype is not int32 or int64. 688 """ 689 if dtype_hint is None: 690 dtype_hint = dtypes.int64 691 if (isinstance(partition, np.ndarray) and 692 partition.dtype == np.int32 and dtype is None): 693 partition = ops.convert_to_tensor(partition, name=name) 694 else: 695 partition = ops.convert_to_tensor_v2( 696 partition, dtype_hint=dtype_hint, dtype=dtype, name=name) 697 if partition.dtype not in (dtypes.int32, dtypes.int64): 698 raise ValueError("%s must have dtype int32 or int64" % name) 699 700 return partition 701 702 def _with_dependencies(self, dependencies): 703 """Returns a new RowPartition equal to self with control dependencies. 704 705 Specifically, self._row_splits is gated by the given control dependencies. 706 Used to add sanity checks to the constructors. 707 708 Args: 709 dependencies: a list of tensors to use as dependencies. 710 711 Returns: 712 A new RowPartition object. 713 """ 714 new_row_splits = control_flow_ops.with_dependencies(dependencies, 715 self._row_splits) 716 return RowPartition( 717 row_splits=new_row_splits, 718 row_lengths=self._row_lengths, 719 value_rowids=self._value_rowids, 720 nrows=self._nrows, 721 uniform_row_length=self._uniform_row_length, 722 internal=_row_partition_factory_key) 723 724 #============================================================================= 725 # Accessors 726 #============================================================================= 727 728 @property 729 def dtype(self): 730 """The `DType` used to encode the row partition (either int32 or int64).""" 731 return self._row_splits.dtype 732 733 def row_splits(self): 734 """Returns the row-split indices for this row partition. 735 736 `row_splits` specifies where the values for each row begin and end. 737 In particular, the values for row `i` are stored in the slice 738 `values[row_splits[i]:row_splits[i+1]]`. 739 740 Returns: 741 A 1-D integer `Tensor` with shape `[self.nrows+1]`. 742 The returned tensor is non-empty, and is sorted in ascending order. 743 `self.row_splits()[0] == 0`. 744 `self.row_splits()[-1] == self.nvals()`. 745 """ 746 return self._row_splits 747 748 def value_rowids(self): 749 """Returns the row indices for this row partition. 750 751 `value_rowids` specifies the row index fo reach value. In particular, 752 `value_rowids[i]` is the row index for `values[i]`. 753 754 Returns: 755 A 1-D integer `Tensor` with shape `[self.nvals()]`. 756 The returned tensor is nonnegative, and is sorted in ascending order. 757 """ 758 if self._value_rowids is not None: 759 return self._value_rowids 760 return segment_id_ops.row_splits_to_segment_ids(self._row_splits) 761 762 def nvals(self): 763 """Returns the number of values partitioned by this `RowPartition`. 764 765 If the sequence partitioned by this `RowPartition` is a tensor, then 766 `nvals` is the size of that tensor's outermost dimension -- i.e., 767 `nvals == values.shape[0]`. 768 769 Returns: 770 scalar integer Tensor 771 """ 772 # TODO(martinz): Uncomment these lines. 773 # if self._nvals is not None: 774 # return self._nvals 775 return self._row_splits[-1] 776 777 def nrows(self): 778 """Returns the number of rows created by this `RowPartition`. 779 780 Returns: 781 scalar integer Tensor 782 """ 783 if self._nrows is not None: 784 return self._nrows 785 nsplits = tensor_shape.dimension_at_index(self._row_splits.shape, 0) 786 if nsplits.value is None: 787 return array_ops.shape(self._row_splits, out_type=self.dtype)[0] - 1 788 else: 789 return constant_op.constant(nsplits.value - 1, dtype=self.dtype) 790 791 def uniform_row_length(self): 792 """Returns the length of each row in this partition, if rows are uniform. 793 794 If all rows in this `RowPartition` have the same length, then this returns 795 that length as a scalar integer `Tensor`. Otherwise, it returns `None`. 796 797 Returns: 798 scalar Tensor with `type=self.dtype`, or `None`. 799 """ 800 return self._uniform_row_length 801 802 def row_starts(self): 803 """Returns the start indices for rows in this row partition. 804 805 These indices specify where the values for each row begin. 806 `partition.row_starts()` is equal to `partition.row_splits()[:-1]`. 807 808 Returns: 809 A 1-D integer Tensor with shape `[self.nrows()]`. 810 The returned tensor is nonnegative, and is sorted in ascending order. 811 `self.row_starts()[0] == 0`. 812 `self.row_starts()[-1] <= self.nvals()`. 813 """ 814 return self._row_splits[:-1] 815 816 def row_limits(self): 817 """Returns the limit indices for rows in this row partition. 818 819 These indices specify where the values for each row end. 820 `partition.row_limits()` is equal to `partition.row_splits()[:-1]`. 821 822 Returns: 823 A 1-D integer Tensor with shape `[self.nrows]`. 824 The returned tensor is nonnegative, and is sorted in ascending order. 825 `self.row_limits()[-1] == self.nvals()`. 826 """ 827 return self._row_splits[1:] 828 829 def row_lengths(self): 830 """Returns the lengths of rows in this `RowPartition`. 831 832 Returns: 833 A 1-D integer Tensor with shape `[self.nrows]`. 834 The returned tensor is nonnegative. 835 `tf.reduce_sum(self.row_lengths) == self.nvals()`. 836 """ 837 if self._row_lengths is not None: 838 return self._row_lengths 839 splits = self._row_splits 840 return splits[1:] - splits[:-1] 841 842 @property 843 def static_nrows(self): 844 """The number of rows in this partition, if statically known. 845 846 ```python 847 self.row_lengths().shape == [self.static_nrows] 848 self.row_starts().shape == [self.static_nrows] 849 self.row_limits().shape == [self.static_nrows] 850 self.row_splits().shape == [self.static_nrows + 1] 851 ``` 852 853 Returns: 854 The number of rows in this partition as an `int` (if statically known); 855 or `None` (otherwise). 856 """ 857 if self._row_splits is not None: 858 nrows_plus_one = tensor_shape.dimension_value(self._row_splits.shape[0]) 859 if nrows_plus_one is not None: 860 return nrows_plus_one - 1 861 if self._row_lengths is not None: 862 nrows = tensor_shape.dimension_value(self._row_lengths.shape[0]) 863 if nrows is not None: 864 return nrows 865 if self._nrows is not None: 866 return tensor_util.constant_value(self._nrows) 867 return None 868 869 @property 870 def static_nvals(self): 871 """The number of values in this partition, if statically known. 872 873 ```python 874 self.value_rowids().shape == [self.static_vals] 875 ``` 876 877 Returns: 878 The number of values in this partition as an `int` (if statically known); 879 or `None` (otherwise). 880 """ 881 if self._nvals is not None: 882 nvals = tensor_util.constant_value(self._nvals) 883 if nvals is not None: 884 return nvals 885 if self._value_rowids is not None: 886 nvals = tensor_shape.dimension_at_index(self._value_rowids.shape, 0) 887 if nvals.value is not None: 888 return nvals.value 889 return None 890 891 @property 892 def static_uniform_row_length(self): 893 """The number of values in each row of this partition, if statically known. 894 895 Returns: 896 The number of values in each row of this partition as an `int` (if 897 statically known); or `None` (otherwise). 898 """ 899 if self._uniform_row_length is not None: 900 return tensor_util.constant_value(self._uniform_row_length) 901 return None 902 903 def offsets_in_rows(self): 904 """Return the offset of each value. 905 906 RowPartition takes an array x and converts it into sublists. 907 offsets[i] is the index of x[i] in its sublist. 908 Given a shape, such as: 909 [*,*,*],[*,*],[],[*,*] 910 This returns: 911 0,1,2,0,1,0,1 912 913 Returns: 914 an offset for every value. 915 """ 916 return gen_ragged_math_ops.ragged_range( 917 starts=constant_op.constant(0, self.dtype), 918 limits=self.row_lengths(), 919 deltas=constant_op.constant(1, self.dtype)).rt_dense_values 920 921 def is_uniform(self): 922 """Returns true if the partition is known to be uniform statically. 923 924 This is based upon the existence of self._uniform_row_length. For example: 925 RowPartition.from_row_lengths([3,3,3]).is_uniform()==false 926 RowPartition.from_uniform_row_length(5, nvals=20).is_uniform()==true 927 RowPartition.from_row_lengths([2,0,2]).is_uniform()==false 928 929 Returns: 930 Whether a RowPartition is known to be uniform statically. 931 """ 932 return self._uniform_row_length is not None 933 934 def _static_check(self): 935 """Checks if the object is internally consistent. 936 937 Raises: 938 ValueError if inconsistent. 939 """ 940 my_dtype = self.dtype 941 if self._uniform_row_length is not None: 942 if self._uniform_row_length.dtype != my_dtype: 943 raise ValueError("_uniform_row_length.dtype=" + 944 str(self._uniform_row_length.dtype) + ", not " + 945 str(my_dtype)) 946 947 if self._row_lengths is not None and self._row_lengths.dtype != my_dtype: 948 raise ValueError("_row_lengths.dtype=" + str(self._row_lengths.dtype) + 949 ", not " + str(my_dtype)) 950 951 if self._value_rowids is not None and self._value_rowids.dtype != my_dtype: 952 raise ValueError("_value_rowids.dtype=" + str(self._value_rowids.dtype) + 953 ", not " + str(my_dtype)) 954 955 if self._nrows is not None and self._nrows.dtype != my_dtype: 956 raise ValueError("_nrows.dtype=" + str(self._nrows.dtype) + ", not " + 957 str(my_dtype)) 958 959 #============================================================================= 960 # Transformation 961 #============================================================================= 962 963 def with_dtype(self, dtype): 964 """Returns a copy of this RowPartition with the given encoding dtype. 965 966 Args: 967 dtype: The dtype for encoding tensors, such as `row_splits` and `nrows`. 968 One of `tf.int32` or `tf.int64`. 969 970 Returns: 971 A copy of this RowPartition, with the encoding tensors cast to the given 972 type. 973 """ 974 dtype = dtypes.as_dtype(dtype) 975 if dtype not in (dtypes.int32, dtypes.int64): 976 raise ValueError("dtype must be int32 or int64") 977 if self.dtype == dtype: 978 return self 979 980 return RowPartition( 981 row_splits=_cast_if_not_none(self._row_splits, dtype), 982 row_lengths=_cast_if_not_none(self._row_lengths, dtype), 983 value_rowids=_cast_if_not_none(self._value_rowids, dtype), 984 nrows=_cast_if_not_none(self._nrows, dtype), 985 uniform_row_length=_cast_if_not_none(self._uniform_row_length, dtype), 986 internal=_row_partition_factory_key) 987 988 #============================================================================= 989 # String Encoding 990 #============================================================================= 991 992 def __repr__(self): 993 if self._uniform_row_length is not None: 994 return (f"tf.RowPartition(nrows={self._nrows}, " 995 f"uniform_row_length={self._uniform_row_length})") 996 else: 997 return f"tf.RowPartition(row_splits={self._row_splits})" 998 999 #============================================================================= 1000 # Precomputed Encodings 1001 #============================================================================= 1002 1003 def _has_precomputed_row_splits(self): 1004 """Returns true if `row_splits` has already been computed. 1005 1006 If true, then `self.row_splits()` will return its value without calling 1007 any TensorFlow ops. 1008 """ 1009 return self._row_splits is not None 1010 1011 def _has_precomputed_row_lengths(self): 1012 """Returns true if `row_lengths` has already been computed. 1013 1014 If true, then `self.row_lengths()` will return its value without calling 1015 any TensorFlow ops. 1016 """ 1017 return self._row_lengths is not None 1018 1019 def _has_precomputed_value_rowids(self): 1020 """Returns true if `value_rowids` has already been computed. 1021 1022 If true, then `self.value_rowids()` will return its value without calling 1023 any TensorFlow ops. 1024 """ 1025 return self._value_rowids is not None 1026 1027 def _has_precomputed_nrows(self): 1028 """Returns true if `nrows` has already been computed. 1029 1030 If true, then `self.nrows()` will return its value without calling 1031 any TensorFlow ops. 1032 """ 1033 return self._nrows is not None 1034 1035 def _has_precomputed_nvals(self): 1036 """Returns true if `nvals` has already been computed. 1037 1038 If true, then `self.nvals()` will return its value without calling 1039 any TensorFlow ops. 1040 """ 1041 return self._nvals is not None 1042 1043 def _with_precomputed_row_splits(self): 1044 """Returns a copy of `self` with `row_splits` precomputed.""" 1045 return RowPartition( 1046 row_splits=self.row_splits(), 1047 row_lengths=self._row_lengths, 1048 value_rowids=self._value_rowids, 1049 nrows=self._nrows, 1050 uniform_row_length=self._uniform_row_length, 1051 nvals=self._nvals, 1052 internal=_row_partition_factory_key) 1053 1054 def _with_precomputed_row_lengths(self): 1055 """Returns a copy of `self` with `row_lengths` precomputed.""" 1056 return RowPartition( 1057 row_splits=self._row_splits, 1058 row_lengths=self.row_lengths(), 1059 value_rowids=self._value_rowids, 1060 nrows=self._nrows, 1061 nvals=self._nvals, 1062 uniform_row_length=self._uniform_row_length, 1063 internal=_row_partition_factory_key) 1064 1065 def _with_precomputed_value_rowids(self): 1066 """Returns a copy of `self` with `value_rowids` precomputed.""" 1067 return RowPartition( 1068 row_splits=self._row_splits, 1069 row_lengths=self._row_lengths, 1070 value_rowids=self.value_rowids(), 1071 nrows=self._nrows, 1072 nvals=self._nvals, 1073 uniform_row_length=self._uniform_row_length, 1074 internal=_row_partition_factory_key) 1075 1076 def _with_precomputed_nrows(self): 1077 """Returns a copy of `self` with `nrows` precomputed.""" 1078 return RowPartition( 1079 row_splits=self._row_splits, 1080 row_lengths=self._row_lengths, 1081 value_rowids=self._value_rowids, 1082 nrows=self.nrows(), 1083 nvals=self._nvals, 1084 uniform_row_length=self._uniform_row_length, 1085 internal=_row_partition_factory_key) 1086 1087 def _with_precomputed_nvals(self): 1088 """Returns a copy of `self` with `row_splits` precomputed.""" 1089 return RowPartition( 1090 row_splits=self.row_splits(), 1091 row_lengths=self._row_lengths, 1092 value_rowids=self._value_rowids, 1093 nrows=self._nrows, 1094 nvals=self.nvals(), 1095 uniform_row_length=self._uniform_row_length, 1096 internal=_row_partition_factory_key) 1097 1098 def _merge_with_spec(self, b): 1099 """Merge with a TypeSpec to create a new RowPartition.""" 1100 a_spec = self._type_spec 1101 if not a_spec.is_compatible_with(b): 1102 # TODO(martinz): Should a dynamic check be used here? 1103 raise ValueError("RowPartition and RowPartitionSpec are not compatible") 1104 nrows = constant_op.constant( 1105 b.nrows, self.dtype) if b.nrows is not None else self._nrows 1106 nvals = constant_op.constant( 1107 b.nvals, self.dtype) if b.nvals is not None else self._nvals 1108 uniform_row_length = constant_op.constant( 1109 b.uniform_row_length, self.dtype 1110 ) if b.uniform_row_length is not None else self._uniform_row_length 1111 return RowPartition( 1112 row_splits=self._row_splits, 1113 row_lengths=self._row_lengths, 1114 value_rowids=self._value_rowids, 1115 nvals=nvals, 1116 uniform_row_length=uniform_row_length, 1117 nrows=nrows, 1118 internal=_row_partition_factory_key) 1119 1120 def _merge_precomputed_encodings(self, other, validate=True): 1121 """Returns a RowPartition that merges encodings from `self` and `other`. 1122 1123 Requires that `self` and `other` describe the same partition. 1124 1125 Args: 1126 other: A `RowPartition` that encodes the same partition as `self`. 1127 validate: If true, then add runtime checks to verify that `self` and 1128 `other` encode the same row partition. 1129 1130 Returns: 1131 A `RowPartition`. 1132 """ 1133 # pylint: disable=protected-access 1134 if (self is other or # Fast path if row partitions are equal. 1135 (self._row_splits is other._row_splits and 1136 self._row_lengths is other._row_lengths and 1137 self._value_rowids is other._value_rowids and 1138 self._nrows is other._nrows and 1139 self._nvals is other._nvals and 1140 self._uniform_row_length is other._uniform_row_length)): 1141 return self 1142 1143 # Merge the component tensors. We only need to validate one encoding. 1144 # We merge less-expensive encodings first (to avoid expensive validation). 1145 nrows, nrows_validated = _merge_tensors(self._nrows, other._nrows, "nrows", 1146 validate) 1147 nvals, _ = _merge_tensors(self._nvals, other._nvals, "nvals", validate) 1148 uniform_row_length, uniform_row_length_validated = _merge_tensors( 1149 self._uniform_row_length, other._uniform_row_length, 1150 "uniform_row_length", validate) 1151 if uniform_row_length_validated and nrows_validated: 1152 validate = False # Validation complete. 1153 row_splits, row_splits_validated = _merge_tensors(self._row_splits, 1154 other._row_splits, 1155 "row_splits", validate) 1156 if row_splits_validated: 1157 validate = False # Validation complete. 1158 row_lengths, row_lengths_validated = _merge_tensors(self._row_lengths, 1159 other._row_lengths, 1160 "row_lengths", validate) 1161 if row_lengths_validated: 1162 validate = False # Validation complete. 1163 value_rowids, value_rowids_validated = _merge_tensors( 1164 self._value_rowids, other._value_rowids, "value_rowids", validate) 1165 if value_rowids_validated and nrows_validated: 1166 validate = False # Validation complete. 1167 # TODO(edloper): If we make the row_splits encoding optional, then there 1168 # will be cases where we need to do validation at this point -- e.g. if 1169 # self has only row_splits and other has only value_rowids. But for 1170 # now, we are guaranteed to have done validation by this point. 1171 1172 # Avoid creating new RowPartition objects if we don't need to. 1173 if (row_splits is self._row_splits and row_lengths is self._row_lengths and 1174 value_rowids is self._value_rowids and nrows is self._nrows and 1175 uniform_row_length is self._uniform_row_length): 1176 return self 1177 if (row_splits is other._row_splits and 1178 row_lengths is other._row_lengths and 1179 value_rowids is other._value_rowids and nrows is other._nrows and 1180 uniform_row_length is other._uniform_row_length): 1181 return other 1182 1183 return RowPartition( 1184 row_splits=row_splits, 1185 row_lengths=row_lengths, 1186 value_rowids=value_rowids, 1187 nrows=nrows, 1188 uniform_row_length=uniform_row_length, 1189 nvals=nvals, 1190 internal=_row_partition_factory_key) 1191 1192 #============================================================================= 1193 # Composite Tensor 1194 #============================================================================= 1195 1196 @property 1197 def _type_spec(self): 1198 return RowPartitionSpec.from_value(self) 1199 1200 1201#=============================================================================== 1202# RowPartitionSpec 1203#=============================================================================== 1204# TODO(edloper): Consider refactoring RowPartitionSpec to allow any combination 1205# of precomputed row-partition encodings (rather than always using row_splits). 1206 1207 1208@type_spec.register("tf.RowPartitionSpec") 1209class RowPartitionSpec(type_spec.TypeSpec): 1210 """Type specification for a `tf.RowPartition`.""" 1211 1212 __slots__ = ["_nrows", "_nvals", "_uniform_row_length", "_dtype"] 1213 1214 value_type = property(lambda self: RowPartition) 1215 1216 def __init__(self, 1217 nrows=None, 1218 nvals=None, 1219 uniform_row_length=None, 1220 dtype=dtypes.int64): 1221 """Constructs a new RowPartitionSpec. 1222 1223 Args: 1224 nrows: The number of rows in the RowPartition, or `None` if unspecified. 1225 nvals: The number of values partitioned by the RowPartition, or `None` if 1226 unspecified. 1227 uniform_row_length: The number of values in each row for this 1228 RowPartition, or `None` if rows are ragged or row length is unspecified. 1229 dtype: The data type used to encode the partition. One of `tf.int64` or 1230 `tf.int32`. 1231 """ 1232 # Wrap dimension sizes in 1D TensorShapes so the default implementations 1233 # of TypeSpec methods such as `is_compatile_with` will work. 1234 nrows = tensor_shape.TensorShape([nrows]) 1235 nvals = tensor_shape.TensorShape([nvals]) 1236 if not isinstance(uniform_row_length, tensor_shape.TensorShape): 1237 uniform_row_length = tensor_shape.TensorShape([uniform_row_length]) 1238 else: 1239 uniform_row_length = uniform_row_length.with_rank(1) 1240 1241 self._nrows = nrows 1242 self._nvals = nvals 1243 self._uniform_row_length = uniform_row_length 1244 self._dtype = dtypes.as_dtype(dtype) 1245 if self._dtype not in (dtypes.int32, dtypes.int64): 1246 raise ValueError("dtype must be tf.int32 or tf.int64") 1247 1248 # Check dimension consistency, & infer dimensions when possible. 1249 nrows = tensor_shape.dimension_value(nrows[0]) 1250 nvals = tensor_shape.dimension_value(nvals[0]) 1251 ncols = tensor_shape.dimension_value(uniform_row_length[0]) 1252 if nrows == 0: # no rows -> no values. 1253 if nvals is None: 1254 self._nvals = tensor_shape.TensorShape([0]) 1255 elif nvals != 0: 1256 raise ValueError("nvals=%s is not compatible with nrows=%s" % 1257 (nvals, nrows)) 1258 if ncols == 0: # there are no values in each row -> no values. 1259 if nvals is None: 1260 self._nvals = tensor_shape.TensorShape([0]) 1261 elif nvals != 0: 1262 raise ValueError("nvals=%s is not compatible with uniform_row_length" 1263 "=%s" % (nvals, uniform_row_length)) 1264 if ncols is not None and nvals is not None: 1265 if ncols != 0 and nvals % ncols != 0: 1266 raise ValueError("nvals=%s is not compatible with uniform_row_length" 1267 "=%s (doesn't divide evenly)" % (nvals, ncols)) 1268 if nrows is not None and nvals != ncols * nrows: 1269 raise ValueError("nvals=%s is not compatible with nrows=%s and " 1270 "uniform_row_length=%s" % (nvals, nrows, ncols)) 1271 if nrows is None and ncols != 0: 1272 self._nrows = tensor_shape.TensorShape([nvals // ncols]) 1273 if ncols is not None and nrows is not None and nvals is None: 1274 self._nvals = tensor_shape.TensorShape([ncols * nrows]) 1275 1276 def is_compatible_with(self, other): 1277 if not super(RowPartitionSpec, self).is_compatible_with(other): 1278 return False 1279 nrows = self._nrows.merge_with(other.nrows) 1280 nvals = self._nvals.merge_with(other.nvals) 1281 ncols = self._uniform_row_length.merge_with(other.uniform_row_length) 1282 return self._dimensions_compatible(nrows, nvals, ncols) 1283 1284 def _serialize(self): 1285 return (self._nrows, self._nvals, self._uniform_row_length, self._dtype) 1286 1287 @classmethod 1288 def _deserialize(cls, serialization): 1289 # Remove TensorShape wrappers from serialization. 1290 (nrows, nvals, uniform_row_length, dtype) = serialization 1291 nrows = tensor_shape.dimension_value(nrows[0]) 1292 nvals = tensor_shape.dimension_value(nvals[0]) 1293 return cls(nrows, nvals, uniform_row_length, dtype) 1294 1295 @property 1296 def nrows(self): 1297 return tensor_shape.dimension_value(self._nrows[0]) 1298 1299 @property 1300 def nvals(self): 1301 return tensor_shape.dimension_value(self._nvals[0]) 1302 1303 @property 1304 def uniform_row_length(self): 1305 return tensor_shape.dimension_value(self._uniform_row_length[0]) 1306 1307 @property 1308 def dtype(self): 1309 return self._dtype 1310 1311 @property 1312 def _component_specs(self): 1313 row_splits_shape = tensor_shape.TensorShape( 1314 [tensor_shape.dimension_at_index(self._nrows, 0) + 1]) 1315 return tensor_spec.TensorSpec(row_splits_shape, self._dtype) 1316 1317 def _to_components(self, value): 1318 return value.row_splits() 1319 1320 def _from_components(self, tensor): 1321 return RowPartition.from_row_splits(tensor, validate=False) 1322 1323 @classmethod 1324 def from_value(cls, value): 1325 if not isinstance(value, RowPartition): 1326 raise TypeError("Expected `value` to be a `RowPartition`") 1327 return cls(value.static_nrows, value.static_nvals, 1328 value.static_uniform_row_length, value.dtype) 1329 1330 def __repr__(self): 1331 return ("RowPartitionSpec(nrows=%s, nvals=%s, uniform_row_length=%s, " 1332 "dtype=%r)" % (self.nrows, self.nvals, self.uniform_row_length, 1333 self.dtype)) 1334 1335 @staticmethod 1336 def _dimensions_compatible(nrows, nvals, uniform_row_length): 1337 """Returns true if the given dimensions are compatible.""" 1338 nrows = tensor_shape.dimension_value(nrows[0]) 1339 nvals = tensor_shape.dimension_value(nvals[0]) 1340 ncols = tensor_shape.dimension_value(uniform_row_length[0]) 1341 if nrows == 0 and nvals not in (0, None): 1342 return False # can't have values if we have no rows. 1343 if ncols == 0 and nvals not in (0, None): 1344 return False # can't have values if we have no values in each row. 1345 if ncols is not None and nvals is not None: 1346 if ncols != 0 and nvals % ncols != 0: 1347 return False # rows aren't uniform. 1348 if nrows is not None and nvals != ncols * nrows: 1349 return False # inconsistent number of values. 1350 return True 1351 1352 def _merge_with(self, other): 1353 """Merge two RowPartitionSpecs.""" 1354 nrows = self._nrows.merge_with(other.nrows) 1355 nvals = self._nvals.merge_with(other.nvals) 1356 ncols = self._uniform_row_length.merge_with(other.uniform_row_length) 1357 1358 if not RowPartitionSpec._dimensions_compatible(nrows, nvals, ncols): 1359 raise ValueError("Merging incompatible RowPartitionSpecs") 1360 1361 # NOTE: if the dtypes are unequal, behavior is unspecified. 1362 if self.dtype != other.dtype: 1363 raise ValueError("Merging RowPartitionSpecs with incompatible dtypes") 1364 1365 return RowPartitionSpec(nrows=nrows[0], 1366 nvals=nvals[0], 1367 uniform_row_length=ncols[0], 1368 dtype=self.dtype) 1369 1370 def with_dtype(self, dtype): 1371 nrows = tensor_shape.dimension_value(self._nrows[0]) 1372 nvals = tensor_shape.dimension_value(self._nvals[0]) 1373 return RowPartitionSpec(nrows, nvals, self._uniform_row_length, dtype) 1374 1375 def __deepcopy__(self, memo): 1376 del memo 1377 dtype = self.dtype 1378 nrows = tensor_shape.dimension_value(self._nrows[0]) 1379 nvals = tensor_shape.dimension_value(self._nvals[0]) 1380 uniform_row_length = (None if self._uniform_row_length is None else 1381 tensor_shape.dimension_value( 1382 self._uniform_row_length[0])) 1383 return RowPartitionSpec(nrows, nvals, uniform_row_length, dtype) 1384 1385 1386#=============================================================================== 1387# Helper Functions 1388#=============================================================================== 1389 1390 1391def _assert_monotonic_increasing(tensor, message=None): 1392 return check_ops.assert_non_negative( 1393 tensor[1:] - tensor[:-1], message=message) 1394 1395 1396def _assert_zero(tensor, message=None): 1397 return check_ops.assert_equal( 1398 tensor, constant_op.constant(0, dtype=tensor.dtype), message=message) 1399 1400 1401def _cast_if_not_none(tensor, dtype): 1402 return None if tensor is None else math_ops.cast(tensor, dtype) 1403 1404 1405def _merge_tensors(t1, t2, name, validate): 1406 """Merge two optional Tensors with equal values into a single Tensor. 1407 1408 Args: 1409 t1: tf.Tensor or None 1410 t2: tf.Tensor or None 1411 name: A name for the tensors (for error messages) 1412 validate: If true, then check that `t1` is compatible with `t2` (if both are 1413 non-None). 1414 1415 Returns: 1416 A pair `(merged_value, validated)`: 1417 * `merged_value` is `t1` if it is not None; or `t2` otherwise. 1418 * `validated` is true if we validated that t1 and t2 are equal (either 1419 by adding a check, or because t1 is t2). 1420 """ 1421 if t1 is None: 1422 return t2, False 1423 elif t2 is None: 1424 return t1, False 1425 elif t1 is t2: 1426 return t1, True 1427 else: 1428 err_msg = ("RowPartition._merge_precomputed_encodings: partitions " 1429 "have incompatible %s" % name) 1430 if not t1.shape.is_compatible_with(t2.shape): 1431 raise ValueError(err_msg) 1432 if validate: 1433 checks = [check_ops.assert_equal(t1, t2, message=err_msg)] 1434 return control_flow_ops.with_dependencies(checks, t1), True 1435 else: 1436 return t1, False 1437 1438_row_partition_factory_key = object() # unique private object 1439 1440 1441def _get_dtype_or_none(value): 1442 if isinstance(value, ops.Tensor): 1443 return value.dtype 1444 return None 1445 1446 1447def _get_target_dtype(values, dtype=None, dtype_hint=None): 1448 """Gets the target dtype of a family of values.""" 1449 if dtype is not None: 1450 return dtype 1451 1452 for value in values: 1453 if isinstance(value, ops.Tensor): 1454 return value.dtype 1455 1456 for value in values: 1457 if isinstance(value, np.ndarray): 1458 return dtypes.as_dtype(value.dtype) 1459 1460 if dtype_hint is not None: 1461 return dtype_hint 1462 1463 return dtypes.int64 1464 1465 1466def _convert_all_to_tensors(values, dtype=None, dtype_hint=None): 1467 """Convert a list of objects to tensors of the same dtype.""" 1468 target_dtype = _get_target_dtype([x for (x, _) in values], dtype, dtype_hint) 1469 1470 # If dtype is None, we use convert behavior. 1471 # If dtype is not None, we use cast behavior. 1472 convert_behavior = dtype is None 1473 1474 if convert_behavior: 1475 return [ 1476 None if x is None else ops.convert_to_tensor( 1477 x, dtype=target_dtype, name=name) for (x, name) in values 1478 ] 1479 else: 1480 return [ 1481 None if x is None else math_ops.cast(x, dtype=target_dtype, name=name) 1482 for (x, name) in values 1483 ] 1484