1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities to construct a TF subgraph implementing distributed All-Reduce.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import math 23 24from tensorflow.python.framework import device as device_lib 25from tensorflow.python.framework import ops 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import math_ops 28from tensorflow.python.ops import nccl_ops 29 30 31def _flatten_tensors(tensors): 32 """Check tensors for isomorphism and flatten. 33 34 Args: 35 tensors: list of T `tf.Tensor` which must all have the same shape. 36 37 Returns: 38 tensors: a list of T `tf.Tensor` which are flattened (1D) views of tensors 39 shape: the original shape of each element of input tensors 40 41 Raises: 42 ValueError: tensors are empty or non-isomorphic or have unknown shape. 43 """ 44 if not tensors: 45 raise ValueError("tensors cannot be empty") 46 shape = tensors[0].shape 47 for tensor in tensors: 48 shape = shape.merge_with(tensor.shape) 49 if not shape.is_fully_defined(): 50 raise ValueError("Tensors must have statically known shape.") 51 if len(shape) != 1: 52 reshaped = [] 53 for t in tensors: 54 with ops.colocate_with(t): 55 reshaped.append(array_ops.reshape(t, [-1])) 56 tensors = reshaped 57 return tensors, shape 58 59 60def _reshape_tensors(tensors, shape): 61 """Reshape tensors flattened by _flatten_tensors. 62 63 Args: 64 tensors: list of T `tf.Tensor` of identical length 1D tensors. 65 shape: list of integers describing the desired shape. Product of 66 the elements must equal the length of each tensor. 67 68 Returns: 69 list of T `tf.Tensor` which are the reshaped inputs. 70 """ 71 reshaped = [] 72 for t in tensors: 73 with ops.colocate_with(t): 74 reshaped.append(array_ops.reshape(t, shape)) 75 return reshaped 76 77 78def _padded_split(tensor, pieces): 79 """Like split for 1D tensors but pads-out case where len % pieces != 0. 80 81 Args: 82 tensor: T `tf.Tensor` that must be 1D. 83 pieces: a positive integer specifying the number of pieces into which 84 tensor should be split. 85 86 Returns: 87 list of T `tf.Tensor` of length pieces, which hold the values of 88 thin input tensor, in order. The final tensor may 89 be zero-padded on the end to make its size equal to those of all 90 of the other tensors. 91 92 Raises: 93 ValueError: The input tensor is not 1D. 94 """ 95 shape = tensor.shape 96 if 1 != len(shape): 97 raise ValueError("input tensor must be 1D") 98 tensor_len = shape.dims[0].value 99 with ops.colocate_with(tensor): 100 if tensor_len % pieces != 0: 101 # pad to an even length 102 chunk_size = 1 + tensor_len // pieces 103 if pieces > tensor_len: 104 # This is an edge case that should not come up in practice, 105 # i.e. a different reduction algorithm would be better, 106 # but we'll make it work just for completeness. 107 pad_len = pieces - tensor_len 108 extended_whole = array_ops.concat( 109 [tensor, array_ops.zeros([pad_len], dtype=tensor.dtype)], 0) 110 parts = array_ops.split(extended_whole, pieces) 111 return parts, pad_len 112 elif (pieces - 1) * chunk_size >= tensor_len: 113 # Another edge case of limited real interest. 114 pad_len = (pieces * chunk_size) % tensor_len 115 extended_whole = array_ops.concat( 116 [tensor, array_ops.zeros([pad_len], dtype=tensor.dtype)], 0) 117 parts = array_ops.split(extended_whole, pieces) 118 return parts, pad_len 119 else: 120 last_chunk_size = tensor_len - (pieces - 1) * chunk_size 121 pad_len = chunk_size - last_chunk_size 122 piece_lens = [chunk_size for _ in range(pieces - 1)] + [last_chunk_size] 123 parts = array_ops.split(tensor, piece_lens) 124 parts[-1] = array_ops.concat( 125 [parts[-1], array_ops.zeros([pad_len], dtype=tensor.dtype)], 0) 126 return parts, pad_len 127 else: 128 return array_ops.split(tensor, pieces), 0 129 130 131def _strip_padding(tensors, pad_len): 132 """Strip the suffix padding added by _padded_split. 133 134 Args: 135 tensors: list of T `tf.Tensor` of identical length 1D tensors. 136 pad_len: number of elements to be stripped from the end of each tensor. 137 138 Returns: 139 list of T `tf.Tensor` which are the stripped inputs. 140 141 Raises: 142 ValueError: tensors must be a non-empty list of 1D tensors, and 143 each must be longer than pad_len. 144 """ 145 if not tensors: 146 raise ValueError("tensors cannot be empty") 147 shape = tensors[0].shape 148 if len(shape) > 1: 149 raise ValueError("tensors must be 1D") 150 prefix_len = int(shape[0] - pad_len) 151 if prefix_len < 0: 152 raise ValueError("pad_len longer than tensor") 153 stripped = [] 154 for t in tensors: 155 with ops.colocate_with(t): 156 stripped.append(array_ops.slice(t, [0], [prefix_len])) 157 return stripped 158 159 160def _ragged_split(tensor, pieces): 161 """Like split for 1D tensors but allows case where len % pieces != 0. 162 163 Args: 164 tensor: T `tf.Tensor` that must be 1D. 165 pieces: a positive integer specifying the number of pieces into which 166 tensor should be split. 167 168 Returns: 169 list of T `tf.Tensor` of length pieces, which hold the values of 170 the input tensor, in order. The final tensor may be shorter 171 than the others, which will all be of equal length. 172 173 Raises: 174 ValueError: input tensor must be 1D. 175 """ 176 shape = tensor.shape 177 if 1 != len(shape): 178 raise ValueError("input tensor must be 1D") 179 tensor_len = shape.dims[0].value 180 chunk_size = tensor_len // pieces 181 with ops.colocate_with(tensor): 182 if tensor_len != (pieces * chunk_size): 183 # last piece will be short 184 assert pieces > 1 185 last_chunk_size = tensor_len - ((pieces - 1) * chunk_size) 186 assert last_chunk_size > 0 187 piece_lens = [chunk_size for _ in range(pieces - 1)] + [last_chunk_size] 188 return array_ops.split(tensor, piece_lens) 189 else: 190 return array_ops.split(tensor, pieces) 191 192 193def _ring_permutations(num_workers, num_subchunks, gpu_perm): 194 """"Generate an array of device index arrays, one for each subchunk. 195 196 In the basic ring reduction algorithm there are size(T)/num_devices 197 data chunks and each device process one chunk per tick, i.e. sending 198 one chunk and receiving one chunk. The idea of subchunking is that 199 each device processes num_subchunks smaller data regions per tick, 200 and the ring rank permutation is different for each subchunk index 201 so that a device is potentially sending to and receiving from 202 num_subchunks different other devices at each tick. Where multiple 203 independent data channels exist between devices, this strategy 204 supplies a method of using them in parallel. 205 206 Args: 207 num_workers: number of worker tasks 208 num_subchunks: number of subchunks into which to divide each per-GPU chunk. 209 gpu_perm: an array of integers in [0, num_gpus-1] giving the default 210 ring order of GPUs at each worker. Other permutations will be generated 211 by rotating this array and splicing together per-worker instances. 212 213 Raises: 214 ValueError: the number of subchunks may not exceed the number of GPUs. 215 216 Returns: 217 pred_by_s_d: list of lists that maps (by index) from (subchunk, dev) to 218 preceding device in the permutation for that subchunk. The 219 device index of GPU i at worker j is i + (j * num_gpus). 220 rank_by_s_d: list of lists that maps (by index) from (subchunk, dev) to 221 local rank of device d in the permutation for that subchunk. 222 """ 223 num_gpus = len(gpu_perm) 224 devices = num_workers * num_gpus 225 if devices == 0: 226 return [], [] 227 if num_subchunks > num_gpus: 228 raise ValueError( 229 "num_subchunks %d must be <= num_gpus %d" % (num_subchunks, num_gpus)) 230 rotation_interval = max(1, int(num_gpus / num_subchunks)) 231 perms_by_s = [] 232 for s in range(0, num_subchunks): 233 full_order = [] 234 offset = s * rotation_interval 235 for w in range(0, num_workers): 236 default_order = [(w * num_gpus) + i for i in gpu_perm] 237 dev_order = default_order[offset:] + default_order[:offset] 238 full_order += dev_order 239 perms_by_s.append(full_order) 240 pred_by_s_d = [[-1 for d in range(0, devices)] 241 for s in range(0, num_subchunks)] 242 rank_by_s_d = [[-1 for d in range(0, devices)] 243 for s in range(0, num_subchunks)] 244 for s in range(0, num_subchunks): 245 for d in range(0, devices): 246 for t in range(0, devices): 247 if d == perms_by_s[s][t]: 248 rank_by_s_d[s][d] = t 249 pred_by_s_d[s][d] = perms_by_s[s][(t + devices - 1) % devices] 250 break 251 return (pred_by_s_d, rank_by_s_d) 252 253 254def build_ring_all_reduce(input_tensors, num_workers, num_subchunks, 255 gpu_perm, red_op, un_op=None): 256 """Construct a subgraph performing a ring-style all-reduce of input_tensors. 257 258 Args: 259 input_tensors: a list of T `tf.Tensor` objects, which must all 260 have the same shape and type. 261 num_workers: number of worker tasks spanned by input_tensors. 262 num_subchunks: number of subchunks each device should process in one tick. 263 gpu_perm: a list of ints giving a ring-wise rank ordering of GPUs at 264 each worker. All workers must have the same number of 265 GPUs with the same rank ordering. If NVLINK is available, this should 266 be a ring order supported by NVLINK edges. 267 red_op: a binary operator for elementwise reduction. 268 un_op: an optional unary operator to apply to fully reduced values. 269 270 Raises: 271 ValueError: empty input_tensors or they don't all have same 272 size. 273 274 Returns: 275 a list of T `tf.Tensor` identical sum-reductions of input_tensors. 276 """ 277 if len(input_tensors) < 2: 278 raise ValueError("input_tensors must be length 2 or longer") 279 input_tensors, shape = _flatten_tensors(input_tensors) 280 devices = [t.device for t in input_tensors] 281 (pred_by_s_d, rank_by_s_d) = _ring_permutations( 282 num_workers, num_subchunks, gpu_perm) 283 chunks_by_dev, pad_len = _build_ring_gather( 284 input_tensors, devices, 285 num_subchunks, pred_by_s_d, rank_by_s_d, red_op) 286 if un_op: 287 chunks_by_dev = _apply_unary_to_chunks(un_op, chunks_by_dev) 288 output_tensors = _build_ring_scatter(pred_by_s_d, rank_by_s_d, 289 chunks_by_dev) 290 if pad_len > 0: 291 output_tensors = _strip_padding(output_tensors, pad_len) 292 if len(shape) != 1: 293 output_tensors = _reshape_tensors(output_tensors, shape) 294 return output_tensors 295 296 297def _build_ring_gather(input_tensors, devices, num_subchunks, 298 pred_by_s_d, rank_by_s_d, red_op): 299 """Construct a subgraph for the first (reduction) pass of ring all-reduce. 300 301 Args: 302 input_tensors: a list of T `tf.Tensor` 1D input tensors of same 303 shape and type. 304 devices: array of device name strings 305 num_subchunks: number of subchunks each device should process in one tick. 306 pred_by_s_d: as produced by _ring_permutations 307 rank_by_s_d: as produced by _ring_permutations 308 red_op: a binary operator for elementwise reduction 309 310 Raises: 311 ValueError: tensors must all be one dimensional. 312 313 Returns: 314 list of list of T `tf.Tensor` of (partially) reduced values where 315 exactly num_subchunks chunks at each device are fully reduced. 316 """ 317 num_devices = len(input_tensors) 318 if num_devices == 0: 319 return [] 320 if num_devices == 1: 321 return input_tensors 322 shape = input_tensors[0].shape 323 if 1 != len(shape): 324 raise ValueError("input tensors must be 1D") 325 num_chunks = num_devices * num_subchunks 326 num_ticks = num_devices - 1 327 # Initialize chunks_by_dev with splits of the input tensors. 328 chunks_by_dev = [] 329 split_pad_len = 0 330 for d in range(0, num_devices): 331 with ops.device(devices[d]): 332 splits, split_pad_len = _padded_split(input_tensors[d], num_chunks) 333 chunks_by_dev.append(splits) 334 # Reduction phase 335 for tick in range(0, num_ticks): 336 # One new partial reduction for every chunk 337 new_partial_reductions = [None for _ in range(0, num_chunks)] 338 # Compute reductions with respect to last tick's values 339 for d in range(0, num_devices): 340 with ops.device(devices[d]): 341 for s in range(0, num_subchunks): 342 rank = rank_by_s_d[s][d] 343 seg_index = (rank + num_devices - (2 + tick)) % num_devices 344 pred_dev = pred_by_s_d[s][d] 345 chunk_index = (seg_index * num_subchunks) + s 346 new_partial_reductions[chunk_index] = red_op( 347 chunks_by_dev[pred_dev][chunk_index], 348 chunks_by_dev[d][chunk_index]) 349 # Update chunks_by_dev with the new values at the end of the tick. 350 for d in range(0, num_devices): 351 for s in range(0, num_subchunks): 352 rank = rank_by_s_d[s][d] 353 seg_index = (rank + num_devices - (2 + tick)) % num_devices 354 chunk_index = (seg_index * num_subchunks) + s 355 chunks_by_dev[d][chunk_index] = new_partial_reductions[chunk_index] 356 return chunks_by_dev, split_pad_len 357 358 359def _apply_unary_to_chunks(f, chunks_by_dev): 360 """Apply a unary op to each tensor in chunks_by_dev, on same device. 361 362 Args: 363 f: a unary function over T `tf.Tensor`. 364 chunks_by_dev: list of lists of T `tf.Tensor`. 365 366 Returns: 367 new list of lists of T `tf.Tensor` with the same structure as 368 chunks_by_dev containing the derived tensors. 369 """ 370 output = [] 371 for x in chunks_by_dev: 372 with ops.colocate_with(x[0]): 373 output.append([f(t) for t in x]) 374 return output 375 376 377def _build_ring_scatter(pred_by_s_d, rank_by_s_d, 378 chunks_by_dev): 379 """Construct subgraph for second (scatter) pass of ring all-reduce. 380 381 Args: 382 pred_by_s_d: as produced by _ring_permutations 383 rank_by_s_d: as produced by _ring_permutations 384 chunks_by_dev: list of list of T `tf.Tensor` indexed by ints 385 (device, chunk) 386 387 Raises: 388 ValueError: chunks_by_dev is not well-formed 389 390 Returns: 391 list of T `tf.Tensor` which are the fully reduced tensors, one 392 at each device corresponding to the outer dimension of chunks_by_dev. 393 """ 394 num_devices = len(chunks_by_dev) 395 num_chunks = len(chunks_by_dev[0]) 396 if 0 != num_chunks % num_devices: 397 raise ValueError( 398 "Expect number of chunks per device to be divisible by num_devices") 399 num_subchunks = int(num_chunks / num_devices) 400 num_ticks = num_devices - 1 401 for tick in range(0, num_ticks): 402 passed_values = [None for _ in range(0, num_chunks)] 403 for d in range(0, num_devices): 404 with ops.colocate_with(chunks_by_dev[d][0]): 405 for s in range(0, num_subchunks): 406 rank = rank_by_s_d[s][d] 407 seg_index = (rank + num_devices - (1 + tick)) % num_devices 408 pred_dev = pred_by_s_d[s][d] 409 chunk_index = (seg_index * num_subchunks) + s 410 passed_values[chunk_index] = array_ops.identity( 411 chunks_by_dev[pred_dev][chunk_index]) 412 for d in range(0, num_devices): 413 for s in range(0, num_subchunks): 414 rank = rank_by_s_d[s][d] 415 seg_index = (rank + num_devices - (1 + tick)) % num_devices 416 chunk_index = (seg_index * num_subchunks) + s 417 chunks_by_dev[d][chunk_index] = passed_values[chunk_index] 418 # Join chunks at each device. 419 output = [] 420 for x in chunks_by_dev: 421 with ops.colocate_with(x[0]): 422 output.append(array_ops.concat(x, 0)) 423 return output 424 425 426def build_recursive_hd_all_reduce(input_tensors, red_op, un_op=None): 427 """Construct a subgraph for recursive halving-doubling all-reduce. 428 429 The recursive halving-doubling algorithm is described in 430 http://www.mcs.anl.gov/~thakur/papers/ijhpca-coll.pdf 431 432 The concept is to arrange the participating n devices in 433 a linear sequence where devices exchange data pairwise 434 with one other device in each round. During the gather 435 phase there are lg(n) rounds where devices exchange 436 increasingly smaller sub-tensors with another device 437 at increasingly greater distances, until at the top 438 each device has 1/n of the fully reduced values. During the 439 scatter phase each device exchanges its fully reduced 440 sub-tensor (which doubles in length at each round) 441 with one other device at increasingly smaller distances 442 until each device has all of the fully reduced values. 443 444 Note: this preliminary version requires that len(input_tensors) be a 445 power of 2. TODO(tucker): relax this restriction. Also, the 446 number of elements in each tensor must be divisible by 2^h where h 447 is the number of hops in each phase. This will also be relaxed in 448 the future with edge-case specific logic. 449 450 Args: 451 input_tensors: list of T `tf.Tensor` to be elementwise reduced. 452 red_op: a binary elementwise reduction Op. 453 un_op: an optional unary elementwise Op to apply to reduced values. 454 455 Returns: 456 list of T `tf.Tensor` which are the fully reduced tensors, one 457 at each device of input_tensors. 458 459 Raises: 460 ValueError: num_devices not a power of 2, or tensor len not divisible 461 by 2 the proper number of times. 462 """ 463 devices = [t.device for t in input_tensors] 464 input_tensors, shape = _flatten_tensors(input_tensors) 465 reduced_shards = _build_recursive_hd_gather(input_tensors, devices, red_op) 466 if un_op: 467 reduced_shards = [un_op(t) for t in reduced_shards] 468 output_tensors = _build_recursive_hd_scatter(reduced_shards, devices) 469 if len(shape) != 1: 470 output_tensors = _reshape_tensors(output_tensors, shape) 471 return output_tensors 472 473 474def _build_recursive_hd_gather(input_tensors, devices, red_op): 475 """Construct the gather phase of recursive halving-doubling all-reduce. 476 477 Args: 478 input_tensors: list of T `tf.Tensor` to be elementwise reduced. 479 devices: a list of strings naming the devices hosting input_tensors, 480 which will also be used to host the (partial) reduction values. 481 red_op: a binary elementwise reduction Op. 482 483 Returns: 484 list of T `tf.Tensor` which are the fully reduced tensor shards. 485 486 Raises: 487 ValueError: num_devices not a power of 2, or tensor len not divisible 488 by 2 the proper number of times. 489 """ 490 num_devices = len(devices) 491 num_hops = int(math.log(num_devices, 2)) 492 if num_devices != (2 ** num_hops): 493 raise ValueError("num_devices must be a power of 2") 494 chunks = input_tensors 495 for h in range(0, num_hops): 496 span = 2 ** h 497 group_size = span * 2 498 new_chunks = [[] for _ in devices] 499 for d in range(0, num_devices): 500 if (d % group_size) >= (group_size / 2): 501 # skip right half of a pair 502 continue 503 left_dev = devices[d] 504 right_dev = devices[d + span] 505 left_split = array_ops.split(chunks[d], 2) 506 right_split = array_ops.split(chunks[d+span], 2) 507 with ops.device(left_dev): 508 new_chunks[d] = red_op(left_split[0], right_split[0]) 509 with ops.device(right_dev): 510 new_chunks[d + span] = red_op(left_split[1], right_split[1]) 511 chunks = new_chunks 512 return chunks 513 514 515def _build_recursive_hd_scatter(input_tensors, devices): 516 """Construct the scatter phase of recursive halving-doublng all-reduce. 517 518 Args: 519 input_tensors: list of T `tf.Tensor` that are fully-reduced shards. 520 devices: a list of strings naming the devices on which the reconstituted 521 full tensors should be placed. 522 523 Returns: 524 list of T `tf.Tensor` which are the fully reduced tensors. 525 """ 526 num_devices = len(devices) 527 num_hops = int(math.log(num_devices, 2)) 528 assert num_devices == (2 ** num_hops), "num_devices must be a power of 2" 529 chunks = input_tensors 530 for h in reversed(range(0, num_hops)): 531 span = 2 ** h 532 group_size = span * 2 533 new_chunks = [[] for _ in devices] 534 for d in range(0, num_devices): 535 if (d % group_size) >= (group_size / 2): 536 # skip right half of a pair 537 continue 538 left_idx = d 539 right_idx = d + span 540 left_dev = devices[left_idx] 541 right_dev = devices[right_idx] 542 with ops.device(left_dev): 543 new_chunks[left_idx] = array_ops.concat([chunks[left_idx], 544 chunks[right_idx]], 0) 545 with ops.device(right_dev): 546 new_chunks[right_idx] = array_ops.concat([chunks[left_idx], 547 chunks[right_idx]], 0) 548 chunks = new_chunks 549 return chunks 550 551 552def build_shuffle_all_reduce(input_tensors, gather_devices, red_op, un_op=None): 553 """Construct a subgraph for shuffle all-reduce. 554 555 Shuffle reduce is essentially the algorithm implemented when using 556 parameter servers. Suppose tensor length is n, there are d devices 557 and g gather shards. Each device sends a n/g length sub-tensor to 558 each gather shard. The gather shards perform a reduction across d 559 fragments, then broadcast the result back to each device. The 560 devices then join the g fully reduced fragments they receive from 561 the shards. The gather shards could perform d-1 pairwise 562 reductions, or one d-way reduction. The first is better where 563 reduction Op time is low compared to transmission time, the second 564 better in the other case. 565 566 Args: 567 input_tensors: list of T @(tf.Tensor} values to be reduced. 568 gather_devices: list of names of devices on which reduction shards 569 should be placed. 570 red_op: an n-array elementwise reduction Op 571 un_op: optional elementwise unary Op to be applied to fully-reduced values. 572 573 Returns: 574 list of T `tf.Tensor` which are the fully reduced tensors. 575 """ 576 input_tensors, shape = _flatten_tensors(input_tensors) 577 dst_devices = [t.device for t in input_tensors] 578 reduced_shards = _build_shuffle_gather(input_tensors, gather_devices, 579 red_op, un_op) 580 output_tensors = _build_shuffle_scatter(reduced_shards, dst_devices) 581 if len(shape) != 1: 582 output_tensors = _reshape_tensors(output_tensors, shape) 583 return output_tensors 584 585 586def _build_shuffle_gather(input_tensors, gather_devices, red_op, un_op=None): 587 """Construct the gather (concentrate and reduce) phase of shuffle all-reduce. 588 589 Args: 590 input_tensors: list of T @(tf.Tensor} values to be reduced. 591 gather_devices: list of names of devices on which reduction shards 592 should be placed. 593 red_op: the binary reduction Op 594 un_op: optional elementwise unary Op to be applied to fully-reduced values. 595 596 Returns: 597 list of T `tf.Tensor` which are the fully reduced shards. 598 599 Raises: 600 ValueError: inputs not well-formed. 601 """ 602 num_source_devices = len(input_tensors) 603 num_gather_devices = len(gather_devices) 604 shape = input_tensors[0].shape 605 if len(shape) != 1: 606 raise ValueError("input_tensors must be 1D") 607 shards_by_source = [] 608 for d in range(0, num_source_devices): 609 with ops.colocate_with(input_tensors[d]): 610 shards_by_source.append( 611 _ragged_split(input_tensors[d], num_gather_devices)) 612 reduced_shards = [] 613 for d in range(0, num_gather_devices): 614 with ops.device(gather_devices[d]): 615 values = [s[d] for s in shards_by_source] 616 red_shard = red_op(values) 617 if un_op: 618 red_shard = un_op(red_shard) 619 reduced_shards.append(red_shard) 620 return reduced_shards 621 622 623def _build_shuffle_scatter(reduced_shards, dst_devices): 624 """Build the scatter phase of shuffle all-reduce. 625 626 Args: 627 reduced_shards: list of T @(tf.Tensor} fully reduced shards 628 dst_devices: list of names of devices at which the fully-reduced value 629 should be reconstituted. 630 631 Returns: 632 list of T `tf.Tensor` scattered tensors. 633 """ 634 num_devices = len(dst_devices) 635 out_tensors = [] 636 for d in range(0, num_devices): 637 with ops.device(dst_devices[d]): 638 out_tensors.append(array_ops.concat(reduced_shards, 0)) 639 return out_tensors 640 641 642def _split_by_task(devices, values): 643 """Partition devices and values by common task. 644 645 Args: 646 devices: list of device name strings 647 values: list of T `tf.tensor` of same length as devices. 648 649 Returns: 650 (per_task_devices, per_task_values) where both values are 651 lists of lists with isomorphic structure: the outer list is 652 indexed by task, and the inner list has length of the number 653 of values belonging to that task. per_task_devices contains 654 the specific devices to which the values are local, and 655 per_task_values contains the corresponding values. 656 657 Raises: 658 ValueError: devices must be same length as values. 659 """ 660 num_devices = len(devices) 661 if num_devices != len(values): 662 raise ValueError("len(devices) must equal len(values)") 663 per_task_devices = collections.OrderedDict() 664 per_task_values = collections.OrderedDict() 665 for d in range(num_devices): 666 d_spec = device_lib.DeviceSpec.from_string(devices[d]) 667 if not hasattr(d_spec, "task") or d_spec.task is None: 668 assert False, "failed to parse device %s" % devices[d] 669 index = (d_spec.job or "localhost", d_spec.replica or 0, d_spec.task) 670 if index not in per_task_devices: 671 per_task_devices[index] = [] 672 per_task_values[index] = [] 673 per_task_devices[index].append(devices[d]) 674 per_task_values[index].append(values[d]) 675 676 return (list(per_task_devices.values()), list(per_task_values.values())) 677 678 679def build_nccl_all_reduce(input_tensors, red_op, un_op=None): 680 """Build a subgraph that does one full all-reduce, using NCCL. 681 682 Args: 683 input_tensors: list of T `tf.Tensor` of same-shape and type values to 684 be reduced. 685 red_op: binary elementwise reduction operator. Must be one of 686 {tf.add} 687 un_op: optional unary elementwise Op to apply to fully-reduce values. 688 689 Returns: 690 list of T `tf.Tensor` of reduced values. 691 692 Raises: 693 ValueError: red_op not supported. 694 """ 695 if red_op == math_ops.add: 696 output_tensors = nccl_ops.all_sum(input_tensors) 697 else: 698 raise ValueError("red_op not supported by NCCL all-reduce: ", red_op) 699 if un_op: 700 un_op_wrapped = [] 701 for t in output_tensors: 702 with ops.colocate_with(t): 703 un_op_wrapped.append(un_op(t)) 704 output_tensors = un_op_wrapped 705 return output_tensors 706 707 708def _build_nccl_hybrid(input_tensors, red_op, upper_level_f): 709 """Construct a subgraph for NCCL hybrid all-reduce. 710 711 Args: 712 input_tensors: list of T `tf.Tensor` of same-shape and type values to 713 be reduced. 714 red_op: binary elementwise reduction operator. 715 upper_level_f: function for reducing one value per worker, across 716 workers. 717 718 Returns: 719 list of T `tf.Tensor` of reduced values. 720 721 Raises: 722 ValueError: inputs not well-formed. 723 """ 724 input_tensors, shape = _flatten_tensors(input_tensors) 725 devices = [t.device for t in input_tensors] 726 per_worker_devices, per_worker_values = _split_by_task(devices, input_tensors) 727 num_workers = len(per_worker_devices) 728 up_values = [None for w in range(0, num_workers)] 729 up_devices = up_values[:] 730 down_values = up_values[:] 731 # First stage: reduce within each worker using NCCL 732 for w in range(0, num_workers): 733 worker_values = build_nccl_all_reduce(per_worker_values[w], red_op) 734 # NOTE: these reductions will not run to completion unless 735 # every output value is used. Since we only need one, we 736 # need to put control dependencies on the rest. 737 with ops.control_dependencies(worker_values): 738 with ops.device(worker_values[0].device): 739 up_values[w] = array_ops.identity(worker_values[0]) 740 up_devices[w] = per_worker_devices[w][0] 741 # Second stage: Apply upper_level_f to reduce across first device at 742 # each worker 743 level_2_output = upper_level_f(up_values) 744 # Third stage: propagate within each worker using NCCL Broadcast 745 for w in range(0, num_workers): 746 dst_tensors = [] 747 with ops.device(per_worker_devices[w][0]): 748 broadcast_src = nccl_ops.broadcast(array_ops.identity(level_2_output[w])) 749 for d in per_worker_devices[w]: 750 with ops.device(d): 751 dst_tensors.append(array_ops.identity(broadcast_src)) 752 down_values[w] = dst_tensors 753 output_tensors = [v for sublist in down_values for v in sublist] 754 if len(shape) != 1: 755 output_tensors = _reshape_tensors(output_tensors, shape) 756 return output_tensors 757 758 759def _reduce_non_singleton(input_tensors, red_f, un_op): 760 """If len(input_tensors) > 1, apply red_f, else apply un_op.""" 761 if len(input_tensors) > 1: 762 return red_f(input_tensors) 763 else: 764 if not un_op: 765 return input_tensors 766 output_tensors = [] 767 for t in input_tensors: 768 with ops.colocate_with(t): 769 output_tensors.append(un_op(t)) 770 return output_tensors 771 772 773def build_nccl_then_ring(input_tensors, subdiv, red_op, un_op=None): 774 """Construct hybrid of NCCL within workers, Ring across workers.""" 775 def upper_builder(y): 776 return build_ring_all_reduce(y, len(y), subdiv, [0], red_op, un_op) 777 def upper_level_f(x): 778 return _reduce_non_singleton(x, upper_builder, un_op) 779 return _build_nccl_hybrid(input_tensors, red_op, upper_level_f) 780 781 782def build_nccl_then_recursive_hd(input_tensors, red_op, un_op=None): 783 """Construct hybrid of NCCL within workers, Recursive-HD across workers.""" 784 upper_level_f = lambda x: build_recursive_hd_all_reduce(x, red_op, un_op) 785 return _build_nccl_hybrid(input_tensors, red_op, upper_level_f) 786 787 788def build_nccl_then_shuffle(input_tensors, gather_devices, nccl_red_op, 789 shuffle_red_op, un_op=None): 790 """Construct hybrid of NCCL within workers, Shuffle across workers.""" 791 def upper_level_f(x): 792 return build_shuffle_all_reduce(x, gather_devices, shuffle_red_op, un_op) 793 794 return _build_nccl_hybrid(input_tensors, nccl_red_op, upper_level_f) 795 796 797def _build_shuffle_hybrid(input_tensors, gather_devices, red_op, upper_level_f): 798 """Construct a subgraph for Shuffle hybrid all-reduce. 799 800 Args: 801 input_tensors: list of T `tf.Tensor` of same-shape and type values to 802 be reduced. 803 gather_devices: list of device names on which to host gather shards. 804 red_op: binary elementwise reduction operator. 805 upper_level_f: function for reducing one value per worker, across 806 workers. 807 808 Returns: 809 list of T `tf.Tensor` of reduced values. 810 811 Raises: 812 ValueError: inputs not well-formed. 813 """ 814 input_tensors, shape = _flatten_tensors(input_tensors) 815 # First stage, reduce across each worker using gather_devices. 816 devices = [t.device for t in input_tensors] 817 per_worker_devices, per_worker_values = _split_by_task(devices, input_tensors) 818 num_workers = len(per_worker_devices) 819 up_values = [] 820 if len(gather_devices) != num_workers: 821 raise ValueError("For shuffle hybrid, gather_devices must contain one " 822 "device per worker. ") 823 for w in range(0, num_workers): 824 reduced_shards = _build_shuffle_gather( 825 per_worker_values[w], [gather_devices[w]], red_op) 826 up_values.append(reduced_shards[0]) 827 # Second stage, apply upper_level_f. 828 level_2_output = upper_level_f(up_values) 829 # Third stage, apply shuffle scatter at each worker. 830 output_tensors = [] 831 for w in range(0, num_workers): 832 output_tensors += _build_shuffle_scatter( 833 [level_2_output[w]], per_worker_devices[w]) 834 if len(shape) != 1: 835 output_tensors = _reshape_tensors(output_tensors, shape) 836 return output_tensors 837 838 839def build_shuffle_then_ring(input_tensors, gather_devices, subdiv, 840 red_n_op, red_op, un_op=None): 841 """Construct hybrid of Shuffle within workers, Ring across workers.""" 842 def upper_builder(tensors): 843 return build_ring_all_reduce(tensors, len(tensors), subdiv, [0], 844 red_op, un_op) 845 def upper_level_f(tensors): 846 return _reduce_non_singleton(tensors, upper_builder, un_op) 847 return _build_shuffle_hybrid( 848 input_tensors, gather_devices, red_n_op, upper_level_f) 849 850 851def build_shuffle_then_shuffle(input_tensors, first_gather_devices, 852 second_gather_devices, red_op, un_op=None): 853 """Construct hybrid of Shuffle within workers, Shuffle across workers.""" 854 def upper_builder(tensors): 855 return build_shuffle_all_reduce(tensors, second_gather_devices, 856 red_op, un_op) 857 def upper_level_f(tensors): 858 return _reduce_non_singleton(tensors, upper_builder, un_op) 859 return _build_shuffle_hybrid( 860 input_tensors, first_gather_devices, red_op, upper_level_f) 861