1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for ragged_range op.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import math 22from absl.testing import parameterized 23 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import errors 26from tensorflow.python.framework import test_util 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops.ragged import ragged_factory_ops 29from tensorflow.python.ops.ragged import ragged_math_ops 30from tensorflow.python.ops.ragged import ragged_tensor 31from tensorflow.python.ops.ragged import ragged_test_util 32from tensorflow.python.platform import googletest 33 34 35def prod(values): 36 val = 1 37 for v in values: 38 val *= v 39 return val 40 # return reduce(lambda x, y: x * y, values, 1) 41 42 43def mean(values): 44 return 1.0 * sum(values) / len(values) 45 46 47def sqrt_n(values): 48 return 1.0 * sum(values) / math.sqrt(len(values)) 49 50 51@test_util.run_all_in_graph_and_eager_modes 52class RaggedSegmentOpsTest(ragged_test_util.RaggedTensorTestCase, 53 parameterized.TestCase): 54 55 def expected_value(self, data, segment_ids, num_segments, combiner): 56 """Find the expected value for a call to ragged_segment_<aggregate>. 57 58 Args: 59 data: The input RaggedTensor, expressed as a nested python list. 60 segment_ids: The segment ids, as a python list of ints. 61 num_segments: The number of segments, as a python int. 62 combiner: The Python function used to combine values. 63 Returns: 64 The expected value, as a nested Python list. 65 """ 66 self.assertLen(data, len(segment_ids)) 67 68 # Build an empty (num_segments x ncols) "grouped" matrix 69 ncols = max(len(row) for row in data) 70 grouped = [[[] for _ in range(ncols)] for row in range(num_segments)] 71 72 # Append values from data[row] to grouped[segment_ids[row]] 73 for row in range(len(data)): 74 for col in range(len(data[row])): 75 grouped[segment_ids[row]][col].append(data[row][col]) 76 77 # Combine the values. 78 return [[combiner(values) 79 for values in grouped_row 80 if values] 81 for grouped_row in grouped] 82 83 @parameterized.parameters( 84 (ragged_math_ops.segment_sum, sum, [0, 0, 1, 1, 2, 2]), 85 (ragged_math_ops.segment_sum, sum, [0, 0, 0, 1, 1, 1]), 86 (ragged_math_ops.segment_sum, sum, [5, 4, 3, 2, 1, 0]), 87 (ragged_math_ops.segment_sum, sum, [0, 0, 0, 10, 10, 10]), 88 (ragged_math_ops.segment_prod, prod, [0, 0, 1, 1, 2, 2]), 89 (ragged_math_ops.segment_prod, prod, [0, 0, 0, 1, 1, 1]), 90 (ragged_math_ops.segment_prod, prod, [5, 4, 3, 2, 1, 0]), 91 (ragged_math_ops.segment_prod, prod, [0, 0, 0, 10, 10, 10]), 92 (ragged_math_ops.segment_min, min, [0, 0, 1, 1, 2, 2]), 93 (ragged_math_ops.segment_min, min, [0, 0, 0, 1, 1, 1]), 94 (ragged_math_ops.segment_min, min, [5, 4, 3, 2, 1, 0]), 95 (ragged_math_ops.segment_min, min, [0, 0, 0, 10, 10, 10]), 96 (ragged_math_ops.segment_max, max, [0, 0, 1, 1, 2, 2]), 97 (ragged_math_ops.segment_max, max, [0, 0, 0, 1, 1, 1]), 98 (ragged_math_ops.segment_max, max, [5, 4, 3, 2, 1, 0]), 99 (ragged_math_ops.segment_max, max, [0, 0, 0, 10, 10, 10]), 100 (ragged_math_ops.segment_mean, mean, [0, 0, 1, 1, 2, 2]), 101 (ragged_math_ops.segment_mean, mean, [0, 0, 0, 1, 1, 1]), 102 (ragged_math_ops.segment_mean, mean, [5, 4, 3, 2, 1, 0]), 103 (ragged_math_ops.segment_mean, mean, [0, 0, 0, 10, 10, 10]), 104 ) 105 def testRaggedSegment_Int(self, segment_op, combiner, segment_ids): 106 rt_as_list = [[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]] 107 rt = ragged_factory_ops.constant(rt_as_list) 108 num_segments = max(segment_ids) + 1 109 expected = self.expected_value(rt_as_list, segment_ids, num_segments, 110 combiner) 111 112 segmented = segment_op(rt, segment_ids, num_segments) 113 self.assertRaggedEqual(segmented, expected) 114 115 @parameterized.parameters( 116 (ragged_math_ops.segment_sum, sum, [0, 0, 1, 1, 2, 2]), 117 (ragged_math_ops.segment_sum, sum, [0, 0, 0, 1, 1, 1]), 118 (ragged_math_ops.segment_sum, sum, [5, 4, 3, 2, 1, 0]), 119 (ragged_math_ops.segment_sum, sum, [0, 0, 0, 10, 10, 10]), 120 (ragged_math_ops.segment_prod, prod, [0, 0, 1, 1, 2, 2]), 121 (ragged_math_ops.segment_prod, prod, [0, 0, 0, 1, 1, 1]), 122 (ragged_math_ops.segment_prod, prod, [5, 4, 3, 2, 1, 0]), 123 (ragged_math_ops.segment_prod, prod, [0, 0, 0, 10, 10, 10]), 124 (ragged_math_ops.segment_min, min, [0, 0, 1, 1, 2, 2]), 125 (ragged_math_ops.segment_min, min, [0, 0, 0, 1, 1, 1]), 126 (ragged_math_ops.segment_min, min, [5, 4, 3, 2, 1, 0]), 127 (ragged_math_ops.segment_min, min, [0, 0, 0, 10, 10, 10]), 128 (ragged_math_ops.segment_max, max, [0, 0, 1, 1, 2, 2]), 129 (ragged_math_ops.segment_max, max, [0, 0, 0, 1, 1, 1]), 130 (ragged_math_ops.segment_max, max, [5, 4, 3, 2, 1, 0]), 131 (ragged_math_ops.segment_max, max, [0, 0, 0, 10, 10, 10]), 132 (ragged_math_ops.segment_mean, mean, [0, 0, 1, 1, 2, 2]), 133 (ragged_math_ops.segment_mean, mean, [0, 0, 0, 1, 1, 1]), 134 (ragged_math_ops.segment_mean, mean, [5, 4, 3, 2, 1, 0]), 135 (ragged_math_ops.segment_mean, mean, [0, 0, 0, 10, 10, 10]), 136 (ragged_math_ops.segment_sqrt_n, sqrt_n, [0, 0, 1, 1, 2, 2]), 137 (ragged_math_ops.segment_sqrt_n, sqrt_n, [0, 0, 0, 1, 1, 1]), 138 (ragged_math_ops.segment_sqrt_n, sqrt_n, [5, 4, 3, 2, 1, 0]), 139 (ragged_math_ops.segment_sqrt_n, sqrt_n, [0, 0, 0, 10, 10, 10]), 140 ) 141 def testRaggedSegment_Float(self, segment_op, combiner, segment_ids): 142 rt_as_list = [[0., 1., 2., 3.], [4.], [], [5., 6.], [7.], [8., 9.]] 143 rt = ragged_factory_ops.constant(rt_as_list) 144 num_segments = max(segment_ids) + 1 145 expected = self.expected_value(rt_as_list, segment_ids, num_segments, 146 combiner) 147 148 segmented = segment_op(rt, segment_ids, num_segments) 149 self.assertRaggedAlmostEqual(segmented, expected, places=5) 150 151 def testRaggedRankTwo(self): 152 rt = ragged_factory_ops.constant([ 153 [[111, 112, 113, 114], [121],], # row 0 154 [], # row 1 155 [[], [321, 322], [331]], # row 2 156 [[411, 412]] # row 3 157 ]) # pyformat: disable 158 segment_ids1 = [0, 2, 2, 2] 159 segmented1 = ragged_math_ops.segment_sum(rt, segment_ids1, 3) 160 expected1 = [[[111, 112, 113, 114], [121]], # row 0 161 [], # row 1 162 [[411, 412], [321, 322], [331]] # row 2 163 ] # pyformat: disable 164 self.assertRaggedEqual(segmented1, expected1) 165 166 segment_ids2 = [1, 2, 1, 1] 167 segmented2 = ragged_math_ops.segment_sum(rt, segment_ids2, 3) 168 expected2 = [[], 169 [[111+411, 112+412, 113, 114], [121+321, 322], [331]], 170 []] # pyformat: disable 171 self.assertRaggedEqual(segmented2, expected2) 172 173 def testRaggedSegmentIds(self): 174 rt = ragged_factory_ops.constant([ 175 [[111, 112, 113, 114], [121],], # row 0 176 [], # row 1 177 [[], [321, 322], [331]], # row 2 178 [[411, 412]] # row 3 179 ]) # pyformat: disable 180 segment_ids = ragged_factory_ops.constant([[1, 2], [], [1, 1, 2], [2]]) 181 segmented = ragged_math_ops.segment_sum(rt, segment_ids, 3) 182 expected = [[], 183 [111+321, 112+322, 113, 114], 184 [121+331+411, 412]] # pyformat: disable 185 self.assertRaggedEqual(segmented, expected) 186 187 def testShapeMismatchError1(self): 188 dt = constant_op.constant([1, 2, 3, 4, 5, 6]) 189 segment_ids = ragged_factory_ops.constant([[1, 2], []]) 190 self.assertRaisesRegexp( 191 ValueError, 'segment_ids.shape must be a prefix of data.shape, ' 192 'but segment_ids is ragged and data is not.', 193 ragged_math_ops.segment_sum, dt, segment_ids, 3) 194 195 def testShapeMismatchError2(self): 196 rt = ragged_factory_ops.constant([ 197 [[111, 112, 113, 114], [121]], # row 0 198 [], # row 1 199 [[], [321, 322], [331]], # row 2 200 [[411, 412]] # row 3 201 ]) # pyformat: disable 202 segment_ids = ragged_factory_ops.constant([[1, 2], [1], [1, 1, 2], [2]]) 203 204 # Error is raised at graph-building time if we can detect it then. 205 self.assertRaisesRegexp( 206 errors.InvalidArgumentError, 207 'segment_ids.shape must be a prefix of data.shape.*', 208 ragged_math_ops.segment_sum, rt, segment_ids, 3) 209 210 # Otherwise, error is raised when we run the graph. 211 segment_ids2 = ragged_tensor.RaggedTensor.from_row_splits( 212 array_ops.placeholder_with_default(segment_ids.values, None), 213 array_ops.placeholder_with_default(segment_ids.row_splits, None)) 214 with self.assertRaisesRegexp( 215 errors.InvalidArgumentError, 216 'segment_ids.shape must be a prefix of data.shape.*'): 217 self.evaluate(ragged_math_ops.segment_sum(rt, segment_ids2, 3)) 218 219 220if __name__ == '__main__': 221 googletest.main() 222