1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Functional tests for segment reduction ops.""" 16 17import itertools 18 19import numpy as np 20 21from tensorflow.python.client import session 22from tensorflow.python.framework import constant_op 23from tensorflow.python.framework import dtypes as dtypes_lib 24from tensorflow.python.framework import errors_impl 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import test_util 27from tensorflow.python.ops import gradient_checker 28from tensorflow.python.ops import gradient_checker_v2 29from tensorflow.python.ops import math_ops 30from tensorflow.python.ops import variables 31from tensorflow.python.platform import test 32 33 34class SegmentReductionHelper(test.TestCase): 35 36 def _input(self, input_shape, dtype=dtypes_lib.int32): 37 num_elem = 1 38 for x in input_shape: 39 num_elem *= x 40 values = np.arange(1, num_elem + 1) 41 np_values = values.reshape(input_shape).astype(dtype.as_numpy_dtype) 42 # Add a non-zero imaginary component to complex types. 43 if dtype.is_complex: 44 np_values -= 1j * np_values 45 return constant_op.constant( 46 np_values, shape=input_shape, dtype=dtype), np_values 47 48 def _segmentReduce(self, indices, x, op1, op2=None, num_segments=None, 49 initial_value=0): 50 if not x.size: 51 return np.array([]) 52 indices = np.asarray(indices) 53 if num_segments is None: 54 num_segments = indices[-1] + 1 55 output = [None] * num_segments 56 slice_shape = x.shape[indices.ndim:] 57 x_flat = x.reshape((indices.size,) + slice_shape) 58 for i, index in enumerate(indices.ravel()): 59 if (output[index] is not None) and op1 == np.max: 60 for j in range(0, output[index].shape[0]): 61 output[index][j] = op1([output[index][j], x_flat[i][j]]) 62 elif output[index] is not None: 63 output[index] = op1(output[index], x_flat[i]) 64 else: 65 output[index] = x_flat[i] 66 # zero initialize values that are still uncalculated. 67 initial_value_slice = np.ones(slice_shape) * initial_value 68 output = [o if o is not None else initial_value_slice for o in output] 69 if op2 is not None: 70 output = [op2(o) for o in output] 71 output = [o.reshape(slice_shape) for o in output] 72 return np.array(output) 73 74 def _mean_cum_op(self, x, y): 75 return (x[0] + y, x[1] + 1) if isinstance(x, tuple) else (x + y, 2) 76 77 def _mean_reduce_op(self, x): 78 return x[0] / x[1] if isinstance(x, tuple) else x 79 80 def _sqrt_n_reduce_op(self, x): 81 return x[0] / np.sqrt(x[1]) if isinstance(x, tuple) else x 82 83 84class SegmentReductionOpTest(SegmentReductionHelper): 85 86 def testValues(self): 87 dtypes = [ 88 dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int64, 89 dtypes_lib.int32, dtypes_lib.complex64, dtypes_lib.complex128 90 ] 91 92 # Each item is np_op1, np_op2, tf_op 93 ops_list = [(np.add, None, math_ops.segment_sum), 94 (self._mean_cum_op, self._mean_reduce_op, 95 math_ops.segment_mean), 96 (np.ndarray.__mul__, None, math_ops.segment_prod), 97 (np.minimum, None, math_ops.segment_min), 98 (np.maximum, None, math_ops.segment_max)] 99 100 # A subset of ops has been enabled for complex numbers 101 complex_ops_list = [(np.add, None, math_ops.segment_sum), 102 (np.ndarray.__mul__, None, math_ops.segment_prod), 103 (self._mean_cum_op, self._mean_reduce_op, 104 math_ops.segment_mean)] 105 106 n = 10 107 # Note that the GPU implem has different paths for different inner sizes. 108 for shape in [[n, 1], [n, 2], [n, 3], [n, 32]]: 109 indices = [i // 3 for i in range(n)] 110 for dtype in dtypes: 111 if dtype in (dtypes_lib.complex64, dtypes_lib.complex128): 112 curr_ops_list = complex_ops_list 113 else: 114 curr_ops_list = ops_list 115 for use_gpu in [True, False]: 116 with self.cached_session(use_gpu=use_gpu): 117 tf_x, np_x = self._input(shape, dtype=dtype) 118 for np_op1, np_op2, tf_op in curr_ops_list: 119 initial_value = 1 if tf_op is math_ops.segment_prod else 0 120 np_ans = self._segmentReduce( 121 indices, np_x, np_op1, np_op2, initial_value=initial_value) 122 s = tf_op(data=tf_x, segment_ids=indices) 123 tf_ans = self.evaluate(s) 124 self.assertAllClose(np_ans, tf_ans) 125 # NOTE(mrry): The static shape inference that computes 126 # `tf_ans.shape` can only infer that sizes from dimension 1 127 # onwards, because the size of dimension 0 is data-dependent 128 # and may therefore vary dynamically. 129 self.assertAllEqual(np_ans.shape[1:], tf_ans.shape[1:]) 130 131 @test_util.run_deprecated_v1 132 def testSegmentIdsShape(self): 133 shape = [4, 4] 134 tf_x, _ = self._input(shape) 135 indices = constant_op.constant([0, 1, 2, 2], shape=[2, 2]) 136 with self.assertRaises(ValueError): 137 math_ops.segment_sum(data=tf_x, segment_ids=indices) 138 139 @test_util.run_deprecated_v1 140 def testSegmentIdsSize(self): 141 shape = [4, 4] 142 for use_gpu in [True, False]: 143 with self.cached_session(use_gpu=use_gpu): 144 tf_x, _ = self._input(shape) 145 indices = [0, 1] 146 s = math_ops.segment_sum(data=tf_x, segment_ids=indices) 147 with self.assertRaisesOpError("segment_ids should be the same size"): 148 self.evaluate(s) 149 150 @test_util.run_deprecated_v1 151 def testSegmentIdsValid(self): 152 # This is a baseline for the following SegmentIdsInvalid* tests. 153 shape = [4, 4] 154 for use_gpu in [True, False]: 155 with self.cached_session(use_gpu=use_gpu): 156 tf_x, _ = self._input(shape, dtype=dtypes_lib.float32) 157 indices = [0, 0, 0, 1] 158 result = math_ops.segment_sum(data=tf_x, segment_ids=indices).eval() 159 self.assertAllEqual([[15, 18, 21, 24], [13, 14, 15, 16]], result) 160 161 def testSegmentIdsGreaterThanZero(self): 162 shape = [4, 4] 163 for use_gpu in [True, False]: 164 with self.cached_session(use_gpu=use_gpu): 165 tf_x, np_x = self._input(shape, dtype=dtypes_lib.float32) 166 indices = [1, 1, 2, 2] 167 np_ans = self._segmentReduce(indices, np_x, np.add) 168 s = math_ops.segment_sum(data=tf_x, segment_ids=indices) 169 tf_ans = self.evaluate(s) 170 self.assertAllClose(np_ans, tf_ans) 171 172 def testSegmentIdsHole(self): 173 shape = [4, 4] 174 for use_gpu in [True, False]: 175 with self.cached_session(use_gpu=use_gpu): 176 tf_x, np_x = self._input(shape, dtype=dtypes_lib.float32) 177 indices = [0, 0, 3, 3] 178 np_ans = self._segmentReduce(indices, np_x, np.add) 179 s = math_ops.segment_sum(data=tf_x, segment_ids=indices) 180 tf_ans = self.evaluate(s) 181 self.assertAllClose(np_ans, tf_ans) 182 183 @test_util.run_deprecated_v1 184 def testSegmentIdsInvalid1(self): 185 shape = [4, 4] 186 with self.cached_session(): 187 tf_x, _ = self._input(shape) 188 indices = [-1, -1, 0, 0] 189 s = math_ops.segment_sum(data=tf_x, segment_ids=indices) 190 with self.assertRaisesOpError( 191 r"Segment id -1 out of range \[0, 1\), possibly because " 192 "'segment_ids' input is not sorted."): 193 self.evaluate(s) 194 195 @test_util.run_deprecated_v1 196 def testSegmentIdsInvalid2(self): 197 shape = [4, 4] 198 with self.cached_session(): 199 tf_x, _ = self._input(shape) 200 indices = [0, 1, 0, 1] 201 s = math_ops.segment_sum(data=tf_x, segment_ids=indices) 202 with self.assertRaisesOpError("segment ids are not increasing"): 203 self.evaluate(s) 204 205 @test_util.run_deprecated_v1 206 def testSegmentIdsInvalid3(self): 207 shape = [4, 4] 208 with self.cached_session(): 209 tf_x, _ = self._input(shape) 210 indices = [0, 1, 2, 0] 211 s = math_ops.segment_sum(data=tf_x, segment_ids=indices) 212 with self.assertRaisesOpError( 213 r"Segment id 1 out of range \[0, 1\), possibly " 214 "because 'segment_ids' input is not sorted."): 215 self.evaluate(s) 216 217 @test_util.run_deprecated_v1 218 def testSegmentIdsInvalid4(self): 219 shape = [4, 4] 220 for use_gpu in [True, False]: 221 with self.cached_session(use_gpu=use_gpu): 222 tf_x, _ = self._input(shape, dtype=dtypes_lib.float32) 223 indices = [0, 0, 0, -1] 224 s = math_ops.segment_sum(data=tf_x, segment_ids=indices) 225 with self.assertRaisesOpError("segment ids must be >= 0"): 226 self.evaluate(s) 227 228 @test_util.run_deprecated_v1 229 def testSegmentIdsInvalid5(self): 230 shape = [4, 4] 231 for use_gpu in [True, False]: 232 with self.cached_session(use_gpu=use_gpu): 233 tf_x, _ = self._input(shape, dtype=dtypes_lib.float32) 234 indices = [0, 0, 0, -2] 235 s = math_ops.segment_sum(data=tf_x, segment_ids=indices) 236 with self.assertRaisesOpError("segment ids must be >= 0"): 237 self.evaluate(s) 238 239 @test_util.run_deprecated_v1 240 def testGradient(self): 241 shape = [4, 4] 242 indices = [0, 1, 2, 2] 243 for tf_op in [ 244 math_ops.segment_sum, math_ops.segment_mean, math_ops.segment_min, 245 math_ops.segment_max 246 ]: 247 with self.cached_session(): 248 tf_x, np_x = self._input(shape, dtype=dtypes_lib.float64) 249 s = tf_op(data=tf_x, segment_ids=indices) 250 jacob_t, jacob_n = gradient_checker.compute_gradient( 251 tf_x, 252 shape, 253 s, [3, 4], 254 x_init_value=np_x.astype(np.double), 255 delta=1) 256 self.assertAllClose(jacob_t, jacob_n) 257 258 def testDataInvalid(self): 259 # Test case for GitHub issue 40653. 260 for use_gpu in [True, False]: 261 with self.cached_session(use_gpu=use_gpu): 262 with self.assertRaisesRegex( 263 (ValueError, errors_impl.InvalidArgumentError), 264 "must be at least rank 1"): 265 s = math_ops.segment_mean( 266 data=np.uint16(10), segment_ids=np.array([]).astype("int64")) 267 self.evaluate(s) 268 269 def testInvalidIds(self): 270 # Test case for GitHub issue 46888. 271 for op in [ 272 math_ops.segment_max, 273 math_ops.segment_min, 274 math_ops.segment_mean, 275 math_ops.segment_sum, 276 math_ops.segment_prod, 277 ]: 278 with self.cached_session(use_gpu=False): 279 with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)): 280 s = op(data=np.ones((1, 10, 1)), segment_ids=[1676240524292489355]) 281 self.evaluate(s) 282 283 284class UnsortedSegmentTest(SegmentReductionHelper): 285 286 def __init__(self, methodName='runTest'): 287 # Each item is np_op1, np_op2, tf_op, initial_value functor 288 self.ops_list = [(np.add, None, 289 math_ops.unsorted_segment_sum, lambda t: 0), 290 (self._mean_cum_op, self._mean_reduce_op, 291 math_ops.unsorted_segment_mean, lambda t: 0), 292 (self._mean_cum_op, self._sqrt_n_reduce_op, 293 math_ops.unsorted_segment_sqrt_n, lambda t: 0), 294 (np.ndarray.__mul__, None, 295 math_ops.unsorted_segment_prod, lambda t: 1), 296 (np.minimum, None, 297 math_ops.unsorted_segment_min, lambda t: t.max), 298 (np.maximum, None, 299 math_ops.unsorted_segment_max, lambda t: t.min)] 300 301 # A subset of ops has been enabled for complex numbers 302 self.complex_ops_list = [(np.add, None, 303 math_ops.unsorted_segment_sum, lambda t: 0), 304 (np.ndarray.__mul__, None, 305 math_ops.unsorted_segment_prod, lambda t: 1)] 306 self.differentiable_dtypes = [dtypes_lib.float16, dtypes_lib.float32, 307 dtypes_lib.float64] 308 self.all_dtypes = (self.differentiable_dtypes + 309 [dtypes_lib.bfloat16, 310 dtypes_lib.int64, dtypes_lib.int32, 311 dtypes_lib.complex64, dtypes_lib.complex128]) 312 super(UnsortedSegmentTest, self).__init__(methodName=methodName) 313 314 def testValues(self): 315 indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3]) 316 num_segments = 12 317 for indices in indices_flat, indices_flat.reshape(5, 2): 318 # Note that the GPU implem has different paths for different inner sizes. 319 for inner_size in [1, 2, 3, 32]: 320 shape = indices.shape + (inner_size,) 321 for dtype in self.all_dtypes: 322 ops_list = ( 323 self.complex_ops_list if dtype.is_complex else self.ops_list) 324 tf_x, np_x = self._input(shape, dtype=dtype) 325 for use_gpu in [True, False]: 326 with self.cached_session(): 327 for np_op1, np_op2, tf_op, init_op in ops_list: 328 # sqrt_n doesn't support integers 329 if (np_op2 == self._sqrt_n_reduce_op and dtype.is_integer): 330 continue 331 # todo(philjd): enable this test once real_div supports bfloat16 332 if (np_op2 in [self._sqrt_n_reduce_op, self._mean_reduce_op] and 333 dtype == dtypes_lib.bfloat16): 334 continue 335 np_ans = self._segmentReduce( 336 indices, 337 np_x, 338 np_op1, 339 np_op2, 340 num_segments=num_segments, 341 initial_value=init_op(dtype)) 342 s = tf_op(tf_x, segment_ids=indices, num_segments=num_segments) 343 tf_ans = self.evaluate(s) 344 if dtype is dtypes_lib.bfloat16: 345 tf_ans = tf_ans.astype(np.float32) 346 self.assertAllCloseAccordingToType(np_ans, tf_ans) 347 self.assertShapeEqual(np_ans, s) 348 349 def testNumSegmentsTypes(self): 350 dtypes = [dtypes_lib.int32, dtypes_lib.int64] 351 indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3]) 352 num_segments = 12 353 for indices in indices_flat, indices_flat.reshape(5, 2): 354 shape = indices.shape + (2,) 355 for dtype in dtypes: 356 with self.cached_session(): 357 tf_x, np_x = self._input(shape) 358 num_segments_constant = constant_op.constant( 359 num_segments, dtype=dtype) 360 np_ans = self._segmentReduce( 361 indices, np_x, np.add, op2=None, num_segments=num_segments) 362 s = math_ops.unsorted_segment_sum( 363 data=tf_x, 364 segment_ids=indices, 365 num_segments=num_segments_constant) 366 tf_ans = self.evaluate(s) 367 self.assertAllClose(np_ans, tf_ans) 368 self.assertShapeEqual(np_ans, s) 369 370 @test_util.run_deprecated_v1 371 def testGradientsTFGradients(self): 372 num_cols = 2 373 indices_flat = np.array([0, 4, 0, -1, 3, -1, 4, 7, 7, 3]) 374 num_segments = max(indices_flat) + 3 375 for dtype in self.differentiable_dtypes: 376 ops_list = self.complex_ops_list if dtype.is_complex else self.ops_list 377 for indices in indices_flat, indices_flat.reshape(5, 2): 378 shape = indices.shape + (num_cols,) 379 # test CPU and GPU as tf.gather behaves differently on each device 380 for use_gpu in [False, True]: 381 with self.cached_session(use_gpu=use_gpu): 382 for _, _, tf_op, _ in ops_list: 383 tf_x, np_x = self._input(shape, dtype=dtype) 384 s = tf_op(tf_x, indices, num_segments) 385 jacob_t, jacob_n = gradient_checker.compute_gradient( 386 tf_x, 387 shape, 388 s, [num_segments, num_cols], 389 x_init_value=np_x, 390 delta=1.) 391 self.assertAllCloseAccordingToType(jacob_t, jacob_n, 392 half_atol=1e-2) 393 394 @test_util.run_in_graph_and_eager_modes 395 def testGradientsGradientTape(self): 396 num_cols = 2 397 indices_flat = np.array([0, 4, 0, -1, 3, -1, 4, 7, 7, 3]) 398 num_segments = max(indices_flat) + 3 399 for dtype in self.differentiable_dtypes: 400 ops_list = self.complex_ops_list if dtype.is_complex else self.ops_list 401 for indices in indices_flat, indices_flat.reshape(5, 2): 402 shape = indices.shape + (num_cols,) 403 # test CPU and GPU as tf.gather behaves differently on each device 404 for use_gpu in [test_util.use_gpu, test_util.force_cpu]: 405 with use_gpu(): 406 for _, _, tf_op, _ in ops_list: 407 _, np_x = self._input(shape, dtype=dtype) 408 # pylint: disable=cell-var-from-loop 409 def f(x): 410 return tf_op(x, indices, num_segments) 411 gradient_tape_jacob_t, jacob_n = ( 412 gradient_checker_v2.compute_gradient( 413 f, [np_x], delta=1.)) 414 # pylint: enable=cell-var-from-loop 415 self.assertAllCloseAccordingToType(jacob_n, gradient_tape_jacob_t, 416 half_atol=1e-2) 417 418 @test_util.run_deprecated_v1 419 def testProdGrad(self): 420 # additional test for the prod gradient to ensure correct handling of zeros 421 values = np.array([0, 0, 1, 0, 2, 2, 3, 3, 3], dtype=np.float32) 422 indices = np.array([0, 0, 0, 1, 1, 1, 2, 2, 2], dtype=np.int32) 423 indices_neg = np.array([-1, 0, 0, -1, 1, 1, -1, 2, 2], dtype=np.int32) 424 values_tf = constant_op.constant(values) 425 # ground truth partial derivatives 426 gradients_indices = np.zeros((9, 3), dtype=np.float32) 427 gradients_indices_neg = np.zeros((9, 3), dtype=np.float32) 428 # the derivative w.r.t. to the other segments is zero, so here we only 429 # explicitly set the grad values for the corresponding segment 430 gradients_indices[range(9), indices] = [0, 0, 0, 4, 0, 0, 9, 9, 9] 431 gradients_indices_neg[range(9), indices_neg] = [0, 1, 0, 0, 2, 2, 0, 3, 3] 432 for use_gpu in [False, True]: 433 with self.cached_session(use_gpu=use_gpu): 434 for ind, grad_gt in [(indices, gradients_indices), 435 (indices_neg, gradients_indices_neg)]: 436 s = math_ops.unsorted_segment_prod(values_tf, 437 constant_op.constant(ind), 3) 438 jacob_t, jacob_n = gradient_checker.compute_gradient( 439 values_tf, (9,), s, (3,), x_init_value=values, delta=1) 440 self.assertAllClose(jacob_t, jacob_n) 441 self.assertAllClose(jacob_t, grad_gt) 442 443 @test_util.run_deprecated_v1 444 def testGradientMatchesSegmentSum(self): 445 # Strategy: compute the gradient for UnsortedSegmentSum and SegmentSum 446 # and compare the outputs, which should be identical. 447 # NB: for this test to work, indices must be valid for SegmentSum, namely 448 # it must be sorted, the indices must be contiguous, and num_segments 449 # must be max(indices) + 1. 450 indices = [0, 0, 1, 1, 1, 2, 3, 4, 5] 451 n = len(indices) 452 num_cols = 2 453 shape = [n, num_cols] 454 num_segments = max(indices) + 1 455 for dtype in self.differentiable_dtypes: 456 with self.cached_session(): 457 tf_x, np_x = self._input(shape, dtype=dtype) 458 # Results from UnsortedSegmentSum 459 unsorted_s = math_ops.unsorted_segment_sum( 460 data=tf_x, segment_ids=indices, num_segments=num_segments) 461 unsorted_jacob_t, unsorted_jacob_n = ( 462 gradient_checker.compute_gradient(tf_x, shape, unsorted_s, 463 [num_segments, num_cols], 464 x_init_value=np_x, delta=1)) 465 466 # Results from SegmentSum 467 sorted_s = math_ops.segment_sum(data=tf_x, segment_ids=indices) 468 sorted_jacob_t, sorted_jacob_n = gradient_checker.compute_gradient( 469 tf_x, 470 shape, 471 sorted_s, [num_segments, num_cols], 472 x_init_value=np_x, 473 delta=1) 474 self.assertAllClose(unsorted_jacob_t, sorted_jacob_t) 475 self.assertAllClose(unsorted_jacob_n, sorted_jacob_n) 476 477 @test_util.run_deprecated_v1 478 def testBadIndices(self): 479 # Note: GPU kernel does not return the out-of-range error needed for this 480 # test, so this test is marked as cpu-only. 481 # Note: With PR #13055 a negative index will be ignored silently. 482 with self.session(use_gpu=False): 483 for bad in [[2]], [[7]]: 484 unsorted = math_ops.unsorted_segment_sum([[17]], bad, num_segments=2) 485 with self.assertRaisesOpError( 486 r"segment_ids\[0,0\] = %d is out of range \[0, 2\)" % bad[0][0]): 487 self.evaluate(unsorted) 488 489 @test_util.run_deprecated_v1 490 def testEmptySecondDimension(self): 491 dtypes = [np.float16, np.float32, np.float64, np.int64, np.int32, 492 np.complex64, np.complex128] 493 with self.session(): 494 for dtype in dtypes: 495 for itype in (np.int32, np.int64): 496 data = np.zeros((2, 0), dtype=dtype) 497 segment_ids = np.array([0, 1], dtype=itype) 498 unsorted = math_ops.unsorted_segment_sum(data, segment_ids, 2) 499 self.assertAllEqual(unsorted, np.zeros((2, 0), dtype=dtype)) 500 501 def testDropNegatives(self): 502 # Note: the test is done by replacing segment_ids with 8 to -1 503 # for index and replace values generated by numpy with 0. 504 indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3]) 505 num_segments = 12 506 for indices in indices_flat, indices_flat.reshape(5, 2): 507 shape = indices.shape + (2,) 508 for dtype in self.all_dtypes: 509 with self.session(): 510 tf_x, np_x = self._input(shape, dtype=dtype) 511 np_ans = self._segmentReduce( 512 indices, np_x, np.add, op2=None, num_segments=num_segments) 513 # Replace np_ans[8] with 0 for the value 514 np_ans[8:] = 0 515 # Replace 8 with -1 in indices 516 np.place(indices, indices == 8, [-1]) 517 s = math_ops.unsorted_segment_sum( 518 data=tf_x, segment_ids=indices, num_segments=num_segments) 519 tf_ans = self.evaluate(s) 520 self.assertAllClose(np_ans, tf_ans) 521 self.assertShapeEqual(np_ans, s) 522 523 @test_util.run_deprecated_v1 524 def testAllNegatives(self): 525 with self.session(use_gpu=False): 526 data = np.ones((2, 1), dtype=np.float32) 527 segment_ids = np.array([-1, -1], dtype=np.int32) 528 unsorted = math_ops.unsorted_segment_sum(data, segment_ids, 2) 529 self.assertAllClose(unsorted.eval(), np.zeros((2, 1), dtype=np.float32)) 530 531 532class SparseSegmentReductionHelper(SegmentReductionHelper): 533 534 def _sparse_input(self, input_shape, num_indices, dtype=dtypes_lib.int32): 535 a, b = super(SparseSegmentReductionHelper, self)._input(input_shape, dtype) 536 indices = np.random.randint(0, input_shape[0], num_indices).astype(np.int32) 537 return (constant_op.constant( 538 indices, dtype=dtypes_lib.int32), indices, a, b) 539 540 def _sparseSegmentReduce(self, 541 x, 542 indices, 543 segment_indices, 544 op1, 545 op2=None, 546 num_segments=None): 547 return self._segmentReduce( 548 segment_indices, x[indices], op1, op2, num_segments=num_segments) 549 550 def _sparseSegmentReduceGrad(self, ygrad, indices, segment_ids, output_dim0, 551 mode): 552 assert mode in ("sum", "mean", "sqrtn") 553 if mode != "sum": 554 weights = np.zeros(ygrad.shape[0], ygrad.dtype) 555 for segment in segment_ids: 556 weights[segment] += 1 557 weights = 1. / weights if mode == "mean" else 1. / np.sqrt(weights) 558 xgrad = np.zeros([output_dim0, ygrad.shape[1]], ygrad.dtype) 559 for segment, index in zip(segment_ids, indices): 560 if mode == "sum": 561 xgrad[index] += ygrad[segment] 562 else: 563 xgrad[index] += ygrad[segment] * weights[segment] 564 return xgrad 565 566 567class SparseSegmentReductionOpTest(SparseSegmentReductionHelper): 568 569 def testValues(self): 570 dtypes = [ 571 dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int64, 572 dtypes_lib.int32 573 ] 574 575 index_dtypes = [dtypes_lib.int32, dtypes_lib.int64] 576 segment_ids_dtypes = [dtypes_lib.int32, dtypes_lib.int64] 577 578 mean_dtypes = [dtypes_lib.float32, dtypes_lib.float64] 579 580 # Each item is np_op1, np_op2, tf_op 581 ops_list = [(np.add, None, math_ops.sparse_segment_sum), 582 (self._mean_cum_op, self._mean_reduce_op, 583 math_ops.sparse_segment_mean)] 584 585 n = 400 586 # Note that the GPU implem has different paths for different inner sizes. 587 for inner_size in [1, 2, 3, 32]: 588 shape = [n, inner_size] 589 segment_indices = [] 590 for i in range(20): 591 for _ in range(i + 1): 592 segment_indices.append(i) 593 num_indices = len(segment_indices) 594 for dtype in dtypes: 595 for index_dtype in index_dtypes: 596 for segment_ids_dtype in segment_ids_dtypes: 597 with self.cached_session(): 598 tf_indices, np_indices, tf_x, np_x = self._sparse_input( 599 shape, num_indices, dtype=dtype) 600 for np_op1, np_op2, tf_op in ops_list: 601 if (tf_op == math_ops.sparse_segment_mean and 602 dtype not in mean_dtypes): 603 continue 604 np_ans = self._sparseSegmentReduce(np_x, np_indices, 605 segment_indices, np_op1, 606 np_op2) 607 s = tf_op( 608 data=tf_x, 609 indices=math_ops.cast(tf_indices, index_dtype), 610 segment_ids=math_ops.cast(segment_indices, 611 segment_ids_dtype)) 612 tf_ans = self.evaluate(s) 613 self.assertAllClose(np_ans, tf_ans) 614 # NOTE(mrry): The static shape inference that computes 615 # `tf_ans.shape` can only infer that sizes from dimension 1 616 # onwards, because the size of dimension 0 is data-dependent 617 # and may therefore vary dynamically. 618 self.assertAllEqual(np_ans.shape[1:], tf_ans.shape[1:]) 619 620 def testSegmentIdsHole(self): 621 tf_x, np_x = self._input([10, 4], dtype=dtypes_lib.float32) 622 ops_list = [(np.add, None, math_ops.sparse_segment_sum), ( 623 self._mean_cum_op, self._mean_reduce_op, math_ops.sparse_segment_mean)] 624 segment_indices = [0, 2, 2, 2] 625 tf_indices = [8, 3, 0, 9] 626 with self.session(): 627 for np_op1, np_op2, tf_op in ops_list: 628 np_ans = self._sparseSegmentReduce(np_x, tf_indices, segment_indices, 629 np_op1, np_op2) 630 s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices) 631 tf_ans = self.evaluate(s) 632 self.assertAllClose(np_ans, tf_ans) 633 634 def testWithNumSegments(self): 635 tf_x, np_x = self._input([10, 4], dtype=dtypes_lib.float32) 636 ops_list = [(np.add, None, math_ops.sparse_segment_sum_with_num_segments), 637 (self._mean_cum_op, self._mean_reduce_op, 638 math_ops.sparse_segment_mean_with_num_segments)] 639 segment_indices = [0, 2, 2, 2] 640 tf_indices = [8, 3, 0, 9] 641 num_segments = 5 642 with self.session(): 643 for np_op1, np_op2, tf_op in ops_list: 644 np_ans = self._sparseSegmentReduce( 645 np_x, 646 tf_indices, 647 segment_indices, 648 np_op1, 649 np_op2, 650 num_segments=num_segments) 651 s = tf_op( 652 data=tf_x, 653 indices=tf_indices, 654 segment_ids=segment_indices, 655 num_segments=num_segments) 656 tf_ans = self.evaluate(s) 657 self.assertAllClose(np_ans, tf_ans) 658 659 def testWithEmptySegments(self): 660 tf_x = constant_op.constant([], shape=[0, 4], dtype=dtypes_lib.float32) 661 ops_list = [ 662 math_ops.sparse_segment_sum_with_num_segments, 663 math_ops.sparse_segment_mean_with_num_segments 664 ] 665 segment_indices = [] 666 tf_indices = [] 667 num_segments = 5 668 with self.session(): 669 for tf_op in ops_list: 670 s = tf_op( 671 data=tf_x, 672 indices=tf_indices, 673 segment_ids=segment_indices, 674 num_segments=num_segments) 675 tf_ans = self.evaluate(s) 676 self.assertAllClose(np.zeros([5, 4]), tf_ans) 677 678 @test_util.run_in_graph_and_eager_modes 679 def testSegmentScalarIdiRaisesInvalidArgumentError(self): 680 """Test for github #46897.""" 681 ops_list = [ 682 math_ops.sparse_segment_sum, 683 math_ops.sparse_segment_mean, 684 math_ops.sparse_segment_sqrt_n, 685 ] 686 for op in ops_list: 687 with self.assertRaisesRegex( 688 (ValueError, errors_impl.InvalidArgumentError), 689 "Shape must be at least rank 1"): 690 op(data=1.0, indices=[0], segment_ids=[3]) 691 692 def testSegmentIdsGreaterThanZero(self): 693 tf_x, np_x = self._input([10, 4], dtype=dtypes_lib.float32) 694 ops_list = [(np.add, None, math_ops.sparse_segment_sum), ( 695 self._mean_cum_op, self._mean_reduce_op, math_ops.sparse_segment_mean)] 696 segment_indices = [1, 2, 2, 2] 697 tf_indices = [8, 3, 0, 9] 698 with self.session(): 699 for np_op1, np_op2, tf_op in ops_list: 700 np_ans = self._sparseSegmentReduce(np_x, tf_indices, segment_indices, 701 np_op1, np_op2) 702 s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices) 703 tf_ans = self.evaluate(s) 704 self.assertAllClose(np_ans, tf_ans) 705 706 def testValid(self): 707 # Baseline for the test*Invalid* methods below. 708 tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32) 709 ops_list = [math_ops.sparse_segment_sum, math_ops.sparse_segment_mean] 710 segment_indices = [0, 1, 2, 2] 711 tf_indices = [8, 3, 0, 9] 712 with self.session(): 713 for tf_op in ops_list: 714 s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices) 715 self.evaluate(s) 716 717 @test_util.run_deprecated_v1 718 def testIndicesInvalid1(self): 719 tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32) 720 ops_list = [math_ops.sparse_segment_sum, math_ops.sparse_segment_mean] 721 segment_indices = [0, 1, 2, 2] 722 tf_indices = [8, -1, 0, 9] 723 with self.session(use_gpu=False): 724 for tf_op in ops_list: 725 s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices) 726 with self.assertRaisesOpError( 727 r"indices\[1\] == -1 out of range \[0, 10\)"): 728 self.evaluate(s) 729 730 @test_util.run_deprecated_v1 731 def testIndicesInvalid2(self): 732 tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32) 733 ops_list = [math_ops.sparse_segment_sum, math_ops.sparse_segment_mean] 734 segment_indices = [0, 1, 2, 2] 735 tf_indices = [8, 3, 0, 10] 736 with self.session(use_gpu=False): 737 for tf_op in ops_list: 738 s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices) 739 with self.assertRaisesOpError( 740 r"indices\[3\] == 10 out of range \[0, 10\)"): 741 self.evaluate(s) 742 743 @test_util.run_deprecated_v1 744 def testSegmentsInvalid2(self): 745 tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32) 746 ops_list = [math_ops.sparse_segment_sum, math_ops.sparse_segment_mean] 747 segment_indices = [0, 1, 0, 1] 748 tf_indices = [8, 3, 0, 9] 749 with self.session(use_gpu=False): 750 for tf_op in ops_list: 751 s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices) 752 with self.assertRaisesOpError("segment ids are not increasing"): 753 self.evaluate(s) 754 755 @test_util.run_deprecated_v1 756 def testSegmentsInvalid3(self): 757 tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32) 758 ops_list = [math_ops.sparse_segment_sum, math_ops.sparse_segment_mean] 759 segment_indices = [0, 1, 2, 0] 760 tf_indices = [8, 3, 0, 9] 761 with self.session(use_gpu=False): 762 for tf_op in ops_list: 763 s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices) 764 with self.assertRaisesOpError( 765 r"Segment id 1 out of range \[0, 1\), possibly because " 766 "'segment_ids' input is not sorted"): 767 self.evaluate(s) 768 769 @test_util.run_deprecated_v1 770 def testSegmentsInvalid4(self): 771 tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32) 772 ops_list = [math_ops.sparse_segment_sum, math_ops.sparse_segment_mean] 773 segment_indices = [-1, 0, 1, 1] 774 tf_indices = [8, 3, 0, 9] 775 with self.session(use_gpu=False): 776 for tf_op in ops_list: 777 s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices) 778 with self.assertRaisesOpError( 779 r"Segment id -1 out of range \[0, 2\), possibly because " 780 "'segment_ids' input is not sorted"): 781 self.evaluate(s) 782 783 @test_util.run_deprecated_v1 784 def testSegmentsInvalid6(self): 785 tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32) 786 ops_list = [math_ops.sparse_segment_sum, math_ops.sparse_segment_mean] 787 segment_indices = [0, 0, 0, -1] 788 tf_indices = [8, 3, 0, 9] 789 with self.session(use_gpu=False): 790 for tf_op in ops_list: 791 s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices) 792 with self.assertRaisesOpError("segment ids must be >= 0"): 793 self.evaluate(s) 794 795 @test_util.run_deprecated_v1 796 def testSegmentsInvalid7(self): 797 tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32) 798 ops_list = [math_ops.sparse_segment_sum, math_ops.sparse_segment_mean] 799 segment_indices = [0, 0, 0, -2] 800 tf_indices = [8, 3, 0, 9] 801 with self.session(use_gpu=False): 802 for tf_op in ops_list: 803 s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices) 804 with self.assertRaisesOpError("segment ids must be >= 0"): 805 self.evaluate(s) 806 807 @test_util.run_deprecated_v1 808 def testSegmentsInvalid8(self): 809 tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32) 810 ops_list = [math_ops.sparse_segment_sum, math_ops.sparse_segment_mean] 811 segment_indices = [2**62 - 1] 812 tf_indices = [2**62 - 1] 813 with self.session(use_gpu=False): 814 for tf_op in ops_list: 815 s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices) 816 with self.assertRaisesOpError( 817 "Encountered overflow when multiplying"): 818 self.evaluate(s) 819 820 def testSegmentWithNumSegmentsValid(self): 821 # Baseline for the test*WithNumSegmentsInvalid* methods below. 822 tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32) 823 ops_list = [ 824 math_ops.sparse_segment_sum_with_num_segments, 825 math_ops.sparse_segment_mean_with_num_segments, 826 ] 827 num_segments = 5 828 segment_indices = [0, 1, 3, 3] 829 tf_indices = [8, 3, 0, 9] 830 with self.session(): 831 for tf_op in ops_list: 832 s = tf_op( 833 data=tf_x, 834 indices=tf_indices, 835 segment_ids=segment_indices, 836 num_segments=num_segments) 837 self.evaluate(s) 838 839 @test_util.run_deprecated_v1 840 def testSegmentWithNumSegmentsInvalid1(self): 841 tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32) 842 ops_list = [ 843 math_ops.sparse_segment_sum_with_num_segments, 844 math_ops.sparse_segment_mean_with_num_segments, 845 ] 846 num_segments = 5 847 segment_indices = [0, 1, 3, 5] 848 tf_indices = [8, 3, 0, 9] 849 with self.session(use_gpu=False): 850 for tf_op in ops_list: 851 s = tf_op( 852 data=tf_x, 853 indices=tf_indices, 854 segment_ids=segment_indices, 855 num_segments=num_segments) 856 with self.assertRaisesOpError("segment ids must be < num_segments"): 857 self.evaluate(s) 858 859 @test_util.run_deprecated_v1 860 def testSegmentWithNumSegmentsInvalid2(self): 861 tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32) 862 ops_list = [ 863 math_ops.sparse_segment_sum_with_num_segments, 864 math_ops.sparse_segment_mean_with_num_segments, 865 ] 866 num_segments = -2 867 segment_indices = [0, 1, 3, 3] 868 tf_indices = [8, 3, 0, 9] 869 with self.session(use_gpu=False): 870 for tf_op in ops_list: 871 with self.assertRaisesRegex( 872 ValueError, "Cannot specify a negative value for num_segments"): 873 tf_op( 874 data=tf_x, 875 indices=tf_indices, 876 segment_ids=segment_indices, 877 num_segments=num_segments) 878 879 @test_util.run_deprecated_v1 880 def testGradient(self): 881 shape = [10, 4] 882 883 segment_indices = [0, 1, 2, 2] 884 num_indices = len(segment_indices) 885 for tf_op in [math_ops.sparse_segment_sum, math_ops.sparse_segment_mean]: 886 with self.cached_session(): 887 tf_indices, _, tf_x, np_x = self._sparse_input( 888 shape, num_indices, dtype=dtypes_lib.float64) 889 s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices) 890 jacob_t, jacob_n = gradient_checker.compute_gradient( 891 tf_x, 892 shape, 893 s, [3, 4], 894 x_init_value=np_x.astype(np.double), 895 delta=1) 896 self.assertAllClose(jacob_t, jacob_n) 897 898 @test_util.run_deprecated_v1 899 def testGradientWithEmptySegmentsAtEnd(self): 900 shape = [10, 4] 901 902 num_segments = 5 903 segment_indices = [0, 1, 2, 2] 904 num_indices = len(segment_indices) 905 for tf_op in [ 906 math_ops.sparse_segment_sum_with_num_segments, 907 math_ops.sparse_segment_mean_with_num_segments, 908 ]: 909 with self.cached_session(): 910 tf_indices, _, tf_x, np_x = self._sparse_input( 911 shape, num_indices, dtype=dtypes_lib.float64) 912 s = tf_op( 913 data=tf_x, 914 indices=tf_indices, 915 segment_ids=segment_indices, 916 num_segments=num_segments) 917 jacob_t, jacob_n = gradient_checker.compute_gradient( 918 tf_x, 919 shape, 920 s, [5, 4], 921 x_init_value=np_x.astype(np.double), 922 delta=1) 923 self.assertAllClose(jacob_t, jacob_n) 924 925 def testGradientExplicit(self): 926 # Note that the GPU implem has different paths for different inner sizes. 927 for inner_size in (1, 2, 3, 32): 928 with self.session(): 929 tf_ygrad, np_ygrad = self._input([3, inner_size], 930 dtype=dtypes_lib.float32) 931 segment_ids = [0, 1, 2, 2, 2] 932 indices = [8, 3, 0, 9, 3] 933 output_dim0 = 10 934 ops_list = [ 935 (math_ops.sparse_segment_sum_grad, "sum"), 936 (math_ops.sparse_segment_mean_grad, "mean"), 937 (math_ops.sparse_segment_sqrt_n_grad, "sqrtn"), 938 ] 939 for tf_op, mode in ops_list: 940 np_xgrad = self._sparseSegmentReduceGrad(np_ygrad, indices, 941 segment_ids, output_dim0, 942 mode) 943 tf_xgrad = tf_op(tf_ygrad, indices, segment_ids, output_dim0) 944 self.assertAllClose(tf_xgrad, np_xgrad) 945 946 def testGradientExplicitSingleOutput(self): 947 # The GPU implem has a special case when there is a single output. 948 for inner_size in (1, 2, 3, 32): 949 with self.session(): 950 tf_ygrad, np_ygrad = self._input([3, inner_size], 951 dtype=dtypes_lib.float32) 952 segment_ids = [0, 1, 2, 2, 2] 953 indices = [0, 0, 0, 0, 0] 954 output_dim0 = 1 955 ops_list = [ 956 (math_ops.sparse_segment_sum_grad, "sum"), 957 (math_ops.sparse_segment_mean_grad, "mean"), 958 (math_ops.sparse_segment_sqrt_n_grad, "sqrtn"), 959 ] 960 for tf_op, mode in ops_list: 961 np_xgrad = self._sparseSegmentReduceGrad(np_ygrad, indices, 962 segment_ids, output_dim0, 963 mode) 964 tf_xgrad = tf_op(tf_ygrad, indices, segment_ids, output_dim0) 965 self.assertAllClose(tf_xgrad, np_xgrad) 966 967 def testGradientValid(self): 968 # Baseline for the testGradient*Invalid* methods below. 969 tf_x, _ = self._input([3, 4], dtype=dtypes_lib.float32) 970 ops_list = [ 971 math_ops.sparse_segment_sum_grad, math_ops.sparse_segment_mean_grad, 972 math_ops.sparse_segment_sqrt_n_grad 973 ] 974 segment_indices = [0, 1, 2, 2] 975 tf_indices = [8, 3, 0, 9] 976 with self.session(use_gpu=False): 977 for tf_op in ops_list: 978 s = tf_op(tf_x, tf_indices, segment_indices, 10) 979 self.evaluate(s) 980 981 @test_util.run_deprecated_v1 982 def testGradientIndicesInvalid1(self): 983 tf_x, _ = self._input([3, 4], dtype=dtypes_lib.float32) 984 ops_list = [ 985 math_ops.sparse_segment_sum_grad, math_ops.sparse_segment_mean_grad, 986 math_ops.sparse_segment_sqrt_n_grad 987 ] 988 segment_indices = [0, 1, 2, 2] 989 tf_indices = [8, 3, 0, 10] 990 with self.session(use_gpu=False): 991 for tf_op in ops_list: 992 s = tf_op(tf_x, tf_indices, segment_indices, 10) 993 with self.assertRaisesOpError(r"Index 10 out of range \[0, 10\)"): 994 self.evaluate(s) 995 996 @test_util.run_deprecated_v1 997 def testGradientIndicesInvalid2(self): 998 tf_x, _ = self._input([3, 4], dtype=dtypes_lib.float32) 999 ops_list = [ 1000 math_ops.sparse_segment_sum_grad, math_ops.sparse_segment_mean_grad, 1001 math_ops.sparse_segment_sqrt_n_grad 1002 ] 1003 segment_indices = [0, 1, 2, 2] 1004 tf_indices = [8, 3, -1, 9] 1005 with self.session(use_gpu=False): 1006 for tf_op in ops_list: 1007 s = tf_op(tf_x, tf_indices, segment_indices, 10) 1008 with self.assertRaisesOpError(r"Index -1 out of range \[0, 10\)"): 1009 self.evaluate(s) 1010 1011 @test_util.run_deprecated_v1 1012 def testGradientSegmentsInvalid1(self): 1013 tf_x, _ = self._input( 1014 [3, 4], dtype=dtypes_lib.float32) # expecting 3 segments 1015 ops_list = [ 1016 math_ops.sparse_segment_sum_grad, math_ops.sparse_segment_mean_grad, 1017 math_ops.sparse_segment_sqrt_n_grad 1018 ] 1019 segment_indices = [0, 1, 1, 4] # 5 segments 1020 tf_indices = [8, 3, 0, 9] 1021 with self.session(use_gpu=False): 1022 for tf_op in ops_list: 1023 s = tf_op(tf_x, tf_indices, segment_indices, 10) 1024 with self.assertRaisesOpError("Invalid number of segments"): 1025 self.evaluate(s) 1026 1027 @test_util.run_deprecated_v1 1028 def testGradientSegmentsInvalid2(self): 1029 tf_x, _ = self._input([1, 4], dtype=dtypes_lib.float32) 1030 ops_list = [ 1031 math_ops.sparse_segment_sum_grad, math_ops.sparse_segment_mean_grad, 1032 math_ops.sparse_segment_sqrt_n_grad 1033 ] 1034 segment_indices = [0, 1, 2, 0] 1035 tf_indices = [8, 3, 0, 9] 1036 with self.session(use_gpu=False): 1037 for tf_op in ops_list: 1038 s = tf_op(tf_x, tf_indices, segment_indices, 10) 1039 with self.assertRaisesOpError(r"Segment id 1 out of range \[0, 1\)"): 1040 self.evaluate(s) 1041 1042 @test_util.run_deprecated_v1 1043 def testGradientSegmentsInvalid3(self): 1044 tf_x, _ = self._input([2, 4], dtype=dtypes_lib.float32) 1045 ops_list = [ 1046 math_ops.sparse_segment_sum_grad, math_ops.sparse_segment_mean_grad, 1047 math_ops.sparse_segment_sqrt_n_grad 1048 ] 1049 segment_indices = [-1, 0, 1, 1] 1050 tf_indices = [8, 3, 0, 9] 1051 with self.session(use_gpu=False): 1052 for tf_op in ops_list: 1053 s = tf_op(tf_x, tf_indices, segment_indices, 10) 1054 with self.assertRaisesOpError(r"Segment id -1 out of range \[0, 2\)"): 1055 self.evaluate(s) 1056 1057 @test_util.run_deprecated_v1 1058 def testGradientSegmentsInvalid4(self): 1059 tf_x, _ = self._input([0, 4], dtype=dtypes_lib.float32) 1060 ops_list = [ 1061 math_ops.sparse_segment_sum_grad, math_ops.sparse_segment_mean_grad, 1062 math_ops.sparse_segment_sqrt_n_grad 1063 ] 1064 segment_indices = [0, 1, 2, -1] 1065 tf_indices = [8, 3, 0, 9] 1066 with self.session(use_gpu=False): 1067 for tf_op in ops_list: 1068 s = tf_op(tf_x, tf_indices, segment_indices, 10) 1069 with self.assertRaisesOpError(r"Segment id 0 out of range \[0, 0\)"): 1070 self.evaluate(s) 1071 1072 1073class SegmentReductionOpBenchmark(test.Benchmark): 1074 outer_dim_options = [2**x for x in range(9, 14, 2)] 1075 ratio_options = [2**x for x in range(1, 6, 2)] 1076 inner_dim_options = [2**x for x in range(9, 14, 2)] 1077 # randomly generated sizes with less alignments 1078 inner_dim_options += [ 1079 1120, 1215, 1856, 1302, 1329, 1531, 1313, 1672, 1851, 1584 1080 ] 1081 dtype_options = [np.float32, np.float64] 1082 options = (outer_dim_options, ratio_options, inner_dim_options, dtype_options) 1083 # pylint: disable=g-long-lambda 1084 op_functors = [lambda vc, vs, seg_ids: 1085 ("sorted", math_ops.segment_sum(vc, vs)), 1086 lambda vc, vs, seg_ids: 1087 ("unsorted", 1088 math_ops.unsorted_segment_sum(vc, vs, seg_ids[-1]+1))] 1089 # pylint: enable=g-long-lambda 1090 repeat = 10 1091 1092 def _npTypeToStr(self, t): 1093 if t == np.float32: 1094 return "fp32" 1095 if t == np.float64: 1096 return "fp64" 1097 1098 def _runGraph(self, op_functor, outer_dim, ratio, inner_dim, dtype): 1099 output_outer_dim = int(outer_dim / ratio) 1100 const = np.random.randint(5, size=(outer_dim, inner_dim)) 1101 seg_ids = np.sort(np.random.randint(output_outer_dim, size=outer_dim)) 1102 vs = variables.Variable(seg_ids.astype(np.int32)) 1103 with ops.device("/gpu:0"): 1104 vc = variables.Variable(const.astype(dtype)) 1105 name, op = op_functor(vc, vs, seg_ids) 1106 with session.Session() as sess: 1107 self.evaluate(variables.global_variables_initializer()) 1108 r = self.run_op_benchmark( 1109 sess, 1110 op, 1111 min_iters=self.repeat, 1112 name="_".join( 1113 map(str, 1114 [name, outer_dim, ratio, inner_dim, 1115 self._npTypeToStr(dtype)]))) 1116 return name, r["wall_time"] 1117 1118 def benchmarkSegmentSumGPU(self): 1119 if not test.is_gpu_available(cuda_only=True): 1120 return 1121 for outer_dim, ratio, inner_dim, dtype in itertools.product(*self.options): 1122 op_functor = self.op_functors[0] 1123 with ops.Graph().as_default(): 1124 self._runGraph(op_functor, outer_dim, ratio, inner_dim, dtype) 1125 1126 def benchmarkUnsortedSegmentSumGPU(self): 1127 if not test.is_gpu_available(cuda_only=True): 1128 return 1129 for outer_dim, ratio, inner_dim, dtype in itertools.product(*self.options): 1130 op_functor = self.op_functors[1] 1131 with ops.Graph().as_default(): 1132 self._runGraph(op_functor, outer_dim, ratio, inner_dim, dtype) 1133 1134 1135if __name__ == "__main__": 1136 test.main() 1137