1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TensorFlow collective Ops.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.ops import gen_collective_ops 21 22 23def all_reduce(t, 24 group_size, 25 group_key, 26 instance_key, 27 merge_op='Add', 28 final_op='Id', 29 subdiv_offsets=(0,), 30 communication_hint='auto', 31 timeout=0): 32 """Reduces tensors collectively, across devices. 33 34 Args: 35 t: the tensor to be reduced. 36 group_size: the total number of tensors to be collectively reduced. 37 Each must reside on a different device. Should be a positive integer. 38 group_key: an integer identifying the group of devices. 39 instance_key: an integer identifying the participating group of Ops. 40 merge_op: string naming the binary Op to be applied to compute each 41 partial reduction. 42 final_op: string naming the unary Op to be applied to each fully 43 reduced value. Can be 'Id' for no operation. 44 subdiv_offsets: a list of integer offsets into the tensor at which each 45 independent subdivision should begin. Use [0] if no subdivision should 46 be done. 47 communication_hint: preferred collective communication. The implementation 48 may fall back to another mechanism. Options include `auto`, `ring`, and 49 `nccl`. 50 timeout: a float. If set to a non zero, set a completion timeout to detect 51 staleness. If the timer goes off, a DeadlineExceededError is raised. The 52 timeout value in seconds. This feature is experimental. 53 54 Returns: 55 An Op implementing the distributed reduction. 56 57 Raises: 58 ValueError: if any of the input parameter constraints are not met. 59 """ 60 if group_size < 1: 61 raise ValueError('Parameter group_size to all_reduce must be at least 1.') 62 return gen_collective_ops.collective_reduce( 63 t, 64 group_size=group_size, 65 group_key=group_key, 66 instance_key=instance_key, 67 merge_op=merge_op, 68 final_op=final_op, 69 subdiv_offsets=subdiv_offsets, 70 communication_hint=communication_hint.lower(), 71 timeout_seconds=timeout) 72 73 74def all_reduce_v2(t, 75 group_size, 76 group_key, 77 instance_key, 78 merge_op='Add', 79 final_op='Id', 80 communication_hint='auto', 81 timeout=0, 82 ordering_token=None): 83 """Reduces tensors collectively, across devices. 84 85 Args: 86 t: the tensor to be reduced. 87 group_size: an int32 tensor. The total number of tensors to be collectively 88 reduced. Each must reside on a different device. Should be a positive 89 integer. 90 group_key: an int32 tensor identifying the group of devices. 91 instance_key: an int32 tensor identifying the participating group of Ops. 92 merge_op: string naming the binary Op to be applied to compute each partial 93 reduction. 94 final_op: string naming the unary Op to be applied to each fully reduced 95 value. Can be 'Id' for no operation. 96 communication_hint: preferred collective communication. The implementation 97 may fall back to another mechanism. Options include `auto`, `ring`, and 98 `nccl`. 99 timeout: a float. If set to a non zero, set a completion timeout to detect 100 staleness. If the timer goes off, a DeadlineExceededError is raised. The 101 timeout value in seconds. This feature is experimental. 102 ordering_token: an optional resource tensor to pass to the op as inputs. 103 They aren't used by the kernel but allow AutoControlDependency to order 104 the collectives with control dependencies. 105 106 Returns: 107 An Op implementing the distributed reduction. 108 """ 109 if ordering_token is not None: 110 ordering_token = [ordering_token] 111 return gen_collective_ops.collective_reduce_v2( 112 t, 113 group_size=group_size, 114 group_key=group_key, 115 instance_key=instance_key, 116 merge_op=merge_op, 117 final_op=final_op, 118 communication_hint=communication_hint.lower(), 119 timeout_seconds=timeout, 120 ordering_token=ordering_token or []) 121 122 123def all_gather(t, 124 group_size, 125 group_key, 126 instance_key, 127 communication_hint='auto', 128 timeout=0): 129 """Accumulates tensors collectively, across devices, along first dimension. 130 131 Args: 132 t: the tensor to participate in the accumulation. 133 group_size: the total number of tensors to be collectively accumulated. 134 Each must reside on a different device. Should be a positive integer. 135 group_key: an integer identifying the group of devices. 136 instance_key: an integer identifying the participating group of Ops. 137 communication_hint: preferred collective communication. The implementation 138 may fall back to another mechanism. Options include `auto`, `ring`, and 139 `nccl`. 140 timeout: a float. If set to a non zero, set a completion timeout to detect 141 staleness. If the timer goes off, a DeadlineExceededError is raised. The 142 timeout value in seconds. This feature is experimental. 143 144 Returns: 145 An Op implementing the distributed operation. 146 147 Raises: 148 ValueError: if any of the input parameter constraints are not met. 149 """ 150 if group_size < 1: 151 raise ValueError('Parameter group_size to all_gather must be at least 1.') 152 return gen_collective_ops.collective_gather( 153 t, 154 shape=[0], 155 group_size=group_size, 156 group_key=group_key, 157 instance_key=instance_key, 158 communication_hint=communication_hint.lower(), 159 timeout_seconds=timeout) 160 161 162def all_gather_v2(t, 163 group_size, 164 group_key, 165 instance_key, 166 communication_hint='auto', 167 timeout=0, 168 ordering_token=None): 169 """Accumulates tensors collectively, across devices, along first dimension. 170 171 Args: 172 t: the tensor to participate in the accumulation. 173 group_size: an int32 tensor, the total number of tensors to be collectively 174 accumulated. Each must reside on a different device. Should be a positive 175 integer. 176 group_key: an int32 tensor identifying the group of devices. 177 instance_key: an int32 tensor identifying the participating group of Ops. 178 communication_hint: preferred collective communication. The implementation 179 may fall back to another mechanism. Options include `auto`, `ring`, and 180 `nccl`. 181 timeout: a float. If set to a non zero, set a completion timeout to detect 182 staleness. If the timer goes off, a DeadlineExceededError is raised. The 183 timeout value in seconds. This feature is experimental. 184 ordering_token: an optional resource tensor to pass to the op as inputs. 185 They aren't used by the kernel but allow AutoControlDependency to order 186 the collectives with control dependencies. 187 188 Returns: 189 An Op implementing the distributed operation. 190 """ 191 if ordering_token is not None: 192 ordering_token = [ordering_token] 193 return gen_collective_ops.collective_gather_v2( 194 t, 195 group_size=group_size, 196 group_key=group_key, 197 instance_key=instance_key, 198 communication_hint=communication_hint.lower(), 199 timeout_seconds=timeout, 200 ordering_token=ordering_token or []) 201 202 203def broadcast_send(t, 204 shape, 205 dtype, 206 group_size, 207 group_key, 208 instance_key, 209 communication_hint='auto', 210 timeout=0): 211 """Broadcasts one tensor to a group of others, across devices. 212 213 Args: 214 t: the tensor to be sent. 215 shape: the shape of the tensor being sent, which must agree with t. 216 dtype: the type of the tensor being sent, which must agree with t. 217 group_size: one plus the number of receiving tensors, i.e. the total 218 number of devices participating. Each tensor must reside on a 219 different device. 220 group_key: an integer identifying the group of devices. 221 instance_key: an integer identifying the participating group of Ops. 222 communication_hint: preferred collective communication. The implementation 223 may fall back to another mechanism. Options include `auto`, `ring`, and 224 `nccl`. 225 timeout: If set to a non zero, set a completion timeout to detect staleness. 226 If the timer goes off, a DeadlineExceededError is raised. 227 The timeout value in seconds. This feature is experimental. 228 229 Returns: 230 An Op implementing the distributed broadcast send. 231 232 Raises: 233 ValueError: if any of the input parameter constraints are not met. 234 235 Note that the shape and dtype arguments appear redundant since they 236 should be obtainable from t. The are two reasons for including 237 them. First, the shape and type of tensors passed via broadcast must 238 be known ahead of time in their most specific form so that the receive 239 side can allocate memory for the operation and shape/type inference can 240 carry forward from there. Including the same declarations on the 241 send side clarifies a commitment already made. Secondly, having nearly 242 identical use syntax for send and receive sides may simplify tool-driven 243 generation of broadcast. 244 """ 245 if group_size <= 1: 246 raise ValueError( 247 'Parameter group_size to broadcast_send must be at least 2.') 248 if t.shape != shape: 249 raise ValueError( 250 'Shape of broadcast_send tensor not equal to declared shape') 251 if t.dtype != dtype: 252 raise ValueError( 253 'Type of broadcast_send tensor not equal to declared type') 254 return gen_collective_ops.collective_bcast_send( 255 t, 256 shape=shape, 257 group_size=group_size, 258 group_key=group_key, 259 instance_key=instance_key, 260 communication_hint=communication_hint.lower(), 261 timeout_seconds=timeout) 262 263 264def broadcast_send_v2(t, 265 group_size, 266 group_key, 267 instance_key, 268 communication_hint='auto', 269 timeout=0): 270 """Broadcasts one tensor to a group of others, across devices. 271 272 Args: 273 t: the tensor to be sent. 274 group_size: an int32 tensor. One plus the number of receiving tensors, i.e. 275 the total number of devices participating. Each tensor must reside on a 276 different device. 277 group_key: an int32 tensor identifying the group of devices. 278 instance_key: an int32 tensor identifying the participating group of Ops. 279 communication_hint: preferred collective communication. The implementation 280 may fall back to another mechanism. Options include `auto`, `ring`, and 281 `nccl`. 282 timeout: If set to a non zero, set a completion timeout to detect staleness. 283 If the timer goes off, a DeadlineExceededError is raised. 284 The timeout value in seconds. This feature is experimental. 285 286 Returns: 287 An Op implementing the distributed broadcast send. 288 """ 289 return gen_collective_ops.collective_bcast_send_v2( 290 t, 291 group_size=group_size, 292 group_key=group_key, 293 instance_key=instance_key, 294 communication_hint=communication_hint.lower(), 295 timeout_seconds=timeout) 296 297 298def broadcast_recv(shape, 299 dtype, 300 group_size, 301 group_key, 302 instance_key, 303 communication_hint='auto', 304 timeout=0): 305 """Receives a broadcasts tensor, across devices. 306 307 Args: 308 shape: Shape of the tensor to be received. 309 dtype: Type of the tensor to be received. 310 group_size: one plus the number of receiving tensors, i.e. the total 311 number of devices participating. Each tensor must reside on a 312 different device. 313 group_key: an integer identifying the group of devices. 314 instance_key: an integer identifying the participating group of Ops. 315 communication_hint: preferred collective communication. The implementation 316 may fall back to another mechanism. Options include `auto`, `ring`, and 317 `nccl`. 318 timeout: If set to a non zero, set a completion timeout to detect staleness. 319 If the timer goes off, a DeadlineExceededError is raised. 320 The timeout value in seconds. This feature is experimental. 321 322 Returns: 323 An Op implementing the broadcast receive. 324 325 Raises: 326 ValueError: if any of the input parameter constraints are not met. 327 """ 328 if group_size <= 1: 329 raise ValueError( 330 'Parameter group_size to broadcast_send must be at least 2.') 331 return gen_collective_ops.collective_bcast_recv( 332 shape=shape, 333 T=dtype, 334 group_size=group_size, 335 group_key=group_key, 336 instance_key=instance_key, 337 communication_hint=communication_hint.lower(), 338 timeout_seconds=timeout) 339 340 341def broadcast_recv_v2(shape, 342 dtype, 343 group_size, 344 group_key, 345 instance_key, 346 communication_hint='auto', 347 timeout=0): 348 """Receives a broadcasts tensor, across devices. 349 350 Args: 351 shape: an int tensor. Shape of the tensor to be received. 352 dtype: Type of the tensor to be received. 353 group_size: an int32 tensor. One plus the number of receiving tensors, i.e. 354 the total number of devices participating. Each tensor must reside on a 355 different device. 356 group_key: an int32 tensor identifying the group of devices. 357 instance_key: an int32 tensor identifying the participating group of Ops. 358 communication_hint: preferred collective communication. The implementation 359 may fall back to another mechanism. Options include `auto`, `ring`, and 360 `nccl`. 361 timeout: If set to a non zero, set a completion timeout to detect staleness. 362 If the timer goes off, a DeadlineExceededError is raised. 363 The timeout value in seconds. This feature is experimental. 364 365 Returns: 366 An Op implementing the broadcast receive. 367 """ 368 return gen_collective_ops.collective_bcast_recv_v2( 369 T=dtype, 370 group_size=group_size, 371 group_key=group_key, 372 instance_key=instance_key, 373 shape=shape, 374 communication_hint=communication_hint.lower(), 375 timeout_seconds=timeout) 376 377