• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#    http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Collective Operations."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import time
22
23from tensorflow.core.protobuf import config_pb2
24from tensorflow.core.protobuf import rewriter_config_pb2
25from tensorflow.python.eager import context
26from tensorflow.python.eager import def_function
27from tensorflow.python.framework import config
28from tensorflow.python.framework import constant_op
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import errors
31from tensorflow.python.framework import kernels
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import test_util
34from tensorflow.python.ops import array_ops
35from tensorflow.python.ops import collective_ops
36from tensorflow.python.ops import control_flow_ops
37from tensorflow.python.ops import math_ops
38from tensorflow.python.ops import variables
39from tensorflow.python.platform import test
40from tensorflow.python.platform import tf_logging as logging
41
42
43class CollectiveOpTest(test.TestCase):
44
45  def setUp(self):
46    context._reset_context()  # pylint: disable=protected-access
47    super(CollectiveOpTest, self).setUp()
48
49  def _testCollectiveReduce(self,
50                            inputs,
51                            expected,
52                            set_graph_key,
53                            communication_hint='auto',
54                            fp16=False,
55                            instance_key=1,
56                            merge_op='Add',
57                            final_op='Div',
58                            timeout=0,
59                            reported_group_size=None):
60    group_key = 1
61    group_size = len(inputs)
62    if reported_group_size is None:
63      reported_group_size = group_size
64    device_type = 'CPU'
65    config = config_pb2.ConfigProto(device_count={device_type: group_size})
66    devices = ['/{}:{}'.format(device_type, i) for i in range(group_size)]
67
68    with self.session(config=config) as sess:
69      colred = []
70      for i in range(group_size):
71        with ops.device(devices[i]):
72          tensor = constant_op.constant(inputs[i], dtype=(
73              dtypes.float16 if fp16 else dtypes.float32))
74          colred.append(
75              collective_ops.all_reduce(
76                  tensor,
77                  reported_group_size,
78                  group_key,
79                  instance_key,
80                  merge_op,
81                  final_op,
82                  communication_hint=communication_hint,
83                  timeout=timeout))
84      run_options = config_pb2.RunOptions()
85      if set_graph_key:
86        run_options.experimental.collective_graph_key = 1
87      results = sess.run(colred, options=run_options)
88    tolerance = 1e-3 if fp16 else 1e-5
89    for i in range(group_size):
90      logging.info('i {} result {} expected {}'.format(i, results[i], expected))
91      self.assertAllClose(results[i], expected, rtol=tolerance, atol=tolerance)
92
93  def _testMultipleConcurrentCollectiveReduce(self, t0, t1, expected):
94    group_key = 1
95    group_size = 2
96    num_instances = 2
97    all_reduces = []
98    config = config_pb2.ConfigProto(device_count={'CPU': group_size})
99    config.experimental.collective_deterministic_sequential_execution = True
100    with self.session(config=config) as sess:
101      for cpu in range(group_size):
102        with ops.device('/CPU:%d' % cpu):
103          in_tensor = constant_op.constant(t0 if cpu == 0 else t1)
104          for instance in range(num_instances):
105            all_reduces.append(collective_ops.all_reduce(
106                in_tensor, group_size, group_key, instance, 'Add', 'Div'))
107      results = sess.run(all_reduces)
108    for i in range(group_size * num_instances):
109      self.assertAllClose(results[i], expected, rtol=1e-5, atol=1e-5)
110
111  def testCollectiveReduce(self):
112    # Tests that execute collectives need to be enclosed in graph or tf.function
113    with ops.Graph().as_default():
114      self._testCollectiveReduce(
115          inputs=[[0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1],
116                  [0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3]],
117          expected=[0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2],
118          set_graph_key=True)
119
120  def testCollectiveAutoGraphKey(self):
121    # Tests that execute collectives need to be enclosed in graph or tf.function
122    with ops.Graph().as_default():
123      self._testCollectiveReduce(
124          inputs=[[0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1],
125                  [0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3]],
126          expected=[0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2],
127          set_graph_key=False)
128
129  def testFp16Reduce(self):
130    # Tests that execute collectives need to be enclosed in graph or tf.function
131    with ops.Graph().as_default():
132      self._testCollectiveReduce(
133          inputs=[[0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1],
134                  [0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3]],
135          expected=[0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2],
136          set_graph_key=True,
137          fp16=True)
138
139  def testCollectiveMultipleConcurrentReduce(self):
140    # Tests that execute collectives need to be enclosed in graph or tf.function
141    with ops.Graph().as_default():
142      self._testMultipleConcurrentCollectiveReduce(
143          [0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1],
144          [0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3],
145          [0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2])
146
147  def testCollectiveTimeoutV1(self):
148    timeout = 4.5
149    kwargs = dict(
150        inputs=[[i + j + 0.1 for i in range(8)] for j in range(3)],
151        expected=[1 + i + 0.1 for i in range(8)],
152        set_graph_key=True,
153        timeout=timeout)
154
155    # Tests that execute collectives need to be enclosed in graph or tf.function
156    with ops.Graph().as_default():
157      self._testCollectiveReduce(**kwargs)
158
159    start_time = time.time()
160    with ops.Graph().as_default():
161      with self.assertRaisesRegex(
162          errors.DeadlineExceededError,
163          'Collective has timed out waiting for other workers'):
164        self._testCollectiveReduce(
165            reported_group_size=len(kwargs['inputs']) + 1, **kwargs)
166    elapsed = time.time() - start_time
167    self.assertAllGreaterEqual(elapsed, timeout)
168
169  def testNcclHintFallbackToRingReduce(self):
170    """Tests that setting `communication_hint=nccl` works on non-GPU builds."""
171    if kernels.get_registered_kernels_for_op('NcclAllReduce'):
172      self.skipTest('Run only on non-GPU environments')
173    # Tests that execute collectives need to be enclosed in graph or tf.function
174    with ops.Graph().as_default():
175      self._testCollectiveReduce(
176          inputs=[[0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1],
177                  [0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3]],
178          expected=[0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2],
179          set_graph_key=False,
180          communication_hint='nccl')
181
182  def _testWhile(self, num_vars, num_iterations, key_base):
183    group_size = 2
184    group_key = 1
185    instances = [(key_base + i) for i in range(num_vars)]
186    devices = ['CPU:{}'.format(i) for i in range(group_size)]
187
188    config = config_pb2.ConfigProto(device_count={'CPU': group_size})
189    rewrite_options = config.graph_options.rewrite_options
190    rewrite_options.scoped_allocator_optimization = (
191        rewriter_config_pb2.RewriterConfig.ON)
192    del rewrite_options.scoped_allocator_opts.enable_op[:]
193    rewrite_options.scoped_allocator_opts.enable_op.append('CollectiveReduce')
194
195    with self.session(config=config) as sess:
196      loop_vars = []
197      for device in devices:
198        with ops.device(device):
199          loop_vars.append(
200              [variables.VariableV1((1 << i) * 1.) for i in range(num_vars)])
201      # This variable controls number of iterations.
202      loop_vars.append(variables.VariableV1(0.))
203      def loop_body(dev0_tensors, dev1_tensors, loop_tensor):
204        return_ops = []
205        for i in range(len(devices)):
206          device = devices[i]
207          device_tensors = dev0_tensors if i == 0 else dev1_tensors
208          with ops.device(device):
209            device_collectives = []
210            for j in range(num_vars):
211              # NOTE(ayushd): we need the `cast` here to ensure that the input
212              # to `all_reduce` has an explicit device string.  We don't use
213              # `identity` because `cast` is more resilient to getting optimized
214              # away by various optimization passes.
215              input_tensor = math_ops.cast(device_tensors[j], dtypes.float16)
216              collective_op = collective_ops.all_reduce(
217                  input_tensor, group_size, group_key, instances[j],
218                  'Add', 'Id')
219              output_tensor = math_ops.cast(collective_op, dtypes.float32)
220              device_collectives.append(output_tensor)
221            return_ops.append(device_collectives)
222        return_ops.append(math_ops.add(loop_tensor, 1.))
223        return return_ops
224      # Run until last variable exceeds number of iterations.
225      loop_cond = lambda d0, d1, i: math_ops.less(i, num_iterations)
226      sess.run(variables.global_variables_initializer())
227      results = sess.run(control_flow_ops.while_loop(loop_cond, loop_body,
228                                                     loop_vars))
229      self.assertEqual(results[:-1], [
230          [((1 << (num_iterations + v)) * 1.) for v in range(num_vars)]
231          for _ in range(group_size)])
232
233  def testSimpleWhile(self):
234    # Tests that execute collectives need to be enclosed in graph or tf.function
235    with ops.Graph().as_default():
236      self._testWhile(num_vars=1, num_iterations=4, key_base=20)
237
238  def testWhileMultipleAllReduce(self):
239    # Tests that execute collectives need to be enclosed in graph or tf.function
240    with ops.Graph().as_default():
241      self._testWhile(num_vars=2, num_iterations=4, key_base=20)
242
243  def testWhileWithScopedAllocator(self):
244    group_size = 2
245    group_key = 1
246    instance_key0 = 1
247    instance_key1 = 2
248
249    config = config_pb2.ConfigProto(device_count={'CPU': group_size})
250    rewrite_options = config.graph_options.rewrite_options
251    rewrite_options.scoped_allocator_optimization = (
252        rewriter_config_pb2.RewriterConfig.ON)
253    del rewrite_options.scoped_allocator_opts.enable_op[:]
254    rewrite_options.scoped_allocator_opts.enable_op.append('CollectiveReduce')
255
256    # Tests that execute collectives need to be enclosed in graph or tf.function
257    with ops.Graph().as_default():
258      with self.session(config=config) as sess:
259        run_ops = []
260        for i in range(group_size):
261          with ops.device('CPU:%d' % i):
262            constant = constant_op.constant(0.)
263            cond = lambda i: math_ops.less(i, 10.)
264            body = lambda i: math_ops.add(i, 1.)
265            input0 = control_flow_ops.while_loop(cond, body, [constant])
266            input1 = math_ops.add(constant, 5)
267            colred0 = collective_ops.all_reduce(input0, group_size, group_key,
268                                                instance_key0, 'Add', 'Id')
269            colred1 = collective_ops.all_reduce(input1, group_size, group_key,
270                                                instance_key1, 'Add', 'Id')
271            run_ops.append(math_ops.add_n([colred0, colred1]))
272        results = sess.run(run_ops)
273      self.assertEqual(results, [30., 30.])
274
275  def testCollectiveReduceScalar(self):
276    # Tests that execute collectives need to be enclosed in graph or tf.function
277    with ops.Graph().as_default():
278      self._testCollectiveReduce(inputs=[0.1, 0.3], expected=0.2,
279                                 set_graph_key=True)
280
281  def testCollectiveReduceMaximum(self):
282    # Tests that execute collectives need to be enclosed in graph or tf.function
283    with ops.Graph().as_default():
284      self._testCollectiveReduce(
285          inputs=[[1., 20., 3., 40., 5.], [10., 2., 30., 4., 50.]],
286          expected=[10., 20., 30., 40., 50.],
287          set_graph_key=True,
288          instance_key=30,
289          merge_op='Max',
290          final_op='Id')
291
292  def testCollectiveReduceMinimum(self):
293    # Tests that execute collectives need to be enclosed in graph or tf.function
294    with ops.Graph().as_default():
295      self._testCollectiveReduce(
296          inputs=[[1., 20., 3., 40., 5.], [10., 2., 30., 4., 50.]],
297          expected=[1., 2., 3., 4., 5.],
298          set_graph_key=True,
299          instance_key=40,
300          merge_op='Min',
301          final_op='Id')
302
303  def _testCollectiveBroadcast(self, in_val):
304    group_key = 1
305    instance_key = 1
306    with self.session(
307        config=config_pb2.ConfigProto(device_count={'CPU': 2})) as sess:
308      with ops.device('/CPU:0'):
309        in0 = constant_op.constant(in_val)
310        out0 = collective_ops.broadcast_send(in0, in0.shape, in0.dtype,
311                                             2, group_key, instance_key)
312      with ops.device('/CPU:1'):
313        c1 = constant_op.constant(in_val)
314        out1 = collective_ops.broadcast_recv(c1.shape, c1.dtype,
315                                             2, group_key, instance_key)
316      run_options = config_pb2.RunOptions()
317      run_options.experimental.collective_graph_key = 1
318      results = sess.run([out0, out1], options=run_options)
319    self.assertAllClose(results[0], in_val, rtol=1e-5, atol=1e-5)
320    self.assertAllClose(results[1], in_val, rtol=1e-5, atol=1e-5)
321
322  def testCollectiveBroadcast(self):
323    # Tests that execute collectives need to be enclosed in graph or tf.function
324    with ops.Graph().as_default():
325      self._testCollectiveBroadcast([0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1])
326
327  def testCollectiveBroadcastBool(self):
328    # Tests that execute collectives need to be enclosed in graph or tf.function
329    with ops.Graph().as_default():
330      self._testCollectiveBroadcast([True, False])
331
332  def _testCollectiveGather(self, t0, t1, expected, set_graph_key):
333    group_key = 1
334    instance_key = 1
335    with self.session(
336        config=config_pb2.ConfigProto(device_count={'CPU': 2})) as sess:
337      with ops.device('/CPU:0'):
338        in0 = constant_op.constant(t0)
339        c0 = collective_ops.all_gather(in0, 2, group_key, instance_key)
340      with ops.device('/CPU:1'):
341        in1 = constant_op.constant(t1)
342        c1 = collective_ops.all_gather(in1, 2, group_key, instance_key)
343      run_options = config_pb2.RunOptions()
344      if set_graph_key:
345        run_options.experimental.collective_graph_key = 1
346      results = sess.run([c0, c1], options=run_options)
347    self.assertAllClose(results[0], expected, rtol=1e-5, atol=1e-5)
348    self.assertAllClose(results[1], expected, rtol=1e-5, atol=1e-5)
349
350  def testCollectiveGather(self):
351    # Tests that execute collectives need to be enclosed in graph or tf.function
352    with ops.Graph().as_default():
353      self._testCollectiveGather([0, 1, 2, 3, 4, 5, 6, 7],
354                                 [10, 11, 12, 13, 14, 15, 16, 17],
355                                 [0, 1, 2, 3, 4, 5, 6, 7,
356                                  10, 11, 12, 13, 14, 15, 16, 17],
357                                 True)
358      self._testCollectiveGather([[0, 1, 2, 3], [4, 5, 6, 7]],
359                                 [[10, 11, 12, 13], [14, 15, 16, 17]],
360                                 [[0, 1, 2, 3], [4, 5, 6, 7],
361                                  [10, 11, 12, 13], [14, 15, 16, 17]],
362                                 True)
363      self._testCollectiveGather([[[0, 1], [2, 3]], [[4, 5], [6, 7]]],
364                                 [[[10, 11], [12, 13]], [[14, 15], [16, 17]]],
365                                 [[[0, 1], [2, 3]], [[4, 5], [6, 7]],
366                                  [[10, 11], [12, 13]], [[14, 15], [16, 17]]],
367                                 True)
368
369  def testCollectiveGatherShapeMismatch(self):
370    group_key = 1
371    instance_key = 1
372    t0 = [1, 2, 3, 4]
373    t1 = [5, 6, 7, 8]
374    t2 = [9, 10]
375    # Tests that execute collectives need to be enclosed in graph or tf.function
376    with ops.Graph().as_default():
377      with self.session(
378          config=config_pb2.ConfigProto(device_count={'CPU': 2})) as sess:
379        with ops.device('/CPU:0'):
380          in0 = constant_op.constant(t0)
381          c0 = collective_ops.all_gather(in0, 2, group_key, instance_key)
382        with ops.device('/CPU:1'):
383          in1 = constant_op.constant(t1)
384          in2 = constant_op.constant(t2)
385          c1 = collective_ops.all_gather(in1, 2, group_key, instance_key)
386          c2 = collective_ops.all_gather(in2, 2, group_key, instance_key)
387        run_options = config_pb2.RunOptions()
388        run_options.experimental.collective_graph_key = 1
389        sess.run([c0, c1], options=run_options)
390        with self.assertRaisesRegex(errors.InvalidArgumentError,
391                                    'Shape mismatch'):
392          sess.run([c0, c2], options=run_options)
393
394  def testCollectiveGatherShapeMismatchAcrossDevices(self):
395    group_key = 1
396    instance_key = 1
397    t0 = [1, 2, 3, 4]
398    t1 = [5, 6]
399    # Tests that execute collectives need to be enclosed in graph or tf.function
400    with ops.Graph().as_default():
401      with self.session(
402          config=config_pb2.ConfigProto(device_count={'CPU': 2})) as sess:
403        with ops.device('/CPU:0'):
404          in0 = constant_op.constant(t0)
405          c0 = collective_ops.all_gather(in0, 2, group_key, instance_key)
406        with ops.device('/CPU:1'):
407          in1 = constant_op.constant(t1)
408          c1 = collective_ops.all_gather(in1, 2, group_key, instance_key)
409        run_options = config_pb2.RunOptions()
410        run_options.experimental.collective_graph_key = 1
411        with self.assertRaisesRegex(errors.InvalidArgumentError,
412                                    'Shape mismatch'):
413          sess.run([c0, c1], options=run_options)
414
415  def testCollectiveGatherPolymorphicShape(self):
416    t0 = [0, 1, 2, 3, 4, 5, 6, 7]
417    t1 = [10, 11, 12, 13, 14, 15, 16, 17]
418    group_size = 2
419    group_key = 1
420    instance_key = 123
421    # Tests that execute collectives need to be enclosed in graph or tf.function
422    with ops.Graph().as_default():
423      with self.session(
424          config=config_pb2.ConfigProto(
425              device_count={'CPU': group_size})) as sess:
426        with ops.device('/CPU:0'):
427          in0 = array_ops.placeholder(dtype=dtypes.int32, shape=[None])
428          c0 = collective_ops.all_gather(in0, group_size, group_key,
429                                         instance_key)
430        with ops.device('/CPU:1'):
431          in1 = array_ops.placeholder(dtype=dtypes.int32, shape=[None])
432          c1 = collective_ops.all_gather(in1, group_size, group_key,
433                                         instance_key)
434
435        results = sess.run([c0, c1], feed_dict={in0: t0, in1: t1})
436        results_ = sess.run([c0, c1], feed_dict={in0: t0[1:], in1: t1[1:]})
437
438    expected_output = [0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13, 14, 15, 16, 17]
439    self.assertAllClose(results[0], expected_output, rtol=1e-5, atol=1e-5)
440    self.assertAllClose(results[1], expected_output, rtol=1e-5, atol=1e-5)
441
442    expected_output_ = [1, 2, 3, 4, 5, 6, 7, 11, 12, 13, 14, 15, 16, 17]
443    self.assertAllClose(results_[0], expected_output_, rtol=1e-5, atol=1e-5)
444    self.assertAllClose(results_[1], expected_output_, rtol=1e-5, atol=1e-5)
445
446  @test_util.run_v2_only
447  @test_util.disable_tfrt(
448      'b/177270918: TFRT has dead lock when executing collective ops.')
449  def testCollectiveGroupSizeMismatch(self):
450    cpus = config.list_physical_devices('CPU')
451    self.assertEqual(len(cpus), 1)
452    config.set_logical_device_configuration(cpus[0], [
453        context.LogicalDeviceConfiguration(),
454        context.LogicalDeviceConfiguration()
455    ])
456    context.ensure_initialized()
457
458    @def_function.function
459    def run_all_reduce():
460      group_key = 10
461      instance_key = 20
462      t0 = [1, 2, 3, 4]
463      t1 = [5, 6, 7, 8]
464      with ops.device('/CPU:0'):
465        in0 = constant_op.constant(t0)
466        c0 = collective_ops.all_reduce(
467            in0, group_size=2, group_key=group_key, instance_key=instance_key,
468            merge_op='Add', final_op='Id')
469      with ops.device('/CPU:1'):
470        in1 = constant_op.constant(t1)
471        c1 = collective_ops.all_reduce(
472            in1, group_size=3, group_key=group_key, instance_key=instance_key,
473            merge_op='Add', final_op='Id')
474      return c0, c1
475
476    with self.assertRaisesRegex(errors.InternalError,
477                                'but that group has size'):
478      run_all_reduce()
479
480  @test_util.run_v2_only
481  def testCollectiveTensorsHaveNoDeviceSpecified(self):
482    cpus = config.list_physical_devices('CPU')
483    self.assertEqual(len(cpus), 1)
484    config.set_logical_device_configuration(cpus[0], [
485        context.LogicalDeviceConfiguration(),
486        context.LogicalDeviceConfiguration()
487    ])
488    context.ensure_initialized()
489
490    group_size = 2
491    group_key = 1
492    instance_key = 1
493
494    @def_function.function
495    def fn(all_args):
496      results = []
497      # The inputs have no devices set. This is expected to be a trace-time
498      # check only.
499      self.assertEqual(all_args[0].device, '')
500      self.assertEqual(all_args[1].device, '')
501
502      with ops.device('/CPU:0'):
503        results.append(
504            collective_ops.all_reduce(all_args[0], group_size, group_key,
505                                      instance_key, 'Add', 'Div'))
506      with ops.device('/CPU:1'):
507        results.append(
508            collective_ops.all_reduce(all_args[1], group_size, group_key,
509                                      instance_key, 'Add', 'Div'))
510
511      return results
512
513    with ops.device('/CPU:0'):
514      in0 = constant_op.constant(1)
515    with ops.device('/CPU:1'):
516      in1 = constant_op.constant(3)
517    result = fn([in0, in1])
518    self.assertAllClose(result, [2, 2])
519
520  def testConstantWithScopedAllocator(self):
521    group_size = 2
522    group_key = 1
523    instance_key1 = 1
524    instance_key2 = 2
525
526    graph_options = config_pb2.GraphOptions(
527        optimizer_options=config_pb2.OptimizerOptions(do_constant_folding=True))
528    cfg = config_pb2.ConfigProto(device_count={'CPU': group_size},
529                                 graph_options=graph_options)
530    rewrite_options = cfg.graph_options.rewrite_options
531    rewrite_options.scoped_allocator_optimization = (
532        rewriter_config_pb2.RewriterConfig.ON)
533    del rewrite_options.scoped_allocator_opts.enable_op[:]
534    rewrite_options.scoped_allocator_opts.enable_op.append('CollectiveReduce')
535
536    # Tests that execute collectives need to be enclosed in graph or tf.function
537    with ops.Graph().as_default():
538      with self.session(config=cfg) as sess:
539        run_ops = []
540        for i in range(group_size):
541          with ops.device('CPU:%d' % i):
542            constant = constant_op.constant(i + 1.)
543            input_tensor1 = array_ops.identity(constant)
544            input_tensor2 = array_ops.identity(constant)
545            reduced_tensor1 = collective_ops.all_reduce(
546                input_tensor1, group_size, group_key, instance_key1, 'Add',
547                'Id')
548            reduced_tensor2 = collective_ops.all_reduce(
549                input_tensor2, group_size, group_key, instance_key2, 'Add',
550                'Id')
551            run_ops.append(array_ops.identity(reduced_tensor1))
552            run_ops.append(array_ops.identity(reduced_tensor2))
553        results = sess.run(run_ops)
554    self.assertEqual(results, [3., 3., 3., 3.])
555
556
557if __name__ == '__main__':
558  test.main()
559