• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Classes for different algorithms of reduction and broadcasting."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import copy
23import multiprocessing.dummy
24import multiprocessing.pool
25import threading
26
27import six
28
29from tensorflow.python.client import device_lib
30from tensorflow.python.distribute import collective_util
31from tensorflow.python.distribute import cross_device_utils
32from tensorflow.python.distribute import device_util
33from tensorflow.python.distribute import distribute_utils
34from tensorflow.python.distribute import ps_values
35from tensorflow.python.distribute import reduce_util
36from tensorflow.python.distribute import tpu_values
37from tensorflow.python.distribute import values as value_lib
38from tensorflow.python.distribute import values_util
39from tensorflow.python.eager import context
40from tensorflow.python.eager import def_function
41from tensorflow.python.framework import kernels
42from tensorflow.python.framework import ops
43from tensorflow.python.framework import tensor_util
44from tensorflow.python.ops import array_ops
45from tensorflow.python.ops import math_ops
46from tensorflow.python.ops import resource_variable_ops
47from tensorflow.python.platform import tf_logging as logging
48from tensorflow.python.util import nest
49from tensorflow.python.util.tf_export import tf_export
50from tensorflow.tools.docs import doc_controls
51
52
53def check_destinations(destinations):
54  """Checks whether `destinations` is not empty.
55
56  Args:
57    destinations: a `DistributedValues`, variable, or string object.
58
59  Returns:
60    Boolean which is True if `destinations` is not empty.
61  """
62  # Calling bool() on a ResourceVariable is not allowed.
63  if isinstance(destinations,
64                (resource_variable_ops.BaseResourceVariable, ops.Tensor)):
65    return bool(destinations.device)
66  return bool(destinations)
67
68
69def validate_destinations(destinations):
70  """Validates the `destination` is one of expected types."""
71  if not isinstance(
72      destinations,
73      (value_lib.DistributedValues, ops.Tensor, ops.IndexedSlices,
74       ps_values.AggregatingVariable, six.string_types,
75       tpu_values.TPUMirroredVariable
76      )) and not resource_variable_ops.is_resource_variable(destinations):
77    raise ValueError("destinations must be one of a `DistributedValues` object,"
78                     " a tf.Variable object, or a device string.")
79
80  if not check_destinations(destinations):
81    raise ValueError("destinations can not be empty")
82
83
84def reduce_non_distributed_value(reduce_op,
85                                 value,
86                                 destinations,
87                                 num_replicas_in_graph,
88                                 canonicalize_devices=True):
89  """Reduce a non-DistributedValue `value` to `destinations`."""
90  if isinstance(value, value_lib.DistributedValues):
91    raise ValueError("You are passing a `DistributedValues` to "
92                     "`reduce_non_distributed_value`, which is not allowed.")
93
94  # If the same value is present on all replicas then the PerReplica value will
95  # be a single value. We also handle the case when `value` is a single value
96  # and equal to 0.
97  # TODO:(b/138823479): handle the tensor value properly.
98  if not tensor_util.is_tf_type(value) and value == 0:
99    return 0
100  # If there is only a single value and the reduce op is MEAN,
101  # that value should be on all destinations.
102  if reduce_op == reduce_util.ReduceOp.MEAN:
103    return value
104  elif num_replicas_in_graph != 1:
105    # We do not support a reduce op of SUM if the value is the same across
106    # all replicas. We call this as part of assign functions for
107    # MirroredVariables and summing up identical values across replicas is not
108    # clearly defined.
109    raise ValueError("A non-DistributedValues value %s cannot be reduced with "
110                     "the given reduce op %s." % (value, reduce_op))
111  else:
112    validate_destinations(destinations)
113    return simple_broadcast(
114        value, destinations, canonicalize_devices=canonicalize_devices)
115
116
117def _make_tensor_into_per_replica(input_tensor):
118  """Converts a single tensor into a PerReplica object."""
119  if isinstance(input_tensor, (tuple, list)):
120    raise ValueError("Cannot convert `input_tensor` to a `PerReplica` object, "
121                     "got %r but expected a object that is not a tuple or list."
122                     % (input_tensor,))
123  if isinstance(input_tensor, value_lib.PerReplica):
124    return input_tensor
125  elif hasattr(input_tensor, "device"):
126    return value_lib.PerReplica((input_tensor,))
127  else:
128    raise ValueError("Cannot convert `input_tensor` to a `PerReplica` object "
129                     "because it doesn't have device set.")
130
131
132def _normalize_value_destination_pairs(value_destination_pairs):
133  """Converts each tensor into a PerReplica object in the input list."""
134  result = []
135
136  value_destination_pairs = list(value_destination_pairs)
137
138  if not isinstance(value_destination_pairs, (list, tuple)):
139    raise ValueError("`value_destination_pairs` should be a list or tuple")
140  for pair in value_destination_pairs:
141    if not isinstance(pair, tuple):
142      raise ValueError(
143          "Each element of `value_destination_pairs` should be a tuple.")
144    if len(pair) != 2:
145      raise ValueError("Each element of `value_destination_pairs` should be a "
146                       "tuple of size 2.")
147
148    per_replica = _make_tensor_into_per_replica(pair[0])
149    result.append((per_replica, pair[1]))
150  return result
151
152
153def _validate_value_destination_pairs(value_destination_pairs):
154  """Validates value_destination_pairs are valid."""
155  # TODO(yuefengz): raise exceptions instead of returning False.
156  if not value_destination_pairs: return False
157  if not isinstance(value_destination_pairs, (list, tuple)): return False
158  if not all(isinstance(pair, tuple) for pair in value_destination_pairs):
159    return False
160  if not all(isinstance(v[0], value_lib.PerReplica)
161             for v in value_destination_pairs):
162    return False
163  return True
164
165
166# TODO(yuefengz): consider calling this function in the caller of
167# CrossDeviceOps.
168def get_devices_from(destinations, canonicalize_devices=True):
169  if isinstance(destinations, value_lib.DistributedValues):
170    return destinations._devices  # pylint: disable=protected-access
171  if canonicalize_devices:
172    if isinstance(destinations, six.string_types):
173      return (device_util.resolve(destinations),)
174    return (device_util.resolve(destinations.device),)
175
176  # Let placer canonicalize and resolve destination devices.
177  if isinstance(destinations, six.string_types):
178    return (device_util.canonicalize_without_job_and_task(destinations),)
179  return (device_util.canonicalize_without_job_and_task(destinations.device),)
180
181
182def _devices_match(left, right, canonicalize_devices=True):
183  return left is right or set(get_devices_from(
184      left, canonicalize_devices)) == set(
185          get_devices_from(right, canonicalize_devices))
186
187
188def _all_devices_match(value_destination_pairs, canonicalize_devices=True):
189  if not all(
190      _devices_match(v, d, canonicalize_devices)
191      for v, d in value_destination_pairs):
192    return False
193  if not all(
194      _devices_match(v, value_destination_pairs[0][0], canonicalize_devices)
195      for v, _ in value_destination_pairs[1:]):
196    return False
197  return True
198
199
200def simple_broadcast(value,
201                     destinations,
202                     always_mirrored=False,
203                     canonicalize_devices=True):
204  """Broadcast `value` to `destinations` using simple copies."""
205  devices = get_devices_from(destinations, canonicalize_devices)
206  if len(devices) == 1 and not always_mirrored:
207    return cross_device_utils.copy_tensor_or_indexed_slices_to_device(
208        value, devices[0])
209  else:
210    value_updates = []
211    for d in devices:
212      value_updates.append(
213          cross_device_utils.copy_tensor_or_indexed_slices_to_device(value, d))
214    return distribute_utils.regroup(value_updates,
215                                    wrap_class=value_lib.Mirrored)
216
217
218def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn,
219                   reduce_op):
220  """Reduces the value by accumulation_fn and reduce_op."""
221  all_values = per_replica_value.values
222  if not all_values:
223    raise ValueError("`per_replica_value` must be non-empty")
224  count = len(all_values)
225
226  with ops.device(reduce_to_device):
227    with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
228      reduced = cross_device_utils.aggregate_tensors_or_indexed_slices(
229          all_values, accumulation_fn)
230      if reduce_op == reduce_util.ReduceOp.MEAN:
231        reduced = cross_device_utils.divide_by_n_tensors_or_indexed_slices(
232            reduced, count)
233      elif reduce_op != reduce_util.ReduceOp.SUM:
234        raise ValueError("`reduce_op` must be Reduce.SUM or Reduce.MEAN.")
235  return reduced
236
237
238def _simple_gather(per_replica_value, reduce_to_device, axis):
239  """Concatenate all values in the DistributedValues input and return."""
240  all_values = per_replica_value.values
241  if not all_values:
242    raise ValueError("`per_replica_value` must be non-empty")
243
244  with ops.device(reduce_to_device):
245    with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
246      gathered = array_ops.concat(all_values, axis)
247  return gathered
248
249
250@tf_export("distribute.CrossDeviceOps")
251class CrossDeviceOps(object):
252  """Base class for cross-device reduction and broadcasting algorithms.
253
254  The main purpose of this class is to be passed to
255  `tf.distribute.MirroredStrategy` in order to choose among different cross
256  device communication implementations. Prefer using the methods of
257  `tf.distribute.Strategy` instead of the ones of this class.
258
259  Implementations:
260  * `tf.distribute.ReductionToOneDevice`
261  * `tf.distribute.NcclAllReduce`
262  * `tf.distribute.HierarchicalCopyAllReduce`
263  """
264
265  def __init__(self):
266    self._canonicalize_devices = True
267    pass
268
269  @property
270  def _num_between_graph_workers(self):
271    # Returns 1 by default, the value may be overridden by sub classes.
272    return 1
273
274  def reduce(self, reduce_op, per_replica_value, destinations, options=None):
275    """Reduce `per_replica_value` to `destinations`.
276
277    See `tf.distribute.StrategyExtended.reduce_to`. This can only be called in
278    the cross-replica context.
279
280    Args:
281      reduce_op: a `tf.distribute.ReduceOp` specifying how values should be
282        combined.
283      per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor`
284        like object.
285      destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
286        `tf.Tensor` alike object, or a device string. It specifies the devices
287        to reduce to. To perform an all-reduce, pass the same to `value` and
288        `destinations`. Note that if it's a `tf.Variable`, the value is reduced
289        to the devices of that variable, and this method doesn't update the
290        variable.
291      options: a `tf.distribute.experimental.CommunicationOptions`. See
292        `tf.distribute.experimental.CommunicationOptions` for details.
293
294    Returns:
295      A `tf.Tensor` or `tf.distribute.DistributedValues`.
296
297    Raises:
298      ValueError: if per_replica_value can't be converted to a
299        `tf.distribute.DistributedValues` or if destinations is not a string,
300        `tf.Variable` or `tf.distribute.DistributedValues`.
301    """
302    if options is None:
303      options = collective_util.Options()
304    if not isinstance(per_replica_value, value_lib.DistributedValues):
305      per_replica_value = _make_tensor_into_per_replica(per_replica_value)
306
307    validate_destinations(destinations)
308
309    # Shortcut if `per_replica_value` only contains one value.
310    if self._num_between_graph_workers == 1 and len(
311        per_replica_value.values) == 1 and _devices_match(
312            per_replica_value, destinations, self._canonicalize_devices):
313      with ops.device(per_replica_value.values[0].device):
314        v = array_ops.identity(per_replica_value.values[0])
315      return distribute_utils.regroup((v,), wrap_class=value_lib.Mirrored)
316
317    if options is None:
318      options = collective_util.Options()
319    return self.reduce_implementation(reduce_op, per_replica_value,
320                                      destinations, options)
321
322  def _gather(self, per_replica_value, destinations, axis, options=None):
323    """Gather `per_replica_value` to `destinations`.
324
325    Args:
326      per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor`
327        like object.
328      destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
329        `tf.Tensor` alike object, or a device string. It specifies the devices
330        to gather to. To perform an all-gather, pass the same to `value` and
331        `destinations`. Note that if it's a `tf.Variable`, the value is gathered
332        to the devices of that variable, and this method doesn't update the
333        variable.
334      axis: specifies the dimension to gather along within each replica's
335        tensor.
336      options: a `tf.distribute.experimental.CommunicationOptions`. See
337        `tf.distribute.experimental.CommunicationOptions` for details.
338
339    Returns:
340      A `tf.Tensor` or `tf.distribute.DistributedValues`
341
342    Raises:
343      ValueError: if per_replica_value can't be converted to a
344        `tf.distribute.DistributedValues` or if destinations is not a string,
345        `tf.Variable` or `tf.distribute.DistributedValues`.
346    """
347    if isinstance(per_replica_value, ops.IndexedSlices):
348      raise NotImplementedError("gather/all_gather does not support "
349                                "IndexedSlices")
350    if options is None:
351      options = collective_util.Options()
352
353    if not isinstance(per_replica_value, value_lib.DistributedValues):
354      per_replica_value = _make_tensor_into_per_replica(per_replica_value)
355
356    validate_destinations(destinations)
357
358    # Shortcut if `per_replica_value` only contains one value.
359    if self._num_between_graph_workers == 1 and len(
360        per_replica_value.values) == 1 and _devices_match(
361            per_replica_value, destinations, self._canonicalize_devices):
362      with ops.device(per_replica_value.values[0].device):
363        v = array_ops.identity(per_replica_value.values[0])
364      return distribute_utils.regroup((v,), wrap_class=value_lib.Mirrored)
365
366    return self._gather_implementation(per_replica_value, destinations, axis,
367                                       options)
368
369  def _gather_implementation(self, per_replica_value, destinations, axis,
370                             options):
371    """Implementation of `gather` method of `tf.distribute.CrossDeviceOps`.
372
373    Overriding this method is useful for subclass implementers.
374
375    Args:
376      per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor`
377        like object.
378      destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
379        `tf.Tensor` alike object, or a device string. It specifies the devices
380        to gather to. To perform an all-gather, pass the same to `value` and
381        `destinations`. Note that if it's a `tf.Variable`, the value is gathered
382        to the devices of that variable, this method doesn't update the
383        variable.
384      axis: specifies the dimension to gather along within each replica's
385        tensor.
386      options: a `tf.distribute.experimental.CommunicationOptions`. See
387        `tf.distribute.experimental.CommunicationOptions` for details.
388
389    Returns:
390      A `tf.Tensor` or `tf.distribute.DistributedValues`.
391
392    Raises:
393      ValueError: if per_replica_value can't be converted to a
394        `tf.distribute.DistributedValues` or if destinations is not a string,
395        `tf.Variable` or `tf.distribute.DistributedValues`.
396    """
397    raise NotImplementedError(
398        "_gather method must be implemented in descendants.")
399
400  def batch_reduce(self, reduce_op, value_destination_pairs, options=None):
401    """Reduce values to destinations in batches.
402
403    See `tf.distribute.StrategyExtended.batch_reduce_to`. This can only be
404    called in the cross-replica context.
405
406    Args:
407      reduce_op: a `tf.distribute.ReduceOp` specifying how values should be
408        combined.
409      value_destination_pairs: a sequence of (value, destinations) pairs. See
410        `tf.distribute.CrossDeviceOps.reduce` for descriptions.
411      options: a `tf.distribute.experimental.CommunicationOptions`. See
412        `tf.distribute.experimental.CommunicationOptions` for details.
413
414    Returns:
415      A list of `tf.Tensor` or `tf.distribute.DistributedValues`, one per pair
416      in `value_destination_pairs`.
417
418    Raises:
419      ValueError: if `value_destination_pairs` is not an iterable of
420        tuples of `tf.distribute.DistributedValues` and destinations.
421    """
422    if options is None:
423      options = collective_util.Options()
424    # TODO(yuefengz): if destinations are different, split into several
425    # `_batch_reduce` invocations.
426    if not _validate_value_destination_pairs(value_destination_pairs):
427      # If the first element of each pair is a tensor, we try to turn it into a
428      # PerReplica object.
429      value_destination_pairs = _normalize_value_destination_pairs(
430          value_destination_pairs)
431
432    for _, d in value_destination_pairs:
433      validate_destinations(d)
434
435    # Shortcut all PerReplica objects only contain one value.
436    if self._num_between_graph_workers == 1 and _all_devices_match(
437        value_destination_pairs, self._canonicalize_devices) and len(
438            value_destination_pairs[0][0].values) == 1:
439      return [
440          distribute_utils.regroup(v.values, wrap_class=value_lib.Mirrored)
441          for v, _ in value_destination_pairs
442      ]
443
444    if options is None:
445      options = collective_util.Options()
446    return self.batch_reduce_implementation(reduce_op, value_destination_pairs,
447                                            options)
448
449  def broadcast(self, tensor, destinations):
450    """Broadcast `tensor` to `destinations`.
451
452    This can only be called in the cross-replica context.
453
454    Args:
455      tensor: a `tf.Tensor` like object. The value to broadcast.
456      destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
457        `tf.Tensor` alike object, or a device string. It specifies the devices
458        to broadcast to. Note that if it's a `tf.Variable`, the value is
459        broadcasted to the devices of that variable, this method doesn't update
460        the variable.
461
462    Returns:
463      A `tf.Tensor` or `tf.distribute.DistributedValues`.
464    """
465    validate_destinations(destinations)
466    return self.broadcast_implementation(tensor, destinations)
467
468  @doc_controls.for_subclass_implementers
469  def reduce_implementation(self, reduce_op, per_replica_value, destinations,
470                            options):
471    """Implementation of `reduce`.
472
473    Overriding this method is useful for subclass implementers.
474
475    Args:
476      reduce_op: a `tf.distribute.ReduceOp` specifying how values should be
477        combined.
478      per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor`
479        like object.
480      destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
481        `tf.Tensor` alike object, or a device string. It specifies the devices
482        to reduce to. To perform an all-reduce, pass the same to `value` and
483        `destinations`. Note that if it's a `tf.Variable`, the value is reduced
484        to the devices of that variable, this method doesn't update the
485        variable.
486      options: a `tf.distribute.experimental.CommunicationOptions`. See
487        `tf.distribute.experimental.CommunicationOptions` for details.
488
489    Returns:
490      A `tf.Tensor` or `tf.distribute.DistributedValues`.
491
492    Raises:
493      ValueError: if per_replica_value can't be converted to a
494        `tf.distribute.DistributedValues` or if destinations is not a string,
495        `tf.Variable` or `tf.distribute.DistributedValues`.
496    """
497    raise NotImplementedError(
498        "_reduce method must be implemented in descendants.")
499
500  @doc_controls.for_subclass_implementers
501  def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
502                                  options):
503    """Implementation of `batch_reduce`.
504
505    Overriding this method is useful for subclass implementers.
506
507    Args:
508      reduce_op: a `tf.distribute.ReduceOp` specifying how values should be
509        combined.
510      value_destination_pairs: a sequence of (value, destinations) pairs. See
511        `reduce` for descriptions.
512      options: a `tf.distribute.experimental.CommunicationOptions`. See
513        `tf.distribute.experimental.CommunicationOptions` for details.
514
515    Returns:
516      A list of `tf.Tensor` or `tf.distribute.DistributedValues`, one per pair
517      in `value_destination_pairs`.
518
519    Raises:
520      ValueError: if `value_destination_pairs` is not an iterable of
521        tuples of `tf.distribute.DistributedValues` and destinations.
522    """
523    raise NotImplementedError(
524        "batch_reduce_implementation method must be implemented in descendants."
525    )
526
527  @doc_controls.for_subclass_implementers
528  def broadcast_implementation(self, tensor, destinations):
529    """Implementation of `broadcast`.
530
531    Args:
532      tensor: a `tf.Tensor` like object. The value to broadcast.
533      destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
534        `tf.Tensor` alike object, or a device string. It specifies the devices
535        to broadcast to.
536        `destinations`. Note that if it's a `tf.Variable`, the value is
537        broadcasted to the devices of that variable, this method doesn't update
538        the variable.
539
540    Returns:
541      A `tf.Tensor` or `tf.distribute.DistributedValues`.
542    """
543    return simple_broadcast(
544        tensor,
545        destinations,
546        always_mirrored=True,
547        canonicalize_devices=self._canonicalize_devices)
548
549  # ========================== Collective APIs ================================
550  #
551  # Different than `reduce`, `batch_reduce` and `broadcast` which must be called
552  # in cross-replcia context, collective APIs are to be called in replica
553  # context.
554
555  def _all_reduce(self, reduce_op, value, replica_id, options):
556    """All-reduce the `value` across all replicas so that all get the result.
557
558    `value` can be a nested structure of tensors or `IndexedSlices`. The
559    implementation should generally batch the all-reduces when possible.
560    `options` can be set to hint the batching behavior.
561
562    This API must be called in a replica context.
563
564    Args:
565      reduce_op: A `tf.distribute.ReduceOp` value specifying how values should
566        be combined.
567      value: Value to be reduced. A tensor or a nested structure of tensors or
568        `IndexedSlices`.
569      replica_id: An interger indicating the id of the replica where this
570        all_reduce is called under. This is the local replica id that ranges
571        from 0 to len(local_devices) - 1.
572      options: A `tf.distribute.experimental.CommunicationOptions`.
573
574    Returns:
575      A tensor/IndexedSlices or a nested strucutre of tensors/IndexedSlices with
576      the reduced values. The structure is the same as `value`.
577    """
578    raise NotImplementedError("_all_reduce must be implemented in descendants.")
579
580
581@tf_export("distribute.ReductionToOneDevice")
582class ReductionToOneDevice(CrossDeviceOps):
583  """A CrossDeviceOps implementation that copies values to one device to reduce.
584
585  This implementation always copies values to one device to reduce them, then
586  broadcast reduced values to the destinations. It doesn't support efficient
587  batching.
588
589  Here is how you can use `ReductionToOneDevice` in
590  `tf.distribute.MirroredStrategy`:
591
592  ```
593    strategy = tf.distribute.MirroredStrategy(
594      cross_device_ops=tf.distribute.ReductionToOneDevice())
595  ```
596  """
597
598  def __init__(self, reduce_to_device=None, accumulation_fn=None):
599    """Initializes with a device to reduce to and a way to accumulate.
600
601    Args:
602      reduce_to_device: the intermediate device to reduce to. If None, reduce
603        to the first device in `destinations` of the `reduce` method.
604      accumulation_fn: a function that does accumulation.  If None,
605        `tf.math.add_n` is used.
606    """
607    self.reduce_to_device = reduce_to_device
608    self.accumulation_fn = accumulation_fn or math_ops.add_n
609    super(ReductionToOneDevice, self).__init__()
610
611  def reduce_implementation(self, reduce_op, per_replica_value, destinations,
612                            options):
613    del options  # Unused.
614    if check_destinations(destinations):
615      devices = get_devices_from(destinations, self._canonicalize_devices)
616    else:
617      devices = get_devices_from(per_replica_value, self._canonicalize_devices)
618    reduce_to_device = self.reduce_to_device or devices[0]
619    logging.log_first_n(
620        logging.INFO,
621        "Reduce to %s then broadcast to %r." % (reduce_to_device, devices), 10)
622    reduced = _simple_reduce(per_replica_value, reduce_to_device,
623                             self.accumulation_fn, reduce_op)
624    return self.broadcast(reduced, destinations)
625
626  def _gather_implementation(self, per_replica_value, destinations, axis,
627                             options):
628    del options  # Unused.
629    if check_destinations(destinations):
630      devices = get_devices_from(destinations, self._canonicalize_devices)
631    else:
632      devices = get_devices_from(per_replica_value, self._canonicalize_devices)
633    reduce_to_device = self.reduce_to_device or devices[0]
634    logging.log_first_n(
635        logging.INFO,
636        "Gather to %s then broadcast to %r." % (reduce_to_device, devices), 10)
637    gathered = _simple_gather(per_replica_value, reduce_to_device, axis)
638    return self.broadcast(gathered, destinations)
639
640  def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
641                                  options):
642    return [
643        self.reduce_implementation(
644            reduce_op, t, destinations=v, options=options)
645        for t, v in value_destination_pairs
646    ]
647
648
649def _group_value_by_device(per_replica_values):
650  """Group values into sublists by their devices.
651
652  This grouping is needed to call the all-reduce library because it expects a
653  list of the following form:
654    [[(grad0_gpu0, v0_gpu0), (grad1_gpu0, v1_gpu0), (grad2_gpu0, v2_gpu0) ...],
655     [(grad0_gpu1, v0_gpu1), (grad1_gpu1, v1_gpu1), (grad2_gpu1, v2_gpu1) ...],
656     [(grad0_gpu2, v0_gpu2), (grad1_gpu0, v1_gpu2), (grad2_gpu0, v2_gpu2) ...],
657     ...
658    ]
659
660  Args:
661    per_replica_values: a list of PerReplica objects.
662
663  Returns:
664    a list of lists, each sublist has components for its corresponding device of
665      PerReplica objects, paired with a None.
666  """
667  destinations = per_replica_values[0]._devices  # pylint: disable=protected-access
668  grouped = [[] for _ in range(len(destinations))]
669  for per_replica_value in per_replica_values:
670    # pylint: disable=protected-access
671    for i, v in enumerate(per_replica_value.values):
672      assert per_replica_value._devices == destinations
673      grouped[i].append((v, None))
674  return grouped
675
676
677def _ungroup_and_make_mirrored(grouped_reduced,
678                               destinations,
679                               reduce_op,
680                               num_between_graph_workers=1):
681  """Ungroup results from all-reduce and make Mirrored objects.
682
683  Each all-reduce result will be divided by the number of destinations before
684  Mirrored objects are created if reduce_op is "mean".
685
686  Args:
687    grouped_reduced: a list of lists, each sublist has components for each
688      device, paired with a None. It is the result from
689      cross_device_utils.aggregate_gradients_using*.
690    destinations: a value to colocate the result with.
691    reduce_op: Indicates how values will be aggregated. Accepted values
692      are `tf.distribute.ReduceOp.SUM`, `tf.distribute.ReduceOp.MEAN`.
693    num_between_graph_workers: number of workers in the between-graph
694      replication.
695
696  Returns:
697    a list of Mirrored objects.
698  """
699  num_replicas = len(get_devices_from(destinations)) * num_between_graph_workers
700  index = [[] for _ in range(len(grouped_reduced[0]))]
701  for per_replica_reduced in grouped_reduced:
702    for i, (v, _) in enumerate(per_replica_reduced):
703      if reduce_op == reduce_util.ReduceOp.MEAN:
704        with ops.device(v.device):
705          index[i].append(v / num_replicas)
706      else:
707        index[i].append(v)
708  return [distribute_utils.regroup(
709      v, wrap_class=value_lib.Mirrored) for v in index]
710
711
712class _ConcatAndSplitPacker(object):
713  """Concatenate and split tensors for reduction."""
714
715  def __init__(self, num_packs=1):
716    """Initialize the _ConcatAndSplitPacker object.
717
718    Args:
719      num_packs: specifies the number of split packs that will be
720        formed.
721
722    Raises:
723      ValueError: if num_packs is not greater than 0.
724    """
725    if num_packs <= 0:
726      raise ValueError("num_packs must be greater than zero.")
727    self.num_packs = num_packs
728
729  def pack(self, grouped_grads_and_vars):
730    """Pack tensors."""
731    self.grouped_grads_and_vars = grouped_grads_and_vars
732    self.all_device_shapes = []
733    self.all_device_sizes = []
734
735    device_grad_packs = []
736    for device_grads_and_vars in grouped_grads_and_vars:
737      with ops.colocate_with(device_grads_and_vars[0][0]):
738        # Flatten all the grads.
739        flat_grads = [
740            array_ops.reshape(g, [-1]) for g, _ in device_grads_and_vars
741        ]
742        # Remember the original shape of all the grads.
743        device_shapes = [array_ops.shape(g) for g, _ in device_grads_and_vars]
744        # Remember the original sizes of all the grads.
745        device_sizes = [array_ops.size(g) for g, _ in device_grads_and_vars]
746        # Concat all the flat grads into a big flat tensor.
747        concat_grads = array_ops.concat(flat_grads, 0)
748
749        # Split the big tensor into num_splits packs. In cases where the
750        # total size is not divisible num_splits, the last pack gets
751        # more elements.
752        # TODO(zhengxq): it is also possible to optimize away all the concat
753        # as well.
754        num_splits = self.num_packs
755
756        # The array_ops.size function will sometimes remove static shapes. So if
757        # all gradient shapes are defined, we use another method to get the
758        # total size.
759        # TODO(yuefengz): move this logic to array_ops.size.
760        if all(g.shape.is_fully_defined() for g, _ in device_grads_and_vars):
761          total_grad_size = sum(
762              [g.shape.num_elements() for g, _ in device_grads_and_vars])
763        else:
764          total_grad_size = array_ops.size(concat_grads)
765
766        split_size = total_grad_size // num_splits
767        split_size_last = total_grad_size - split_size * (num_splits - 1)
768        split_sizes = [split_size] * (num_splits - 1) + [split_size_last]
769        grad_packs = array_ops.split(concat_grads, split_sizes)
770
771        # Ready to aggregate the repacked gradients, with fake variables.
772        # TODO(zhengxq): It is hacky to have to use fake variables.
773        # We should remove the need for variables in
774        # aggregate_gradients_using*.
775        device_grad_packs.append(zip(grad_packs, [None] * num_splits))
776        self.all_device_shapes.append(device_shapes)
777        self.all_device_sizes.append(device_sizes)
778
779    return device_grad_packs
780
781  def unpack(self, summed_device_grad_packs):
782    """Reverse the pack."""
783    aggregated_device_grads = []
784    for (summed_device_grad_packs,
785         device_grads_and_vars, device_shapes, device_sizes) in zip(
786             summed_device_grad_packs, self.grouped_grads_and_vars,
787             self.all_device_shapes, self.all_device_sizes):
788      # pylint: enable=line-too-long
789      # Reverse the packing operations in the previous steps. Form the
790      # summed gradients back into their original shapes.
791      with ops.colocate_with(summed_device_grad_packs[0][0]):
792        # Form a list of the summed grad packs.
793        device_grad_packs = [g for g, _ in summed_device_grad_packs]
794
795        # Concat them back into a big flat tensor.
796        device_grads_concat = array_ops.concat(device_grad_packs, 0)
797
798        # Split the tensors back into their original sizes.
799        grads_with_sizes = array_ops.split(device_grads_concat, device_sizes)
800
801        # Reshape the tensors back into their original shapes.
802        grads_with_shapes = [
803            array_ops.reshape(grad, shape)
804            for shape, grad in zip(device_shapes, grads_with_sizes)
805        ]
806
807        # Form the list with the original list of variables.
808        summed_device_grads = [
809            (g, v) for g, (_, v) in zip(grads_with_shapes,
810                                        device_grads_and_vars)
811        ]
812        aggregated_device_grads.append(summed_device_grads)
813    return aggregated_device_grads
814
815
816def _pack_tensors(device_grads, num_packs=0):
817  """Pack tensors if specified."""
818  if num_packs > 0:
819    tensor_packer = _ConcatAndSplitPacker(num_packs)
820    device_grad_packs = tensor_packer.pack(device_grads)
821  else:
822    tensor_packer = None
823    device_grad_packs = device_grads
824  return device_grad_packs, tensor_packer
825
826
827def _unpack_tensors(reduced, tensor_packer=None):
828  """Unpack tensors if they are packed before all-reduce."""
829  if tensor_packer:
830    return tensor_packer.unpack(reduced)
831  return reduced
832
833
834class AllReduceCrossDeviceOps(CrossDeviceOps):
835  """All-reduce implementation of CrossDeviceOps.
836
837  It performs all-reduce when applicable using NCCL or hierarchical copy. For
838  the batch API, tensors will be repacked or aggregated for more efficient
839  cross-device transportation.
840
841  For reduces that are not all-reduce, it falls back to
842  `tf.distribute.ReductionToOneDevice`.
843  """
844
845  def __init__(self, all_reduce_alg="nccl", num_packs=1):
846    """Initializes the object.
847
848    Args:
849      all_reduce_alg: the all-reduce algorithm to use, currently only "nccl" or
850        "hierarchical_copy" are supported.
851      num_packs: a non-negative integer. The number of packs to split values
852        into. If zero, no packing will be done.
853    """
854    self._all_reduce_alg = all_reduce_alg
855    self._num_packs = num_packs
856    self._simple_cross_replica_ops = ReductionToOneDevice()
857    super(AllReduceCrossDeviceOps, self).__init__()
858
859  def reduce_implementation(self, reduce_op, per_replica_value, destinations,
860                            options):
861    del options  # Unused.
862    # To use NCCL or all-reduce, source and destination devices should match,
863    # and none of the devices should be CPU.
864    if (_devices_match(per_replica_value, destinations) and
865        not any("cpu" in d.lower() for d in get_devices_from(destinations))):
866      return self._batch_all_reduce(reduce_op, [per_replica_value])[0]
867    else:
868      return self._simple_cross_replica_ops.reduce(reduce_op, per_replica_value,
869                                                   destinations)
870
871  def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
872                                  options):
873    if _all_devices_match(value_destination_pairs):
874      return self._batch_all_reduce(reduce_op,
875                                    [v[0] for v in value_destination_pairs])
876    else:
877      return [
878          self.reduce_implementation(reduce_op, value, dest, options)
879          for value, dest in value_destination_pairs
880      ]
881
882  def _batch_all_reduce(self, reduce_op, per_replica_values):
883    """All-reduce algorithm in a batch."""
884    dense_values, dense_indices, sparse_values, sparse_indices = (
885        cross_device_utils.split_by_sparsity(per_replica_values))
886    if dense_values:
887      dense_results = self._do_batch_all_reduce(reduce_op, dense_values)
888    else:
889      dense_results = []
890    if sparse_values:
891      sparse_results = self._do_batch_all_reduce_sparse(reduce_op,
892                                                        sparse_values)
893    else:
894      sparse_results = []
895    return cross_device_utils.stitch_values(((dense_results, dense_indices),
896                                             (sparse_results, sparse_indices)))
897
898  def _do_batch_all_reduce(self, reduce_op, dense_values):
899    """Run batch all-reduces."""
900    logging.log_first_n(
901        logging.INFO,
902        "batch_all_reduce: %d all-reduces with algorithm = %s, num_packs = %d" %
903        (len(dense_values), self._all_reduce_alg, self._num_packs), 10)
904
905    destinations = dense_values[0]._devices  # pylint: disable=protected-access
906    grouped = _group_value_by_device(dense_values)
907
908    # device_grad_packs:
909    # [[(t0_gpu0, None), (t1_gpu0, None)], [(t0_gpu1, None), (t1_gpu1, None)]]
910    device_grad_packs, tensor_packer = _pack_tensors(grouped, self._num_packs)
911
912    # The actual aggregation of the repacked gradients. Note that they are
913    # sharded among different aggregation trees. So it is important to strike
914    # the balance on num_splits.
915    if self._all_reduce_alg == "nccl":
916      # TODO(yuefengz): merge this into the all-reduce library.
917      reduced = cross_device_utils.aggregate_gradients_using_nccl(
918          device_grad_packs)
919    else:
920      # TODO(yuefengz): check that gpu ids in `destinations` are in ascending
921      # order.
922      reduced = (
923          cross_device_utils.aggregate_gradients_using_hierarchical_copy(
924              destinations, device_grad_packs))
925
926    reduced = _unpack_tensors(reduced, tensor_packer)
927    return _ungroup_and_make_mirrored(reduced, dense_values[0], reduce_op)
928
929  def _do_batch_all_reduce_sparse(self, reduce_op, sparse_values):
930    """Run batch all-reduce for sparse values."""
931    logging.log_first_n(
932        logging.WARN,
933        "Efficient allreduce is not supported for %d IndexedSlices" %
934        len(sparse_values), 10)
935    # Use `sparse_values` as destinations to do all-reduces. It is effectively
936    # an allgather under the hood but not an efficient one.
937    return self._simple_cross_replica_ops.batch_reduce(
938        reduce_op, zip(sparse_values, sparse_values))
939
940  def _gather_implementation(self, per_replica_value, destinations, axis,
941                             options):
942    logging.warning("gather/all_gather with NCCL or HierarchicalCopy is not "
943                    "supported. Falling back to gather on one device and "
944                    "then broadcast. We're working on a more efficient "
945                    "implementation.")
946    return ReductionToOneDevice()._gather(per_replica_value, destinations, axis,  # pylint: disable=protected-access
947                                          options)
948
949
950# For compatibility with code using the old name of `AllReduceCrossDeviceOps`.
951AllReduceCrossTowerOps = AllReduceCrossDeviceOps
952
953
954AllReduceSpecTuple = collections.namedtuple("AllReduceSpecTuple",
955                                            "alg shards limit")
956
957
958@tf_export("distribute.NcclAllReduce")
959class NcclAllReduce(AllReduceCrossDeviceOps):
960  """NCCL all-reduce implementation of CrossDeviceOps.
961
962  It uses Nvidia NCCL for all-reduce. For the batch API, tensors will be
963  repacked or aggregated for more efficient cross-device transportation.
964
965  For reduces that are not all-reduce, it falls back to
966  `tf.distribute.ReductionToOneDevice`.
967
968  Here is how you can use `NcclAllReduce` in `tf.distribute.MirroredStrategy`:
969
970
971  ```
972    strategy = tf.distribute.MirroredStrategy(
973      cross_device_ops=tf.distribute.NcclAllReduce())
974  ```
975  """
976
977  def __init__(self, num_packs=1):
978    """Initializes the object.
979
980    Args:
981      num_packs: a non-negative integer. The number of packs to split values
982        into. If zero, no packing will be done.
983
984    Raises:
985      ValueError: if `num_packs` is negative.
986    """
987    if num_packs < 0:
988      raise ValueError(
989          "NCCL all-reduce requires num_packs >= 0, but {} is specified".format(
990              num_packs))
991    super(NcclAllReduce, self).__init__(
992        all_reduce_alg="nccl", num_packs=num_packs)
993
994
995@tf_export("distribute.HierarchicalCopyAllReduce")
996class HierarchicalCopyAllReduce(AllReduceCrossDeviceOps):
997  """Hierarchical copy all-reduce implementation of CrossDeviceOps.
998
999  It reduces to one GPU along edges in some hierarchy and broadcasts back to
1000  each GPU along the same path. For the batch API, tensors will be repacked or
1001  aggregated for more efficient cross-device transportation.
1002
1003  This is a reduction created for Nvidia DGX-1 which assumes GPUs connects like
1004  that on DGX-1 machine. If you have different GPU inter-connections, it is
1005  likely that it would be slower than `tf.distribute.ReductionToOneDevice`.
1006
1007  For reduces that are not all-reduce, it falls back to
1008  `tf.distribute.ReductionToOneDevice`.
1009
1010  Here is how you can use `HierarchicalCopyAllReduce` in
1011  `tf.distribute.MirroredStrategy`:
1012
1013  ```
1014    strategy = tf.distribute.MirroredStrategy(
1015      cross_device_ops=tf.distribute.HierarchicalCopyAllReduce())
1016  ```
1017  """
1018
1019  def __init__(self, num_packs=1):
1020    """Initializes the object.
1021
1022    Args:
1023      num_packs: a non-negative integer. The number of packs to split values
1024        into. If zero, no packing will be done.
1025
1026    Raises:
1027      ValueError if `num_packs` is negative.
1028    """
1029    if num_packs < 0:
1030      raise ValueError(
1031          "HierarchicalCopy requires num_packs >= 0, but {} is specified"
1032          .format(num_packs))
1033    super(HierarchicalCopyAllReduce, self).__init__(
1034        all_reduce_alg="hierarchical_copy",
1035        num_packs=num_packs)
1036
1037
1038# TODO(crccw): remove after migrating all callers.
1039CollectiveCommunication = collective_util.CommunicationImplementation
1040CommunicationImplementation = collective_util.CommunicationImplementation
1041
1042
1043# TODO(yuefengz): support in-graph collective all-reduce.
1044class CollectiveAllReduce(CrossDeviceOps):
1045  """All-reduce cross device ops using collective ops.
1046
1047  In the between-graph replicated training, it will still do all-reduces across
1048  all workers and then put results on the right destinations.
1049  """
1050
1051  def __init__(self,
1052               devices,
1053               group_size,
1054               collective_keys=None,
1055               canonicalize_devices=True):
1056    """Initializes the object.
1057
1058    Args:
1059      devices: a list of device strings to run collectives on.
1060      group_size: the global group size. For between-graph replicated training
1061        it's the total number of devices across all workers.
1062      collective_keys: an optional CollectiveKey object.
1063      canonicalize_devices: Whether to canonicalize devices for workers or not.
1064    """
1065    if group_size % len(devices) > 0:
1066      raise ValueError("group_size must be divisible by the number of devices.")
1067
1068    self._group_size = group_size
1069    self._collective_keys = (collective_keys or
1070                             cross_device_utils.CollectiveKeys())
1071    # This lock guards all collective launches, i.e. calls to
1072    # cross_device_utils.build_collectve_*.
1073    #
1074    # In a multi threaded eager program we need to ensure different groups of
1075    # collectives don't interleave each other, otherwise there could be
1076    # deadlocks. E.g. if two user threads both are launching collectives:
1077    #   user-thread-0  device0                 device1
1078    #   user-thread-1          device0 device1
1079    # In eager mode, we use one thread per device to launch collective ops, so
1080    # the above launch sequences end up with the following queues:
1081    #   device-0  collective-0  collective-1
1082    #   device-1  collective-1  collective-0
1083    # This deadlocks since neither collective is able to finish.
1084    self._lock = threading.Lock()
1085
1086    if canonicalize_devices:
1087      self._devices = tuple(device_util.canonicalize(d) for d in devices)
1088    else:
1089      self._devices = tuple(
1090          device_util.canonicalize_without_job_and_task(d) for d in devices)
1091    group_key = self._collective_keys.get_group_key(self._devices)
1092    self._launchers = []
1093    # Whether to only use NCCL for batched all-reduce when NCCL is requested.
1094    # This is because of the lack of mechanism to order NCCL operations
1095    # deterministically.
1096    self._limited_nccl = False
1097    for device in self._devices:
1098      launcher = cross_device_utils.CollectiveReplicaLauncher(
1099          group_key, group_size, self._collective_keys, device)
1100      self._launchers.append(launcher)
1101      if not launcher.can_order_nccl():
1102        self._limited_nccl = True
1103
1104    self._pool = multiprocessing.pool.ThreadPool(len(self._devices))
1105
1106    super(CollectiveAllReduce, self).__init__()
1107    self._canonicalize_devices = canonicalize_devices
1108
1109  @property
1110  def _num_between_graph_workers(self):
1111    # Currently we only support equal number of devices on each worker.
1112    return self._group_size / len(self._devices)
1113
1114  def _all_reduce(self, reduce_op, value, replica_id, options):
1115    """Implements CrossDeviceOps.all_reduce."""
1116    # TODO(b/122840926): reuse this method in _batch_all_reduce.
1117    flat_values = nest.flatten(value)
1118
1119    implementation = options.implementation.value
1120    # If NCCL launches can't be ordered (self._limited_nccl == True), we only
1121    # use NCCL when batch_size > 1, hoping that there's only one batched
1122    # all-reduce, which is the gradient aggregation in optimizer. For TF 2.x,
1123    # NCCL launches are always ordered.
1124    if (self._limited_nccl and
1125        options.implementation == CommunicationImplementation.NCCL and
1126        len(flat_values) == 1):
1127      implementation = CommunicationImplementation.AUTO.value
1128
1129    launcher = self._launchers[replica_id]
1130    dense_values, dense_indices, sparse_values, sparse_indices = (
1131        cross_device_utils.split_by_sparsity(flat_values))
1132    dense_results = []
1133    sparse_results = []
1134
1135    if dense_values:
1136      # Reverse the lists so that there's better chance that values follows
1137      # the order in which they are calculated (e.g. when they're gradients), so
1138      # as to overlap calculation with communication. However, this may not be
1139      # optimal for cases like gradients of complicated non-sequential models.
1140      #
1141      # Note that we reverse the list before packing so that the first pack
1142      # won't be too small, since it's more likely for first few packs to have
1143      # long queuing time due to concurrent intense computation.
1144      #
1145      # TODO(b/147393503): explore solutions for optimal ordering.
1146      dense_values.reverse()
1147      packs = cross_device_utils.group_by_size(dense_values,
1148                                               options.bytes_per_pack)
1149
1150      if not context.executing_eagerly() and replica_id == 0:
1151        logging.info(
1152            "Collective all_reduce tensors: %d all_reduces, num_devices = %d, "
1153            "group_size = %d, implementation = %s, num_packs = %d",
1154            len(dense_values), len(self._launchers), self._group_size,
1155            implementation, len(packs))
1156
1157      dense_results = launcher.batch_all_reduce(packs, implementation,
1158                                                options.timeout_seconds)
1159      if reduce_op == reduce_util.ReduceOp.MEAN:
1160        for i, v in enumerate(dense_results):
1161          with ops.device(self._devices[replica_id]):
1162            dense_results[i] = v / self._group_size
1163      dense_results.reverse()
1164
1165    if sparse_values:
1166      if not context.executing_eagerly() and replica_id == 0:
1167        logging.info(
1168            "Collective all_reduce IndexedSlices: %d all_reduces, num_devices ="
1169            "%d, group_size = %d, implementation = %s", len(sparse_values),
1170            len(self._launchers), self._group_size, implementation)
1171
1172      for indexed_slice in sparse_values:
1173        sparse_results.append(
1174            launcher.all_reduce_indexed_slices(indexed_slice, implementation,
1175                                               options.timeout_seconds))
1176
1177      if reduce_op == reduce_util.ReduceOp.MEAN:
1178        for i, v in enumerate(sparse_results):
1179          with ops.device(self._devices[replica_id]):
1180            sparse_results[i] = ops.IndexedSlices(
1181                values=sparse_results[i].values / self._group_size,
1182                indices=sparse_results[i].indices,
1183                dense_shape=sparse_results[i].dense_shape)
1184
1185    flat_results = cross_device_utils.stitch_values(
1186        ((dense_results, dense_indices), (sparse_results, sparse_indices)))
1187    return nest.pack_sequence_as(value, flat_results)
1188
1189  def _all_reduce_per_replica_values(self, reduce_op, per_replica_values,
1190                                     options):
1191    """All reduce a list of per_replica_value."""
1192    values_by_device = [[] for _ in self._devices]
1193    num_devices = len(self._devices)
1194    for per_replica in per_replica_values:
1195      for i in range(num_devices):
1196        values_by_device[i].append(per_replica.values[i])
1197
1198    if context.executing_eagerly():
1199
1200      def thread_fn(device_id):
1201        with context.eager_mode():
1202          return self._all_reduce(reduce_op, values_by_device[device_id],
1203                                  device_id, options)
1204
1205      with self._lock:
1206        outputs_by_device = self._pool.map(thread_fn, list(range(num_devices)))
1207    else:
1208      outputs_by_device = []
1209      with self._lock:
1210        for i in range(num_devices):
1211          outputs_by_device.append(
1212              self._all_reduce(reduce_op, values_by_device[i], i, options))
1213
1214    result = []
1215    for values in zip(*outputs_by_device):
1216      result.append(
1217          distribute_utils.regroup(values, wrap_class=value_lib.Mirrored))
1218    return result
1219
1220  def reduce_implementation(self, reduce_op, per_replica_value, destinations,
1221                            options):
1222    values_util.mark_as_unsaveable()
1223    all_reduced = self._all_reduce_per_replica_values(reduce_op,
1224                                                      [per_replica_value],
1225                                                      options)[0]
1226    devices = get_devices_from(destinations, self._canonicalize_devices)
1227
1228    if _devices_match(per_replica_value, destinations,
1229                      self._canonicalize_devices):
1230      return all_reduced
1231
1232    # Convert `all_reduced` to a `Mirrored` object, as a simple and uniform
1233    # utility to access component for a particular device.
1234    if not isinstance(all_reduced, value_lib.Mirrored):
1235      all_reduced = value_lib.Mirrored([all_reduced])
1236
1237    # If we got this far, the destination devices do not match the all-reduce
1238    # devices, so we must map from one to the other.
1239    index = []
1240    # We must add these control dependencies, otherwise we can get deadlock.
1241    with ops.control_dependencies(all_reduced.values):
1242      for d in devices:
1243        with ops.device(d):
1244          for v in all_reduced.values:
1245            if v.device == d:
1246              index.append(array_ops.identity(v))
1247              break
1248          else:
1249            # TODO(josh11b): Once we add support for model parallelism, get the
1250            # copy from the corresponding replica instead of the primary.
1251            index.append(array_ops.identity(all_reduced._primary))  # pylint: disable=protected-access
1252    return distribute_utils.regroup(index, wrap_class=value_lib.Mirrored)
1253
1254  def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
1255                                  options):
1256    values_util.mark_as_unsaveable()
1257    all_devices_match = _all_devices_match(value_destination_pairs,
1258                                           self._canonicalize_devices)
1259    if all_devices_match:
1260      return self._all_reduce_per_replica_values(
1261          reduce_op, [v[0] for v in value_destination_pairs], options)
1262    else:
1263      if not all_devices_match:
1264        logging.log_first_n(
1265            logging.WARN, "Efficient batch_reduce is not supported if "
1266            "destinations are different.", 10)
1267
1268      return [
1269          self.reduce_implementation(reduce_op, value, dest, options)
1270          for value, dest in value_destination_pairs
1271      ]
1272
1273  def _gather_implementation(self, per_replica_value, destinations, axis,
1274                             options):
1275    all_gathered = self._batch_all_gather([per_replica_value], axis, options)[0]
1276    values_util.mark_as_unsaveable()
1277    devices = get_devices_from(destinations, self._canonicalize_devices)
1278
1279    if _devices_match(per_replica_value, destinations,
1280                      self._canonicalize_devices):
1281      return all_gathered
1282
1283    # Convert `all_gathered` to a `Mirrored` object, as a simple and uniform
1284    # utility to access component for a particular device.
1285    if not isinstance(all_gathered, value_lib.Mirrored):
1286      all_gathered = value_lib.Mirrored([all_gathered])
1287
1288    # If we got this far, the destination devices do not match the all-gather
1289    # devices, so we must map from one to the other.
1290    index = []
1291    # We must add these control dependencies, otherwise we can get deadlock.
1292    with ops.control_dependencies(all_gathered.values):
1293      for d in devices:
1294        with ops.device(d):
1295          for v in all_gathered.values:
1296            if v.device == d:
1297              index.append(array_ops.identity(v))
1298              break
1299            else:
1300              index.append(array_ops.identity(all_gathered._primary))  # pylint: disable=protected-access
1301    return distribute_utils.regroup(index, wrap_class=value_lib.Mirrored)
1302
1303  def _batch_all_gather(self, per_replica_values, axis, options):
1304    """all gather multiple per-replica-values."""
1305    batch_size = len(per_replica_values)
1306    # Pass options.implementation to the runtime as a communication
1307    # implementation hint.
1308    implementation = options.implementation.value
1309    # For now, we use NCCL only when batch_size > 1.
1310    # TODO(b/132575814): switch to NCCL for all collectives when implementation
1311    # is NCCL.
1312    if (options.implementation == CommunicationImplementation.NCCL and
1313        batch_size == 1):
1314      implementation = CommunicationImplementation.AUTO.value
1315
1316    logging.log_first_n(
1317        logging.INFO, "Collective batch_all_gather: %d all-gathers, "
1318        "num_devices = %d, group_size = %d, implementation = %s, " %
1319        (batch_size, len(self._devices), self._group_size, implementation), 10)
1320
1321    def compute_gathered_values():
1322      gathered_values = []
1323      with self._lock, ops.name_scope("allgather"):
1324        for per_replica in per_replica_values:
1325          outputs = []
1326          for i in range(len(self._devices)):
1327            outputs.append(self._launchers[i].all_gather(
1328                per_replica.values[i], axis, implementation,
1329                options.timeout_seconds))
1330          gathered_values.append(outputs)
1331      return gathered_values
1332
1333    if context.executing_eagerly():
1334      gathered_values = def_function.function(compute_gathered_values)()
1335    else:
1336      gathered_values = compute_gathered_values()
1337
1338    mirrored = []
1339    for value in gathered_values:
1340      mirrored.append(
1341          distribute_utils.regroup(value, wrap_class=value_lib.Mirrored))
1342    return mirrored
1343
1344  def __deepcopy__(self, memo):
1345    # distribute_coordinator deep-copies the strategy object, so
1346    # CollectiveAllReduce needs to support deep copy as well.
1347    collective_keys = copy.deepcopy(self._collective_keys, memo)
1348    return CollectiveAllReduce(self._devices, self._group_size, collective_keys,
1349                               self._canonicalize_devices)
1350
1351
1352def select_cross_device_ops(devices, session_config=None):
1353  """Find the best `CrossDeviceOps` locally given a `tf.compat.v1.ConfigProto`.
1354
1355  Args:
1356    devices: a list of devices passed to `tf.distribute.Strategy`.
1357    session_config: a `tf.compat.v1.ConfigProto` or `None`. If `None`, it will
1358      make decision based on all logical devices.
1359
1360  Returns:
1361    A subclass of `CrossDeviceOps`.
1362  """
1363  requested_devices = set(device_util.canonicalize(d) for d in devices)
1364  if ops.executing_eagerly_outside_functions():
1365    logical_gpus = context.context().list_logical_devices(device_type="GPU")
1366    physical_gpus = context.context().list_physical_devices(device_type="GPU")
1367    if len(logical_gpus) != len(physical_gpus):
1368      logging.warning("NCCL is not supported when using virtual GPUs, falling"
1369                      "back to reduction to one device")
1370      return ReductionToOneDevice()
1371
1372    machine_devices = context.context().list_logical_devices()
1373  else:
1374    machine_devices = device_lib.list_local_devices(
1375        session_config=session_config)
1376  using_devices = set()
1377  for d in machine_devices:
1378    if device_util.canonicalize(d.name) in requested_devices:
1379      using_devices.add(d.name)
1380
1381  if len(using_devices) != len(requested_devices):
1382    logging.warning(
1383        "Some requested devices in `tf.distribute.Strategy` are not visible "
1384        "to TensorFlow: %s", ",".join(list(requested_devices - using_devices)))
1385
1386  if any("gpu" not in d.lower() for d in requested_devices):
1387    logging.warning("There are non-GPU devices in `tf.distribute.Strategy`, "
1388                    "not using nccl allreduce.")
1389    return ReductionToOneDevice()
1390
1391  if kernels.get_registered_kernels_for_op("NcclAllReduce"):
1392    return NcclAllReduce(num_packs=1)
1393  else:
1394    logging.warning("Nccl kernel is not found, not using nccl allreduce.")
1395    return ReductionToOneDevice()
1396