1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities to construct a TF subgraph implementing distributed All-Reduce.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import math 23 24from tensorflow.python.framework import device as device_lib 25from tensorflow.python.framework import ops 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import math_ops 28from tensorflow.python.ops import nccl_ops 29 30 31def _flatten_tensors(tensors): 32 """Check tensors for isomorphism and flatten. 33 34 Args: 35 tensors: list of `tf.Tensor` which must all have the same shape. 36 37 Returns: 38 tensors: a list of `tf.Tensor` which are flattened (1D) views of tensors 39 shape: the original shape of each element of input tensors 40 41 Raises: 42 ValueError: tensors are empty or non-isomorphic or have unknown shape. 43 """ 44 if not tensors: 45 raise ValueError("tensors cannot be empty") 46 shape = tensors[0].shape 47 for tensor in tensors: 48 shape = shape.merge_with(tensor.shape) 49 if not shape.is_fully_defined(): 50 raise ValueError("Tensors must have statically known shape.") 51 if len(shape) != 1: 52 reshaped = [] 53 for t in tensors: 54 with ops.colocate_with(t): 55 reshaped.append(array_ops.reshape(t, [-1])) 56 tensors = reshaped 57 return tensors, shape 58 59 60def _reshape_tensors(tensors, shape): 61 """Reshape tensors flattened by _flatten_tensors. 62 63 Args: 64 tensors: list of `tf.Tensor` of identical length 1D tensors. 65 shape: list of integers describing the desired shape. Product of 66 the elements must equal the length of each tensor. 67 68 Returns: 69 list of `tf.Tensor` which are the reshaped inputs. 70 """ 71 reshaped = [] 72 for t in tensors: 73 with ops.colocate_with(t): 74 reshaped.append(array_ops.reshape(t, shape)) 75 return reshaped 76 77 78def _padded_split(tensor, pieces): 79 """Like split for 1D tensors but pads-out case where len % pieces != 0. 80 81 Args: 82 tensor: `tf.Tensor` that must be 1D. 83 pieces: a positive integer specifying the number of pieces into which 84 tensor should be split. 85 86 Returns: 87 list of `tf.Tensor` of length pieces, which hold the values of 88 thin input tensor, in order. The final tensor may 89 be zero-padded on the end to make its size equal to those of all 90 of the other tensors. 91 92 Raises: 93 ValueError: The input tensor is not 1D. 94 """ 95 shape = tensor.shape 96 if 1 != len(shape): 97 raise ValueError("input tensor must be 1D") 98 tensor_len = shape.dims[0].value 99 with ops.colocate_with(tensor): 100 if tensor_len % pieces != 0: 101 # pad to an even length 102 chunk_size = 1 + tensor_len // pieces 103 if pieces > tensor_len: 104 # This is an edge case that should not come up in practice, 105 # i.e. a different reduction algorithm would be better, 106 # but we'll make it work just for completeness. 107 pad_len = pieces - tensor_len 108 extended_whole = array_ops.concat( 109 [tensor, array_ops.zeros([pad_len], dtype=tensor.dtype)], 0) 110 parts = array_ops.split(extended_whole, pieces) 111 return parts, pad_len 112 elif (pieces - 1) * chunk_size >= tensor_len: 113 # Another edge case of limited real interest. 114 pad_len = (pieces * chunk_size) % tensor_len 115 extended_whole = array_ops.concat( 116 [tensor, array_ops.zeros([pad_len], dtype=tensor.dtype)], 0) 117 parts = array_ops.split(extended_whole, pieces) 118 return parts, pad_len 119 else: 120 last_chunk_size = tensor_len - (pieces - 1) * chunk_size 121 pad_len = chunk_size - last_chunk_size 122 piece_lens = [chunk_size for _ in range(pieces - 1)] + [last_chunk_size] 123 parts = array_ops.split(tensor, piece_lens) 124 parts[-1] = array_ops.concat( 125 [parts[-1], array_ops.zeros([pad_len], dtype=tensor.dtype)], 0) 126 return parts, pad_len 127 else: 128 return array_ops.split(tensor, pieces), 0 129 130 131def _strip_padding(tensors, pad_len): 132 """Strip the suffix padding added by _padded_split. 133 134 Args: 135 tensors: list of `tf.Tensor` of identical length 1D tensors. 136 pad_len: number of elements to be stripped from the end of each tensor. 137 138 Returns: 139 list of `tf.Tensor` which are the stripped inputs. 140 141 Raises: 142 ValueError: tensors must be a non-empty list of 1D tensors, and 143 each must be longer than pad_len. 144 """ 145 if not tensors: 146 raise ValueError("tensors cannot be empty") 147 shape = tensors[0].shape 148 if len(shape) > 1: 149 raise ValueError("tensors must be 1D") 150 prefix_len = int(shape[0] - pad_len) 151 if prefix_len < 0: 152 raise ValueError("pad_len longer than tensor") 153 stripped = [] 154 for t in tensors: 155 with ops.colocate_with(t): 156 stripped.append(array_ops.slice(t, [0], [prefix_len])) 157 return stripped 158 159 160def _ragged_split(tensor, pieces): 161 """Like split for 1D tensors but allows case where len % pieces != 0. 162 163 Args: 164 tensor: `tf.Tensor` that must be 1D. 165 pieces: a positive integer specifying the number of pieces into which 166 tensor should be split. 167 168 Returns: 169 list of `tf.Tensor` of length pieces, which hold the values of 170 the input tensor, in order. The final tensor may be shorter 171 than the others, which will all be of equal length. 172 173 Raises: 174 ValueError: input tensor must be 1D. 175 """ 176 shape = tensor.shape 177 if 1 != len(shape): 178 raise ValueError("input tensor must be 1D") 179 tensor_len = shape.dims[0].value 180 chunk_size = tensor_len // pieces 181 with ops.colocate_with(tensor): 182 if tensor_len != (pieces * chunk_size): 183 # last piece will be short 184 assert pieces > 1 185 last_chunk_size = tensor_len - ((pieces - 1) * chunk_size) 186 assert last_chunk_size > 0 187 piece_lens = [chunk_size for _ in range(pieces - 1)] + [last_chunk_size] 188 return array_ops.split(tensor, piece_lens) 189 else: 190 return array_ops.split(tensor, pieces) 191 192 193def _ring_permutations(num_workers, num_subchunks, gpu_perm): 194 """"Generate an array of device index arrays, one for each subchunk. 195 196 In the basic ring reduction algorithm there are size(T)/num_devices 197 data chunks and each device process one chunk per tick, i.e. sending 198 one chunk and receiving one chunk. The idea of subchunking is that 199 each device processes num_subchunks smaller data regions per tick, 200 and the ring rank permutation is different for each subchunk index 201 so that a device is potentially sending to and receiving from 202 num_subchunks different other devices at each tick. Where multiple 203 independent data channels exist between devices, this strategy 204 supplies a method of using them in parallel. 205 206 Args: 207 num_workers: number of worker tasks 208 num_subchunks: number of subchunks into which to divide each per-GPU chunk. 209 gpu_perm: an array of integers in [0, num_gpus-1] giving the default 210 ring order of GPUs at each worker. Other permutations will be generated 211 by rotating this array and splicing together per-worker instances. 212 213 Raises: 214 ValueError: the number of subchunks may not exceed the number of GPUs. 215 216 Returns: 217 pred_by_s_d: list of lists that maps (by index) from (subchunk, dev) to 218 preceding device in the permutation for that subchunk. The 219 device index of GPU i at worker j is i + (j * num_gpus). 220 rank_by_s_d: list of lists that maps (by index) from (subchunk, dev) to 221 local rank of device d in the permutation for that subchunk. 222 """ 223 num_gpus = len(gpu_perm) 224 devices = num_workers * num_gpus 225 if devices == 0: 226 return [], [] 227 if num_subchunks > num_gpus: 228 raise ValueError( 229 "num_subchunks %d must be <= num_gpus %d" % (num_subchunks, num_gpus)) 230 rotation_interval = max(1, int(num_gpus / num_subchunks)) 231 perms_by_s = [] 232 for s in range(0, num_subchunks): 233 full_order = [] 234 offset = s * rotation_interval 235 for w in range(0, num_workers): 236 default_order = [(w * num_gpus) + i for i in gpu_perm] 237 dev_order = default_order[offset:] + default_order[:offset] 238 full_order += dev_order 239 perms_by_s.append(full_order) 240 pred_by_s_d = [[-1 for d in range(0, devices)] 241 for s in range(0, num_subchunks)] 242 rank_by_s_d = [[-1 for d in range(0, devices)] 243 for s in range(0, num_subchunks)] 244 for s in range(0, num_subchunks): 245 for d in range(0, devices): 246 for t in range(0, devices): 247 if d == perms_by_s[s][t]: 248 rank_by_s_d[s][d] = t 249 pred_by_s_d[s][d] = perms_by_s[s][(t + devices - 1) % devices] 250 break 251 return (pred_by_s_d, rank_by_s_d) 252 253 254def build_ring_all_reduce(input_tensors, num_workers, num_subchunks, 255 gpu_perm, red_op, un_op=None): 256 """Construct a subgraph performing a ring-style all-reduce of input_tensors. 257 258 Args: 259 input_tensors: a list of `tf.Tensor` objects, which must all 260 have the same shape and type. 261 num_workers: number of worker tasks spanned by input_tensors. 262 num_subchunks: number of subchunks each device should process in one tick. 263 gpu_perm: a list of ints giving a ring-wise rank ordering of GPUs at 264 each worker. All workers must have the same number of 265 GPUs with the same rank ordering. If NVLINK is available, this should 266 be a ring order supported by NVLINK edges. 267 red_op: a binary operator for elementwise reduction. 268 un_op: an optional unary operator to apply to fully reduced values. 269 270 Raises: 271 ValueError: empty input_tensors or they don't all have same 272 size. 273 274 Returns: 275 a list of `tf.Tensor` identical sum-reductions of input_tensors. 276 """ 277 if len(input_tensors) < 2: 278 raise ValueError("input_tensors must be length 2 or longer") 279 input_tensors, shape = _flatten_tensors(input_tensors) 280 devices = [t.device for t in input_tensors] 281 (pred_by_s_d, rank_by_s_d) = _ring_permutations( 282 num_workers, num_subchunks, gpu_perm) 283 chunks_by_dev, pad_len = _build_ring_gather( 284 input_tensors, devices, 285 num_subchunks, pred_by_s_d, rank_by_s_d, red_op) 286 if un_op: 287 chunks_by_dev = _apply_unary_to_chunks(un_op, chunks_by_dev) 288 output_tensors = _build_ring_scatter(pred_by_s_d, rank_by_s_d, 289 chunks_by_dev) 290 if pad_len > 0: 291 output_tensors = _strip_padding(output_tensors, pad_len) 292 if len(shape) != 1: 293 output_tensors = _reshape_tensors(output_tensors, shape) 294 return output_tensors 295 296 297def _build_ring_gather(input_tensors, devices, num_subchunks, 298 pred_by_s_d, rank_by_s_d, red_op): 299 """Construct a subgraph for the first (reduction) pass of ring all-reduce. 300 301 Args: 302 input_tensors: a list of `tf.Tensor` 1D input tensors of same 303 shape and type. 304 devices: array of device name strings 305 num_subchunks: number of subchunks each device should process in one tick. 306 pred_by_s_d: as produced by _ring_permutations 307 rank_by_s_d: as produced by _ring_permutations 308 red_op: a binary operator for elementwise reduction 309 310 Raises: 311 ValueError: tensors must all be one dimensional. 312 313 Returns: 314 list of list of `tf.Tensor` of (partially) reduced values where 315 exactly num_subchunks chunks at each device are fully reduced. 316 """ 317 num_devices = len(input_tensors) 318 if num_devices == 0: 319 return [] 320 if num_devices == 1: 321 return input_tensors 322 shape = input_tensors[0].shape 323 if 1 != len(shape): 324 raise ValueError("input tensors must be 1D") 325 num_chunks = num_devices * num_subchunks 326 num_ticks = num_devices - 1 327 # Initialize chunks_by_dev with splits of the input tensors. 328 chunks_by_dev = [] 329 split_pad_len = 0 330 for d in range(0, num_devices): 331 with ops.device(devices[d]): 332 splits, split_pad_len = _padded_split(input_tensors[d], num_chunks) 333 chunks_by_dev.append(splits) 334 # Reduction phase 335 for tick in range(0, num_ticks): 336 # One new partial reduction for every chunk 337 new_partial_reductions = [None for _ in range(0, num_chunks)] 338 # Compute reductions with respect to last tick's values 339 for d in range(0, num_devices): 340 with ops.device(devices[d]): 341 for s in range(0, num_subchunks): 342 rank = rank_by_s_d[s][d] 343 seg_index = (rank + num_devices - (2 + tick)) % num_devices 344 pred_dev = pred_by_s_d[s][d] 345 chunk_index = (seg_index * num_subchunks) + s 346 new_partial_reductions[chunk_index] = red_op( 347 chunks_by_dev[pred_dev][chunk_index], 348 chunks_by_dev[d][chunk_index]) 349 # Update chunks_by_dev with the new values at the end of the tick. 350 for d in range(0, num_devices): 351 for s in range(0, num_subchunks): 352 rank = rank_by_s_d[s][d] 353 seg_index = (rank + num_devices - (2 + tick)) % num_devices 354 chunk_index = (seg_index * num_subchunks) + s 355 chunks_by_dev[d][chunk_index] = new_partial_reductions[chunk_index] 356 return chunks_by_dev, split_pad_len 357 358 359def _apply_unary_to_chunks(f, chunks_by_dev): 360 """Apply a unary op to each tensor in chunks_by_dev, on same device. 361 362 Args: 363 f: a unary function over `tf.Tensor`. 364 chunks_by_dev: list of lists of `tf.Tensor`. 365 366 Returns: 367 new list of lists of `tf.Tensor` with the same structure as 368 chunks_by_dev containing the derived tensors. 369 """ 370 output = [] 371 for x in chunks_by_dev: 372 with ops.colocate_with(x[0]): 373 output.append([f(t) for t in x]) 374 return output 375 376 377def _build_ring_scatter(pred_by_s_d, rank_by_s_d, 378 chunks_by_dev): 379 """Construct subgraph for second (scatter) pass of ring all-reduce. 380 381 Args: 382 pred_by_s_d: as produced by _ring_permutations 383 rank_by_s_d: as produced by _ring_permutations 384 chunks_by_dev: list of list of `tf.Tensor` indexed by ints 385 (device, chunk) 386 387 Raises: 388 ValueError: chunks_by_dev is not well-formed 389 390 Returns: 391 list of `tf.Tensor` which are the fully reduced tensors, one 392 at each device corresponding to the outer dimension of chunks_by_dev. 393 """ 394 num_devices = len(chunks_by_dev) 395 num_chunks = len(chunks_by_dev[0]) 396 if 0 != num_chunks % num_devices: 397 raise ValueError( 398 "Expect number of chunks per device to be divisible by num_devices") 399 num_subchunks = int(num_chunks / num_devices) 400 num_ticks = num_devices - 1 401 for tick in range(0, num_ticks): 402 passed_values = [None for _ in range(0, num_chunks)] 403 for d in range(0, num_devices): 404 with ops.colocate_with(chunks_by_dev[d][0]): 405 for s in range(0, num_subchunks): 406 rank = rank_by_s_d[s][d] 407 seg_index = (rank + num_devices - (1 + tick)) % num_devices 408 pred_dev = pred_by_s_d[s][d] 409 chunk_index = (seg_index * num_subchunks) + s 410 passed_values[chunk_index] = array_ops.identity( 411 chunks_by_dev[pred_dev][chunk_index]) 412 for d in range(0, num_devices): 413 for s in range(0, num_subchunks): 414 rank = rank_by_s_d[s][d] 415 seg_index = (rank + num_devices - (1 + tick)) % num_devices 416 chunk_index = (seg_index * num_subchunks) + s 417 chunks_by_dev[d][chunk_index] = passed_values[chunk_index] 418 # Join chunks at each device. 419 output = [] 420 for x in chunks_by_dev: 421 with ops.colocate_with(x[0]): 422 output.append(array_ops.concat(x, 0)) 423 return output 424 425 426def build_recursive_hd_all_reduce(input_tensors, red_op, un_op=None): 427 """Construct a subgraph for recursive halving-doubling all-reduce. 428 429 The recursive halving-doubling algorithm is described in 430 (Thakur et al., 2015). 431 432 The concept is to arrange the participating n devices in 433 a linear sequence where devices exchange data pairwise 434 with one other device in each round. During the gather 435 phase there are lg(n) rounds where devices exchange 436 increasingly smaller sub-tensors with another device 437 at increasingly greater distances, until at the top 438 each device has 1/n of the fully reduced values. During the 439 scatter phase each device exchanges its fully reduced 440 sub-tensor (which doubles in length at each round) 441 with one other device at increasingly smaller distances 442 until each device has all of the fully reduced values. 443 444 Note: this preliminary version requires that len(input_tensors) be a 445 power of 2. TODO(tucker): relax this restriction. Also, the 446 number of elements in each tensor must be divisible by 2^h where h 447 is the number of hops in each phase. This will also be relaxed in 448 the future with edge-case specific logic. 449 450 Args: 451 input_tensors: list of `tf.Tensor` to be elementwise reduced. 452 red_op: a binary elementwise reduction Op. 453 un_op: an optional unary elementwise Op to apply to reduced values. 454 455 Returns: 456 list of `tf.Tensor` which are the fully reduced tensors, one 457 at each device of input_tensors. 458 459 Raises: 460 ValueError: num_devices not a power of 2, or tensor len not divisible 461 by 2 the proper number of times. 462 463 References: 464 Optimization of Collective Communication Operations in MPICH: 465 [Thakur et al., 2005] 466 (https://journals.sagepub.com/doi/abs/10.1177/1094342005051521) 467 ([pdf](http://wwwi10.lrr.in.tum.de/~gerndt/home/Teaching/HPCSeminar/mpich_multi_coll.pdf)) 468 """ 469 devices = [t.device for t in input_tensors] 470 input_tensors, shape = _flatten_tensors(input_tensors) 471 reduced_shards = _build_recursive_hd_gather(input_tensors, devices, red_op) 472 if un_op: 473 reduced_shards = [un_op(t) for t in reduced_shards] 474 output_tensors = _build_recursive_hd_scatter(reduced_shards, devices) 475 if len(shape) != 1: 476 output_tensors = _reshape_tensors(output_tensors, shape) 477 return output_tensors 478 479 480def _build_recursive_hd_gather(input_tensors, devices, red_op): 481 """Construct the gather phase of recursive halving-doubling all-reduce. 482 483 Args: 484 input_tensors: list of `tf.Tensor` to be elementwise reduced. 485 devices: a list of strings naming the devices hosting input_tensors, 486 which will also be used to host the (partial) reduction values. 487 red_op: a binary elementwise reduction Op. 488 489 Returns: 490 list of `tf.Tensor` which are the fully reduced tensor shards. 491 492 Raises: 493 ValueError: num_devices not a power of 2, or tensor len not divisible 494 by 2 the proper number of times. 495 """ 496 num_devices = len(devices) 497 num_hops = int(math.log(num_devices, 2)) 498 if num_devices != (2 ** num_hops): 499 raise ValueError("num_devices must be a power of 2") 500 chunks = input_tensors 501 for h in range(0, num_hops): 502 span = 2 ** h 503 group_size = span * 2 504 new_chunks = [[] for _ in devices] 505 for d in range(0, num_devices): 506 if (d % group_size) >= (group_size / 2): 507 # skip right half of a pair 508 continue 509 left_dev = devices[d] 510 right_dev = devices[d + span] 511 left_split = array_ops.split(chunks[d], 2) 512 right_split = array_ops.split(chunks[d+span], 2) 513 with ops.device(left_dev): 514 new_chunks[d] = red_op(left_split[0], right_split[0]) 515 with ops.device(right_dev): 516 new_chunks[d + span] = red_op(left_split[1], right_split[1]) 517 chunks = new_chunks 518 return chunks 519 520 521def _build_recursive_hd_scatter(input_tensors, devices): 522 """Construct the scatter phase of recursive halving-doubling all-reduce. 523 524 Args: 525 input_tensors: list of `tf.Tensor` that are fully-reduced shards. 526 devices: a list of strings naming the devices on which the reconstituted 527 full tensors should be placed. 528 529 Returns: 530 list of `tf.Tensor` which are the fully reduced tensors. 531 """ 532 num_devices = len(devices) 533 num_hops = int(math.log(num_devices, 2)) 534 assert num_devices == (2 ** num_hops), "num_devices must be a power of 2" 535 chunks = input_tensors 536 for h in reversed(range(0, num_hops)): 537 span = 2 ** h 538 group_size = span * 2 539 new_chunks = [[] for _ in devices] 540 for d in range(0, num_devices): 541 if (d % group_size) >= (group_size / 2): 542 # skip right half of a pair 543 continue 544 left_idx = d 545 right_idx = d + span 546 left_dev = devices[left_idx] 547 right_dev = devices[right_idx] 548 with ops.device(left_dev): 549 new_chunks[left_idx] = array_ops.concat([chunks[left_idx], 550 chunks[right_idx]], 0) 551 with ops.device(right_dev): 552 new_chunks[right_idx] = array_ops.concat([chunks[left_idx], 553 chunks[right_idx]], 0) 554 chunks = new_chunks 555 return chunks 556 557 558def build_shuffle_all_reduce(input_tensors, gather_devices, red_op, un_op=None): 559 """Construct a subgraph for shuffle all-reduce. 560 561 Shuffle reduce is essentially the algorithm implemented when using 562 parameter servers. Suppose tensor length is n, there are d devices 563 and g gather shards. Each device sends a n/g length sub-tensor to 564 each gather shard. The gather shards perform a reduction across d 565 fragments, then broadcast the result back to each device. The 566 devices then join the g fully reduced fragments they receive from 567 the shards. The gather shards could perform d-1 pairwise 568 reductions, or one d-way reduction. The first is better where 569 reduction Op time is low compared to transmission time, the second 570 better in the other case. 571 572 Args: 573 input_tensors: list of `tf.Tensor` values to be reduced. 574 gather_devices: list of names of devices on which reduction shards 575 should be placed. 576 red_op: an n-array elementwise reduction Op 577 un_op: optional elementwise unary Op to be applied to fully-reduced values. 578 579 Returns: 580 list of `tf.Tensor` which are the fully reduced tensors. 581 """ 582 input_tensors, shape = _flatten_tensors(input_tensors) 583 dst_devices = [t.device for t in input_tensors] 584 reduced_shards = _build_shuffle_gather(input_tensors, gather_devices, 585 red_op, un_op) 586 output_tensors = _build_shuffle_scatter(reduced_shards, dst_devices) 587 if len(shape) != 1: 588 output_tensors = _reshape_tensors(output_tensors, shape) 589 return output_tensors 590 591 592def _build_shuffle_gather(input_tensors, gather_devices, red_op, un_op=None): 593 """Construct the gather (concentrate and reduce) phase of shuffle all-reduce. 594 595 Args: 596 input_tensors: list of `tf.Tensor` values to be reduced. 597 gather_devices: list of names of devices on which reduction shards 598 should be placed. 599 red_op: the binary reduction Op 600 un_op: optional elementwise unary Op to be applied to fully-reduced values. 601 602 Returns: 603 list of `tf.Tensor` which are the fully reduced shards. 604 605 Raises: 606 ValueError: inputs not well-formed. 607 """ 608 num_source_devices = len(input_tensors) 609 num_gather_devices = len(gather_devices) 610 shape = input_tensors[0].shape 611 if len(shape) != 1: 612 raise ValueError("input_tensors must be 1D") 613 shards_by_source = [] 614 for d in range(0, num_source_devices): 615 with ops.colocate_with(input_tensors[d]): 616 shards_by_source.append( 617 _ragged_split(input_tensors[d], num_gather_devices)) 618 reduced_shards = [] 619 for d in range(0, num_gather_devices): 620 with ops.device(gather_devices[d]): 621 values = [s[d] for s in shards_by_source] 622 red_shard = red_op(values) 623 if un_op: 624 red_shard = un_op(red_shard) 625 reduced_shards.append(red_shard) 626 return reduced_shards 627 628 629def _build_shuffle_scatter(reduced_shards, dst_devices): 630 """Build the scatter phase of shuffle all-reduce. 631 632 Args: 633 reduced_shards: list of `tf.Tensor` fully reduced shards 634 dst_devices: list of names of devices at which the fully-reduced value 635 should be reconstituted. 636 637 Returns: 638 list of `tf.Tensor` scattered tensors. 639 """ 640 num_devices = len(dst_devices) 641 out_tensors = [] 642 for d in range(0, num_devices): 643 with ops.device(dst_devices[d]): 644 out_tensors.append(array_ops.concat(reduced_shards, 0)) 645 return out_tensors 646 647 648def _split_by_task(devices, values): 649 """Partition devices and values by common task. 650 651 Args: 652 devices: list of device name strings 653 values: list of `tf.Tensor` of same length as devices. 654 655 Returns: 656 (per_task_devices, per_task_values) where both values are 657 lists of lists with isomorphic structure: the outer list is 658 indexed by task, and the inner list has length of the number 659 of values belonging to that task. per_task_devices contains 660 the specific devices to which the values are local, and 661 per_task_values contains the corresponding values. 662 663 Raises: 664 ValueError: devices must be same length as values. 665 """ 666 num_devices = len(devices) 667 if num_devices != len(values): 668 raise ValueError("len(devices) must equal len(values)") 669 per_task_devices = collections.OrderedDict() 670 per_task_values = collections.OrderedDict() 671 for d in range(num_devices): 672 d_spec = device_lib.DeviceSpec.from_string(devices[d]) 673 if not hasattr(d_spec, "task") or d_spec.task is None: 674 assert False, "failed to parse device %s" % devices[d] 675 index = (d_spec.job or "localhost", d_spec.replica or 0, d_spec.task) 676 if index not in per_task_devices: 677 per_task_devices[index] = [] 678 per_task_values[index] = [] 679 per_task_devices[index].append(devices[d]) 680 per_task_values[index].append(values[d]) 681 682 return (list(per_task_devices.values()), list(per_task_values.values())) 683 684 685def build_nccl_all_reduce(input_tensors, red_op, un_op=None): 686 """Build a subgraph that does one full all-reduce, using NCCL. 687 688 Args: 689 input_tensors: list of `tf.Tensor` of same-shape and type values to 690 be reduced. 691 red_op: binary elementwise reduction operator. Must be one of 692 {tf.add} 693 un_op: optional unary elementwise Op to apply to fully-reduce values. 694 695 Returns: 696 list of `tf.Tensor` of reduced values. 697 698 Raises: 699 ValueError: red_op not supported. 700 """ 701 if red_op == math_ops.add: 702 output_tensors = nccl_ops.all_sum(input_tensors) 703 else: 704 raise ValueError("red_op not supported by NCCL all-reduce: ", red_op) 705 if un_op: 706 un_op_wrapped = [] 707 for t in output_tensors: 708 with ops.colocate_with(t): 709 un_op_wrapped.append(un_op(t)) 710 output_tensors = un_op_wrapped 711 return output_tensors 712 713 714def _build_nccl_hybrid(input_tensors, red_op, upper_level_f): 715 """Construct a subgraph for NCCL hybrid all-reduce. 716 717 Args: 718 input_tensors: list of `tf.Tensor` of same-shape and type values to 719 be reduced. 720 red_op: binary elementwise reduction operator. 721 upper_level_f: function for reducing one value per worker, across 722 workers. 723 724 Returns: 725 list of `tf.Tensor` of reduced values. 726 727 Raises: 728 ValueError: inputs not well-formed. 729 """ 730 input_tensors, shape = _flatten_tensors(input_tensors) 731 devices = [t.device for t in input_tensors] 732 per_worker_devices, per_worker_values = _split_by_task(devices, input_tensors) 733 num_workers = len(per_worker_devices) 734 up_values = [None for w in range(0, num_workers)] 735 up_devices = up_values[:] 736 down_values = up_values[:] 737 # First stage: reduce within each worker using NCCL 738 for w in range(0, num_workers): 739 worker_values = build_nccl_all_reduce(per_worker_values[w], red_op) 740 # NOTE: these reductions will not run to completion unless 741 # every output value is used. Since we only need one, we 742 # need to put control dependencies on the rest. 743 with ops.control_dependencies(worker_values): 744 with ops.device(worker_values[0].device): 745 up_values[w] = array_ops.identity(worker_values[0]) 746 up_devices[w] = per_worker_devices[w][0] 747 # Second stage: Apply upper_level_f to reduce across first device at 748 # each worker 749 level_2_output = upper_level_f(up_values) 750 # Third stage: propagate within each worker using NCCL Broadcast 751 for w in range(0, num_workers): 752 dst_tensors = [] 753 with ops.device(per_worker_devices[w][0]): 754 broadcast_src = nccl_ops.broadcast(array_ops.identity(level_2_output[w])) 755 for d in per_worker_devices[w]: 756 with ops.device(d): 757 dst_tensors.append(array_ops.identity(broadcast_src)) 758 down_values[w] = dst_tensors 759 output_tensors = [v for sublist in down_values for v in sublist] 760 if len(shape) != 1: 761 output_tensors = _reshape_tensors(output_tensors, shape) 762 return output_tensors 763 764 765def _reduce_non_singleton(input_tensors, red_f, un_op): 766 """If len(input_tensors) > 1, apply red_f, else apply un_op.""" 767 if len(input_tensors) > 1: 768 return red_f(input_tensors) 769 else: 770 if not un_op: 771 return input_tensors 772 output_tensors = [] 773 for t in input_tensors: 774 with ops.colocate_with(t): 775 output_tensors.append(un_op(t)) 776 return output_tensors 777 778 779def build_nccl_then_ring(input_tensors, subdiv, red_op, un_op=None): 780 """Construct hybrid of NCCL within workers, Ring across workers.""" 781 def upper_builder(y): 782 return build_ring_all_reduce(y, len(y), subdiv, [0], red_op, un_op) 783 def upper_level_f(x): 784 return _reduce_non_singleton(x, upper_builder, un_op) 785 return _build_nccl_hybrid(input_tensors, red_op, upper_level_f) 786 787 788def build_nccl_then_recursive_hd(input_tensors, red_op, un_op=None): 789 """Construct hybrid of NCCL within workers, Recursive-HD across workers.""" 790 upper_level_f = lambda x: build_recursive_hd_all_reduce(x, red_op, un_op) 791 return _build_nccl_hybrid(input_tensors, red_op, upper_level_f) 792 793 794def build_nccl_then_shuffle(input_tensors, gather_devices, nccl_red_op, 795 shuffle_red_op, un_op=None): 796 """Construct hybrid of NCCL within workers, Shuffle across workers.""" 797 def upper_level_f(x): 798 return build_shuffle_all_reduce(x, gather_devices, shuffle_red_op, un_op) 799 800 return _build_nccl_hybrid(input_tensors, nccl_red_op, upper_level_f) 801 802 803def _build_shuffle_hybrid(input_tensors, gather_devices, red_op, upper_level_f): 804 """Construct a subgraph for Shuffle hybrid all-reduce. 805 806 Args: 807 input_tensors: list of `tf.Tensor` of same-shape and type values to 808 be reduced. 809 gather_devices: list of device names on which to host gather shards. 810 red_op: binary elementwise reduction operator. 811 upper_level_f: function for reducing one value per worker, across 812 workers. 813 814 Returns: 815 list of `tf.Tensor` of reduced values. 816 817 Raises: 818 ValueError: inputs not well-formed. 819 """ 820 input_tensors, shape = _flatten_tensors(input_tensors) 821 # First stage, reduce across each worker using gather_devices. 822 devices = [t.device for t in input_tensors] 823 per_worker_devices, per_worker_values = _split_by_task(devices, input_tensors) 824 num_workers = len(per_worker_devices) 825 up_values = [] 826 if len(gather_devices) != num_workers: 827 raise ValueError("For shuffle hybrid, gather_devices must contain one " 828 "device per worker. ") 829 for w in range(0, num_workers): 830 reduced_shards = _build_shuffle_gather( 831 per_worker_values[w], [gather_devices[w]], red_op) 832 up_values.append(reduced_shards[0]) 833 # Second stage, apply upper_level_f. 834 level_2_output = upper_level_f(up_values) 835 # Third stage, apply shuffle scatter at each worker. 836 output_tensors = [] 837 for w in range(0, num_workers): 838 output_tensors += _build_shuffle_scatter( 839 [level_2_output[w]], per_worker_devices[w]) 840 if len(shape) != 1: 841 output_tensors = _reshape_tensors(output_tensors, shape) 842 return output_tensors 843 844 845def build_shuffle_then_ring(input_tensors, gather_devices, subdiv, 846 red_n_op, red_op, un_op=None): 847 """Construct hybrid of Shuffle within workers, Ring across workers.""" 848 def upper_builder(tensors): 849 return build_ring_all_reduce(tensors, len(tensors), subdiv, [0], 850 red_op, un_op) 851 def upper_level_f(tensors): 852 return _reduce_non_singleton(tensors, upper_builder, un_op) 853 return _build_shuffle_hybrid( 854 input_tensors, gather_devices, red_n_op, upper_level_f) 855 856 857def build_shuffle_then_shuffle(input_tensors, first_gather_devices, 858 second_gather_devices, red_op, un_op=None): 859 """Construct hybrid of Shuffle within workers, Shuffle across workers.""" 860 def upper_builder(tensors): 861 return build_shuffle_all_reduce(tensors, second_gather_devices, 862 red_op, un_op) 863 def upper_level_f(tensors): 864 return _reduce_non_singleton(tensors, upper_builder, un_op) 865 return _build_shuffle_hybrid( 866 input_tensors, first_gather_devices, red_op, upper_level_f) 867