1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Ops for converting between row_splits and segment_ids.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import ops 23from tensorflow.python.framework import tensor_shape 24from tensorflow.python.framework import tensor_util 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import math_ops 27from tensorflow.python.ops.ragged import ragged_util 28from tensorflow.python.util.tf_export import tf_export 29 30 31# For background on "segments" and "segment ids", see: 32# https://www.tensorflow.org/api_docs/python/tf/math#Segmentation 33@tf_export("ragged.row_splits_to_segment_ids") 34def row_splits_to_segment_ids(splits, name=None): 35 """Generates the segmentation corresponding to a RaggedTensor `row_splits`. 36 37 Returns an integer vector `segment_ids`, where `segment_ids[i] == j` if 38 `splits[j] <= i < splits[j+1]`. Example: 39 40 ```python 41 >>> ragged.row_splits_to_segment_ids([0, 3, 3, 5, 6, 9]).eval() 42 [ 0 0 0 2 2 3 4 4 4 ] 43 ``` 44 45 Args: 46 splits: A sorted 1-D int64 Tensor. `splits[0]` must be zero. 47 name: A name prefix for the returned tensor (optional). 48 49 Returns: 50 A sorted 1-D int64 Tensor, with `shape=[splits[-1]]` 51 52 Raises: 53 ValueError: If `splits` is invalid. 54 """ 55 with ops.name_scope(name, "RaggedSplitsToSegmentIds", [splits]) as name: 56 splits = ops.convert_to_tensor(splits, dtype=dtypes.int64, name="splits") 57 splits.shape.assert_has_rank(1) 58 if tensor_shape.dimension_value(splits.shape[0]) == 0: 59 raise ValueError("Invalid row_splits: []") 60 row_lengths = splits[1:] - splits[:-1] 61 nrows = array_ops.shape(splits, out_type=dtypes.int64)[-1] - 1 62 indices = math_ops.range(nrows) 63 return ragged_util.repeat(indices, repeats=row_lengths, axis=0) 64 65 66# For background on "segments" and "segment ids", see: 67# https://www.tensorflow.org/api_docs/python/tf/math#Segmentation 68@tf_export("ragged.segment_ids_to_row_splits") 69def segment_ids_to_row_splits(segment_ids, num_segments=None, name=None): 70 """Generates the RaggedTensor `row_splits` corresponding to a segmentation. 71 72 Returns an integer vector `splits`, where `splits[0] = 0` and 73 `splits[i] = splits[i-1] + count(segment_ids==i)`. Example: 74 75 ```python 76 >>> ragged.segment_ids_to_row_splits([0, 0, 0, 2, 2, 3, 4, 4, 4]).eval() 77 [ 0 3 3 5 6 9 ] 78 ``` 79 80 Args: 81 segment_ids: A 1-D integer Tensor. 82 num_segments: A scalar integer indicating the number of segments. Defaults 83 to `max(segment_ids) + 1` (or zero if `segment_ids` is empty). 84 name: A name prefix for the returned tensor (optional). 85 86 Returns: 87 A sorted 1-D int64 Tensor, with `shape=[num_segments + 1]`. 88 """ 89 with ops.name_scope(name, "SegmentIdsToRaggedSplits", [segment_ids]) as name: 90 segment_ids = ragged_util.convert_to_int_tensor(segment_ids, "segment_ids") 91 segment_ids.shape.assert_has_rank(1) 92 if num_segments is not None: 93 num_segments = ragged_util.convert_to_int_tensor(num_segments, 94 "num_segments") 95 num_segments.shape.assert_has_rank(0) 96 97 row_lengths = math_ops.bincount( 98 segment_ids, 99 minlength=num_segments, 100 maxlength=num_segments, 101 dtype=dtypes.int64) 102 splits = array_ops.concat([[0], math_ops.cumsum(row_lengths)], axis=0) 103 104 # Update shape information, if possible. 105 if num_segments is not None: 106 const_num_segments = tensor_util.constant_value(num_segments) 107 if const_num_segments is not None: 108 splits.set_shape(tensor_shape.TensorShape([const_num_segments + 1])) 109 110 return splits 111