# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Utilities for cross_device_ops.""" import copy import threading from typing import Callable, List, Optional, Union from tensorflow.python.distribute import collective_util from tensorflow.python.distribute import values as value_lib from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import collective_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nccl_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.types import core INSTANCE_KEY_START_NUMBER = 100 def aggregate_gradients_using_nccl(replica_grads): """Aggregate gradients using nccl allreduce.""" agg_all_g_and_v = [] for single_g_and_v in zip(*replica_grads): single_grads = [g for g, _ in single_g_and_v] agg_grads = nccl_ops.all_sum(single_grads) agg_all_g_and_v.append( [(g, v) for g, (_, v) in zip(agg_grads, single_g_and_v)]) agg_all_g_and_v = list(zip(*agg_all_g_and_v)) return agg_all_g_and_v def aggregate_gradients_using_hierarchical_copy(avail_devices, replica_grads): """Aggregate gradients using hierarchical copies. Args: avail_devices: available GPU devices. replica_grads: List of lists of (gradient, variable) tuples. The outer list is over replicas. The inner list is over individual gradients. Returns: The list of (aggregated_gradient, variable), where the gradient has been summed across all replicas and the variable is chosen from the first replica. """ # This only works for DGX-1 type of machine topology # Device peer to peer matrix # DMA: 0 1 2 3 4 5 6 7 # 0: Y Y Y Y Y N N N # 1: Y Y Y Y N Y N N # 2: Y Y Y Y N N Y N # 3: Y Y Y Y N N N Y # 4: Y N N N Y Y Y Y # 5: N Y N N Y Y Y Y # 6: N N Y N Y Y Y Y # 7: N N N Y Y Y Y Y agg_grads = [] num_devices = len(avail_devices) # In the special case of DGX-1 machine topology, the two groups have equal # size. group_size = num_devices // 2 for i, single_grads in enumerate(zip(*replica_grads)): group_0_main_device = i % num_devices group_1_main_device = (group_0_main_device + group_size) % num_devices if group_0_main_device < group_size: group_0_begin = 0 group_1_begin = group_size else: group_0_begin = group_size group_1_begin = 0 # Aggregate the first group. group_0_device_grads = single_grads[group_0_begin: group_0_begin + group_size] with ops.device(avail_devices[group_0_main_device]): group_0_agg_grads, _ = aggregate_single_gradient_using_copy( group_0_device_grads, False, False) # Aggregate the second group. group_1_device_grads = single_grads[group_1_begin: group_1_begin + group_size] with ops.device(avail_devices[group_1_main_device]): group_1_agg_grads, _ = aggregate_single_gradient_using_copy( group_1_device_grads, False, False) # Aggregate between the groups. with ops.device(avail_devices[group_0_main_device]): (agg_total_grads, _), _ = aggregate_single_gradient_using_copy( [group_0_agg_grads, group_1_agg_grads], False, False) # Broadcast the result back into the root of each group. with ops.device(avail_devices[group_0_main_device]): group_0_agg_grads_bcast = array_ops.identity(agg_total_grads) with ops.device(avail_devices[group_1_main_device]): group_1_agg_grads_bcast = array_ops.identity(agg_total_grads) agg_grads_bcast = [] for j in range(len(single_grads)): with ops.device(avail_devices[j]): # Broadcast the result back to each member in the group from the root. if (group_0_main_device < group_size) == (j < group_size): src_device_grad = group_0_agg_grads_bcast else: src_device_grad = group_1_agg_grads_bcast agg_grads_bcast.append(array_ops.identity(src_device_grad)) agg_grads.append( [(g, v) for g, (_, v) in zip(agg_grads_bcast, single_grads)]) agg_grads = list(zip(*agg_grads)) return agg_grads def aggregate_single_gradient_using_copy(grad_and_vars, use_mean, check_inf_nan): """Calculate the average gradient for a shared variable across all replicas. Note that this function provides a synchronization point across all replicas. Args: grad_and_vars: A list or tuple of (gradient, variable) tuples. Each (gradient, variable) pair within the outer list represents the gradient of the variable calculated for a single replica, and the number of pairs equals the number of replicas. use_mean: if True, mean is taken, else sum of gradients is taken. check_inf_nan: check grads for nans and infs. Returns: The tuple ([(average_gradient, variable),], has_nan_or_inf) where the gradient has been averaged across all replicas. The variable is chosen from the first replica. The has_nan_or_inf indicates the grads has nan or inf. """ grads = [g for g, _ in grad_and_vars] grad = math_ops.add_n(grads) if use_mean and len(grads) > 1: grad = array_ops.multiply(grad, 1.0 / len(grads)) v = grad_and_vars[0][1] if check_inf_nan: has_nan_or_inf = array_ops.logical_not( array_ops.reduce_all(array_ops.is_finite(grads))) return (grad, v), has_nan_or_inf else: return (grad, v), None # TODO(yuefengz): use random key starts to avoid reusing keys? class CollectiveKeys(object): """Class that manages collective keys. We need to manage three different keys for collective: *Group key*: an integer key to identify the set of cooperative devices. Collective ops work under the same set of devices must using the same group key. *Instance key*: an integer key to identify the set of same counterpart of tensors on different devices in a device group that need to be all-reduced. This class is thread safe. """ def __init__(self, group_key_start=1): """Initializes the object. Args: group_key_start: the starting integer of group key. """ self._group_key = group_key_start self._instance_key_table = {} self._lock = threading.Lock() def get_group_key(self, devices): """Returns a new group key. The caller should store and reuse the same group key for the same set of devices. Calling this method always returns a new group key. Args: devices: a list of canonical device strings in a collective group. Returns: a new group key. """ with self._lock: new_key = self._group_key self._group_key += 1 self._instance_key_table[new_key] = {} for device in devices: self._instance_key_table[new_key][device] = INSTANCE_KEY_START_NUMBER return new_key def get_instance_key(self, group_key, device): """Returns a new instance key for use in defining a collective op. You should call this once per each collective op of a collective instance. Args: group_key: the group key returned by get_group_key(). You should not assign the group key yourself. device: a canonical device string. It should be the device this collective op is on. Returns: a new instance key. Raises: ValueError: when the group key is invalid or the device is not in the group. """ with self._lock: group = self._instance_key_table.get(group_key, None) if group is None: raise ValueError(f'Group {group_key} is not found.') if device not in group: raise ValueError(f'Device {device} is not present in group {group_key}') v = group[device] group[device] += 1 return v def __deepcopy__(self, memo): # distribute_coordinator deep-copies the strategy object, so # CollectiveKeys needs to support deep copy as well. copied = CollectiveKeys() copied._group_key = self._group_key copied._instance_key_table = copy.deepcopy(self._instance_key_table, memo) return copied class CollectiveReplicaLauncher(object): """Launch collectives on one replica.""" _prefer_unique_instance_key = True _prefer_ordering_token = True def __init__(self, group_key: int, group_size: int, collective_keys: CollectiveKeys, device: str, options: collective_util.Options): self._group_key = group_key self._group_size = group_size self._collective_keys = collective_keys self._device = device self._options = options if self._use_ordering_token(): with ops.init_scope(), ops.device(device): self._ordering_token = resource_variable_ops.ResourceVariable(0.) else: self._ordering_token = None def _control_input(self, control_input: Union[core.TensorLike, ops.Operation]): if control_input is not None and not self._use_ordering_token(): return ops.control_dependencies([control_input]) return ops.NullContextmanager() def _use_unique_instance_key(self): if not ops.executing_eagerly_outside_functions(): return False return CollectiveReplicaLauncher._prefer_unique_instance_key def _use_ordering_token(self): # We rely on auto control dep to insert control edges between NCCL calls, # but for tf1 graph mode auto control dep is not used. if not ops.executing_eagerly_outside_functions(): return False return CollectiveReplicaLauncher._prefer_ordering_token def _next_instance_key(self): """Returns the next instance key.""" if self._use_unique_instance_key(): # Assigning instance keys at function building time have issues since # different workers may retrace the function at different times. With # collective V2 we can use capture_call_time_value to use a placeholder as # the instance key and feed it at function call time. In this way we also # don't reuse instance keys, which allows for per-instance cancellation. graph = ops.get_default_graph() # Control flow ops don't work with capture_call_time_value, so we put the # capture in the function graph of that control flow op. while getattr(graph, 'is_control_flow_graph', False): graph = graph.outer_graph if not context.executing_eagerly() and graph.building_function: with graph.as_default(): # Capture self._next_instance_key so that when building a function # that calls another tf.function, the instance key assignment is # further delayed until we actually call the function in eager. Note # that capture_call_time_value doesn't automatically propagate the # deferred capture to the outer function. return graph.capture_call_time_value( self._next_instance_key, tensor_spec.TensorSpec([], dtypes.int32)) else: instance_key = self._collective_keys.get_instance_key( self._group_key, self._device) with ops.device('CPU:0'): return ops.convert_to_tensor(instance_key, dtype=dtypes.int32) else: return self._collective_keys.get_instance_key(self._group_key, self._device) def _get_ordering_token(self): if self._use_ordering_token(): return self._ordering_token.handle def can_order_nccl(self): """Whether this launcher can order NCCL operations.""" return self._use_ordering_token() def all_reduce( self, input_tensor: core.TensorLike, control_input: Optional[Union[core.TensorLike, ops.Operation]] = None, options: Optional[collective_util.Options] = None) -> core.Tensor: """All-reduce a dense tensor. Args: input_tensor: a dense tensor. It must have the same shape on all replicas. control_input: if not None, add control edges between control_input and the all-reduce. options: an optional tf.distribute.experimental.CommunicationOptions. If provided, it overrides the default options. Returns: The reduced tensor. """ instance_key = self._next_instance_key() options = self._options.merge(options) ordering_token = self._get_ordering_token() with ops.device(self._device), \ self._control_input(control_input): return collective_ops.all_reduce_v2( input_tensor, self._group_size, self._group_key, instance_key, communication_hint=options.implementation.value, timeout=options.timeout_seconds, ordering_token=ordering_token) def _all_gather(self, input_tensor: core.TensorLike, options: Optional[collective_util.Options]) -> core.Tensor: """All-gather a dense tensor. Args: input_tensor: a dense tensor. It must have the same shape on all replicas. options: an optional tf.distribute.experimental.CommunicationOptions. If provided, it overrides the default options. Returns: The reduced tensor. """ instance_key = self._next_instance_key() options = self._options.merge(options) ordering_token = self._get_ordering_token() with ops.device(self._device): return collective_ops.all_gather_v2( input_tensor, self._group_size, self._group_key, instance_key, communication_hint=options.implementation.value, timeout=options.timeout_seconds, ordering_token=ordering_token) def batch_all_reduce( self, input_tensor_packs: List[List[core.TensorLike]], options: Optional[collective_util.Options] = None) -> core.Tensor: """Batch all-reduce dense tensors. This takes a list of batches of tensors. Using multiple batches have the benefit that it doesn't need to wait for all inputs to be ready to start the all-reduce. Args: input_tensor_packs: a list of lists of dense tensors. options: an optional tf.distribute.experimental.CommunicationOptions. If provided, it overrides the default options. Returns: A flat list of reduced tensors. """ options = self._options.merge(options) outputs = [] for pack in input_tensor_packs: if context.executing_eagerly(): # We don't batch in eager as it sometimes makes the performance worse # due the concat/split ops. for input_tensor in pack: outputs.append(self.all_reduce(input_tensor, None, options)) else: # TODO(b/169168846): inserts a parallel all_gather to verify packings # are the same on each replica. with ops.device(self._device): flat_tensors = [array_ops.reshape(t, [-1]) for t in pack] shapes = [array_ops.shape(t) for t in pack] if (options.implementation == collective_util.CommunicationImplementation.NCCL and outputs): control_input = outputs[-1] else: control_input = None reduced = self.all_reduce( array_ops.concat(flat_tensors, axis=0), control_input, options) num_elements = [math_ops.reduce_prod(s) for s in shapes] flat_outputs = array_ops.split(reduced, num_elements, axis=0) for shape, flat_output in zip(shapes, flat_outputs): outputs.append(array_ops.reshape(flat_output, shape)) return outputs def all_gather( self, input_tensor: core.TensorLike, axis: core.TensorLike, options: Optional[collective_util.Options] = None) -> core.Tensor: """All-gather a dense tensor. This method must be called inside a tf.function. Args: input_tensor: a dense tensor. It must have the same rank on all replicas, and dimensions other than `axis` need to be the same as well. axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the range [0, rank(value)). options: an optional tf.distribute.experimental.CommunicationOptions. If provided, it overrides the default options. Returns: The gathered Tensor. Raises: RuntimeError: if called in eager mode. """ if context.executing_eagerly(): raise RuntimeError('all_gather is not supported in eager mode.') with ops.device(self._device), \ ops.control_dependencies([array_ops.identity(input_tensor)]): # 1. Transpose # E.g. Given an input_tensor with shape [2,2,5,1] and axis to gather is 3, # we use perm_pre=[3 0 1 2] to reshape it to [1,2,2,5], which # brings the 3rd dim first; afterwards we use perm_after=[1,2,3,0] to # place it back. perm_pre = array_ops.concat( ([axis], math_ops.range(axis), math_ops.range(axis + 1, array_ops.rank(input_tensor))), axis=0) input_tensor_t = array_ops.transpose(input_tensor, perm=perm_pre) # 2. Pad gathered_shape = self._all_gather( array_ops.expand_dims_v2(array_ops.shape_v2(input_tensor_t), axis=0), options) first_dims = gathered_shape[:, 0] full_axis_dim = math_ops.reduce_max(first_dims) padded_input_tensor = _pad_util(input_tensor_t, full_axis_dim) # 3. Gather gather_padded_out_tensor = self._all_gather(padded_input_tensor, options) # 4. Unpad split_tensors = [] for i in range(self._group_size): start_pos = i * full_axis_dim split_tensors.append(gather_padded_out_tensor[start_pos:start_pos + first_dims[i]]) out_tensor_t = array_ops.concat(split_tensors, 0) # 5. Transpose back perm_after = array_ops.concat( (math_ops.range(1, axis + 1), [0], math_ops.range(axis + 1, array_ops.rank(input_tensor_t))), axis=0) return array_ops.transpose(out_tensor_t, perm=perm_after) def all_reduce_indexed_slices( self, input_slices: indexed_slices.IndexedSlices, options: Optional[collective_util.Options] = None ) -> indexed_slices.IndexedSlices: """All-reduce an IndexedSlices. This method must be called inside a tf.function. Args: input_slices: an IndexedSlices. options: an optional tf.distribute.experimental.CommunicationOptions. If provided, it overrides the default options. Returns: The reduced IndexedSlices. Raises: RuntimeError: if called in eager mode. """ if context.executing_eagerly(): raise RuntimeError( 'all_reduce_indexed_slices is not supported in eager mode.') # Current CollectiveAllGather implementations require input IndexedSlices to # have consistent length across the board, we handle the reduction of # IndexedSlices as follows: # 1. Gather the lengths of IndexedSlices from all participants. # 2. If they have consistent length, apply all_gather. # 3. Otherwise pad IndexedSlices to be the same length across all # participants and apply_gather. options = self._options.merge(options) with ops.device(self._device): def all_gather_indexed_slices( all_gather_fn: Callable[ [core.TensorLike, Optional[collective_util.Options]], core.Tensor] ) -> indexed_slices.IndexedSlices: """Use all_gather_fn to aggregate `IndexedSlices`.""" all_values = all_gather_fn(input_slices.values, options) # Add control dependency to order the all-gather. if (options.implementation == collective_util.CommunicationImplementation.NCCL): control = [all_values] else: control = [] with ops.control_dependencies(control): all_indices = all_gather_fn(input_slices.indices, options) return indexed_slices.IndexedSlices( values=all_values, indices=all_indices, dense_shape=input_slices.dense_shape) length = array_ops.shape(input_slices.indices) all_lengths = self._all_gather(length, options) def all_gather_with_padding( input_tensor: core.TensorLike, options: Optional[collective_util.Options]) -> core.Tensor: """all_gather tensors of different sizes using padding.""" max_length = math_ops.reduce_max(all_lengths) padded_tensor = _pad_util(input_tensor, max_length) all_padded_tensors = self._all_gather(padded_tensor, options) split_tensors = [] for i in range(self._group_size): start_pos = i * max_length split_tensors.append(all_padded_tensors[start_pos:start_pos + all_lengths[i]]) return array_ops.concat(split_tensors, 0) return control_flow_ops.cond( math_ops.equal( math_ops.reduce_max(all_lengths), math_ops.reduce_min(all_lengths)), lambda: all_gather_indexed_slices(self._all_gather), lambda: all_gather_indexed_slices(all_gather_with_padding)) def aggregate_tensors_or_indexed_slices(values, accumulation_fn=math_ops.add_n): """Aggregate tensors using `accumulation_fn` and IndexedSlices via concat.""" if any(isinstance(v, indexed_slices.IndexedSlices) for v in values): return backprop.aggregate_indexed_slices_gradients(values) else: return accumulation_fn(values) def divide_by_n_tensors_or_indexed_slices(value, n): if isinstance(value, indexed_slices.IndexedSlices): value = backprop.flatten_nested_indexed_slices(value) return indexed_slices.IndexedSlices(value.values / n, value.indices, value.dense_shape) else: return value / n def copy_tensor_or_indexed_slices_to_device(value, device): """Copies a tensor or IndexedSlices to a device.""" with ops.device(device): if isinstance(value, indexed_slices.IndexedSlices): copied_values = array_ops.identity(value.values) copied_indices = array_ops.identity(value.indices) if value.dense_shape is not None: copied_shape = array_ops.identity(value.dense_shape) else: copied_shape = None result = indexed_slices.IndexedSlices(copied_values, copied_indices, copied_shape) else: result = array_ops.identity(value) return result def is_indexed_slices(value): if isinstance(value, indexed_slices.IndexedSlices): return True if isinstance(value, value_lib.DistributedValues): return all( isinstance(v, indexed_slices.IndexedSlices) for v in value.values) return False def split_by_sparsity(values): """Split values into dense and sparse values. Args: values: a list of tensors or `PerReplica`s. Returns: Four lists: a list of dense values, a list of their indices in `values` and a list of sparse values, a list of their indices in `values`. """ dense_values = [] dense_indices = [] sparse_values = [] sparse_indices = [] for i, v in enumerate(values): if is_indexed_slices(v): sparse_values.append(v) sparse_indices.append(i) else: dense_values.append(v) dense_indices.append(i) return dense_values, dense_indices, sparse_values, sparse_indices def stitch_values(values_and_indices_list): """Stitch values together according to their indices. Args: values_and_indices_list: a list of tuples of values and indices indicating the values and positions in the returned list. Returns: a stitched list of values. """ length = 0 for values_and_indices in values_and_indices_list: length += len(values_and_indices[0]) result = [None] * length for values_and_indices in values_and_indices_list: if values_and_indices and values_and_indices[0]: for v, i in zip(*values_and_indices): assert result[i] is None result[i] = v return result def group_by_size(input_tensors, bytes_per_pack): """Groups `input_tensors` into chunks of `bytes_per_pack`. The method preserves the original order of `input_tensors`. The grouping is best effort, each pack could have more or less bytes than `bytes_per_pack`. It only groups values with known shape. Args: input_tensors: a list of Tensor. bytes_per_pack: an integer. Returns: A list of packs of Tensor. All values are grouped into one pack if `bytes_per_pack` is zero or any of the value has unknown shape. """ if bytes_per_pack == 0: return [input_tensors] packs = [] last_pack_size = 0 for value in input_tensors: num_elements = value.shape.num_elements() if num_elements is None: # Can't pack values with unknown shape. logging.warning( 'not packing values due to the unknown or inconsistent shape of %s', value) return [input_tensors] size = num_elements * value.dtype.size # Try to keep each pack as close to bytes_per_pack as possible, while each # pack is at least bytes_per_pack large. I.E. we err on the side of having # few but large packs. if not packs or last_pack_size > bytes_per_pack: packs.append([]) last_pack_size = 0 packs[-1].append(value) last_pack_size += size return packs def _pad_util(input_tensor, full_axis_dim): """Pad the `input_tensor`'s first dimension to be `full_axis_dim`.""" missing_axis_dim = full_axis_dim - array_ops.shape_v2(input_tensor)[0] tensor_rank = array_ops.rank(input_tensor) paddings_axis = [[0, missing_axis_dim]] paddings = array_ops.concat([ paddings_axis, array_ops.zeros(shape=(tensor_rank - 1, 2), dtype=dtypes.int32) ], axis=0) padded_input_tensor = array_ops.pad(input_tensor, paddings) return padded_input_tensor