• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Ops for converting between row_splits and segment_ids."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import tensor_shape
24from tensorflow.python.framework import tensor_util
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops.ragged import ragged_util
28from tensorflow.python.util.tf_export import tf_export
29
30
31# For background on "segments" and "segment ids", see:
32# https://www.tensorflow.org/api_docs/python/tf/math#Segmentation
33@tf_export("ragged.row_splits_to_segment_ids")
34def row_splits_to_segment_ids(splits, name=None):
35  """Generates the segmentation corresponding to a RaggedTensor `row_splits`.
36
37  Returns an integer vector `segment_ids`, where `segment_ids[i] == j` if
38  `splits[j] <= i < splits[j+1]`.  Example:
39
40  ```python
41  >>> ragged.row_splits_to_segment_ids([0, 3, 3, 5, 6, 9]).eval()
42  [ 0 0 0 2 2 3 4 4 4 ]
43  ```
44
45  Args:
46    splits: A sorted 1-D int64 Tensor.  `splits[0]` must be zero.
47    name: A name prefix for the returned tensor (optional).
48
49  Returns:
50    A sorted 1-D int64 Tensor, with `shape=[splits[-1]]`
51
52  Raises:
53    ValueError: If `splits` is invalid.
54  """
55  with ops.name_scope(name, "RaggedSplitsToSegmentIds", [splits]) as name:
56    splits = ops.convert_to_tensor(splits, dtype=dtypes.int64, name="splits")
57    splits.shape.assert_has_rank(1)
58    if tensor_shape.dimension_value(splits.shape[0]) == 0:
59      raise ValueError("Invalid row_splits: []")
60    row_lengths = splits[1:] - splits[:-1]
61    nrows = array_ops.shape(splits, out_type=dtypes.int64)[-1] - 1
62    indices = math_ops.range(nrows)
63    return ragged_util.repeat(indices, repeats=row_lengths, axis=0)
64
65
66# For background on "segments" and "segment ids", see:
67# https://www.tensorflow.org/api_docs/python/tf/math#Segmentation
68@tf_export("ragged.segment_ids_to_row_splits")
69def segment_ids_to_row_splits(segment_ids, num_segments=None, name=None):
70  """Generates the RaggedTensor `row_splits` corresponding to a segmentation.
71
72  Returns an integer vector `splits`, where `splits[0] = 0` and
73  `splits[i] = splits[i-1] + count(segment_ids==i)`.  Example:
74
75  ```python
76  >>> ragged.segment_ids_to_row_splits([0, 0, 0, 2, 2, 3, 4, 4, 4]).eval()
77  [ 0 3 3 5 6 9 ]
78  ```
79
80  Args:
81    segment_ids: A 1-D integer Tensor.
82    num_segments: A scalar integer indicating the number of segments.  Defaults
83      to `max(segment_ids) + 1` (or zero if `segment_ids` is empty).
84    name: A name prefix for the returned tensor (optional).
85
86  Returns:
87    A sorted 1-D int64 Tensor, with `shape=[num_segments + 1]`.
88  """
89  with ops.name_scope(name, "SegmentIdsToRaggedSplits", [segment_ids]) as name:
90    segment_ids = ragged_util.convert_to_int_tensor(segment_ids, "segment_ids")
91    segment_ids.shape.assert_has_rank(1)
92    if num_segments is not None:
93      num_segments = ragged_util.convert_to_int_tensor(num_segments,
94                                                       "num_segments")
95      num_segments.shape.assert_has_rank(0)
96
97    row_lengths = math_ops.bincount(
98        segment_ids,
99        minlength=num_segments,
100        maxlength=num_segments,
101        dtype=dtypes.int64)
102    splits = array_ops.concat([[0], math_ops.cumsum(row_lengths)], axis=0)
103
104    # Update shape information, if possible.
105    if num_segments is not None:
106      const_num_segments = tensor_util.constant_value(num_segments)
107      if const_num_segments is not None:
108        splits.set_shape(tensor_shape.TensorShape([const_num_segments + 1]))
109
110    return splits
111