• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A class used to partition a sequence into contiguous subsequences ("rows").
16"""
17
18
19# TODO(edloper):  Make into a ExtensionType (if possible)
20
21
22import numpy as np
23
24from tensorflow.python.framework import composite_tensor
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import tensor_shape
29from tensorflow.python.framework import tensor_spec
30from tensorflow.python.framework import tensor_util
31from tensorflow.python.framework import type_spec
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import check_ops
34from tensorflow.python.ops import control_flow_ops
35from tensorflow.python.ops import gen_ragged_math_ops
36from tensorflow.python.ops import math_ops
37from tensorflow.python.ops.ragged import segment_id_ops
38from tensorflow.python.util.tf_export import tf_export
39
40#===============================================================================
41# RowPartition
42#===============================================================================
43# TODO(edloper): Consider removing row_starts and row_limits factory methods
44# and accessors from RowPartition.  In particular, these two encodings are
45# "second-class citizens": we never cache them, and if you do construct a
46# RowPartition from them then it may be more expensive than you might expect
47# (because we append a value to the beginning/end to transform them into
48# splits).  If we do remove them from RowPartition, then we would still keep
49# the from_row_starts and from_row_limits factory methods in RaggedTensor.
50
51
52@tf_export("experimental.RowPartition")
53class RowPartition(composite_tensor.CompositeTensor):
54  """Partitioning of a sequence of values into contiguous subsequences ("rows").
55
56  A `RowPartition` describes how a sequence with `nvals` items should be
57  divided into `nrows` contiguous subsequences ("rows").  For example, a
58  `RowPartition` could be used to partition the vector `[1, 2, 3, 4, 5]` into
59  subsequences `[[1, 2], [3], [], [4, 5]]`.  Note that `RowPartition` stores
60  information about how values are partitioned, but does not include the
61  partitioned values themselves.  `tf.RaggedTensor` is used to pair a `values`
62  tensor with one or more `RowPartition`s, providing a complete encoding for a
63  ragged tensor (i.e. a tensor with variable-length dimensions).
64
65  `RowPartition`s may be defined using several different schemes:
66
67    * `row_lengths`: an integer vector with shape `[nrows]`, which specifies
68      the length of each row.
69
70    * `row_splits`: an integer vector with shape `[nrows+1]`, specifying the
71      "split points" between each row.
72
73    * `row_starts`: an integer vector with shape `[nrows]`, which specifies
74      the start offset for each row.  Equivalent to `row_splits[:-1]`.
75
76    * `row_limits`: an integer vector with shape `[nrows]`, which specifies
77      the stop offset for each row.  Equivalent to `row_splits[1:]`.
78
79    * `value_rowids` is an integer vector with shape `[nvals]`, corresponding
80      one-to-one with sequence values, which specifies the row that each value
81      belongs to.  If the partition has empty trailing rows, then `nrows`
82      must also be specified.
83
84    * `uniform_row_length` is an integer scalar, specifying the length of every
85      row.  This scheme may only be used if all rows have the same length.
86
87  For example, the following `RowPartition`s all represent the partitioning of
88  8 values into 5 sublists as follows: `[[*, *, *, *], [], [*, *, *], [*], []]`.
89
90  >>> p1 = RowPartition.from_row_lengths([4, 0, 3, 1, 0])
91  >>> p2 = RowPartition.from_row_splits([0, 4, 4, 7, 8, 8])
92  >>> p3 = RowPartition.from_row_starts([0, 4, 4, 7, 8], nvals=8)
93  >>> p4 = RowPartition.from_row_limits([4, 4, 7, 8, 8])
94  >>> p5 = RowPartition.from_value_rowids([0, 0, 0, 0, 2, 2, 2, 3], nrows=5)
95
96  For more information about each scheme, see the documentation for the
97  its factory method.  For additional examples, see the documentation on
98  `tf.RaggedTensor`.
99
100  ### Precomputed Encodings
101
102  `RowPartition` always stores at least one encoding of the partitioning, but
103  it can be configured to cache additional encodings as well.  This can
104  avoid unnecessary recomputation in eager mode.  (In graph mode, optimizations
105  such as common subexpression elimination will typically prevent these
106  unnecessary recomputations.)  To check which encodings are precomputed, use
107  `RowPartition.has_precomputed_<encoding>`.  To cache an additional
108  encoding, use `RowPartition.with_precomputed_<encoding>`.
109  """
110
111  #=============================================================================
112  # Constructor (private)
113  #=============================================================================
114  def __init__(self,
115               row_splits,
116               row_lengths=None,
117               value_rowids=None,
118               nrows=None,
119               uniform_row_length=None,
120               nvals=None,
121               internal=False):
122    """Creates a `RowPartition` from the specified encoding tensor(s).
123
124    This constructor is private -- please use one of the following ops to
125    build `RowPartition`s:
126
127      * `RowPartition.from_row_lengths`
128      * `RowPartition.from_value_rowids`
129      * `RowPartition.from_row_splits`
130      * `RowPartition.from_row_starts`
131      * `RowPartition.from_row_limits`
132      * `RowPartition.from_uniform_row_length`
133
134    If row_splits is has a constant value, then all other arguments should
135    have a constant value.
136
137    Args:
138      row_splits: A 1-D integer tensor with shape `[nrows+1]`.
139      row_lengths: A 1-D integer tensor with shape `[nrows]`
140      value_rowids: A 1-D integer tensor with shape `[nvals]`.
141      nrows: A 1-D integer scalar tensor.
142      uniform_row_length: A scalar tensor.
143      nvals: A scalar tensor.
144      internal: Private key value, required to ensure that this private
145        constructor is *only* called from the factory methods.
146
147    Raises:
148      TypeError: If a row partitioning tensor has an inappropriate dtype.
149      TypeError: If exactly one row partitioning argument was not specified.
150      ValueError: If a row partitioning tensor has an inappropriate shape.
151      ValueError: If multiple partitioning arguments are specified.
152      ValueError: If nrows is specified but value_rowids is not None.
153    """
154    if internal is not _row_partition_factory_key:
155      raise ValueError("RowPartition constructor is private; please use one "
156                       "of the factory methods instead (e.g., "
157                       "RowPartition.from_row_lengths())")
158
159    # Validate the arguments.
160    if not isinstance(row_splits, ops.Tensor):
161      raise TypeError("Row-partitioning argument must be a Tensor, got %r" %
162                      row_splits)
163    if row_splits.dtype not in (dtypes.int32, dtypes.int64):
164      raise ValueError("Row-partitioning argument must be int32 or int64")
165
166    # Validate shapes & dtypes.
167    row_splits.shape.assert_has_rank(1)
168    row_splits.set_shape([None])
169    self._row_splits = row_splits
170
171    # Store any cached tensors.  These are used to avoid unnecessary
172    # round-trip conversions when a RowPartition is constructed from
173    # lengths or rowids, and we later want those lengths/rowids back.
174    for tensor in [row_lengths, value_rowids, nrows, uniform_row_length, nvals]:
175      if tensor is not None:
176        if not isinstance(tensor, ops.Tensor):
177          raise TypeError("Cached value must be a Tensor or None.")
178        elif tensor.dtype != row_splits.dtype:
179          raise ValueError(f"Inconsistent dtype for encoding tensors: "
180                           f"{tensor} vs {row_splits}")
181    self._row_lengths = row_lengths
182    self._value_rowids = value_rowids
183    self._nrows = nrows
184    self._uniform_row_length = uniform_row_length
185    self._nvals = nvals
186
187  #=============================================================================
188  # Factory Methods
189  #=============================================================================
190
191  @classmethod
192  def from_value_rowids(cls,
193                        value_rowids,
194                        nrows=None,
195                        validate=True,
196                        dtype=None,
197                        dtype_hint=None):
198    """Creates a `RowPartition` with rows partitioned by `value_rowids`.
199
200    This `RowPartition` divides a sequence `values` into rows by specifying
201    which row each value should be added to:
202
203    ```python
204    partitioned_rows = [[] for _ in nrows]
205    for (value, rowid) in zip(values, value_rowids):
206      partitioned_rows[rowid].append(value)
207    ``
208
209    Args:
210      value_rowids: A 1-D integer tensor with shape `[nvals]`, which corresponds
211        one-to-one with `values`, and specifies each value's row index.  Must be
212        nonnegative, and must be sorted in ascending order.
213      nrows: An integer scalar specifying the number of rows.  This should be
214        specified if the `RowPartition` may containing empty training rows. Must
215        be greater than `value_rowids[-1]` (or greater than or equal to zero if
216        `value_rowids` is empty). Defaults to `value_rowids[-1] + 1` (or zero if
217        `value_rowids` is empty).
218      validate: If true, then use assertions to check that the arguments form a
219        valid `RowPartition`.
220      dtype: Optional dtype for the RowPartition. If missing, the type
221        is inferred from the type of `value_rowids`, dtype_hint, or tf.int64.
222      dtype_hint: Optional dtype for the RowPartition, used when dtype
223        is None. In some cases, a caller may not have a dtype in mind when
224        converting to a tensor, so dtype_hint can be used as a soft preference.
225        If the conversion to `dtype_hint` is not possible, this argument has no
226        effect.
227
228    Returns:
229      A `RowPartition`.
230
231    Raises:
232      ValueError: If `nrows` is incompatible with `value_rowids`.
233
234    #### Example:
235
236    >>> print(RowPartition.from_value_rowids(
237    ...     value_rowids=[0, 0, 0, 0, 2, 2, 2, 3],
238    ...     nrows=4))
239    tf.RowPartition(row_splits=[0 4 4 7 8])
240    """
241    # Local import bincount_ops to avoid import-cycle since bincount_ops
242    # imports ragged_tensor.
243    from tensorflow.python.ops import bincount_ops  # pylint: disable=g-import-not-at-top
244    if not isinstance(validate, bool):
245      raise TypeError("validate must have type bool")
246    with ops.name_scope(None, "RowPartitionFromValueRowIds",
247                        [value_rowids, nrows]):
248      value_rowids = cls._convert_row_partition(
249          value_rowids, "value_rowids", dtype_hint=dtype_hint, dtype=dtype)
250      if nrows is None:
251        const_rowids = tensor_util.constant_value(value_rowids)
252        if const_rowids is None:
253          nrows = array_ops.concat([value_rowids[-1:], [-1]], axis=0)[0] + 1
254          const_nrows = None
255        else:
256          const_nrows = const_rowids[-1] + 1 if const_rowids.size > 0 else 0
257          nrows = ops.convert_to_tensor(
258              const_nrows, value_rowids.dtype, name="nrows")
259      else:
260        nrows = ops.convert_to_tensor(nrows, value_rowids.dtype, "nrows")
261        const_nrows = tensor_util.constant_value(nrows)
262        if const_nrows is not None:
263          if const_nrows < 0:
264            raise ValueError("Expected nrows >= 0; got %d" % const_nrows)
265          const_rowids = tensor_util.constant_value(value_rowids)
266          if const_rowids is not None and const_rowids.size > 0:
267            if not const_nrows >= const_rowids[-1] + 1:
268              raise ValueError(
269                  "Expected nrows >= value_rowids[-1] + 1; got nrows=%d, "
270                  "value_rowids[-1]=%d" % (const_nrows, const_rowids[-1]))
271
272      value_rowids.shape.assert_has_rank(1)
273      nrows.shape.assert_has_rank(0)
274
275      if validate:
276        msg = ("Arguments to from_value_rowids do not form a valid "
277               "RowPartition")
278        checks = [
279            check_ops.assert_rank(value_rowids, 1, message=msg),
280            check_ops.assert_rank(nrows, 0, message=msg),
281            check_ops.assert_non_negative(value_rowids[:1], message=msg),
282            _assert_monotonic_increasing(value_rowids, message=msg),
283            check_ops.assert_less(value_rowids[-1:], nrows, message=msg),
284        ]
285        value_rowids = control_flow_ops.with_dependencies(checks, value_rowids)
286
287      # Convert value_rowids & nrows to row_splits.
288      # Note: we don't use segment_ids_to_row_splits() here because we want
289      # to save the intermediate value `row_lengths`, so we can cache it.
290      # TODO(b/116708836) Upgrade bincount to accept int64 so we can skip the
291      # cast.
292      value_rowids_int32 = math_ops.cast(value_rowids, dtypes.int32)
293      nrows_int32 = math_ops.cast(nrows, dtypes.int32)
294      row_lengths = bincount_ops.bincount(
295          value_rowids_int32,
296          minlength=nrows_int32,
297          maxlength=nrows_int32,
298          dtype=value_rowids.dtype)
299      row_splits = array_ops.concat([[0], math_ops.cumsum(row_lengths)], axis=0)
300      if const_nrows is not None:
301        row_lengths.set_shape([const_nrows])
302        row_splits.set_shape([const_nrows + 1])
303
304      return cls(
305          row_splits=row_splits,
306          row_lengths=row_lengths,
307          value_rowids=value_rowids,
308          nrows=nrows,
309          internal=_row_partition_factory_key)
310
311  @classmethod
312  def from_row_splits(cls,
313                      row_splits,
314                      validate=True,
315                      dtype=None,
316                      dtype_hint=None):
317    """Creates a `RowPartition` with rows partitioned by `row_splits`.
318
319    This `RowPartition` divides a sequence `values` into rows by indicating
320    where each row begins and ends:
321
322    ```python
323    partitioned_rows = []
324    for i in range(len(row_splits) - 1):
325      row_start = row_splits[i]
326      row_end = row_splits[i + 1]
327      partitioned_rows.append(values[row_start:row_end])
328    ```
329
330    Args:
331      row_splits: A 1-D integer tensor with shape `[nrows+1]`.  Must not be
332        empty, and must be sorted in ascending order.  `row_splits[0]` must be
333        zero.
334      validate: If true, then use assertions to check that the arguments form a
335        valid `RowPartition`.
336      dtype: Optional dtype for the RowPartition. If missing, the type
337        is inferred from the type of `row_splits`, dtype_hint, or tf.int64.
338      dtype_hint: Optional dtype for the RowPartition, used when dtype
339        is None. In some cases, a caller may not have a dtype in mind when
340        converting to a tensor, so dtype_hint can be used as a soft preference.
341        If the conversion to `dtype_hint` is not possible, this argument has no
342        effect.
343
344    Returns:
345      A `RowPartition`.
346
347    Raises:
348      ValueError: If `row_splits` is an empty list.
349    """
350    if not isinstance(validate, bool):
351      raise TypeError("validate must have type bool")
352    if isinstance(row_splits, (list, tuple)) and not row_splits:
353      raise ValueError("row_splits tensor may not be empty.")
354    if isinstance(row_splits, tensor_spec.TensorSpec):
355      return cls(row_splits=row_splits, internal=_row_partition_factory_key)
356
357    with ops.name_scope(None, "RowPartitionFromRowSplits", [row_splits]):
358      row_splits = cls._convert_row_partition(
359          row_splits, "row_splits", dtype_hint=dtype_hint, dtype=dtype)
360      row_splits.shape.assert_has_rank(1)
361
362      if validate:
363        msg = "Arguments to from_row_splits do not form a valid RaggedTensor:"
364        checks = [
365            check_ops.assert_rank(row_splits, 1, message=(msg + "rank")),
366            _assert_zero(row_splits[0], message=(msg + "zero")),
367            _assert_monotonic_increasing(
368                row_splits, message=(msg + "monotonic")),
369        ]
370        row_splits = control_flow_ops.with_dependencies(checks, row_splits)
371
372      return cls(row_splits=row_splits, internal=_row_partition_factory_key)
373
374  @classmethod
375  def from_row_lengths(cls,
376                       row_lengths,
377                       validate=True,
378                       dtype=None,
379                       dtype_hint=None):
380    """Creates a `RowPartition` with rows partitioned by `row_lengths`.
381
382    This `RowPartition` divides a sequence `values` into rows by indicating
383    the length of each row:
384
385    ```python
386    partitioned_rows = [[values.pop(0) for _ in range(length)]
387                        for length in row_lengths]
388    ```
389
390    Args:
391      row_lengths: A 1-D integer tensor with shape `[nrows]`.  Must be
392        nonnegative.
393      validate: If true, then use assertions to check that the arguments form a
394        valid `RowPartition`.
395
396      dtype: Optional dtype for the RowPartition. If missing, the type
397        is inferred from the type of `row_lengths`, dtype_hint, or tf.int64.
398      dtype_hint: Optional dtype for the RowPartition, used when dtype
399        is None. In some cases, a caller may not have a dtype in mind when
400        converting to a tensor, so dtype_hint can be used as a soft preference.
401        If the conversion to `dtype_hint` is not possible, this argument has no
402        effect.
403
404    Returns:
405      A `RowPartition`.
406    """
407    if not isinstance(validate, bool):
408      raise TypeError("validate must have type bool")
409    with ops.name_scope(None, "RowPartitionFromRowLengths", [row_lengths]):
410      row_lengths = cls._convert_row_partition(
411          row_lengths, "row_lengths", dtype_hint=dtype_hint, dtype=dtype)
412      row_lengths.shape.assert_has_rank(1)
413
414      if validate:
415        msg = "Arguments to from_row_lengths do not form a valid RowPartition"
416        checks = [
417            check_ops.assert_rank(row_lengths, 1, message=msg),
418            check_ops.assert_non_negative(row_lengths, message=msg),
419        ]
420        row_lengths = control_flow_ops.with_dependencies(checks, row_lengths)
421
422      row_limits = math_ops.cumsum(row_lengths)
423      row_splits = array_ops.concat([[0], row_limits], axis=0)
424      return cls(
425          row_splits=row_splits,
426          row_lengths=row_lengths,
427          internal=_row_partition_factory_key)
428
429  @classmethod
430  def from_row_starts(cls,
431                      row_starts,
432                      nvals,
433                      validate=True,
434                      dtype=None,
435                      dtype_hint=None):
436    """Creates a `RowPartition` with rows partitioned by `row_starts`.
437
438    Equivalent to: `from_row_splits(concat([row_starts, nvals], axis=0))`.
439
440    Args:
441      row_starts: A 1-D integer tensor with shape `[nrows]`.  Must be
442        nonnegative and sorted in ascending order.  If `nrows>0`, then
443        `row_starts[0]` must be zero.
444      nvals: A scalar tensor indicating the number of values.
445      validate: If true, then use assertions to check that the arguments form a
446        valid `RowPartition`.
447      dtype: Optional dtype for the RowPartition. If missing, the type
448        is inferred from the type of `row_starts`, dtype_hint, or tf.int64.
449      dtype_hint: Optional dtype for the RowPartition, used when dtype
450        is None. In some cases, a caller may not have a dtype in mind when
451        converting to a tensor, so dtype_hint can be used as a soft preference.
452        If the conversion to `dtype_hint` is not possible, this argument has no
453        effect.
454
455    Returns:
456      A `RowPartition`.
457    """
458    if not isinstance(validate, bool):
459      raise TypeError("validate must have type bool")
460    with ops.name_scope(None, "RowPartitionFromRowStarts", [row_starts]):
461      row_starts = cls._convert_row_partition(
462          row_starts, "row_starts", dtype_hint=dtype_hint, dtype=dtype)
463      row_starts.shape.assert_has_rank(1)
464      # TODO(martinz): nvals and row_starts could be inconsistent at call time,
465      # even though they eventually end up the same type.
466      nvals = math_ops.cast(nvals, row_starts.dtype)
467      if validate:
468        msg = "Arguments to from_row_starts do not form a valid RaggedTensor"
469        checks = [
470            check_ops.assert_rank(row_starts, 1, message=msg),
471            _assert_zero(row_starts[:1], message=msg),
472            _assert_monotonic_increasing(row_starts, message=msg),
473            check_ops.assert_less_equal(row_starts[-1:], nvals, message=msg),
474        ]
475        row_starts = control_flow_ops.with_dependencies(checks, row_starts)
476
477      row_splits = array_ops.concat([row_starts, [nvals]], axis=0)
478      return cls(row_splits=row_splits, nvals=nvals,
479                 internal=_row_partition_factory_key)
480
481  @classmethod
482  def from_row_limits(cls,
483                      row_limits,
484                      validate=True,
485                      dtype=None,
486                      dtype_hint=None):
487    """Creates a `RowPartition` with rows partitioned by `row_limits`.
488
489    Equivalent to: `from_row_splits(values, concat([0, row_limits], axis=0))`.
490
491    Args:
492      row_limits: A 1-D integer tensor with shape `[nrows]`.  Must be sorted in
493        ascending order.
494      validate: If true, then use assertions to check that the arguments form a
495        valid `RowPartition`.
496      dtype: Optional dtype for the RowPartition. If missing, the type
497        is inferred from the type of `row_limits`, dtype_hint, or tf.int64.
498      dtype_hint: Optional dtype for the RowPartition, used when dtype
499        is None. In some cases, a caller may not have a dtype in mind when
500        converting to a tensor, so dtype_hint can be used as a soft preference.
501        If the conversion to `dtype_hint` is not possible, this argument has no
502        effect.
503
504    Returns:
505      A `RowPartition`.
506    """
507    if not isinstance(validate, bool):
508      raise TypeError("validate must have type bool")
509    with ops.name_scope(None, "RowPartitionFromRowLimits", [row_limits]):
510      row_limits = cls._convert_row_partition(
511          row_limits, "row_limits", dtype_hint=dtype_hint, dtype=dtype)
512      row_limits.shape.assert_has_rank(1)
513
514      if validate:
515        msg = "Arguments to from_row_limits do not form a valid RaggedTensor"
516        checks = [
517            check_ops.assert_rank(row_limits, 1, message=msg),
518            check_ops.assert_non_negative(row_limits[:1], message=msg),
519            _assert_monotonic_increasing(row_limits, message=msg),
520        ]
521        row_limits = control_flow_ops.with_dependencies(checks, row_limits)
522
523      zero = array_ops.zeros([1], row_limits.dtype)
524      row_splits = array_ops.concat([zero, row_limits], axis=0)
525      return cls(row_splits=row_splits, internal=_row_partition_factory_key)
526
527  @classmethod
528  def from_uniform_row_length(cls,
529                              uniform_row_length,
530                              nvals=None,
531                              nrows=None,
532                              validate=True,
533                              dtype=None,
534                              dtype_hint=None):
535    """Creates a `RowPartition` with rows partitioned by `uniform_row_length`.
536
537    This `RowPartition` divides a sequence `values` into rows that all have
538    the same length:
539
540    ```python
541    partitioned_rows = [[values.pop(0) for _ in range(uniform_row_length)]
542             for _ in range(nrows)]
543    ```
544
545    Note that either or both of nvals and nrows must be specified.
546
547    Args:
548      uniform_row_length: A scalar integer tensor.  Must be nonnegative. The
549        size of the outer axis of `values` must be evenly divisible by
550        `uniform_row_length`.
551      nvals: a non-negative scalar integer tensor for the number of values.
552        Must be specified if nrows is not specified. If not specified,
553        defaults to uniform_row_length*nrows
554      nrows: The number of rows in the constructed RowPartition.  If not
555        specified, then it defaults to `nvals/uniform_row_length` (or `0` if
556        `uniform_row_length==0`).  `nrows` only needs to be specified if
557        `uniform_row_length` might be zero.  `uniform_row_length*nrows` must be
558        `nvals`.
559      validate: If true, then use assertions to check that the arguments form a
560        valid `RowPartition`.
561      dtype: Optional dtype for the RowPartition. If missing, the type
562        is inferred from the type of `uniform_row_length`, dtype_hint,
563        or tf.int64.
564      dtype_hint: Optional dtype for the RowPartition, used when dtype
565        is None. In some cases, a caller may not have a dtype in mind when
566        converting to a tensor, so dtype_hint can be used as a soft preference.
567        If the conversion to `dtype_hint` is not possible, this argument has no
568        effect.
569
570    Returns:
571      A `RowPartition`.
572    """
573    if not isinstance(validate, bool):
574      raise TypeError("validate must have type bool")
575    if nrows is None and nvals is None:
576      raise ValueError("Either (or both) of nvals and nrows must be specified")
577    with ops.name_scope(None, "RowPartitionFromUniformRowLength",
578                        [uniform_row_length, nrows]):
579      [uniform_row_length, nvals, nrows
580      ] = _convert_all_to_tensors([(uniform_row_length, "uniform_row_length"),
581                                   (nvals, "nvals"), (nrows, "nrows")],
582                                  dtype=dtype,
583                                  dtype_hint=dtype_hint)
584
585      uniform_row_length.shape.assert_has_rank(0)
586
587      # Find nrows.
588      const_row_length = tensor_util.constant_value(uniform_row_length)
589      if nrows is None:
590        if const_row_length is None:
591          # Avoid division by zero if uniform_row_length==0 (and nvals==0).
592          rowlen_or_1 = math_ops.maximum(
593              uniform_row_length,
594              constant_op.constant(1, uniform_row_length.dtype))
595          nrows = nvals // rowlen_or_1
596        elif const_row_length == 0:
597          nrows = constant_op.constant(0, dtype=uniform_row_length.dtype)
598        else:
599          nrows = nvals // const_row_length
600      const_nrows = None if nrows is None else tensor_util.constant_value(nrows)
601      const_nvals = None if nvals is None else tensor_util.constant_value(nvals)
602      const_uniform_row_length = tensor_util.constant_value(uniform_row_length)
603
604      checks = []
605
606      if const_nvals is None and const_nrows is not None and const_uniform_row_length is not None:
607        const_nvals = const_nrows * const_uniform_row_length
608        if nvals is not None and validate:
609          checks.append(check_ops.assert_equal(nvals, const_nvals))
610        nvals = constant_op.constant(const_nvals, uniform_row_length.dtype)
611
612      if nvals is None:
613        nvals = nrows * uniform_row_length
614
615      # Find row_splits.
616      if const_nrows is not None and const_row_length is not None:
617        row_splits = [v * const_row_length for v in range(const_nrows + 1)]
618        row_splits = constant_op.constant(row_splits, uniform_row_length.dtype)
619      else:
620        row_splits = math_ops.range(
621            nrows + 1, dtype=uniform_row_length.dtype) * uniform_row_length
622
623      if validate:
624
625        if (const_nrows is None or const_row_length is None or
626            const_nvals is None):
627          checks.append(
628              check_ops.assert_equal(
629                  nrows * uniform_row_length, nvals,
630                  ("uniform_row_length", uniform_row_length, "times nrows",
631                   nrows, "must equal nvals", nvals)))
632        else:
633          if const_nrows * const_row_length != const_nvals:
634            raise ValueError(
635                "uniform_row_length=%d times nrows=%d must equal nvals=%d" %
636                (const_row_length, const_nrows, const_nvals))
637
638        if uniform_row_length.shape.rank is None:
639          checks.append(
640              check_ops.assert_rank(
641                  uniform_row_length,
642                  0,
643                  message="uniform_row_length must be a scalar."))
644
645        const_row_length = tensor_util.constant_value(uniform_row_length)
646        if const_row_length is None:
647          checks.append(
648              check_ops.assert_greater_equal(
649                  uniform_row_length,
650                  constant_op.constant(0, uniform_row_length.dtype),
651                  message="uniform_row_length must be >= 0."))
652        else:
653          if const_row_length < 0:
654            raise ValueError("uniform_row_length must be >= 0.")
655
656        row_splits = control_flow_ops.with_dependencies(checks, row_splits)
657
658      return cls(
659          row_splits=row_splits,
660          uniform_row_length=uniform_row_length,
661          nrows=nrows,
662          nvals=nvals,
663          internal=_row_partition_factory_key)
664
665  @classmethod
666  def _convert_row_partition(cls, partition, name, dtype=None, dtype_hint=None):
667    """Converts `partition` to Tensors.
668
669    Args:
670      partition: A row-partitioning tensor for the `RowPartition` being
671        constructed.  I.e., one of: row_splits, row_lengths, row_starts,
672        row_limits, value_rowids, uniform_row_length.
673      name: The name of the row-partitioning tensor.
674      dtype: Optional dtype for the RowPartition. If missing, the type
675        is inferred from the type of `uniform_row_length`, dtype_hint,
676        or tf.int64.
677      dtype_hint: Optional dtype for the RowPartition, used when dtype
678        is None. In some cases, a caller may not have a dtype in mind when
679        converting to a tensor, so dtype_hint can be used as a soft preference.
680        If the conversion to `dtype_hint` is not possible, this argument has no
681        effect.
682
683    Returns:
684      A tensor equivalent to partition.
685
686    Raises:
687      ValueError: if dtype is not int32 or int64.
688    """
689    if dtype_hint is None:
690      dtype_hint = dtypes.int64
691    if (isinstance(partition, np.ndarray) and
692        partition.dtype == np.int32 and dtype is None):
693      partition = ops.convert_to_tensor(partition, name=name)
694    else:
695      partition = ops.convert_to_tensor_v2(
696          partition, dtype_hint=dtype_hint, dtype=dtype, name=name)
697    if partition.dtype not in (dtypes.int32, dtypes.int64):
698      raise ValueError("%s must have dtype int32 or int64" % name)
699
700    return partition
701
702  def _with_dependencies(self, dependencies):
703    """Returns a new RowPartition equal to self with control dependencies.
704
705    Specifically, self._row_splits is gated by the given control dependencies.
706    Used to add sanity checks to the constructors.
707
708    Args:
709      dependencies: a list of tensors to use as dependencies.
710
711    Returns:
712      A new RowPartition object.
713    """
714    new_row_splits = control_flow_ops.with_dependencies(dependencies,
715                                                        self._row_splits)
716    return RowPartition(
717        row_splits=new_row_splits,
718        row_lengths=self._row_lengths,
719        value_rowids=self._value_rowids,
720        nrows=self._nrows,
721        uniform_row_length=self._uniform_row_length,
722        internal=_row_partition_factory_key)
723
724  #=============================================================================
725  # Accessors
726  #=============================================================================
727
728  @property
729  def dtype(self):
730    """The `DType` used to encode the row partition (either int32 or int64)."""
731    return self._row_splits.dtype
732
733  def row_splits(self):
734    """Returns the row-split indices for this row partition.
735
736    `row_splits` specifies where the values for each row begin and end.
737    In particular, the values for row `i` are stored in the slice
738    `values[row_splits[i]:row_splits[i+1]]`.
739
740    Returns:
741      A 1-D integer `Tensor` with shape `[self.nrows+1]`.
742      The returned tensor is non-empty, and is sorted in ascending order.
743      `self.row_splits()[0] == 0`.
744      `self.row_splits()[-1] == self.nvals()`.
745    """
746    return self._row_splits
747
748  def value_rowids(self):
749    """Returns the row indices for this row partition.
750
751    `value_rowids` specifies the row index fo reach value.  In particular,
752    `value_rowids[i]` is the row index for `values[i]`.
753
754    Returns:
755      A 1-D integer `Tensor` with shape `[self.nvals()]`.
756      The returned tensor is nonnegative, and is sorted in ascending order.
757    """
758    if self._value_rowids is not None:
759      return self._value_rowids
760    return segment_id_ops.row_splits_to_segment_ids(self._row_splits)
761
762  def nvals(self):
763    """Returns the number of values partitioned by this `RowPartition`.
764
765    If the sequence partitioned by this `RowPartition` is a tensor, then
766    `nvals` is the size of that tensor's outermost dimension -- i.e.,
767    `nvals == values.shape[0]`.
768
769    Returns:
770      scalar integer Tensor
771    """
772    # TODO(martinz): Uncomment these lines.
773    # if self._nvals is not None:
774    #   return self._nvals
775    return self._row_splits[-1]
776
777  def nrows(self):
778    """Returns the number of rows created by this `RowPartition`.
779
780    Returns:
781      scalar integer Tensor
782    """
783    if self._nrows is not None:
784      return self._nrows
785    nsplits = tensor_shape.dimension_at_index(self._row_splits.shape, 0)
786    if nsplits.value is None:
787      return array_ops.shape(self._row_splits, out_type=self.dtype)[0] - 1
788    else:
789      return constant_op.constant(nsplits.value - 1, dtype=self.dtype)
790
791  def uniform_row_length(self):
792    """Returns the length of each row in this partition, if rows are uniform.
793
794    If all rows in this `RowPartition` have the same length, then this returns
795    that length as a scalar integer `Tensor`.  Otherwise, it returns `None`.
796
797    Returns:
798      scalar Tensor with `type=self.dtype`, or `None`.
799    """
800    return self._uniform_row_length
801
802  def row_starts(self):
803    """Returns the start indices for rows in this row partition.
804
805    These indices specify where the values for each row begin.
806    `partition.row_starts()` is equal to `partition.row_splits()[:-1]`.
807
808    Returns:
809      A 1-D integer Tensor with shape `[self.nrows()]`.
810      The returned tensor is nonnegative, and is sorted in ascending order.
811      `self.row_starts()[0] == 0`.
812      `self.row_starts()[-1] <= self.nvals()`.
813    """
814    return self._row_splits[:-1]
815
816  def row_limits(self):
817    """Returns the limit indices for rows in this row partition.
818
819    These indices specify where the values for each row end.
820    `partition.row_limits()` is equal to `partition.row_splits()[:-1]`.
821
822    Returns:
823      A 1-D integer Tensor with shape `[self.nrows]`.
824      The returned tensor is nonnegative, and is sorted in ascending order.
825      `self.row_limits()[-1] == self.nvals()`.
826    """
827    return self._row_splits[1:]
828
829  def row_lengths(self):
830    """Returns the lengths of rows in this `RowPartition`.
831
832    Returns:
833      A 1-D integer Tensor with shape `[self.nrows]`.
834      The returned tensor is nonnegative.
835      `tf.reduce_sum(self.row_lengths) == self.nvals()`.
836    """
837    if self._row_lengths is not None:
838      return self._row_lengths
839    splits = self._row_splits
840    return splits[1:] - splits[:-1]
841
842  @property
843  def static_nrows(self):
844    """The number of rows in this partition, if statically known.
845
846    ```python
847    self.row_lengths().shape == [self.static_nrows]
848    self.row_starts().shape == [self.static_nrows]
849    self.row_limits().shape == [self.static_nrows]
850    self.row_splits().shape == [self.static_nrows + 1]
851    ```
852
853    Returns:
854      The number of rows in this partition as an `int` (if statically known);
855      or `None` (otherwise).
856    """
857    if self._row_splits is not None:
858      nrows_plus_one = tensor_shape.dimension_value(self._row_splits.shape[0])
859      if nrows_plus_one is not None:
860        return nrows_plus_one - 1
861    if self._row_lengths is not None:
862      nrows = tensor_shape.dimension_value(self._row_lengths.shape[0])
863      if nrows is not None:
864        return nrows
865    if self._nrows is not None:
866      return tensor_util.constant_value(self._nrows)
867    return None
868
869  @property
870  def static_nvals(self):
871    """The number of values in this partition, if statically known.
872
873    ```python
874    self.value_rowids().shape == [self.static_vals]
875    ```
876
877    Returns:
878      The number of values in this partition as an `int` (if statically known);
879      or `None` (otherwise).
880    """
881    if self._nvals is not None:
882      nvals = tensor_util.constant_value(self._nvals)
883      if nvals is not None:
884        return nvals
885    if self._value_rowids is not None:
886      nvals = tensor_shape.dimension_at_index(self._value_rowids.shape, 0)
887      if nvals.value is not None:
888        return nvals.value
889    return None
890
891  @property
892  def static_uniform_row_length(self):
893    """The number of values in each row of this partition, if statically known.
894
895    Returns:
896      The number of values in each row of this partition as an `int` (if
897      statically known); or `None` (otherwise).
898    """
899    if self._uniform_row_length is not None:
900      return tensor_util.constant_value(self._uniform_row_length)
901    return None
902
903  def offsets_in_rows(self):
904    """Return the offset of each value.
905
906    RowPartition takes an array x and converts it into sublists.
907    offsets[i] is the index of x[i] in its sublist.
908    Given a shape, such as:
909    [*,*,*],[*,*],[],[*,*]
910    This returns:
911    0,1,2,0,1,0,1
912
913    Returns:
914      an offset for every value.
915    """
916    return gen_ragged_math_ops.ragged_range(
917        starts=constant_op.constant(0, self.dtype),
918        limits=self.row_lengths(),
919        deltas=constant_op.constant(1, self.dtype)).rt_dense_values
920
921  def is_uniform(self):
922    """Returns true if the partition is known to be uniform statically.
923
924    This is based upon the existence of self._uniform_row_length. For example:
925    RowPartition.from_row_lengths([3,3,3]).is_uniform()==false
926    RowPartition.from_uniform_row_length(5, nvals=20).is_uniform()==true
927    RowPartition.from_row_lengths([2,0,2]).is_uniform()==false
928
929    Returns:
930      Whether a RowPartition is known to be uniform statically.
931    """
932    return self._uniform_row_length is not None
933
934  def _static_check(self):
935    """Checks if the object is internally consistent.
936
937    Raises:
938      ValueError if inconsistent.
939    """
940    my_dtype = self.dtype
941    if self._uniform_row_length is not None:
942      if self._uniform_row_length.dtype != my_dtype:
943        raise ValueError("_uniform_row_length.dtype=" +
944                         str(self._uniform_row_length.dtype) + ", not " +
945                         str(my_dtype))
946
947    if self._row_lengths is not None and self._row_lengths.dtype != my_dtype:
948      raise ValueError("_row_lengths.dtype=" + str(self._row_lengths.dtype) +
949                       ", not " + str(my_dtype))
950
951    if self._value_rowids is not None and self._value_rowids.dtype != my_dtype:
952      raise ValueError("_value_rowids.dtype=" + str(self._value_rowids.dtype) +
953                       ", not " + str(my_dtype))
954
955    if self._nrows is not None and self._nrows.dtype != my_dtype:
956      raise ValueError("_nrows.dtype=" + str(self._nrows.dtype) + ", not " +
957                       str(my_dtype))
958
959  #=============================================================================
960  # Transformation
961  #=============================================================================
962
963  def with_dtype(self, dtype):
964    """Returns a copy of this RowPartition with the given encoding dtype.
965
966    Args:
967      dtype: The dtype for encoding tensors, such as `row_splits` and `nrows`.
968      One of `tf.int32` or `tf.int64`.
969
970    Returns:
971      A copy of this RowPartition, with the encoding tensors cast to the given
972      type.
973    """
974    dtype = dtypes.as_dtype(dtype)
975    if dtype not in (dtypes.int32, dtypes.int64):
976      raise ValueError("dtype must be int32 or int64")
977    if self.dtype == dtype:
978      return self
979
980    return RowPartition(
981        row_splits=_cast_if_not_none(self._row_splits, dtype),
982        row_lengths=_cast_if_not_none(self._row_lengths, dtype),
983        value_rowids=_cast_if_not_none(self._value_rowids, dtype),
984        nrows=_cast_if_not_none(self._nrows, dtype),
985        uniform_row_length=_cast_if_not_none(self._uniform_row_length, dtype),
986        internal=_row_partition_factory_key)
987
988  #=============================================================================
989  # String Encoding
990  #=============================================================================
991
992  def __repr__(self):
993    if self._uniform_row_length is not None:
994      return (f"tf.RowPartition(nrows={self._nrows}, "
995              f"uniform_row_length={self._uniform_row_length})")
996    else:
997      return f"tf.RowPartition(row_splits={self._row_splits})"
998
999  #=============================================================================
1000  # Precomputed Encodings
1001  #=============================================================================
1002
1003  def _has_precomputed_row_splits(self):
1004    """Returns true if `row_splits` has already been computed.
1005
1006    If true, then `self.row_splits()` will return its value without calling
1007    any TensorFlow ops.
1008    """
1009    return self._row_splits is not None
1010
1011  def _has_precomputed_row_lengths(self):
1012    """Returns true if `row_lengths` has already been computed.
1013
1014    If true, then `self.row_lengths()` will return its value without calling
1015    any TensorFlow ops.
1016    """
1017    return self._row_lengths is not None
1018
1019  def _has_precomputed_value_rowids(self):
1020    """Returns true if `value_rowids` has already been computed.
1021
1022    If true, then `self.value_rowids()` will return its value without calling
1023    any TensorFlow ops.
1024    """
1025    return self._value_rowids is not None
1026
1027  def _has_precomputed_nrows(self):
1028    """Returns true if `nrows` has already been computed.
1029
1030    If true, then `self.nrows()` will return its value without calling
1031    any TensorFlow ops.
1032    """
1033    return self._nrows is not None
1034
1035  def _has_precomputed_nvals(self):
1036    """Returns true if `nvals` has already been computed.
1037
1038    If true, then `self.nvals()` will return its value without calling
1039    any TensorFlow ops.
1040    """
1041    return self._nvals is not None
1042
1043  def _with_precomputed_row_splits(self):
1044    """Returns a copy of `self` with `row_splits` precomputed."""
1045    return RowPartition(
1046        row_splits=self.row_splits(),
1047        row_lengths=self._row_lengths,
1048        value_rowids=self._value_rowids,
1049        nrows=self._nrows,
1050        uniform_row_length=self._uniform_row_length,
1051        nvals=self._nvals,
1052        internal=_row_partition_factory_key)
1053
1054  def _with_precomputed_row_lengths(self):
1055    """Returns a copy of `self` with `row_lengths` precomputed."""
1056    return RowPartition(
1057        row_splits=self._row_splits,
1058        row_lengths=self.row_lengths(),
1059        value_rowids=self._value_rowids,
1060        nrows=self._nrows,
1061        nvals=self._nvals,
1062        uniform_row_length=self._uniform_row_length,
1063        internal=_row_partition_factory_key)
1064
1065  def _with_precomputed_value_rowids(self):
1066    """Returns a copy of `self` with `value_rowids` precomputed."""
1067    return RowPartition(
1068        row_splits=self._row_splits,
1069        row_lengths=self._row_lengths,
1070        value_rowids=self.value_rowids(),
1071        nrows=self._nrows,
1072        nvals=self._nvals,
1073        uniform_row_length=self._uniform_row_length,
1074        internal=_row_partition_factory_key)
1075
1076  def _with_precomputed_nrows(self):
1077    """Returns a copy of `self` with `nrows` precomputed."""
1078    return RowPartition(
1079        row_splits=self._row_splits,
1080        row_lengths=self._row_lengths,
1081        value_rowids=self._value_rowids,
1082        nrows=self.nrows(),
1083        nvals=self._nvals,
1084        uniform_row_length=self._uniform_row_length,
1085        internal=_row_partition_factory_key)
1086
1087  def _with_precomputed_nvals(self):
1088    """Returns a copy of `self` with `row_splits` precomputed."""
1089    return RowPartition(
1090        row_splits=self.row_splits(),
1091        row_lengths=self._row_lengths,
1092        value_rowids=self._value_rowids,
1093        nrows=self._nrows,
1094        nvals=self.nvals(),
1095        uniform_row_length=self._uniform_row_length,
1096        internal=_row_partition_factory_key)
1097
1098  def _merge_with_spec(self, b):
1099    """Merge with a TypeSpec to create a new RowPartition."""
1100    a_spec = self._type_spec
1101    if not a_spec.is_compatible_with(b):
1102      # TODO(martinz): Should a dynamic check be used here?
1103      raise ValueError("RowPartition and RowPartitionSpec are not compatible")
1104    nrows = constant_op.constant(
1105        b.nrows, self.dtype) if b.nrows is not None else self._nrows
1106    nvals = constant_op.constant(
1107        b.nvals, self.dtype) if b.nvals is not None else self._nvals
1108    uniform_row_length = constant_op.constant(
1109        b.uniform_row_length, self.dtype
1110    ) if b.uniform_row_length is not None else self._uniform_row_length
1111    return RowPartition(
1112        row_splits=self._row_splits,
1113        row_lengths=self._row_lengths,
1114        value_rowids=self._value_rowids,
1115        nvals=nvals,
1116        uniform_row_length=uniform_row_length,
1117        nrows=nrows,
1118        internal=_row_partition_factory_key)
1119
1120  def _merge_precomputed_encodings(self, other, validate=True):
1121    """Returns a RowPartition that merges encodings from `self` and `other`.
1122
1123    Requires that `self` and `other` describe the same partition.
1124
1125    Args:
1126      other: A `RowPartition` that encodes the same partition as `self`.
1127      validate: If true, then add runtime checks to verify that `self` and
1128        `other` encode the same row partition.
1129
1130    Returns:
1131      A `RowPartition`.
1132    """
1133    # pylint: disable=protected-access
1134    if (self is other or  # Fast path if row partitions are equal.
1135        (self._row_splits is other._row_splits and
1136         self._row_lengths is other._row_lengths and
1137         self._value_rowids is other._value_rowids and
1138         self._nrows is other._nrows and
1139         self._nvals is other._nvals and
1140         self._uniform_row_length is other._uniform_row_length)):
1141      return self
1142
1143    # Merge the component tensors.  We only need to validate one encoding.
1144    # We merge less-expensive encodings first (to avoid expensive validation).
1145    nrows, nrows_validated = _merge_tensors(self._nrows, other._nrows, "nrows",
1146                                            validate)
1147    nvals, _ = _merge_tensors(self._nvals, other._nvals, "nvals", validate)
1148    uniform_row_length, uniform_row_length_validated = _merge_tensors(
1149        self._uniform_row_length, other._uniform_row_length,
1150        "uniform_row_length", validate)
1151    if uniform_row_length_validated and nrows_validated:
1152      validate = False  # Validation complete.
1153    row_splits, row_splits_validated = _merge_tensors(self._row_splits,
1154                                                      other._row_splits,
1155                                                      "row_splits", validate)
1156    if row_splits_validated:
1157      validate = False  # Validation complete.
1158    row_lengths, row_lengths_validated = _merge_tensors(self._row_lengths,
1159                                                        other._row_lengths,
1160                                                        "row_lengths", validate)
1161    if row_lengths_validated:
1162      validate = False  # Validation complete.
1163    value_rowids, value_rowids_validated = _merge_tensors(
1164        self._value_rowids, other._value_rowids, "value_rowids", validate)
1165    if value_rowids_validated and nrows_validated:
1166      validate = False  # Validation complete.
1167    # TODO(edloper): If we make the row_splits encoding optional, then there
1168    # will be cases where we need to do validation at this point -- e.g. if
1169    # self has only row_splits and other has only value_rowids.  But for
1170    # now, we are guaranteed to have done validation by this point.
1171
1172    # Avoid creating new RowPartition objects if we don't need to.
1173    if (row_splits is self._row_splits and row_lengths is self._row_lengths and
1174        value_rowids is self._value_rowids and nrows is self._nrows and
1175        uniform_row_length is self._uniform_row_length):
1176      return self
1177    if (row_splits is other._row_splits and
1178        row_lengths is other._row_lengths and
1179        value_rowids is other._value_rowids and nrows is other._nrows and
1180        uniform_row_length is other._uniform_row_length):
1181      return other
1182
1183    return RowPartition(
1184        row_splits=row_splits,
1185        row_lengths=row_lengths,
1186        value_rowids=value_rowids,
1187        nrows=nrows,
1188        uniform_row_length=uniform_row_length,
1189        nvals=nvals,
1190        internal=_row_partition_factory_key)
1191
1192  #=============================================================================
1193  # Composite Tensor
1194  #=============================================================================
1195
1196  @property
1197  def _type_spec(self):
1198    return RowPartitionSpec.from_value(self)
1199
1200
1201#===============================================================================
1202# RowPartitionSpec
1203#===============================================================================
1204# TODO(edloper): Consider refactoring RowPartitionSpec to allow any combination
1205# of precomputed row-partition encodings (rather than always using row_splits).
1206
1207
1208@type_spec.register("tf.RowPartitionSpec")
1209class RowPartitionSpec(type_spec.TypeSpec):
1210  """Type specification for a `tf.RowPartition`."""
1211
1212  __slots__ = ["_nrows", "_nvals", "_uniform_row_length", "_dtype"]
1213
1214  value_type = property(lambda self: RowPartition)
1215
1216  def __init__(self,
1217               nrows=None,
1218               nvals=None,
1219               uniform_row_length=None,
1220               dtype=dtypes.int64):
1221    """Constructs a new RowPartitionSpec.
1222
1223    Args:
1224      nrows: The number of rows in the RowPartition, or `None` if unspecified.
1225      nvals: The number of values partitioned by the RowPartition, or `None` if
1226        unspecified.
1227      uniform_row_length: The number of values in each row for this
1228        RowPartition, or `None` if rows are ragged or row length is unspecified.
1229      dtype: The data type used to encode the partition.  One of `tf.int64` or
1230        `tf.int32`.
1231    """
1232    # Wrap dimension sizes in 1D TensorShapes so the default implementations
1233    # of TypeSpec methods such as `is_compatile_with` will work.
1234    nrows = tensor_shape.TensorShape([nrows])
1235    nvals = tensor_shape.TensorShape([nvals])
1236    if not isinstance(uniform_row_length, tensor_shape.TensorShape):
1237      uniform_row_length = tensor_shape.TensorShape([uniform_row_length])
1238    else:
1239      uniform_row_length = uniform_row_length.with_rank(1)
1240
1241    self._nrows = nrows
1242    self._nvals = nvals
1243    self._uniform_row_length = uniform_row_length
1244    self._dtype = dtypes.as_dtype(dtype)
1245    if self._dtype not in (dtypes.int32, dtypes.int64):
1246      raise ValueError("dtype must be tf.int32 or tf.int64")
1247
1248    # Check dimension consistency, & infer dimensions when possible.
1249    nrows = tensor_shape.dimension_value(nrows[0])
1250    nvals = tensor_shape.dimension_value(nvals[0])
1251    ncols = tensor_shape.dimension_value(uniform_row_length[0])
1252    if nrows == 0:  # no rows -> no values.
1253      if nvals is None:
1254        self._nvals = tensor_shape.TensorShape([0])
1255      elif nvals != 0:
1256        raise ValueError("nvals=%s is not compatible with nrows=%s" %
1257                         (nvals, nrows))
1258    if ncols == 0:  # there are no values in each row -> no values.
1259      if nvals is None:
1260        self._nvals = tensor_shape.TensorShape([0])
1261      elif nvals != 0:
1262        raise ValueError("nvals=%s is not compatible with uniform_row_length"
1263                         "=%s" % (nvals, uniform_row_length))
1264    if ncols is not None and nvals is not None:
1265      if ncols != 0 and nvals % ncols != 0:
1266        raise ValueError("nvals=%s is not compatible with uniform_row_length"
1267                         "=%s (doesn't divide evenly)" % (nvals, ncols))
1268      if nrows is not None and nvals != ncols * nrows:
1269        raise ValueError("nvals=%s is not compatible with nrows=%s and "
1270                         "uniform_row_length=%s" % (nvals, nrows, ncols))
1271      if nrows is None and ncols != 0:
1272        self._nrows = tensor_shape.TensorShape([nvals // ncols])
1273    if ncols is not None and nrows is not None and nvals is None:
1274      self._nvals = tensor_shape.TensorShape([ncols * nrows])
1275
1276  def is_compatible_with(self, other):
1277    if not super(RowPartitionSpec, self).is_compatible_with(other):
1278      return False
1279    nrows = self._nrows.merge_with(other.nrows)
1280    nvals = self._nvals.merge_with(other.nvals)
1281    ncols = self._uniform_row_length.merge_with(other.uniform_row_length)
1282    return self._dimensions_compatible(nrows, nvals, ncols)
1283
1284  def _serialize(self):
1285    return (self._nrows, self._nvals, self._uniform_row_length, self._dtype)
1286
1287  @classmethod
1288  def _deserialize(cls, serialization):
1289    # Remove TensorShape wrappers from serialization.
1290    (nrows, nvals, uniform_row_length, dtype) = serialization
1291    nrows = tensor_shape.dimension_value(nrows[0])
1292    nvals = tensor_shape.dimension_value(nvals[0])
1293    return cls(nrows, nvals, uniform_row_length, dtype)
1294
1295  @property
1296  def nrows(self):
1297    return tensor_shape.dimension_value(self._nrows[0])
1298
1299  @property
1300  def nvals(self):
1301    return tensor_shape.dimension_value(self._nvals[0])
1302
1303  @property
1304  def uniform_row_length(self):
1305    return tensor_shape.dimension_value(self._uniform_row_length[0])
1306
1307  @property
1308  def dtype(self):
1309    return self._dtype
1310
1311  @property
1312  def _component_specs(self):
1313    row_splits_shape = tensor_shape.TensorShape(
1314        [tensor_shape.dimension_at_index(self._nrows, 0) + 1])
1315    return tensor_spec.TensorSpec(row_splits_shape, self._dtype)
1316
1317  def _to_components(self, value):
1318    return value.row_splits()
1319
1320  def _from_components(self, tensor):
1321    return RowPartition.from_row_splits(tensor, validate=False)
1322
1323  @classmethod
1324  def from_value(cls, value):
1325    if not isinstance(value, RowPartition):
1326      raise TypeError("Expected `value` to be a `RowPartition`")
1327    return cls(value.static_nrows, value.static_nvals,
1328               value.static_uniform_row_length, value.dtype)
1329
1330  def __repr__(self):
1331    return ("RowPartitionSpec(nrows=%s, nvals=%s, uniform_row_length=%s, "
1332            "dtype=%r)" % (self.nrows, self.nvals, self.uniform_row_length,
1333                           self.dtype))
1334
1335  @staticmethod
1336  def _dimensions_compatible(nrows, nvals, uniform_row_length):
1337    """Returns true if the given dimensions are compatible."""
1338    nrows = tensor_shape.dimension_value(nrows[0])
1339    nvals = tensor_shape.dimension_value(nvals[0])
1340    ncols = tensor_shape.dimension_value(uniform_row_length[0])
1341    if nrows == 0 and nvals not in (0, None):
1342      return False  # can't have values if we have no rows.
1343    if ncols == 0 and nvals not in (0, None):
1344      return False  # can't have values if we have no values in each row.
1345    if ncols is not None and nvals is not None:
1346      if ncols != 0 and nvals % ncols != 0:
1347        return False  # rows aren't uniform.
1348      if nrows is not None and nvals != ncols * nrows:
1349        return False  # inconsistent number of values.
1350    return True
1351
1352  def _merge_with(self, other):
1353    """Merge two RowPartitionSpecs."""
1354    nrows = self._nrows.merge_with(other.nrows)
1355    nvals = self._nvals.merge_with(other.nvals)
1356    ncols = self._uniform_row_length.merge_with(other.uniform_row_length)
1357
1358    if not RowPartitionSpec._dimensions_compatible(nrows, nvals, ncols):
1359      raise ValueError("Merging incompatible RowPartitionSpecs")
1360
1361    # NOTE: if the dtypes are unequal, behavior is unspecified.
1362    if self.dtype != other.dtype:
1363      raise ValueError("Merging RowPartitionSpecs with incompatible dtypes")
1364
1365    return RowPartitionSpec(nrows=nrows[0],
1366                            nvals=nvals[0],
1367                            uniform_row_length=ncols[0],
1368                            dtype=self.dtype)
1369
1370  def with_dtype(self, dtype):
1371    nrows = tensor_shape.dimension_value(self._nrows[0])
1372    nvals = tensor_shape.dimension_value(self._nvals[0])
1373    return RowPartitionSpec(nrows, nvals, self._uniform_row_length, dtype)
1374
1375  def __deepcopy__(self, memo):
1376    del memo
1377    dtype = self.dtype
1378    nrows = tensor_shape.dimension_value(self._nrows[0])
1379    nvals = tensor_shape.dimension_value(self._nvals[0])
1380    uniform_row_length = (None if self._uniform_row_length is None else
1381                          tensor_shape.dimension_value(
1382                              self._uniform_row_length[0]))
1383    return RowPartitionSpec(nrows, nvals, uniform_row_length, dtype)
1384
1385
1386#===============================================================================
1387# Helper Functions
1388#===============================================================================
1389
1390
1391def _assert_monotonic_increasing(tensor, message=None):
1392  return check_ops.assert_non_negative(
1393      tensor[1:] - tensor[:-1], message=message)
1394
1395
1396def _assert_zero(tensor, message=None):
1397  return check_ops.assert_equal(
1398      tensor, constant_op.constant(0, dtype=tensor.dtype), message=message)
1399
1400
1401def _cast_if_not_none(tensor, dtype):
1402  return None if tensor is None else math_ops.cast(tensor, dtype)
1403
1404
1405def _merge_tensors(t1, t2, name, validate):
1406  """Merge two optional Tensors with equal values into a single Tensor.
1407
1408  Args:
1409    t1: tf.Tensor or None
1410    t2: tf.Tensor or None
1411    name: A name for the tensors (for error messages)
1412    validate: If true, then check that `t1` is compatible with `t2` (if both are
1413      non-None).
1414
1415  Returns:
1416    A pair `(merged_value, validated)`:
1417      * `merged_value` is `t1` if it is not None; or `t2` otherwise.
1418      * `validated` is true if we validated that t1 and t2 are equal (either
1419        by adding a check, or because t1 is t2).
1420  """
1421  if t1 is None:
1422    return t2, False
1423  elif t2 is None:
1424    return t1, False
1425  elif t1 is t2:
1426    return t1, True
1427  else:
1428    err_msg = ("RowPartition._merge_precomputed_encodings: partitions "
1429               "have incompatible %s" % name)
1430    if not t1.shape.is_compatible_with(t2.shape):
1431      raise ValueError(err_msg)
1432    if validate:
1433      checks = [check_ops.assert_equal(t1, t2, message=err_msg)]
1434      return control_flow_ops.with_dependencies(checks, t1), True
1435    else:
1436      return t1, False
1437
1438_row_partition_factory_key = object()  # unique private object
1439
1440
1441def _get_dtype_or_none(value):
1442  if isinstance(value, ops.Tensor):
1443    return value.dtype
1444  return None
1445
1446
1447def _get_target_dtype(values, dtype=None, dtype_hint=None):
1448  """Gets the target dtype of a family of values."""
1449  if dtype is not None:
1450    return dtype
1451
1452  for value in values:
1453    if isinstance(value, ops.Tensor):
1454      return value.dtype
1455
1456  for value in values:
1457    if isinstance(value, np.ndarray):
1458      return dtypes.as_dtype(value.dtype)
1459
1460  if dtype_hint is not None:
1461    return dtype_hint
1462
1463  return dtypes.int64
1464
1465
1466def _convert_all_to_tensors(values, dtype=None, dtype_hint=None):
1467  """Convert a list of objects to tensors of the same dtype."""
1468  target_dtype = _get_target_dtype([x for (x, _) in values], dtype, dtype_hint)
1469
1470  # If dtype is None, we use convert behavior.
1471  # If dtype is not None, we use cast behavior.
1472  convert_behavior = dtype is None
1473
1474  if convert_behavior:
1475    return [
1476        None if x is None else ops.convert_to_tensor(
1477            x, dtype=target_dtype, name=name) for (x, name) in values
1478    ]
1479  else:
1480    return [
1481        None if x is None else math_ops.cast(x, dtype=target_dtype, name=name)
1482        for (x, name) in values
1483    ]
1484