• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for CrossDeviceOps."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import itertools
22
23from absl.testing import parameterized
24import numpy as np
25
26from tensorflow.contrib.distribute.python import collective_all_reduce_strategy
27from tensorflow.contrib.distribute.python import combinations
28from tensorflow.contrib.distribute.python import mirrored_strategy
29from tensorflow.contrib.distribute.python import multi_worker_test_base
30from tensorflow.core.protobuf import config_pb2
31from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
32from tensorflow.python.distribute import cross_device_utils
33from tensorflow.python.distribute import device_util
34from tensorflow.python.distribute import reduce_util
35from tensorflow.python.distribute import values as value_lib
36from tensorflow.python.eager import context
37from tensorflow.python.eager import test
38from tensorflow.python.framework import constant_op
39from tensorflow.python.framework import ops
40from tensorflow.python.ops import array_ops
41from tensorflow.python.ops import math_ops
42
43
44def _get_devices(devices):
45  if isinstance(devices, (tuple, list)):
46    return tuple(device_util.resolve(d) for d in devices)
47  elif isinstance(devices, value_lib.DistributedValues):
48    return devices.devices
49  return (device_util.resolve(devices),)
50
51
52def _make_per_replica(values, devices, regroup=False):
53  devices = _get_devices(devices)
54  assert len(values) == len(devices)
55
56  # We simulate the result of regroup called on PerReplica which strips the
57  # PerReplica wrapper if it has only one value.
58  if len(values) == 1 and regroup:
59    with ops.device(devices[0]):
60      placed_v = array_ops.identity(values[0])
61    return placed_v
62
63  index = []
64  for d, v in zip(devices, values):
65    with ops.device(d):
66      placed_v = array_ops.identity(v)
67    index.append(placed_v)
68  return value_lib.PerReplica(value_lib.ReplicaDeviceMap(devices), index)
69
70
71# pylint: disable=g-doc-args,g-doc-return-or-yield
72def _fake_mirrored(value, devices):
73  """Create a faked Mirrored object for testing.
74
75  All components of the returned Mirrored have the same objects, which is not
76  true in reality.
77  """
78  devices = _get_devices(devices)
79  return value_lib.Mirrored(value_lib.ReplicaDeviceMap(devices),
80                            [value] * len(devices))
81
82
83def _make_indexed_slices(values, indices, dense_shape, device):
84  with ops.device(device):
85    tensor = ops.IndexedSlices(
86        values=constant_op.constant(values),
87        indices=constant_op.constant(indices),
88        dense_shape=constant_op.constant(dense_shape))
89  return tensor
90
91
92def _make_mirrored_indexed_slices(devices, values, indices, dense_shape):
93  values = [_make_indexed_slices(values, indices, dense_shape, d)
94            for d in devices]
95  return value_lib.Mirrored(value_lib.ReplicaDeviceMap(devices), values)
96
97
98_cpu_device = "/device:CPU:0"
99
100
101class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase):
102
103  def _assert_indexed_slices_equal(self, left, right):
104    self.assertIsInstance(left, ops.IndexedSlices)
105    self.assertIsInstance(right, ops.IndexedSlices)
106    self.assertEqual(device_util.resolve(left.device),
107                     device_util.resolve(right.device))
108    self.assertAllEqual(
109        self.evaluate(ops.convert_to_tensor(left)),
110        self.evaluate(ops.convert_to_tensor(right)))
111
112  def _assert_values_equal(self, left, right):
113    if isinstance(left, list):
114      for l, r in zip(left, right):
115        self._assert_values_equal(l, r)
116    else:
117      self.assertEqual(type(left), type(right))
118      self.assertEqual(set(left.devices), set(right.devices))
119      if isinstance(left.values[0], ops.IndexedSlices):
120        for d in left.devices:
121          self._assert_indexed_slices_equal(left.get(d), right.get(d))
122      elif context.executing_eagerly():
123        self.assertEqual([v.numpy() for v in left.values],
124                         list(right.values))
125      else:
126        with self.cached_session() as sess:
127          self.assertEqual(
128              sess.run(list(left.values)), list(right.values))
129
130  def _testReductionAndBroadcast(self, cross_device_ops, distribution):
131    devices = distribution.extended.worker_devices
132
133    values = [constant_op.constant(float(d)) for d in range(len(devices))]
134    per_replica = _make_per_replica(values, devices)
135    mean = (len(devices) - 1.) / 2.
136
137    values_2 = [constant_op.constant(d + 1.0) for d in range(len(devices))]
138    per_replica_2 = _make_per_replica(values_2, devices)
139    mean_2 = mean + 1.
140
141    destination_mirrored = _fake_mirrored(1., devices)
142    destination_different = _fake_mirrored(1., _cpu_device)
143    destination_str = _cpu_device
144
145    all_destinations = [
146        destination_mirrored, destination_different, destination_str,
147    ]
148
149    # test reduce()
150    for destinations in all_destinations:
151      self._assert_values_equal(
152          cross_device_ops.reduce(
153              reduce_util.ReduceOp.MEAN,
154              per_replica,
155              destinations=destinations),
156          _fake_mirrored(mean, destinations))
157      self._assert_values_equal(
158          cross_device_ops.reduce(
159              reduce_util.ReduceOp.MEAN,
160              per_replica_2,
161              destinations=destinations),
162          _fake_mirrored(mean_2, destinations))
163      self._assert_values_equal(
164          cross_device_ops.reduce(
165              reduce_util.ReduceOp.SUM, per_replica,
166              destinations=destinations),
167          _fake_mirrored(mean * len(devices), destinations))
168      self._assert_values_equal(
169          cross_device_ops.reduce(
170              reduce_util.ReduceOp.SUM,
171              per_replica_2,
172              destinations=destinations),
173          _fake_mirrored(mean_2 * len(devices), destinations))
174
175    # test batch_reduce()
176    for d1, d2 in itertools.product(all_destinations, all_destinations):
177      self._assert_values_equal(
178          cross_device_ops.batch_reduce(
179              reduce_util.ReduceOp.MEAN,
180              [(per_replica, d1), (per_replica_2, d2)]),
181          [
182              _fake_mirrored(mean, d1),
183              _fake_mirrored(mean_2, d2)
184          ])
185      self._assert_values_equal(
186          cross_device_ops.batch_reduce(
187              reduce_util.ReduceOp.SUM,
188              [(per_replica, d1), (per_replica_2, d2)]),
189          [
190              _fake_mirrored(mean * len(devices), d1),
191              _fake_mirrored(mean_2 * len(devices), d2)
192          ])
193
194    # test broadcast()
195    for destinations in all_destinations:
196      self._assert_values_equal(
197          cross_device_ops.broadcast(constant_op.constant(1.), destinations),
198          _fake_mirrored(1., destinations))
199
200
201class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase):
202  # TODO(yuefengz): decouple the num_gpus check from distribution in
203  # combinations module so that we can pass in devices instead of a distribution
204  # strategy.
205  reduction_to_one_combinations = combinations.combine(
206      cross_device_ops=[
207          combinations.NamedObject(
208              "DefaultReductionToOneDevice",
209              cross_device_ops_lib.ReductionToOneDevice()),
210          combinations.NamedObject(
211              "ReductionToCPUDeviceCrossDeviceOps",
212              cross_device_ops_lib.ReductionToOneDevice(
213                  reduce_to_device=_cpu_device)),
214          combinations.NamedObject(
215              "AccumulateNCrossDeviceOp",
216              cross_device_ops_lib.ReductionToOneDevice(
217                  accumulation_fn=math_ops.accumulate_n)),
218      ],
219      distribution=[
220          combinations.one_device_strategy,
221          combinations.mirrored_strategy_with_gpu_and_cpu,
222          combinations.mirrored_strategy_with_two_gpus,
223          combinations.core_mirrored_strategy_with_gpu_and_cpu,
224          combinations.core_mirrored_strategy_with_two_gpus
225      ],
226      mode=["graph", "eager"])
227  allreduce_combinations = combinations.combine(
228      cross_device_ops=[
229          combinations.NamedObject(
230              "AllReduce",
231              cross_device_ops_lib.AllReduceCrossDeviceOps("nccl", 1, 0, 0)),
232          combinations.NamedObject(
233              "AllReduceNoGradientRepacking",
234              cross_device_ops_lib.AllReduceCrossDeviceOps("nccl", 0, 0, 0)),
235          combinations.NamedObject("NcclAllReduce",
236                                   cross_device_ops_lib.NcclAllReduce()),
237          combinations.NamedObject(
238              "HierarchicalCopy",
239              cross_device_ops_lib.HierarchicalCopyAllReduce(8)),
240          combinations.NamedObject(
241              "HierarchicalCopyAggregateSmallTensors",
242              cross_device_ops_lib.AllReduceCrossDeviceOps(
243                  "hierarchical_copy", 0, 100, 10))
244      ],
245      distribution=[
246          combinations.mirrored_strategy_with_two_gpus,
247          combinations.core_mirrored_strategy_with_two_gpus
248      ],
249      mode=["graph", "eager"])
250
251  @combinations.generate(reduction_to_one_combinations + allreduce_combinations)
252  def testReductionAndBroadcast(self, cross_device_ops, distribution):
253    with distribution.scope():
254      self._testReductionAndBroadcast(cross_device_ops, distribution)
255
256  def testChooseAlgorithm(self):
257    device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7],
258                    [0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]]
259    result = cross_device_ops_lib._choose_all_reduce_algorithm(device_links)
260    self.assertIsInstance(result, cross_device_ops_lib.AllReduceCrossDeviceOps)
261    self.assertEqual(result._all_reduce_alg, "hierarchical_copy")
262    self.assertEqual(result._num_packs, 8)
263
264    # if there are only 4 devices
265    device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7]]
266    result = cross_device_ops_lib._choose_all_reduce_algorithm(device_links)
267    self.assertIsInstance(result, cross_device_ops_lib.AllReduceCrossDeviceOps)
268    self.assertEqual(result._all_reduce_alg, "nccl")
269    self.assertEqual(result._num_packs, 1)
270
271    # if devices links contain each device itself
272    device_links = [[0, 1, 2, 3, 4], [0, 1, 2, 3, 5], [0, 1, 2, 3, 6],
273                    [0, 1, 2, 3, 7], [0, 4, 5, 6, 7], [1, 4, 5, 6, 7],
274                    [2, 4, 5, 6, 7], [3, 4, 5, 6, 7]]
275    result = cross_device_ops_lib._choose_all_reduce_algorithm(device_links)
276    self.assertIsInstance(result, cross_device_ops_lib.AllReduceCrossDeviceOps)
277    self.assertEqual(result._all_reduce_alg, "hierarchical_copy")
278    self.assertEqual(result._num_packs, 8)
279
280    # if not dgx1-like links
281    device_links = [[0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7],
282                    [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6], [1, 2, 3, 4]]
283    result = cross_device_ops_lib._choose_all_reduce_algorithm(device_links)
284    self.assertIsInstance(result, cross_device_ops_lib.AllReduceCrossDeviceOps)
285    self.assertEqual(result._all_reduce_alg, "nccl")
286    self.assertEqual(result._num_packs, 1)
287
288  @combinations.generate(combinations.combine(
289      mode=["graph", "eager"],
290      required_gpus=1))
291  def testSimpleReduceWithIndexedSlices(self):
292    devices = ["/cpu:0", "/gpu:0"]
293    t0 = _make_indexed_slices([[1., 2.]], [1], [5, 2], devices[0])
294    t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], [5, 2], devices[1])
295    per_replica = value_lib.PerReplica(
296        value_lib.ReplicaDeviceMap(devices), (t0, t1))
297    result = cross_device_ops_lib._simple_reduce(
298        per_replica, devices[0], math_ops.add_n, reduce_util.ReduceOp.SUM)
299
300    # Test that the result is semantically equal to both the concatenated
301    # IndexedSlices with and without duplicate indices.
302    total_with_dups = _make_indexed_slices(
303        [[1., 2.], [3., 4.], [5., 6.]], [1, 1, 3], [5, 2], devices[0])
304    total_without_dups = _make_indexed_slices(
305        [[4., 6.], [5., 6.]], [1, 3], [5, 2], devices[0])
306    self._assert_indexed_slices_equal(total_with_dups, result)
307    self._assert_indexed_slices_equal(total_without_dups, result)
308
309  @combinations.generate(
310      combinations.combine(
311          cross_device_ops_instance=[
312              combinations.NamedObject(
313                  "ReductionToOneDevice",
314                  cross_device_ops_lib.ReductionToOneDevice()),
315              combinations.NamedObject(
316                  "AllReduceCrossDeviceOps",
317                  cross_device_ops_lib.AllReduceCrossDeviceOps())
318          ],
319          reduce_op=[reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN],
320          batch_reduce=[True, False],
321          mode=["graph", "eager"],
322          required_gpus=1))
323  def testIndexedSlicesAllReduce(self, cross_device_ops_instance, reduce_op,
324                                 batch_reduce):
325    devices = ["/cpu:0", "/gpu:0"]
326    dense_shape = [5, 2]
327    t0 = _make_indexed_slices([[1., 2.]], [1], dense_shape, devices[0])
328    t1 = _make_indexed_slices(
329        [[3., 4.], [5., 6.]], [1, 3], dense_shape, devices[1])
330    per_replica = value_lib.PerReplica(
331        value_lib.ReplicaDeviceMap(devices), (t0, t1))
332
333    if batch_reduce:
334      result = cross_device_ops_instance.batch_reduce(
335          reduce_op, [(per_replica, per_replica)])
336    else:
337      result = cross_device_ops_instance.reduce(
338          reduce_op, per_replica, per_replica)
339
340    total_indices_with_dups = [1, 1, 3]
341    total_indices_without_dups = [1, 3]
342
343    if reduce_op == reduce_util.ReduceOp.SUM:
344      total_values_with_dups = [[1., 2.], [3., 4.], [5., 6.]]
345      total_values_without_dups = [[4., 6.], [5., 6.]]
346    else:
347      assert reduce_op == reduce_util.ReduceOp.MEAN
348      total_values_with_dups = [[0.5, 1.], [1.5, 2.], [2.5, 3.]]
349      total_values_without_dups = [[2., 3.], [2.5, 3.]]
350
351    total_mirrored_with_dups = _make_mirrored_indexed_slices(
352        devices, total_values_with_dups, total_indices_with_dups, dense_shape)
353    total_mirrored_without_dups = _make_mirrored_indexed_slices(
354        devices, total_values_without_dups, total_indices_without_dups,
355        dense_shape)
356
357    # Test that the result is semantically equal to both the concatenated
358    # IndexedSlices, as well as when the duplicate indices are summed up.
359    if batch_reduce:
360      total_mirrored_with_dups = [total_mirrored_with_dups]
361      total_mirrored_without_dups = [total_mirrored_without_dups]
362
363    self._assert_values_equal(total_mirrored_with_dups, result)
364    self._assert_values_equal(total_mirrored_without_dups, result)
365
366
367class MultiWorkerCrossDeviceOpsTest(multi_worker_test_base.MultiWorkerTestBase,
368                                    CrossDeviceOpsTestBase):
369
370  worker_devices = [
371      "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1"
372  ]
373  multi_worker_allreduce_combinations = combinations.combine(
374      cross_device_ops=[
375          combinations.NamedObject(
376              "MultiWorkerAllReduce",
377              cross_device_ops_lib.MultiWorkerAllReduce(
378                  worker_devices, 2, ("pscpu/pscpu", 2, -1), 0, 0, 0)),
379          combinations.NamedObject(
380              "MultiWorkerAllReducePack",
381              cross_device_ops_lib.MultiWorkerAllReduce(
382                  worker_devices, 2, ("pscpu/pscpu", 2, -1), 1, 0, 0)),
383          combinations.NamedObject(
384              "MultiWorkerAllReduceAggregation",
385              cross_device_ops_lib.MultiWorkerAllReduce(
386                  worker_devices, 2, ("pscpu/pscpu", 2, -1), 0, 100, 10)),
387          combinations.NamedObject(
388              "MultiWorkerAllReduceMultipleSpecs",
389              cross_device_ops_lib.MultiWorkerAllReduce(
390                  worker_devices, 2, [("pscpu/pscpu", 2, 100),
391                                      ("xring", 2, -1)], 0, 0, 0)),
392      ],
393      distribution=[
394          combinations.NamedDistribution(
395              "MirroredCPU",
396              lambda: mirrored_strategy.MirroredStrategy(num_gpus_per_worker=0),
397              required_gpus=0),
398          combinations.NamedDistribution(
399              "Mirrored1GPU",
400              lambda: mirrored_strategy.MirroredStrategy(num_gpus_per_worker=1),
401              required_gpus=1),
402          combinations.NamedDistribution(
403              "Mirrored2GPUs",
404              lambda: mirrored_strategy.MirroredStrategy(num_gpus_per_worker=2),
405              required_gpus=2),
406          # pylint: disable=g-long-lambda
407          combinations.NamedDistribution(
408              "CoreMirroredCPU",
409              lambda: mirrored_strategy.CoreMirroredStrategy(["/device:CPU:0"]),
410              required_gpus=0),
411          combinations.NamedDistribution(
412              "CoreMirrored1GPU",
413              lambda: mirrored_strategy.CoreMirroredStrategy(["/device:GPU:0"]),
414              required_gpus=1),
415          combinations.NamedDistribution(
416              "CoreMirrored2GPUs",
417              lambda: mirrored_strategy.CoreMirroredStrategy(
418                  ["/device:GPU:0", "/device:GPU:1"]),
419              required_gpus=2),
420      ],
421      mode=["graph"])
422
423  @combinations.generate(multi_worker_allreduce_combinations)
424  def testReductionAndBroadcast(self, cross_device_ops, distribution):
425    distribution.configure(cluster_spec={
426        "worker":
427            ["/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1"]
428    })
429    with distribution.scope():
430      self._testReductionAndBroadcast(cross_device_ops, distribution)
431
432
433NUM_WORKERS = 3
434
435
436class MultiWorkerCollectiveAllReduceTest(
437    multi_worker_test_base.MultiWorkerTestBase, parameterized.TestCase):
438
439  collective_key_base = 100000
440
441  @classmethod
442  def setUpClass(cls):
443    """Create a local cluster with 3 workers."""
444    cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
445        num_workers=NUM_WORKERS, num_ps=0)
446
447  def setUp(self):
448    super(MultiWorkerCollectiveAllReduceTest, self).setUp()
449    # Reusing keys are not supported well. So we have to give a different
450    # collective key base for different tests.
451    MultiWorkerCollectiveAllReduceTest.collective_key_base += 100000
452
453  def _get_test_objects(self,
454                        task_type,
455                        task_id,
456                        num_gpus=0,
457                        use_strategy_object=False,
458                        local_mode=False):
459    collective_keys = cross_device_utils.CollectiveKeys(
460        group_key_start=10 * num_gpus +
461        MultiWorkerCollectiveAllReduceTest.collective_key_base,
462        instance_key_start=num_gpus * 100 +
463        MultiWorkerCollectiveAllReduceTest.collective_key_base,
464        instance_key_with_id_start=num_gpus * 10000 +
465        MultiWorkerCollectiveAllReduceTest.collective_key_base)
466    if local_mode:
467      if num_gpus:
468        devices = ["/device:GPU:%d" % i for i in range(num_gpus)]
469      else:
470        devices = ["/device:CPU:0"]
471
472      if use_strategy_object:
473        # Still using contrib CollectiveAllReduceStrategy because we can specify
474        # num_gpus in its constructor.
475        strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
476            num_gpus_per_worker=num_gpus)
477        strategy.extended._collective_keys = collective_keys
478        strategy.extended._cross_device_ops._collective_keys = collective_keys
479        return strategy, devices, ""
480      else:
481        collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce(
482            1, num_gpus, collective_keys=collective_keys)
483        return collective_all_reduce_ops, devices, ""
484    else:
485      if num_gpus:
486        devices = [
487            "/job:%s/task:%d/device:GPU:%d" % (task_type, task_id, i)
488            for i in range(num_gpus)
489        ]
490      else:
491        devices = ["/job:%s/task:%d" % (task_type, task_id)]
492
493      if use_strategy_object:
494        strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
495            num_gpus_per_worker=num_gpus)
496        strategy.configure(
497            cluster_spec=self._cluster_spec,
498            task_type=task_type,
499            task_id=task_id)
500        strategy.extended._collective_keys = collective_keys
501        strategy.extended._cross_device_ops._collective_keys = collective_keys
502        return (strategy, devices,
503                "grpc://" + self._cluster_spec[task_type][task_id])
504      else:
505        collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce(
506            NUM_WORKERS, num_gpus, collective_keys=collective_keys)
507        return (collective_all_reduce_ops, devices,
508                "grpc://" + self._cluster_spec[task_type][task_id])
509
510  def _assert_values_equal(self, left, right, sess):
511    if isinstance(left, list):
512      for l, r in zip(left, right):
513        self._assert_values_equal(l, r, sess)
514    else:
515      self.assertEqual(type(left), type(right))
516      self.assertEqual(set(left.devices), set(right.devices))
517
518      run_options = config_pb2.RunOptions()
519      run_options.experimental.collective_graph_key = 6
520
521      left_values = np.array(
522          sess.run(list(left.values), options=run_options)).flatten()
523      right_values = np.array(list(right.values)).flatten()
524      self.assertEqual(len(left_values), len(right_values))
525      for l, r in zip(left_values, right_values):
526        self.assertEqual(l, r)
527
528  def _test_reduction(self,
529                      task_type,
530                      task_id,
531                      num_gpus,
532                      use_strategy_object=False,
533                      local_mode=False):
534    collective_all_reduce, devices, master_target = self._get_test_objects(
535        task_type,
536        task_id,
537        num_gpus,
538        use_strategy_object=use_strategy_object,
539        local_mode=local_mode)
540    if local_mode:
541      num_workers = 1
542      worker_device = None
543    else:
544      num_workers = len(self._cluster_spec.get("chief", [])) + len(
545          self._cluster_spec.get("worker", []))
546      worker_device = "/job:%s/task:%d" % (task_type, task_id)
547
548    def _reduce(test_object, reduce_op, per_replica, destinations):
549      if use_strategy_object:
550        with test_object.scope():
551          # Mimic the behavior that distribution strategy usually strips the
552          # wrapper if there is only one value.
553          if len(per_replica.values) == 1:
554            per_replica = per_replica.values[0]
555          return test_object.extended.reduce_to(reduce_op, per_replica,
556                                                destinations)
557      else:
558        return test_object.reduce(reduce_op, per_replica, destinations)
559
560    def _batch_reduce(test_object, reduce_op, value_destination_pairs):
561      if use_strategy_object:
562        with test_object.scope():
563          return test_object.extended.batch_reduce_to(reduce_op,
564                                                      value_destination_pairs)
565      else:
566        return test_object.batch_reduce(reduce_op, value_destination_pairs)
567
568    with ops.Graph().as_default(), \
569         ops.device(worker_device), \
570         self.cached_session(target=master_target) as sess:
571      # Collective ops doesn't support scalar tensors, so we have to construct
572      # 1-d tensors.
573      values = [constant_op.constant([float(d)]) for d in range(len(devices))]
574      per_replica = _make_per_replica(values, devices)
575      mean = np.array([(len(devices) - 1.) / 2.])
576
577      values_2 = [constant_op.constant([d + 1.0]) for d in range(len(devices))]
578      per_replica_2 = _make_per_replica(values_2, devices)
579      mean_2 = np.array([mean[0] + 1.])
580
581      destination_mirrored = _fake_mirrored(1., devices)
582      destination_different = _fake_mirrored(1., _cpu_device)
583      destination_str = _cpu_device
584
585      all_destinations = [
586          destination_different, destination_mirrored, destination_str
587      ]
588
589      # test reduce()
590      for destinations in all_destinations:
591        self._assert_values_equal(
592            _reduce(
593                collective_all_reduce,
594                reduce_util.ReduceOp.MEAN,
595                per_replica,
596                destinations=destinations), _fake_mirrored(mean, destinations),
597            sess)
598        self._assert_values_equal(
599            _reduce(
600                collective_all_reduce,
601                reduce_util.ReduceOp.MEAN,
602                per_replica_2,
603                destinations=destinations), _fake_mirrored(
604                    mean_2, destinations), sess)
605        self._assert_values_equal(
606            _reduce(
607                collective_all_reduce,
608                reduce_util.ReduceOp.SUM,
609                per_replica,
610                destinations=destinations),
611            _fake_mirrored(mean * len(devices) * num_workers, destinations),
612            sess)
613        self._assert_values_equal(
614            _reduce(
615                collective_all_reduce,
616                reduce_util.ReduceOp.SUM,
617                per_replica_2,
618                destinations=destinations),
619            _fake_mirrored(mean_2 * len(devices) * num_workers, destinations),
620            sess)
621
622      # test batch_reduce()
623      for d1, d2 in itertools.product(all_destinations, all_destinations):
624        self._assert_values_equal(
625            _batch_reduce(collective_all_reduce, reduce_util.ReduceOp.MEAN,
626                          [(per_replica, d1), (per_replica_2, d2)]),
627            [_fake_mirrored(mean, d1),
628             _fake_mirrored(mean_2, d2)], sess)
629        self._assert_values_equal(
630            _batch_reduce(collective_all_reduce, reduce_util.ReduceOp.SUM,
631                          [(per_replica, d1), (per_replica_2, d2)]),
632            [
633                _fake_mirrored(mean * len(devices) * num_workers, d1),
634                _fake_mirrored(mean_2 * len(devices) * num_workers, d2)
635            ], sess)
636
637    return True
638
639  @combinations.generate(
640      combinations.combine(
641          mode=["graph"],
642          num_gpus=[0, 1, 2],
643          required_gpus=1,
644          use_strategy_object=[True, False]))
645  def testReductionDistributed(self, num_gpus, use_strategy_object):
646    if context.num_gpus() < num_gpus:
647      return
648    self._run_between_graph_clients(
649        self._test_reduction,
650        self._cluster_spec,
651        num_gpus,
652        use_strategy_object=use_strategy_object)
653
654  # Collective ops doesn't support strategy with one device.
655  @combinations.generate(
656      combinations.combine(
657          mode=["graph"],
658          num_gpus=[2],
659          required_gpus=2,
660          use_strategy_object=[True, False]))
661  def testReductionLocal(self, num_gpus, use_strategy_object):
662    if context.num_gpus() < num_gpus:
663      return
664    self._test_reduction(
665        None,
666        None,
667        num_gpus,
668        use_strategy_object=use_strategy_object,
669        local_mode=True)
670
671
672if __name__ == "__main__":
673  test.main()
674