• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functional tests for segment reduction ops."""
16
17import itertools
18
19import numpy as np
20
21from tensorflow.python.client import session
22from tensorflow.python.framework import constant_op
23from tensorflow.python.framework import dtypes as dtypes_lib
24from tensorflow.python.framework import errors_impl
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import test_util
27from tensorflow.python.ops import gradient_checker
28from tensorflow.python.ops import gradient_checker_v2
29from tensorflow.python.ops import math_ops
30from tensorflow.python.ops import variables
31from tensorflow.python.platform import test
32
33
34class SegmentReductionHelper(test.TestCase):
35
36  def _input(self, input_shape, dtype=dtypes_lib.int32):
37    num_elem = 1
38    for x in input_shape:
39      num_elem *= x
40    values = np.arange(1, num_elem + 1)
41    np_values = values.reshape(input_shape).astype(dtype.as_numpy_dtype)
42    # Add a non-zero imaginary component to complex types.
43    if dtype.is_complex:
44      np_values -= 1j * np_values
45    return constant_op.constant(
46        np_values, shape=input_shape, dtype=dtype), np_values
47
48  def _segmentReduce(self, indices, x, op1, op2=None, num_segments=None,
49                     initial_value=0):
50    if not x.size:
51      return np.array([])
52    indices = np.asarray(indices)
53    if num_segments is None:
54      num_segments = indices[-1] + 1
55    output = [None] * num_segments
56    slice_shape = x.shape[indices.ndim:]
57    x_flat = x.reshape((indices.size,) + slice_shape)
58    for i, index in enumerate(indices.ravel()):
59      if (output[index] is not None) and op1 == np.max:
60        for j in range(0, output[index].shape[0]):
61          output[index][j] = op1([output[index][j], x_flat[i][j]])
62      elif output[index] is not None:
63        output[index] = op1(output[index], x_flat[i])
64      else:
65        output[index] = x_flat[i]
66    # zero initialize values that are still uncalculated.
67    initial_value_slice = np.ones(slice_shape) * initial_value
68    output = [o if o is not None else initial_value_slice for o in output]
69    if op2 is not None:
70      output = [op2(o) for o in output]
71    output = [o.reshape(slice_shape) for o in output]
72    return np.array(output)
73
74  def _mean_cum_op(self, x, y):
75    return (x[0] + y, x[1] + 1) if isinstance(x, tuple) else (x + y, 2)
76
77  def _mean_reduce_op(self, x):
78    return x[0] / x[1] if isinstance(x, tuple) else x
79
80  def _sqrt_n_reduce_op(self, x):
81    return x[0] / np.sqrt(x[1]) if isinstance(x, tuple) else x
82
83
84class SegmentReductionOpTest(SegmentReductionHelper):
85
86  def testValues(self):
87    dtypes = [
88        dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int64,
89        dtypes_lib.int32, dtypes_lib.complex64, dtypes_lib.complex128
90    ]
91
92    # Each item is np_op1, np_op2, tf_op
93    ops_list = [(np.add, None, math_ops.segment_sum),
94                (self._mean_cum_op, self._mean_reduce_op,
95                 math_ops.segment_mean),
96                (np.ndarray.__mul__, None, math_ops.segment_prod),
97                (np.minimum, None, math_ops.segment_min),
98                (np.maximum, None, math_ops.segment_max)]
99
100    # A subset of ops has been enabled for complex numbers
101    complex_ops_list = [(np.add, None, math_ops.segment_sum),
102                        (np.ndarray.__mul__, None, math_ops.segment_prod),
103                        (self._mean_cum_op, self._mean_reduce_op,
104                         math_ops.segment_mean)]
105
106    n = 10
107    # Note that the GPU implem has different paths for different inner sizes.
108    for shape in [[n, 1], [n, 2], [n, 3], [n, 32]]:
109      indices = [i // 3 for i in range(n)]
110      for dtype in dtypes:
111        if dtype in (dtypes_lib.complex64, dtypes_lib.complex128):
112          curr_ops_list = complex_ops_list
113        else:
114          curr_ops_list = ops_list
115        for use_gpu in [True, False]:
116          with self.cached_session(use_gpu=use_gpu):
117            tf_x, np_x = self._input(shape, dtype=dtype)
118            for np_op1, np_op2, tf_op in curr_ops_list:
119              initial_value = 1 if tf_op is math_ops.segment_prod else 0
120              np_ans = self._segmentReduce(
121                  indices, np_x, np_op1, np_op2, initial_value=initial_value)
122              s = tf_op(data=tf_x, segment_ids=indices)
123              tf_ans = self.evaluate(s)
124              self.assertAllClose(np_ans, tf_ans)
125              # NOTE(mrry): The static shape inference that computes
126              # `tf_ans.shape` can only infer that sizes from dimension 1
127              # onwards, because the size of dimension 0 is data-dependent
128              # and may therefore vary dynamically.
129              self.assertAllEqual(np_ans.shape[1:], tf_ans.shape[1:])
130
131  @test_util.run_deprecated_v1
132  def testSegmentIdsShape(self):
133    shape = [4, 4]
134    tf_x, _ = self._input(shape)
135    indices = constant_op.constant([0, 1, 2, 2], shape=[2, 2])
136    with self.assertRaises(ValueError):
137      math_ops.segment_sum(data=tf_x, segment_ids=indices)
138
139  @test_util.run_deprecated_v1
140  def testSegmentIdsSize(self):
141    shape = [4, 4]
142    for use_gpu in [True, False]:
143      with self.cached_session(use_gpu=use_gpu):
144        tf_x, _ = self._input(shape)
145        indices = [0, 1]
146        s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
147        with self.assertRaisesOpError("segment_ids should be the same size"):
148          self.evaluate(s)
149
150  @test_util.run_deprecated_v1
151  def testSegmentIdsValid(self):
152    # This is a baseline for the following SegmentIdsInvalid* tests.
153    shape = [4, 4]
154    for use_gpu in [True, False]:
155      with self.cached_session(use_gpu=use_gpu):
156        tf_x, _ = self._input(shape, dtype=dtypes_lib.float32)
157        indices = [0, 0, 0, 1]
158        result = math_ops.segment_sum(data=tf_x, segment_ids=indices).eval()
159        self.assertAllEqual([[15, 18, 21, 24], [13, 14, 15, 16]], result)
160
161  def testSegmentIdsGreaterThanZero(self):
162    shape = [4, 4]
163    for use_gpu in [True, False]:
164      with self.cached_session(use_gpu=use_gpu):
165        tf_x, np_x = self._input(shape, dtype=dtypes_lib.float32)
166        indices = [1, 1, 2, 2]
167        np_ans = self._segmentReduce(indices, np_x, np.add)
168        s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
169        tf_ans = self.evaluate(s)
170        self.assertAllClose(np_ans, tf_ans)
171
172  def testSegmentIdsHole(self):
173    shape = [4, 4]
174    for use_gpu in [True, False]:
175      with self.cached_session(use_gpu=use_gpu):
176        tf_x, np_x = self._input(shape, dtype=dtypes_lib.float32)
177        indices = [0, 0, 3, 3]
178        np_ans = self._segmentReduce(indices, np_x, np.add)
179        s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
180        tf_ans = self.evaluate(s)
181        self.assertAllClose(np_ans, tf_ans)
182
183  @test_util.run_deprecated_v1
184  def testSegmentIdsInvalid1(self):
185    shape = [4, 4]
186    with self.cached_session():
187      tf_x, _ = self._input(shape)
188      indices = [-1, -1, 0, 0]
189      s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
190      with self.assertRaisesOpError(
191          r"Segment id -1 out of range \[0, 1\), possibly because "
192          "'segment_ids' input is not sorted."):
193        self.evaluate(s)
194
195  @test_util.run_deprecated_v1
196  def testSegmentIdsInvalid2(self):
197    shape = [4, 4]
198    with self.cached_session():
199      tf_x, _ = self._input(shape)
200      indices = [0, 1, 0, 1]
201      s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
202      with self.assertRaisesOpError("segment ids are not increasing"):
203        self.evaluate(s)
204
205  @test_util.run_deprecated_v1
206  def testSegmentIdsInvalid3(self):
207    shape = [4, 4]
208    with self.cached_session():
209      tf_x, _ = self._input(shape)
210      indices = [0, 1, 2, 0]
211      s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
212      with self.assertRaisesOpError(
213          r"Segment id 1 out of range \[0, 1\), possibly "
214          "because 'segment_ids' input is not sorted."):
215        self.evaluate(s)
216
217  @test_util.run_deprecated_v1
218  def testSegmentIdsInvalid4(self):
219    shape = [4, 4]
220    for use_gpu in [True, False]:
221      with self.cached_session(use_gpu=use_gpu):
222        tf_x, _ = self._input(shape, dtype=dtypes_lib.float32)
223        indices = [0, 0, 0, -1]
224        s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
225        with self.assertRaisesOpError("segment ids must be >= 0"):
226          self.evaluate(s)
227
228  @test_util.run_deprecated_v1
229  def testSegmentIdsInvalid5(self):
230    shape = [4, 4]
231    for use_gpu in [True, False]:
232      with self.cached_session(use_gpu=use_gpu):
233        tf_x, _ = self._input(shape, dtype=dtypes_lib.float32)
234        indices = [0, 0, 0, -2]
235        s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
236        with self.assertRaisesOpError("segment ids must be >= 0"):
237          self.evaluate(s)
238
239  @test_util.run_deprecated_v1
240  def testGradient(self):
241    shape = [4, 4]
242    indices = [0, 1, 2, 2]
243    for tf_op in [
244        math_ops.segment_sum, math_ops.segment_mean, math_ops.segment_min,
245        math_ops.segment_max
246    ]:
247      with self.cached_session():
248        tf_x, np_x = self._input(shape, dtype=dtypes_lib.float64)
249        s = tf_op(data=tf_x, segment_ids=indices)
250        jacob_t, jacob_n = gradient_checker.compute_gradient(
251            tf_x,
252            shape,
253            s, [3, 4],
254            x_init_value=np_x.astype(np.double),
255            delta=1)
256      self.assertAllClose(jacob_t, jacob_n)
257
258  def testDataInvalid(self):
259    # Test case for GitHub issue 40653.
260    for use_gpu in [True, False]:
261      with self.cached_session(use_gpu=use_gpu):
262        with self.assertRaisesRegex(
263            (ValueError, errors_impl.InvalidArgumentError),
264            "must be at least rank 1"):
265          s = math_ops.segment_mean(
266              data=np.uint16(10), segment_ids=np.array([]).astype("int64"))
267          self.evaluate(s)
268
269  def testInvalidIds(self):
270    # Test case for GitHub issue 46888.
271    for op in [
272        math_ops.segment_max,
273        math_ops.segment_min,
274        math_ops.segment_mean,
275        math_ops.segment_sum,
276        math_ops.segment_prod,
277    ]:
278      with self.cached_session(use_gpu=False):
279        with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
280          s = op(data=np.ones((1, 10, 1)), segment_ids=[1676240524292489355])
281          self.evaluate(s)
282
283
284class UnsortedSegmentTest(SegmentReductionHelper):
285
286  def __init__(self, methodName='runTest'):
287    # Each item is np_op1, np_op2, tf_op, initial_value functor
288    self.ops_list = [(np.add, None,
289                      math_ops.unsorted_segment_sum, lambda t: 0),
290                     (self._mean_cum_op, self._mean_reduce_op,
291                      math_ops.unsorted_segment_mean, lambda t: 0),
292                     (self._mean_cum_op, self._sqrt_n_reduce_op,
293                      math_ops.unsorted_segment_sqrt_n, lambda t: 0),
294                     (np.ndarray.__mul__, None,
295                      math_ops.unsorted_segment_prod, lambda t: 1),
296                     (np.minimum, None,
297                      math_ops.unsorted_segment_min, lambda t: t.max),
298                     (np.maximum, None,
299                      math_ops.unsorted_segment_max, lambda t: t.min)]
300
301    # A subset of ops has been enabled for complex numbers
302    self.complex_ops_list = [(np.add, None,
303                              math_ops.unsorted_segment_sum, lambda t: 0),
304                             (np.ndarray.__mul__, None,
305                              math_ops.unsorted_segment_prod, lambda t: 1)]
306    self.differentiable_dtypes = [dtypes_lib.float16, dtypes_lib.float32,
307                                  dtypes_lib.float64]
308    self.all_dtypes = (self.differentiable_dtypes +
309                       [dtypes_lib.bfloat16,
310                        dtypes_lib.int64, dtypes_lib.int32,
311                        dtypes_lib.complex64, dtypes_lib.complex128])
312    super(UnsortedSegmentTest, self).__init__(methodName=methodName)
313
314  def testValues(self):
315    indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3])
316    num_segments = 12
317    for indices in indices_flat, indices_flat.reshape(5, 2):
318      # Note that the GPU implem has different paths for different inner sizes.
319      for inner_size in [1, 2, 3, 32]:
320        shape = indices.shape + (inner_size,)
321        for dtype in self.all_dtypes:
322          ops_list = (
323              self.complex_ops_list if dtype.is_complex else self.ops_list)
324          tf_x, np_x = self._input(shape, dtype=dtype)
325          for use_gpu in [True, False]:
326            with self.cached_session():
327              for np_op1, np_op2, tf_op, init_op in ops_list:
328                # sqrt_n doesn't support integers
329                if (np_op2 == self._sqrt_n_reduce_op and dtype.is_integer):
330                  continue
331                # todo(philjd): enable this test once real_div supports bfloat16
332                if (np_op2 in [self._sqrt_n_reduce_op, self._mean_reduce_op] and
333                    dtype == dtypes_lib.bfloat16):
334                  continue
335                np_ans = self._segmentReduce(
336                    indices,
337                    np_x,
338                    np_op1,
339                    np_op2,
340                    num_segments=num_segments,
341                    initial_value=init_op(dtype))
342                s = tf_op(tf_x, segment_ids=indices, num_segments=num_segments)
343                tf_ans = self.evaluate(s)
344                if dtype is dtypes_lib.bfloat16:
345                  tf_ans = tf_ans.astype(np.float32)
346                self.assertAllCloseAccordingToType(np_ans, tf_ans)
347                self.assertShapeEqual(np_ans, s)
348
349  def testNumSegmentsTypes(self):
350    dtypes = [dtypes_lib.int32, dtypes_lib.int64]
351    indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3])
352    num_segments = 12
353    for indices in indices_flat, indices_flat.reshape(5, 2):
354      shape = indices.shape + (2,)
355      for dtype in dtypes:
356        with self.cached_session():
357          tf_x, np_x = self._input(shape)
358          num_segments_constant = constant_op.constant(
359              num_segments, dtype=dtype)
360          np_ans = self._segmentReduce(
361              indices, np_x, np.add, op2=None, num_segments=num_segments)
362          s = math_ops.unsorted_segment_sum(
363              data=tf_x,
364              segment_ids=indices,
365              num_segments=num_segments_constant)
366          tf_ans = self.evaluate(s)
367        self.assertAllClose(np_ans, tf_ans)
368        self.assertShapeEqual(np_ans, s)
369
370  @test_util.run_deprecated_v1
371  def testGradientsTFGradients(self):
372    num_cols = 2
373    indices_flat = np.array([0, 4, 0, -1, 3, -1, 4, 7, 7, 3])
374    num_segments = max(indices_flat) + 3
375    for dtype in self.differentiable_dtypes:
376      ops_list = self.complex_ops_list if dtype.is_complex else self.ops_list
377      for indices in indices_flat, indices_flat.reshape(5, 2):
378        shape = indices.shape + (num_cols,)
379        # test CPU and GPU as tf.gather behaves differently on each device
380        for use_gpu in [False, True]:
381          with self.cached_session(use_gpu=use_gpu):
382            for _, _, tf_op, _ in ops_list:
383              tf_x, np_x = self._input(shape, dtype=dtype)
384              s = tf_op(tf_x, indices, num_segments)
385              jacob_t, jacob_n = gradient_checker.compute_gradient(
386                  tf_x,
387                  shape,
388                  s, [num_segments, num_cols],
389                  x_init_value=np_x,
390                  delta=1.)
391              self.assertAllCloseAccordingToType(jacob_t, jacob_n,
392                                                 half_atol=1e-2)
393
394  @test_util.run_in_graph_and_eager_modes
395  def testGradientsGradientTape(self):
396    num_cols = 2
397    indices_flat = np.array([0, 4, 0, -1, 3, -1, 4, 7, 7, 3])
398    num_segments = max(indices_flat) + 3
399    for dtype in self.differentiable_dtypes:
400      ops_list = self.complex_ops_list if dtype.is_complex else self.ops_list
401      for indices in indices_flat, indices_flat.reshape(5, 2):
402        shape = indices.shape + (num_cols,)
403        # test CPU and GPU as tf.gather behaves differently on each device
404        for use_gpu in [test_util.use_gpu, test_util.force_cpu]:
405          with use_gpu():
406            for _, _, tf_op, _ in ops_list:
407              _, np_x = self._input(shape, dtype=dtype)
408              # pylint: disable=cell-var-from-loop
409              def f(x):
410                return tf_op(x, indices, num_segments)
411              gradient_tape_jacob_t, jacob_n = (
412                  gradient_checker_v2.compute_gradient(
413                      f, [np_x], delta=1.))
414              # pylint: enable=cell-var-from-loop
415              self.assertAllCloseAccordingToType(jacob_n, gradient_tape_jacob_t,
416                                                 half_atol=1e-2)
417
418  @test_util.run_deprecated_v1
419  def testProdGrad(self):
420    # additional test for the prod gradient to ensure correct handling of zeros
421    values = np.array([0, 0, 1, 0, 2, 2, 3, 3, 3], dtype=np.float32)
422    indices = np.array([0, 0, 0, 1, 1, 1, 2, 2, 2], dtype=np.int32)
423    indices_neg = np.array([-1, 0, 0, -1, 1, 1, -1, 2, 2], dtype=np.int32)
424    values_tf = constant_op.constant(values)
425    # ground truth partial derivatives
426    gradients_indices = np.zeros((9, 3), dtype=np.float32)
427    gradients_indices_neg = np.zeros((9, 3), dtype=np.float32)
428    # the derivative w.r.t. to the other segments is zero, so here we only
429    # explicitly set the grad values for the corresponding segment
430    gradients_indices[range(9), indices] = [0, 0, 0, 4, 0, 0, 9, 9, 9]
431    gradients_indices_neg[range(9), indices_neg] = [0, 1, 0, 0, 2, 2, 0, 3, 3]
432    for use_gpu in [False, True]:
433      with self.cached_session(use_gpu=use_gpu):
434        for ind, grad_gt in [(indices, gradients_indices),
435                             (indices_neg, gradients_indices_neg)]:
436          s = math_ops.unsorted_segment_prod(values_tf,
437                                             constant_op.constant(ind), 3)
438          jacob_t, jacob_n = gradient_checker.compute_gradient(
439              values_tf, (9,), s, (3,), x_init_value=values, delta=1)
440          self.assertAllClose(jacob_t, jacob_n)
441          self.assertAllClose(jacob_t, grad_gt)
442
443  @test_util.run_deprecated_v1
444  def testGradientMatchesSegmentSum(self):
445    # Strategy: compute the gradient for UnsortedSegmentSum and SegmentSum
446    # and compare the outputs, which should be identical.
447    # NB: for this test to work, indices must be valid for SegmentSum, namely
448    # it must be sorted, the indices must be contiguous, and num_segments
449    # must be max(indices) + 1.
450    indices = [0, 0, 1, 1, 1, 2, 3, 4, 5]
451    n = len(indices)
452    num_cols = 2
453    shape = [n, num_cols]
454    num_segments = max(indices) + 1
455    for dtype in self.differentiable_dtypes:
456      with self.cached_session():
457        tf_x, np_x = self._input(shape, dtype=dtype)
458        # Results from UnsortedSegmentSum
459        unsorted_s = math_ops.unsorted_segment_sum(
460            data=tf_x, segment_ids=indices, num_segments=num_segments)
461        unsorted_jacob_t, unsorted_jacob_n = (
462            gradient_checker.compute_gradient(tf_x, shape, unsorted_s,
463                                              [num_segments, num_cols],
464                                              x_init_value=np_x, delta=1))
465
466        # Results from SegmentSum
467        sorted_s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
468        sorted_jacob_t, sorted_jacob_n = gradient_checker.compute_gradient(
469            tf_x,
470            shape,
471            sorted_s, [num_segments, num_cols],
472            x_init_value=np_x,
473            delta=1)
474      self.assertAllClose(unsorted_jacob_t, sorted_jacob_t)
475      self.assertAllClose(unsorted_jacob_n, sorted_jacob_n)
476
477  @test_util.run_deprecated_v1
478  def testBadIndices(self):
479    # Note: GPU kernel does not return the out-of-range error needed for this
480    # test, so this test is marked as cpu-only.
481    # Note: With PR #13055 a negative index will be ignored silently.
482    with self.session(use_gpu=False):
483      for bad in [[2]], [[7]]:
484        unsorted = math_ops.unsorted_segment_sum([[17]], bad, num_segments=2)
485        with self.assertRaisesOpError(
486            r"segment_ids\[0,0\] = %d is out of range \[0, 2\)" % bad[0][0]):
487          self.evaluate(unsorted)
488
489  @test_util.run_deprecated_v1
490  def testEmptySecondDimension(self):
491    dtypes = [np.float16, np.float32, np.float64, np.int64, np.int32,
492              np.complex64, np.complex128]
493    with self.session():
494      for dtype in dtypes:
495        for itype in (np.int32, np.int64):
496          data = np.zeros((2, 0), dtype=dtype)
497          segment_ids = np.array([0, 1], dtype=itype)
498          unsorted = math_ops.unsorted_segment_sum(data, segment_ids, 2)
499          self.assertAllEqual(unsorted, np.zeros((2, 0), dtype=dtype))
500
501  def testDropNegatives(self):
502    # Note: the test is done by replacing segment_ids with 8 to -1
503    # for index  and replace values generated by numpy with 0.
504    indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3])
505    num_segments = 12
506    for indices in indices_flat, indices_flat.reshape(5, 2):
507      shape = indices.shape + (2,)
508      for dtype in self.all_dtypes:
509        with self.session():
510          tf_x, np_x = self._input(shape, dtype=dtype)
511          np_ans = self._segmentReduce(
512              indices, np_x, np.add, op2=None, num_segments=num_segments)
513          # Replace np_ans[8] with 0 for the value
514          np_ans[8:] = 0
515          # Replace 8 with -1 in indices
516          np.place(indices, indices == 8, [-1])
517          s = math_ops.unsorted_segment_sum(
518              data=tf_x, segment_ids=indices, num_segments=num_segments)
519          tf_ans = self.evaluate(s)
520        self.assertAllClose(np_ans, tf_ans)
521        self.assertShapeEqual(np_ans, s)
522
523  @test_util.run_deprecated_v1
524  def testAllNegatives(self):
525    with self.session(use_gpu=False):
526      data = np.ones((2, 1), dtype=np.float32)
527      segment_ids = np.array([-1, -1], dtype=np.int32)
528      unsorted = math_ops.unsorted_segment_sum(data, segment_ids, 2)
529      self.assertAllClose(unsorted.eval(), np.zeros((2, 1), dtype=np.float32))
530
531
532class SparseSegmentReductionHelper(SegmentReductionHelper):
533
534  def _sparse_input(self, input_shape, num_indices, dtype=dtypes_lib.int32):
535    a, b = super(SparseSegmentReductionHelper, self)._input(input_shape, dtype)
536    indices = np.random.randint(0, input_shape[0], num_indices).astype(np.int32)
537    return (constant_op.constant(
538        indices, dtype=dtypes_lib.int32), indices, a, b)
539
540  def _sparseSegmentReduce(self,
541                           x,
542                           indices,
543                           segment_indices,
544                           op1,
545                           op2=None,
546                           num_segments=None):
547    return self._segmentReduce(
548        segment_indices, x[indices], op1, op2, num_segments=num_segments)
549
550  def _sparseSegmentReduceGrad(self, ygrad, indices, segment_ids, output_dim0,
551                               mode):
552    assert mode in ("sum", "mean", "sqrtn")
553    if mode != "sum":
554      weights = np.zeros(ygrad.shape[0], ygrad.dtype)
555      for segment in segment_ids:
556        weights[segment] += 1
557      weights = 1. / weights if mode == "mean" else 1. / np.sqrt(weights)
558    xgrad = np.zeros([output_dim0, ygrad.shape[1]], ygrad.dtype)
559    for segment, index in zip(segment_ids, indices):
560      if mode == "sum":
561        xgrad[index] += ygrad[segment]
562      else:
563        xgrad[index] += ygrad[segment] * weights[segment]
564    return xgrad
565
566
567class SparseSegmentReductionOpTest(SparseSegmentReductionHelper):
568
569  def testValues(self):
570    dtypes = [
571        dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int64,
572        dtypes_lib.int32
573    ]
574
575    index_dtypes = [dtypes_lib.int32, dtypes_lib.int64]
576    segment_ids_dtypes = [dtypes_lib.int32, dtypes_lib.int64]
577
578    mean_dtypes = [dtypes_lib.float32, dtypes_lib.float64]
579
580    # Each item is np_op1, np_op2, tf_op
581    ops_list = [(np.add, None, math_ops.sparse_segment_sum),
582                (self._mean_cum_op, self._mean_reduce_op,
583                 math_ops.sparse_segment_mean)]
584
585    n = 400
586    # Note that the GPU implem has different paths for different inner sizes.
587    for inner_size in [1, 2, 3, 32]:
588      shape = [n, inner_size]
589      segment_indices = []
590      for i in range(20):
591        for _ in range(i + 1):
592          segment_indices.append(i)
593      num_indices = len(segment_indices)
594      for dtype in dtypes:
595        for index_dtype in index_dtypes:
596          for segment_ids_dtype in segment_ids_dtypes:
597            with self.cached_session():
598              tf_indices, np_indices, tf_x, np_x = self._sparse_input(
599                  shape, num_indices, dtype=dtype)
600              for np_op1, np_op2, tf_op in ops_list:
601                if (tf_op == math_ops.sparse_segment_mean and
602                    dtype not in mean_dtypes):
603                  continue
604                np_ans = self._sparseSegmentReduce(np_x, np_indices,
605                                                   segment_indices, np_op1,
606                                                   np_op2)
607                s = tf_op(
608                    data=tf_x,
609                    indices=math_ops.cast(tf_indices, index_dtype),
610                    segment_ids=math_ops.cast(segment_indices,
611                                              segment_ids_dtype))
612                tf_ans = self.evaluate(s)
613                self.assertAllClose(np_ans, tf_ans)
614                # NOTE(mrry): The static shape inference that computes
615                # `tf_ans.shape` can only infer that sizes from dimension 1
616                # onwards, because the size of dimension 0 is data-dependent
617                # and may therefore vary dynamically.
618                self.assertAllEqual(np_ans.shape[1:], tf_ans.shape[1:])
619
620  def testSegmentIdsHole(self):
621    tf_x, np_x = self._input([10, 4], dtype=dtypes_lib.float32)
622    ops_list = [(np.add, None, math_ops.sparse_segment_sum), (
623        self._mean_cum_op, self._mean_reduce_op, math_ops.sparse_segment_mean)]
624    segment_indices = [0, 2, 2, 2]
625    tf_indices = [8, 3, 0, 9]
626    with self.session():
627      for np_op1, np_op2, tf_op in ops_list:
628        np_ans = self._sparseSegmentReduce(np_x, tf_indices, segment_indices,
629                                           np_op1, np_op2)
630        s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices)
631        tf_ans = self.evaluate(s)
632        self.assertAllClose(np_ans, tf_ans)
633
634  def testWithNumSegments(self):
635    tf_x, np_x = self._input([10, 4], dtype=dtypes_lib.float32)
636    ops_list = [(np.add, None, math_ops.sparse_segment_sum_with_num_segments),
637                (self._mean_cum_op, self._mean_reduce_op,
638                 math_ops.sparse_segment_mean_with_num_segments)]
639    segment_indices = [0, 2, 2, 2]
640    tf_indices = [8, 3, 0, 9]
641    num_segments = 5
642    with self.session():
643      for np_op1, np_op2, tf_op in ops_list:
644        np_ans = self._sparseSegmentReduce(
645            np_x,
646            tf_indices,
647            segment_indices,
648            np_op1,
649            np_op2,
650            num_segments=num_segments)
651        s = tf_op(
652            data=tf_x,
653            indices=tf_indices,
654            segment_ids=segment_indices,
655            num_segments=num_segments)
656        tf_ans = self.evaluate(s)
657        self.assertAllClose(np_ans, tf_ans)
658
659  def testWithEmptySegments(self):
660    tf_x = constant_op.constant([], shape=[0, 4], dtype=dtypes_lib.float32)
661    ops_list = [
662        math_ops.sparse_segment_sum_with_num_segments,
663        math_ops.sparse_segment_mean_with_num_segments
664    ]
665    segment_indices = []
666    tf_indices = []
667    num_segments = 5
668    with self.session():
669      for tf_op in ops_list:
670        s = tf_op(
671            data=tf_x,
672            indices=tf_indices,
673            segment_ids=segment_indices,
674            num_segments=num_segments)
675        tf_ans = self.evaluate(s)
676        self.assertAllClose(np.zeros([5, 4]), tf_ans)
677
678  @test_util.run_in_graph_and_eager_modes
679  def testSegmentScalarIdiRaisesInvalidArgumentError(self):
680    """Test for github #46897."""
681    ops_list = [
682        math_ops.sparse_segment_sum,
683        math_ops.sparse_segment_mean,
684        math_ops.sparse_segment_sqrt_n,
685    ]
686    for op in ops_list:
687      with self.assertRaisesRegex(
688          (ValueError, errors_impl.InvalidArgumentError),
689          "Shape must be at least rank 1"):
690        op(data=1.0, indices=[0], segment_ids=[3])
691
692  def testSegmentIdsGreaterThanZero(self):
693    tf_x, np_x = self._input([10, 4], dtype=dtypes_lib.float32)
694    ops_list = [(np.add, None, math_ops.sparse_segment_sum), (
695        self._mean_cum_op, self._mean_reduce_op, math_ops.sparse_segment_mean)]
696    segment_indices = [1, 2, 2, 2]
697    tf_indices = [8, 3, 0, 9]
698    with self.session():
699      for np_op1, np_op2, tf_op in ops_list:
700        np_ans = self._sparseSegmentReduce(np_x, tf_indices, segment_indices,
701                                           np_op1, np_op2)
702        s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices)
703        tf_ans = self.evaluate(s)
704        self.assertAllClose(np_ans, tf_ans)
705
706  def testValid(self):
707    # Baseline for the test*Invalid* methods below.
708    tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32)
709    ops_list = [math_ops.sparse_segment_sum, math_ops.sparse_segment_mean]
710    segment_indices = [0, 1, 2, 2]
711    tf_indices = [8, 3, 0, 9]
712    with self.session():
713      for tf_op in ops_list:
714        s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices)
715        self.evaluate(s)
716
717  @test_util.run_deprecated_v1
718  def testIndicesInvalid1(self):
719    tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32)
720    ops_list = [math_ops.sparse_segment_sum, math_ops.sparse_segment_mean]
721    segment_indices = [0, 1, 2, 2]
722    tf_indices = [8, -1, 0, 9]
723    with self.session(use_gpu=False):
724      for tf_op in ops_list:
725        s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices)
726        with self.assertRaisesOpError(
727            r"indices\[1\] == -1 out of range \[0, 10\)"):
728          self.evaluate(s)
729
730  @test_util.run_deprecated_v1
731  def testIndicesInvalid2(self):
732    tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32)
733    ops_list = [math_ops.sparse_segment_sum, math_ops.sparse_segment_mean]
734    segment_indices = [0, 1, 2, 2]
735    tf_indices = [8, 3, 0, 10]
736    with self.session(use_gpu=False):
737      for tf_op in ops_list:
738        s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices)
739        with self.assertRaisesOpError(
740            r"indices\[3\] == 10 out of range \[0, 10\)"):
741          self.evaluate(s)
742
743  @test_util.run_deprecated_v1
744  def testSegmentsInvalid2(self):
745    tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32)
746    ops_list = [math_ops.sparse_segment_sum, math_ops.sparse_segment_mean]
747    segment_indices = [0, 1, 0, 1]
748    tf_indices = [8, 3, 0, 9]
749    with self.session(use_gpu=False):
750      for tf_op in ops_list:
751        s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices)
752        with self.assertRaisesOpError("segment ids are not increasing"):
753          self.evaluate(s)
754
755  @test_util.run_deprecated_v1
756  def testSegmentsInvalid3(self):
757    tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32)
758    ops_list = [math_ops.sparse_segment_sum, math_ops.sparse_segment_mean]
759    segment_indices = [0, 1, 2, 0]
760    tf_indices = [8, 3, 0, 9]
761    with self.session(use_gpu=False):
762      for tf_op in ops_list:
763        s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices)
764        with self.assertRaisesOpError(
765            r"Segment id 1 out of range \[0, 1\), possibly because "
766            "'segment_ids' input is not sorted"):
767          self.evaluate(s)
768
769  @test_util.run_deprecated_v1
770  def testSegmentsInvalid4(self):
771    tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32)
772    ops_list = [math_ops.sparse_segment_sum, math_ops.sparse_segment_mean]
773    segment_indices = [-1, 0, 1, 1]
774    tf_indices = [8, 3, 0, 9]
775    with self.session(use_gpu=False):
776      for tf_op in ops_list:
777        s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices)
778        with self.assertRaisesOpError(
779            r"Segment id -1 out of range \[0, 2\), possibly because "
780            "'segment_ids' input is not sorted"):
781          self.evaluate(s)
782
783  @test_util.run_deprecated_v1
784  def testSegmentsInvalid6(self):
785    tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32)
786    ops_list = [math_ops.sparse_segment_sum, math_ops.sparse_segment_mean]
787    segment_indices = [0, 0, 0, -1]
788    tf_indices = [8, 3, 0, 9]
789    with self.session(use_gpu=False):
790      for tf_op in ops_list:
791        s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices)
792        with self.assertRaisesOpError("segment ids must be >= 0"):
793          self.evaluate(s)
794
795  @test_util.run_deprecated_v1
796  def testSegmentsInvalid7(self):
797    tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32)
798    ops_list = [math_ops.sparse_segment_sum, math_ops.sparse_segment_mean]
799    segment_indices = [0, 0, 0, -2]
800    tf_indices = [8, 3, 0, 9]
801    with self.session(use_gpu=False):
802      for tf_op in ops_list:
803        s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices)
804        with self.assertRaisesOpError("segment ids must be >= 0"):
805          self.evaluate(s)
806
807  @test_util.run_deprecated_v1
808  def testSegmentsInvalid8(self):
809    tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32)
810    ops_list = [math_ops.sparse_segment_sum, math_ops.sparse_segment_mean]
811    segment_indices = [2**62 - 1]
812    tf_indices = [2**62 - 1]
813    with self.session(use_gpu=False):
814      for tf_op in ops_list:
815        s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices)
816        with self.assertRaisesOpError(
817            "Encountered overflow when multiplying"):
818          self.evaluate(s)
819
820  def testSegmentWithNumSegmentsValid(self):
821    # Baseline for the test*WithNumSegmentsInvalid* methods below.
822    tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32)
823    ops_list = [
824        math_ops.sparse_segment_sum_with_num_segments,
825        math_ops.sparse_segment_mean_with_num_segments,
826    ]
827    num_segments = 5
828    segment_indices = [0, 1, 3, 3]
829    tf_indices = [8, 3, 0, 9]
830    with self.session():
831      for tf_op in ops_list:
832        s = tf_op(
833            data=tf_x,
834            indices=tf_indices,
835            segment_ids=segment_indices,
836            num_segments=num_segments)
837        self.evaluate(s)
838
839  @test_util.run_deprecated_v1
840  def testSegmentWithNumSegmentsInvalid1(self):
841    tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32)
842    ops_list = [
843        math_ops.sparse_segment_sum_with_num_segments,
844        math_ops.sparse_segment_mean_with_num_segments,
845    ]
846    num_segments = 5
847    segment_indices = [0, 1, 3, 5]
848    tf_indices = [8, 3, 0, 9]
849    with self.session(use_gpu=False):
850      for tf_op in ops_list:
851        s = tf_op(
852            data=tf_x,
853            indices=tf_indices,
854            segment_ids=segment_indices,
855            num_segments=num_segments)
856        with self.assertRaisesOpError("segment ids must be < num_segments"):
857          self.evaluate(s)
858
859  @test_util.run_deprecated_v1
860  def testSegmentWithNumSegmentsInvalid2(self):
861    tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32)
862    ops_list = [
863        math_ops.sparse_segment_sum_with_num_segments,
864        math_ops.sparse_segment_mean_with_num_segments,
865    ]
866    num_segments = -2
867    segment_indices = [0, 1, 3, 3]
868    tf_indices = [8, 3, 0, 9]
869    with self.session(use_gpu=False):
870      for tf_op in ops_list:
871        with self.assertRaisesRegex(
872            ValueError, "Cannot specify a negative value for num_segments"):
873          tf_op(
874              data=tf_x,
875              indices=tf_indices,
876              segment_ids=segment_indices,
877              num_segments=num_segments)
878
879  @test_util.run_deprecated_v1
880  def testGradient(self):
881    shape = [10, 4]
882
883    segment_indices = [0, 1, 2, 2]
884    num_indices = len(segment_indices)
885    for tf_op in [math_ops.sparse_segment_sum, math_ops.sparse_segment_mean]:
886      with self.cached_session():
887        tf_indices, _, tf_x, np_x = self._sparse_input(
888            shape, num_indices, dtype=dtypes_lib.float64)
889        s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices)
890        jacob_t, jacob_n = gradient_checker.compute_gradient(
891            tf_x,
892            shape,
893            s, [3, 4],
894            x_init_value=np_x.astype(np.double),
895            delta=1)
896      self.assertAllClose(jacob_t, jacob_n)
897
898  @test_util.run_deprecated_v1
899  def testGradientWithEmptySegmentsAtEnd(self):
900    shape = [10, 4]
901
902    num_segments = 5
903    segment_indices = [0, 1, 2, 2]
904    num_indices = len(segment_indices)
905    for tf_op in [
906        math_ops.sparse_segment_sum_with_num_segments,
907        math_ops.sparse_segment_mean_with_num_segments,
908    ]:
909      with self.cached_session():
910        tf_indices, _, tf_x, np_x = self._sparse_input(
911            shape, num_indices, dtype=dtypes_lib.float64)
912        s = tf_op(
913            data=tf_x,
914            indices=tf_indices,
915            segment_ids=segment_indices,
916            num_segments=num_segments)
917        jacob_t, jacob_n = gradient_checker.compute_gradient(
918            tf_x,
919            shape,
920            s, [5, 4],
921            x_init_value=np_x.astype(np.double),
922            delta=1)
923      self.assertAllClose(jacob_t, jacob_n)
924
925  def testGradientExplicit(self):
926    # Note that the GPU implem has different paths for different inner sizes.
927    for inner_size in (1, 2, 3, 32):
928      with self.session():
929        tf_ygrad, np_ygrad = self._input([3, inner_size],
930                                         dtype=dtypes_lib.float32)
931        segment_ids = [0, 1, 2, 2, 2]
932        indices = [8, 3, 0, 9, 3]
933        output_dim0 = 10
934        ops_list = [
935            (math_ops.sparse_segment_sum_grad, "sum"),
936            (math_ops.sparse_segment_mean_grad, "mean"),
937            (math_ops.sparse_segment_sqrt_n_grad, "sqrtn"),
938        ]
939        for tf_op, mode in ops_list:
940          np_xgrad = self._sparseSegmentReduceGrad(np_ygrad, indices,
941                                                   segment_ids, output_dim0,
942                                                   mode)
943          tf_xgrad = tf_op(tf_ygrad, indices, segment_ids, output_dim0)
944          self.assertAllClose(tf_xgrad, np_xgrad)
945
946  def testGradientExplicitSingleOutput(self):
947    # The GPU implem has a special case when there is a single output.
948    for inner_size in (1, 2, 3, 32):
949      with self.session():
950        tf_ygrad, np_ygrad = self._input([3, inner_size],
951                                         dtype=dtypes_lib.float32)
952        segment_ids = [0, 1, 2, 2, 2]
953        indices = [0, 0, 0, 0, 0]
954        output_dim0 = 1
955        ops_list = [
956            (math_ops.sparse_segment_sum_grad, "sum"),
957            (math_ops.sparse_segment_mean_grad, "mean"),
958            (math_ops.sparse_segment_sqrt_n_grad, "sqrtn"),
959        ]
960        for tf_op, mode in ops_list:
961          np_xgrad = self._sparseSegmentReduceGrad(np_ygrad, indices,
962                                                   segment_ids, output_dim0,
963                                                   mode)
964          tf_xgrad = tf_op(tf_ygrad, indices, segment_ids, output_dim0)
965          self.assertAllClose(tf_xgrad, np_xgrad)
966
967  def testGradientValid(self):
968    # Baseline for the testGradient*Invalid* methods below.
969    tf_x, _ = self._input([3, 4], dtype=dtypes_lib.float32)
970    ops_list = [
971        math_ops.sparse_segment_sum_grad, math_ops.sparse_segment_mean_grad,
972        math_ops.sparse_segment_sqrt_n_grad
973    ]
974    segment_indices = [0, 1, 2, 2]
975    tf_indices = [8, 3, 0, 9]
976    with self.session(use_gpu=False):
977      for tf_op in ops_list:
978        s = tf_op(tf_x, tf_indices, segment_indices, 10)
979        self.evaluate(s)
980
981  @test_util.run_deprecated_v1
982  def testGradientIndicesInvalid1(self):
983    tf_x, _ = self._input([3, 4], dtype=dtypes_lib.float32)
984    ops_list = [
985        math_ops.sparse_segment_sum_grad, math_ops.sparse_segment_mean_grad,
986        math_ops.sparse_segment_sqrt_n_grad
987    ]
988    segment_indices = [0, 1, 2, 2]
989    tf_indices = [8, 3, 0, 10]
990    with self.session(use_gpu=False):
991      for tf_op in ops_list:
992        s = tf_op(tf_x, tf_indices, segment_indices, 10)
993        with self.assertRaisesOpError(r"Index 10 out of range \[0, 10\)"):
994          self.evaluate(s)
995
996  @test_util.run_deprecated_v1
997  def testGradientIndicesInvalid2(self):
998    tf_x, _ = self._input([3, 4], dtype=dtypes_lib.float32)
999    ops_list = [
1000        math_ops.sparse_segment_sum_grad, math_ops.sparse_segment_mean_grad,
1001        math_ops.sparse_segment_sqrt_n_grad
1002    ]
1003    segment_indices = [0, 1, 2, 2]
1004    tf_indices = [8, 3, -1, 9]
1005    with self.session(use_gpu=False):
1006      for tf_op in ops_list:
1007        s = tf_op(tf_x, tf_indices, segment_indices, 10)
1008        with self.assertRaisesOpError(r"Index -1 out of range \[0, 10\)"):
1009          self.evaluate(s)
1010
1011  @test_util.run_deprecated_v1
1012  def testGradientSegmentsInvalid1(self):
1013    tf_x, _ = self._input(
1014        [3, 4], dtype=dtypes_lib.float32)  # expecting 3 segments
1015    ops_list = [
1016        math_ops.sparse_segment_sum_grad, math_ops.sparse_segment_mean_grad,
1017        math_ops.sparse_segment_sqrt_n_grad
1018    ]
1019    segment_indices = [0, 1, 1, 4]  # 5 segments
1020    tf_indices = [8, 3, 0, 9]
1021    with self.session(use_gpu=False):
1022      for tf_op in ops_list:
1023        s = tf_op(tf_x, tf_indices, segment_indices, 10)
1024        with self.assertRaisesOpError("Invalid number of segments"):
1025          self.evaluate(s)
1026
1027  @test_util.run_deprecated_v1
1028  def testGradientSegmentsInvalid2(self):
1029    tf_x, _ = self._input([1, 4], dtype=dtypes_lib.float32)
1030    ops_list = [
1031        math_ops.sparse_segment_sum_grad, math_ops.sparse_segment_mean_grad,
1032        math_ops.sparse_segment_sqrt_n_grad
1033    ]
1034    segment_indices = [0, 1, 2, 0]
1035    tf_indices = [8, 3, 0, 9]
1036    with self.session(use_gpu=False):
1037      for tf_op in ops_list:
1038        s = tf_op(tf_x, tf_indices, segment_indices, 10)
1039        with self.assertRaisesOpError(r"Segment id 1 out of range \[0, 1\)"):
1040          self.evaluate(s)
1041
1042  @test_util.run_deprecated_v1
1043  def testGradientSegmentsInvalid3(self):
1044    tf_x, _ = self._input([2, 4], dtype=dtypes_lib.float32)
1045    ops_list = [
1046        math_ops.sparse_segment_sum_grad, math_ops.sparse_segment_mean_grad,
1047        math_ops.sparse_segment_sqrt_n_grad
1048    ]
1049    segment_indices = [-1, 0, 1, 1]
1050    tf_indices = [8, 3, 0, 9]
1051    with self.session(use_gpu=False):
1052      for tf_op in ops_list:
1053        s = tf_op(tf_x, tf_indices, segment_indices, 10)
1054        with self.assertRaisesOpError(r"Segment id -1 out of range \[0, 2\)"):
1055          self.evaluate(s)
1056
1057  @test_util.run_deprecated_v1
1058  def testGradientSegmentsInvalid4(self):
1059    tf_x, _ = self._input([0, 4], dtype=dtypes_lib.float32)
1060    ops_list = [
1061        math_ops.sparse_segment_sum_grad, math_ops.sparse_segment_mean_grad,
1062        math_ops.sparse_segment_sqrt_n_grad
1063    ]
1064    segment_indices = [0, 1, 2, -1]
1065    tf_indices = [8, 3, 0, 9]
1066    with self.session(use_gpu=False):
1067      for tf_op in ops_list:
1068        s = tf_op(tf_x, tf_indices, segment_indices, 10)
1069        with self.assertRaisesOpError(r"Segment id 0 out of range \[0, 0\)"):
1070          self.evaluate(s)
1071
1072
1073class SegmentReductionOpBenchmark(test.Benchmark):
1074  outer_dim_options = [2**x for x in range(9, 14, 2)]
1075  ratio_options = [2**x for x in range(1, 6, 2)]
1076  inner_dim_options = [2**x for x in range(9, 14, 2)]
1077  # randomly generated sizes with less alignments
1078  inner_dim_options += [
1079      1120, 1215, 1856, 1302, 1329, 1531, 1313, 1672, 1851, 1584
1080  ]
1081  dtype_options = [np.float32, np.float64]
1082  options = (outer_dim_options, ratio_options, inner_dim_options, dtype_options)
1083  # pylint: disable=g-long-lambda
1084  op_functors = [lambda vc, vs, seg_ids:
1085                 ("sorted", math_ops.segment_sum(vc, vs)),
1086                 lambda vc, vs, seg_ids:
1087                 ("unsorted",
1088                  math_ops.unsorted_segment_sum(vc, vs, seg_ids[-1]+1))]
1089  # pylint: enable=g-long-lambda
1090  repeat = 10
1091
1092  def _npTypeToStr(self, t):
1093    if t == np.float32:
1094      return "fp32"
1095    if t == np.float64:
1096      return "fp64"
1097
1098  def _runGraph(self, op_functor, outer_dim, ratio, inner_dim, dtype):
1099    output_outer_dim = int(outer_dim / ratio)
1100    const = np.random.randint(5, size=(outer_dim, inner_dim))
1101    seg_ids = np.sort(np.random.randint(output_outer_dim, size=outer_dim))
1102    vs = variables.Variable(seg_ids.astype(np.int32))
1103    with ops.device("/gpu:0"):
1104      vc = variables.Variable(const.astype(dtype))
1105    name, op = op_functor(vc, vs, seg_ids)
1106    with session.Session() as sess:
1107      self.evaluate(variables.global_variables_initializer())
1108      r = self.run_op_benchmark(
1109          sess,
1110          op,
1111          min_iters=self.repeat,
1112          name="_".join(
1113              map(str,
1114                  [name, outer_dim, ratio, inner_dim,
1115                   self._npTypeToStr(dtype)])))
1116    return name, r["wall_time"]
1117
1118  def benchmarkSegmentSumGPU(self):
1119    if not test.is_gpu_available(cuda_only=True):
1120      return
1121    for outer_dim, ratio, inner_dim, dtype in itertools.product(*self.options):
1122      op_functor = self.op_functors[0]
1123      with ops.Graph().as_default():
1124        self._runGraph(op_functor, outer_dim, ratio, inner_dim, dtype)
1125
1126  def benchmarkUnsortedSegmentSumGPU(self):
1127    if not test.is_gpu_available(cuda_only=True):
1128      return
1129    for outer_dim, ratio, inner_dim, dtype in itertools.product(*self.options):
1130      op_functor = self.op_functors[1]
1131      with ops.Graph().as_default():
1132        self._runGraph(op_functor, outer_dim, ratio, inner_dim, dtype)
1133
1134
1135if __name__ == "__main__":
1136  test.main()
1137