• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Ops for GPU collective operations implemented using NVIDIA nccl."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import threading
21
22from tensorflow.python.eager import context
23from tensorflow.python.eager import def_function
24from tensorflow.python.framework import device
25from tensorflow.python.framework import ops
26from tensorflow.python.ops import gen_nccl_ops
27
28
29_module_lock = threading.Lock()
30_shared_name_counter = 0
31
32
33def all_sum(tensors):
34  """Returns a list of tensors with the all-reduce sum across `tensors`.
35
36  The computation is done with an all-reduce operation, so if only some of the
37  returned tensors are evaluated then the computation will hang.
38
39  Args:
40    tensors: The input tensors across which to sum; must be assigned
41      to GPU devices.
42
43  Returns:
44    List of tensors, each with the sum of the input tensors, where tensor i has
45    the same device as `tensors[i]`.
46  """
47  return _apply_all_reduce('sum', tensors)
48
49
50@ops.RegisterGradient('NcclAllReduce')
51def _all_sum_grad(op, grad):
52  """The gradients for `all_sum`.
53
54  Args:
55    op: The `all_sum` `Operation` that we are differentiating.
56    grad: Gradient with respect to the output of the `all_sum` op.
57
58  Returns:
59    The gradient with respect to the output of `all_sum`.
60
61  Raises:
62    LookupError: If `reduction` is not `sum`.
63  """
64  if op.get_attr('reduction') != b'sum':
65    raise LookupError('No gradient defined for NcclAllReduce except sum.')
66
67  _check_device(grad, expected=op.device)
68  num_devices = op.get_attr('num_devices')
69  shared_name = op.get_attr('shared_name') + b'_grad'
70
71  with ops.device(op.device):
72    return gen_nccl_ops.nccl_all_reduce(
73        input=grad,
74        reduction='sum',
75        num_devices=num_devices,
76        shared_name=shared_name)
77
78
79def all_prod(tensors):
80  """Returns a list of tensors with the all-reduce product across `tensors`.
81
82  The computation is done with an all-reduce operation, so if only some of the
83  returned tensors are evaluated then the computation will hang.
84
85  Args:
86    tensors: The input tensors across which to multiply; must be assigned
87      to GPU devices.
88
89  Returns:
90    List of tensors, each with the product of the input tensors, where tensor i
91    has the same device as `tensors[i]`.
92  """
93  return _apply_all_reduce('prod', tensors)
94
95
96def all_min(tensors):
97  """Returns a list of tensors with the all-reduce min across `tensors`.
98
99  The computation is done with an all-reduce operation, so if only some of the
100  returned tensors are evaluated then the computation will hang.
101
102  Args:
103    tensors: The input tensors across which to reduce; must be assigned
104      to GPU devices.
105
106  Returns:
107    List of tensors, each with the minimum of the input tensors, where tensor i
108    has the same device as `tensors[i]`.
109  """
110  return _apply_all_reduce('min', tensors)
111
112
113def all_max(tensors):
114  """Returns a list of tensors with the all-reduce max across `tensors`.
115
116  The computation is done with an all-reduce operation, so if only some of the
117  returned tensors are evaluated then the computation will hang.
118
119  Args:
120    tensors: The input tensors across which to reduce; must be assigned
121      to GPU devices.
122
123  Returns:
124    List of tensors, each with the maximum of the input tensors, where tensor i
125    has the same device as `tensors[i]`.
126  """
127  return _apply_all_reduce('max', tensors)
128
129
130def reduce_sum(tensors):
131  """Returns a tensor with the reduce sum across `tensors`.
132
133  The computation is done with a reduce operation, so only one tensor is
134  returned.
135
136  Args:
137    tensors: The input tensors across which to sum; must be assigned
138      to GPU devices.
139
140  Returns:
141    A tensor containing the sum of the input tensors.
142
143  Raises:
144    LookupError: If context is not currently using a GPU device.
145  """
146  return _apply_reduce('sum', tensors)
147
148
149@ops.RegisterGradient('NcclReduce')
150def _reduce_sum_grad(op, grad):
151  """The gradients for input `Operation` of `reduce_sum`.
152
153  Args:
154    op: The `sum send` `Operation` that we are differentiating.
155    grad: Gradient with respect to the output of the `reduce_sum` op.
156
157  Returns:
158    The gradient with respect to the input of `reduce_sum` op.
159
160  Raises:
161    LookupError: If the reduction attribute of op is not `sum`.
162  """
163  if op.get_attr('reduction') != b'sum':
164    raise LookupError('No gradient defined for NcclReduce except sum.')
165  _check_device(grad, expected=op.device)
166
167  with ops.device(op.device):
168    result = gen_nccl_ops.nccl_broadcast(input=grad, shape=grad.shape)
169
170  return [result] * len(op.inputs)
171
172
173def broadcast(tensor):
174  """Returns a tensor that can be efficiently transferred to other devices.
175
176  Args:
177    tensor: The tensor to send; must be assigned to a GPU device.
178
179  Returns:
180    A tensor with the value of `src_tensor`, which can be used as input to
181    ops on other GPU devices.
182  """
183  _check_device(tensor)
184
185  with ops.device(tensor.device):
186    return gen_nccl_ops.nccl_broadcast(input=tensor, shape=tensor.shape)
187
188
189@ops.RegisterGradient('NcclBroadcast')
190def _broadcast_grad(op, accumulated_grad):
191  """The gradients for input `Operation` of `broadcast`.
192
193  Args:
194    op: The `broadcast send` `Operation` that we are differentiating.
195    accumulated_grad: Accumulated gradients with respect to the output of the
196      `broadcast` op.
197
198  Returns:
199    Gradients with respect to the input of `broadcast`.
200  """
201  # Grab inputs of accumulated_grad and replace accumulation with reduce_sum.
202  grads = [t for t in accumulated_grad.op.inputs]
203  for t in grads:
204    _check_device(t)
205
206  with ops.device(op.device):
207    return gen_nccl_ops.nccl_reduce(input=grads, reduction='sum')
208
209
210def _apply_all_reduce(reduction, tensors):
211  """Helper function for all_* functions."""
212  if not tensors:
213    raise ValueError('Must pass >0 tensors to all reduce operations')
214
215  shared_name = _get_shared_name()
216
217  def _all_reduce():
218    """Call nccl allreduce."""
219    res = []
220    for t in tensors:
221      _check_device(t)
222      with ops.device(t.device):
223        res.append(
224            gen_nccl_ops.nccl_all_reduce(
225                input=t,
226                reduction=reduction,
227                num_devices=len(tensors),
228                shared_name=shared_name))
229    return res
230
231  if context.executing_eagerly():
232    # Nccl ops will block unless they are executed concurrently such as in a
233    # graph or a defun.
234    return def_function.function(_all_reduce)()
235  else:
236    return _all_reduce()
237
238
239def _apply_reduce(reduction, tensors):
240  """Helper function for reduce_* functions."""
241  if not tensors:
242    raise ValueError('Must pass >0 tensors to reduce operations')
243
244  for t in tensors:
245    _check_device(t)
246  result = gen_nccl_ops.nccl_reduce(input=tensors, reduction=reduction)
247  try:
248    next(t for t in tensors if t.device == result.device)
249  except StopIteration:
250    raise ValueError('One input tensor must be assigned to current device')
251  return result
252
253
254def _get_shared_name():
255  global _shared_name_counter
256
257  with _module_lock:
258    val = _shared_name_counter
259    _shared_name_counter += 1
260  return 'c%s' % val
261
262
263def _check_device(tensor, expected=None):
264  if not device.canonical_name(tensor.device):
265    raise ValueError('Device assignment required for nccl collective ops')
266  if expected and expected != tensor.device:
267    raise ValueError('Expected device %s, got %s' % (expected, tensor.device))
268