1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TensorFlow collective Ops.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.ops import gen_collective_ops 21 22 23def all_reduce(t, 24 group_size, 25 group_key, 26 instance_key, 27 merge_op='Add', 28 final_op='Id', 29 subdiv_offsets=(0,), 30 communication_hint='auto', 31 timeout=0): 32 """Reduces tensors collectively, across devices. 33 34 Args: 35 t: the tensor to be reduced. 36 group_size: the total number of tensors to be collectively reduced. 37 Each must reside on a different device. Should be a positive integer. 38 group_key: an integer identifying the group of devices. 39 instance_key: an integer identifying the participating group of Ops. 40 merge_op: string naming the binary Op to be applied to compute each 41 partial reduction. 42 final_op: string naming the unary Op to be applied to each fully 43 reduced value. Can be 'Id' for no operation. 44 subdiv_offsets: a list of integer offsets into the tensor at which each 45 independent subdivision should begin. Use [0] if no subdivision should 46 be done. 47 communication_hint: preferred collective communication. The implementation 48 may fall back to another mechanism. Options include `auto`, `ring`, and 49 `nccl`. 50 timeout: a float. If set to a non zero, set a completion timeout to detect 51 staleness. If the timer goes off, a DeadlineExceededError is raised. The 52 timeout value in seconds. This feature is experimental. 53 54 Returns: 55 An Op implementing the distributed reduction. 56 57 Raises: 58 ValueError: if any of the input parameter constraints are not met. 59 """ 60 if group_size < 1: 61 raise ValueError('Parameter group_size to all_reduce must be at least 1.') 62 return gen_collective_ops.collective_reduce( 63 t, 64 group_size=group_size, 65 group_key=group_key, 66 instance_key=instance_key, 67 merge_op=merge_op, 68 final_op=final_op, 69 subdiv_offsets=subdiv_offsets, 70 communication_hint=communication_hint.lower(), 71 timeout_seconds=timeout) 72 73 74def all_reduce_v2(t, 75 group_size, 76 group_key, 77 instance_key, 78 merge_op='Add', 79 final_op='Id', 80 communication_hint='auto', 81 timeout=0, 82 ordering_token=None, 83 max_subdivs_per_device=-1): 84 """Reduces tensors collectively, across devices. 85 86 Args: 87 t: the tensor to be reduced. 88 group_size: an int32 tensor. The total number of tensors to be collectively 89 reduced. Each must reside on a different device. Should be a positive 90 integer. 91 group_key: an int32 tensor identifying the group of devices. 92 instance_key: an int32 tensor identifying the participating group of Ops. 93 merge_op: string naming the binary Op to be applied to compute each partial 94 reduction. 95 final_op: string naming the unary Op to be applied to each fully reduced 96 value. Can be 'Id' for no operation. 97 communication_hint: preferred collective communication. The implementation 98 may fall back to another mechanism. Options include `auto`, `ring`, and 99 `nccl`. 100 timeout: a float. If set to a non zero, set a completion timeout to detect 101 staleness. If the timer goes off, a DeadlineExceededError is raised. The 102 timeout value in seconds. This feature is experimental. 103 ordering_token: an optional resource tensor to pass to the op as inputs. 104 They aren't used by the kernel but allow AutoControlDependency to order 105 the collectives with control dependencies. 106 max_subdivs_per_device: int specifying the maximum number of subdivisions a 107 tensor on a device can be divided into. The runtime uses this contraint to 108 parallelize processing of each per-device tensor. Setting to -1 disables 109 subdivision and reverts to previous behavior of not sub-dividing tensor. 110 Setting to 0 uses sytem defaults. 111 112 Returns: 113 An Op implementing the distributed reduction. 114 """ 115 if ordering_token is not None: 116 ordering_token = [ordering_token] 117 return gen_collective_ops.collective_reduce_v2( 118 t, 119 group_size=group_size, 120 group_key=group_key, 121 instance_key=instance_key, 122 merge_op=merge_op, 123 final_op=final_op, 124 communication_hint=communication_hint.lower(), 125 timeout_seconds=timeout, 126 ordering_token=ordering_token or [], 127 max_subdivs_per_device=max_subdivs_per_device) 128 129 130def all_gather(t, 131 group_size, 132 group_key, 133 instance_key, 134 communication_hint='auto', 135 timeout=0): 136 """Accumulates tensors collectively, across devices, along first dimension. 137 138 Args: 139 t: the tensor to participate in the accumulation. 140 group_size: the total number of tensors to be collectively accumulated. 141 Each must reside on a different device. Should be a positive integer. 142 group_key: an integer identifying the group of devices. 143 instance_key: an integer identifying the participating group of Ops. 144 communication_hint: preferred collective communication. The implementation 145 may fall back to another mechanism. Options include `auto`, `ring`, and 146 `nccl`. 147 timeout: a float. If set to a non zero, set a completion timeout to detect 148 staleness. If the timer goes off, a DeadlineExceededError is raised. The 149 timeout value in seconds. This feature is experimental. 150 151 Returns: 152 An Op implementing the distributed operation. 153 154 Raises: 155 ValueError: if any of the input parameter constraints are not met. 156 """ 157 if group_size < 1: 158 raise ValueError('Parameter group_size to all_gather must be at least 1.') 159 return gen_collective_ops.collective_gather( 160 t, 161 shape=[0], 162 group_size=group_size, 163 group_key=group_key, 164 instance_key=instance_key, 165 communication_hint=communication_hint.lower(), 166 timeout_seconds=timeout) 167 168 169def all_gather_v2(t, 170 group_size, 171 group_key, 172 instance_key, 173 communication_hint='auto', 174 timeout=0, 175 ordering_token=None): 176 """Accumulates tensors collectively, across devices, along first dimension. 177 178 Args: 179 t: the tensor to participate in the accumulation. 180 group_size: an int32 tensor, the total number of tensors to be collectively 181 accumulated. Each must reside on a different device. Should be a positive 182 integer. 183 group_key: an int32 tensor identifying the group of devices. 184 instance_key: an int32 tensor identifying the participating group of Ops. 185 communication_hint: preferred collective communication. The implementation 186 may fall back to another mechanism. Options include `auto`, `ring`, and 187 `nccl`. 188 timeout: a float. If set to a non zero, set a completion timeout to detect 189 staleness. If the timer goes off, a DeadlineExceededError is raised. The 190 timeout value in seconds. This feature is experimental. 191 ordering_token: an optional resource tensor to pass to the op as inputs. 192 They aren't used by the kernel but allow AutoControlDependency to order 193 the collectives with control dependencies. 194 195 Returns: 196 An Op implementing the distributed operation. 197 """ 198 if ordering_token is not None: 199 ordering_token = [ordering_token] 200 return gen_collective_ops.collective_gather_v2( 201 t, 202 group_size=group_size, 203 group_key=group_key, 204 instance_key=instance_key, 205 communication_hint=communication_hint.lower(), 206 timeout_seconds=timeout, 207 ordering_token=ordering_token or []) 208 209 210def broadcast_send(t, 211 shape, 212 dtype, 213 group_size, 214 group_key, 215 instance_key, 216 communication_hint='auto', 217 timeout=0): 218 """Broadcasts one tensor to a group of others, across devices. 219 220 Args: 221 t: the tensor to be sent. 222 shape: the shape of the tensor being sent, which must agree with t. 223 dtype: the type of the tensor being sent, which must agree with t. 224 group_size: one plus the number of receiving tensors, i.e. the total 225 number of devices participating. Each tensor must reside on a 226 different device. 227 group_key: an integer identifying the group of devices. 228 instance_key: an integer identifying the participating group of Ops. 229 communication_hint: preferred collective communication. The implementation 230 may fall back to another mechanism. Options include `auto`, `ring`, and 231 `nccl`. 232 timeout: If set to a non zero, set a completion timeout to detect staleness. 233 If the timer goes off, a DeadlineExceededError is raised. 234 The timeout value in seconds. This feature is experimental. 235 236 Returns: 237 An Op implementing the distributed broadcast send. 238 239 Raises: 240 ValueError: if any of the input parameter constraints are not met. 241 242 Note that the shape and dtype arguments appear redundant since they 243 should be obtainable from t. The are two reasons for including 244 them. First, the shape and type of tensors passed via broadcast must 245 be known ahead of time in their most specific form so that the receive 246 side can allocate memory for the operation and shape/type inference can 247 carry forward from there. Including the same declarations on the 248 send side clarifies a commitment already made. Secondly, having nearly 249 identical use syntax for send and receive sides may simplify tool-driven 250 generation of broadcast. 251 """ 252 if group_size <= 1: 253 raise ValueError( 254 'Parameter group_size to broadcast_send must be at least 2.') 255 if t.shape != shape: 256 raise ValueError( 257 'Shape of broadcast_send tensor not equal to declared shape') 258 if t.dtype != dtype: 259 raise ValueError( 260 'Type of broadcast_send tensor not equal to declared type') 261 return gen_collective_ops.collective_bcast_send( 262 t, 263 shape=shape, 264 group_size=group_size, 265 group_key=group_key, 266 instance_key=instance_key, 267 communication_hint=communication_hint.lower(), 268 timeout_seconds=timeout) 269 270 271def broadcast_send_v2(t, 272 group_size, 273 group_key, 274 instance_key, 275 communication_hint='auto', 276 timeout=0): 277 """Broadcasts one tensor to a group of others, across devices. 278 279 Args: 280 t: the tensor to be sent. 281 group_size: an int32 tensor. One plus the number of receiving tensors, i.e. 282 the total number of devices participating. Each tensor must reside on a 283 different device. 284 group_key: an int32 tensor identifying the group of devices. 285 instance_key: an int32 tensor identifying the participating group of Ops. 286 communication_hint: preferred collective communication. The implementation 287 may fall back to another mechanism. Options include `auto`, `ring`, and 288 `nccl`. 289 timeout: If set to a non zero, set a completion timeout to detect staleness. 290 If the timer goes off, a DeadlineExceededError is raised. 291 The timeout value in seconds. This feature is experimental. 292 293 Returns: 294 An Op implementing the distributed broadcast send. 295 """ 296 return gen_collective_ops.collective_bcast_send_v2( 297 t, 298 group_size=group_size, 299 group_key=group_key, 300 instance_key=instance_key, 301 communication_hint=communication_hint.lower(), 302 timeout_seconds=timeout) 303 304 305def broadcast_recv(shape, 306 dtype, 307 group_size, 308 group_key, 309 instance_key, 310 communication_hint='auto', 311 timeout=0): 312 """Receives a broadcasts tensor, across devices. 313 314 Args: 315 shape: Shape of the tensor to be received. 316 dtype: Type of the tensor to be received. 317 group_size: one plus the number of receiving tensors, i.e. the total 318 number of devices participating. Each tensor must reside on a 319 different device. 320 group_key: an integer identifying the group of devices. 321 instance_key: an integer identifying the participating group of Ops. 322 communication_hint: preferred collective communication. The implementation 323 may fall back to another mechanism. Options include `auto`, `ring`, and 324 `nccl`. 325 timeout: If set to a non zero, set a completion timeout to detect staleness. 326 If the timer goes off, a DeadlineExceededError is raised. 327 The timeout value in seconds. This feature is experimental. 328 329 Returns: 330 An Op implementing the broadcast receive. 331 332 Raises: 333 ValueError: if any of the input parameter constraints are not met. 334 """ 335 if group_size <= 1: 336 raise ValueError( 337 'Parameter group_size to broadcast_send must be at least 2.') 338 return gen_collective_ops.collective_bcast_recv( 339 shape=shape, 340 T=dtype, 341 group_size=group_size, 342 group_key=group_key, 343 instance_key=instance_key, 344 communication_hint=communication_hint.lower(), 345 timeout_seconds=timeout) 346 347 348def broadcast_recv_v2(shape, 349 dtype, 350 group_size, 351 group_key, 352 instance_key, 353 communication_hint='auto', 354 timeout=0): 355 """Receives a broadcasts tensor, across devices. 356 357 Args: 358 shape: an int tensor. Shape of the tensor to be received. 359 dtype: Type of the tensor to be received. 360 group_size: an int32 tensor. One plus the number of receiving tensors, i.e. 361 the total number of devices participating. Each tensor must reside on a 362 different device. 363 group_key: an int32 tensor identifying the group of devices. 364 instance_key: an int32 tensor identifying the participating group of Ops. 365 communication_hint: preferred collective communication. The implementation 366 may fall back to another mechanism. Options include `auto`, `ring`, and 367 `nccl`. 368 timeout: If set to a non zero, set a completion timeout to detect staleness. 369 If the timer goes off, a DeadlineExceededError is raised. 370 The timeout value in seconds. This feature is experimental. 371 372 Returns: 373 An Op implementing the broadcast receive. 374 """ 375 return gen_collective_ops.collective_bcast_recv_v2( 376 T=dtype, 377 group_size=group_size, 378 group_key=group_key, 379 instance_key=instance_key, 380 shape=shape, 381 communication_hint=communication_hint.lower(), 382 timeout_seconds=timeout) 383