1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Ops for GPU collective operations implemented using NVIDIA nccl.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import threading 21 22from tensorflow.python.eager import context 23from tensorflow.python.eager import def_function 24from tensorflow.python.framework import device 25from tensorflow.python.framework import ops 26from tensorflow.python.ops import gen_nccl_ops 27 28 29_module_lock = threading.Lock() 30_shared_name_counter = 0 31 32 33def all_sum(tensors): 34 """Returns a list of tensors with the all-reduce sum across `tensors`. 35 36 The computation is done with an all-reduce operation, so if only some of the 37 returned tensors are evaluated then the computation will hang. 38 39 Args: 40 tensors: The input tensors across which to sum; must be assigned 41 to GPU devices. 42 43 Returns: 44 List of tensors, each with the sum of the input tensors, where tensor i has 45 the same device as `tensors[i]`. 46 """ 47 return _apply_all_reduce('sum', tensors) 48 49 50@ops.RegisterGradient('NcclAllReduce') 51def _all_sum_grad(op, grad): 52 """The gradients for `all_sum`. 53 54 Args: 55 op: The `all_sum` `Operation` that we are differentiating. 56 grad: Gradient with respect to the output of the `all_sum` op. 57 58 Returns: 59 The gradient with respect to the output of `all_sum`. 60 61 Raises: 62 LookupError: If `reduction` is not `sum`. 63 """ 64 if op.get_attr('reduction') != b'sum': 65 raise LookupError('No gradient defined for NcclAllReduce except sum.') 66 67 _check_device(grad, expected=op.device) 68 num_devices = op.get_attr('num_devices') 69 shared_name = op.get_attr('shared_name') + b'_grad' 70 71 with ops.device(op.device): 72 return gen_nccl_ops.nccl_all_reduce( 73 input=grad, 74 reduction='sum', 75 num_devices=num_devices, 76 shared_name=shared_name) 77 78 79def all_prod(tensors): 80 """Returns a list of tensors with the all-reduce product across `tensors`. 81 82 The computation is done with an all-reduce operation, so if only some of the 83 returned tensors are evaluated then the computation will hang. 84 85 Args: 86 tensors: The input tensors across which to multiply; must be assigned 87 to GPU devices. 88 89 Returns: 90 List of tensors, each with the product of the input tensors, where tensor i 91 has the same device as `tensors[i]`. 92 """ 93 return _apply_all_reduce('prod', tensors) 94 95 96def all_min(tensors): 97 """Returns a list of tensors with the all-reduce min across `tensors`. 98 99 The computation is done with an all-reduce operation, so if only some of the 100 returned tensors are evaluated then the computation will hang. 101 102 Args: 103 tensors: The input tensors across which to reduce; must be assigned 104 to GPU devices. 105 106 Returns: 107 List of tensors, each with the minimum of the input tensors, where tensor i 108 has the same device as `tensors[i]`. 109 """ 110 return _apply_all_reduce('min', tensors) 111 112 113def all_max(tensors): 114 """Returns a list of tensors with the all-reduce max across `tensors`. 115 116 The computation is done with an all-reduce operation, so if only some of the 117 returned tensors are evaluated then the computation will hang. 118 119 Args: 120 tensors: The input tensors across which to reduce; must be assigned 121 to GPU devices. 122 123 Returns: 124 List of tensors, each with the maximum of the input tensors, where tensor i 125 has the same device as `tensors[i]`. 126 """ 127 return _apply_all_reduce('max', tensors) 128 129 130def reduce_sum(tensors): 131 """Returns a tensor with the reduce sum across `tensors`. 132 133 The computation is done with a reduce operation, so only one tensor is 134 returned. 135 136 Args: 137 tensors: The input tensors across which to sum; must be assigned 138 to GPU devices. 139 140 Returns: 141 A tensor containing the sum of the input tensors. 142 143 Raises: 144 LookupError: If context is not currently using a GPU device. 145 """ 146 return _apply_reduce('sum', tensors) 147 148 149@ops.RegisterGradient('NcclReduce') 150def _reduce_sum_grad(op, grad): 151 """The gradients for input `Operation` of `reduce_sum`. 152 153 Args: 154 op: The `sum send` `Operation` that we are differentiating. 155 grad: Gradient with respect to the output of the `reduce_sum` op. 156 157 Returns: 158 The gradient with respect to the input of `reduce_sum` op. 159 160 Raises: 161 LookupError: If the reduction attribute of op is not `sum`. 162 """ 163 if op.get_attr('reduction') != b'sum': 164 raise LookupError('No gradient defined for NcclReduce except sum.') 165 _check_device(grad, expected=op.device) 166 167 with ops.device(op.device): 168 result = gen_nccl_ops.nccl_broadcast(input=grad, shape=grad.shape) 169 170 return [result] * len(op.inputs) 171 172 173def broadcast(tensor): 174 """Returns a tensor that can be efficiently transferred to other devices. 175 176 Args: 177 tensor: The tensor to send; must be assigned to a GPU device. 178 179 Returns: 180 A tensor with the value of `src_tensor`, which can be used as input to 181 ops on other GPU devices. 182 """ 183 _check_device(tensor) 184 185 with ops.device(tensor.device): 186 return gen_nccl_ops.nccl_broadcast(input=tensor, shape=tensor.shape) 187 188 189@ops.RegisterGradient('NcclBroadcast') 190def _broadcast_grad(op, accumulated_grad): 191 """The gradients for input `Operation` of `broadcast`. 192 193 Args: 194 op: The `broadcast send` `Operation` that we are differentiating. 195 accumulated_grad: Accumulated gradients with respect to the output of the 196 `broadcast` op. 197 198 Returns: 199 Gradients with respect to the input of `broadcast`. 200 """ 201 # Grab inputs of accumulated_grad and replace accumulation with reduce_sum. 202 grads = [t for t in accumulated_grad.op.inputs] 203 for t in grads: 204 _check_device(t) 205 206 with ops.device(op.device): 207 return gen_nccl_ops.nccl_reduce(input=grads, reduction='sum') 208 209 210def _apply_all_reduce(reduction, tensors): 211 """Helper function for all_* functions.""" 212 if not tensors: 213 raise ValueError('Must pass >0 tensors to all reduce operations') 214 215 shared_name = _get_shared_name() 216 217 def _all_reduce(): 218 """Call nccl allreduce.""" 219 res = [] 220 for t in tensors: 221 _check_device(t) 222 with ops.device(t.device): 223 res.append( 224 gen_nccl_ops.nccl_all_reduce( 225 input=t, 226 reduction=reduction, 227 num_devices=len(tensors), 228 shared_name=shared_name)) 229 return res 230 231 if context.executing_eagerly(): 232 # Nccl ops will block unless they are executed concurrently such as in a 233 # graph or a defun. 234 return def_function.function(_all_reduce)() 235 else: 236 return _all_reduce() 237 238 239def _apply_reduce(reduction, tensors): 240 """Helper function for reduce_* functions.""" 241 if not tensors: 242 raise ValueError('Must pass >0 tensors to reduce operations') 243 244 for t in tensors: 245 _check_device(t) 246 result = gen_nccl_ops.nccl_reduce(input=tensors, reduction=reduction) 247 try: 248 next(t for t in tensors if t.device == result.device) 249 except StopIteration: 250 raise ValueError('One input tensor must be assigned to current device') 251 return result 252 253 254def _get_shared_name(): 255 global _shared_name_counter 256 257 with _module_lock: 258 val = _shared_name_counter 259 _shared_name_counter += 1 260 return 'c%s' % val 261 262 263def _check_device(tensor, expected=None): 264 if not device.canonical_name(tensor.device): 265 raise ValueError('Device assignment required for nccl collective ops') 266 if expected and expected != tensor.device: 267 raise ValueError('Expected device %s, got %s' % (expected, tensor.device)) 268