1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Classes for different algorithms of reduction and broadcasting.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import enum 23import six 24 25from tensorflow.python.client import device_lib 26from tensorflow.python.distribute import cross_device_utils 27from tensorflow.python.distribute import device_util 28from tensorflow.python.distribute import reduce_util 29from tensorflow.python.distribute import values as value_lib 30from tensorflow.python.eager import context 31from tensorflow.python.framework import ops 32from tensorflow.python.ops import array_ops 33from tensorflow.python.ops import math_ops 34from tensorflow.python.ops import resource_variable_ops 35from tensorflow.python.platform import tf_logging as logging 36from tensorflow.python.util.tf_export import tf_export 37from tensorflow.tools.docs import doc_controls 38 39 40def check_destinations(destinations): 41 """Checks whether `destinations` is not empty. 42 43 Args: 44 destinations: a `DistributedValues`, variable, or string object. 45 46 Returns: 47 Boolean which is True if `destinations` is not empty. 48 """ 49 # Calling bool() on a ResourceVariable is not allowed. 50 if isinstance(destinations, resource_variable_ops.ResourceVariable): 51 return bool(destinations.device) 52 return bool(destinations) 53 54 55def validate_destinations(destinations): 56 if not isinstance(destinations, 57 (value_lib.DistributedValues, 58 resource_variable_ops.ResourceVariable, 59 value_lib.AggregatingVariable, 60 six.string_types, 61 value_lib.TPUMirroredVariable, 62 # LogicalDeviceSpec is only used internally, e.g. as a 63 # broadcast destination, never supplied by a user. 64 value_lib.LogicalDeviceSpec)): 65 raise ValueError("destinations must be one of a `DistributedValues` object," 66 " a tf.Variable object, or a device string.") 67 68 if not check_destinations(destinations): 69 raise ValueError("destinations can not be empty") 70 71 72def reduce_non_distributed_value(reduce_op, device_map, value, destinations): 73 """Reduce a non-DistributedValue `value` to `destinations`.""" 74 if isinstance(value, value_lib.DistributedValues): 75 raise ValueError("You are passing a `DistributedValue` to " 76 "`reduce_non_distributed_value`, which is not allowed.") 77 78 # If the same value is present on all replicas then the PerReplica value will 79 # be a single value. We also handle the case when `value` is a single value 80 # and equal to 0. 81 if value == 0: 82 return 0 83 # If there is only a single value and the reduce op is MEAN, 84 # that value should be on all destinations. 85 if reduce_op == reduce_util.ReduceOp.MEAN: 86 return value 87 88 validate_destinations(destinations) 89 # We do not support a reduce op of SUM if the value is the same across 90 # all replicas. We call this as part of assign functions for MirroredVariables 91 # and summing up identical values across replicas is not clearly defined. 92 if device_map.num_replicas_in_graph != 1: 93 raise ValueError("A non-DistributedValues value %s cannot be reduced with " 94 "the given reduce op %s." % (value, reduce_op)) 95 return simple_broadcast(value, destinations) 96 97 98def _make_tensor_into_per_replica(input_tensor): 99 """Converts a single tensor into a PerReplica object.""" 100 if isinstance(input_tensor, (tuple, list)): 101 raise ValueError("Cannot convert `input_tensor` to a `PerReplica` object, " 102 "got %r but expected a object that is not a tuple or list." 103 % (input_tensor,)) 104 if isinstance(input_tensor, value_lib.PerReplica): 105 return input_tensor 106 107 try: 108 device = input_tensor.device 109 except AttributeError: 110 raise ValueError("Cannot convert `input_tensor` to a `PerReplica` object " 111 "because it doesn't have device set.") 112 113 device_map = value_lib.SingleDeviceMap(device) 114 return value_lib.PerReplica(device_map, (input_tensor,)) 115 116 117def _normalize_value_destination_pairs(value_destination_pairs): 118 """Converts each tensor into a PerReplica object in the input list.""" 119 result = [] 120 121 value_destination_pairs = list(value_destination_pairs) 122 123 if not isinstance(value_destination_pairs, (list, tuple)): 124 raise ValueError("`value_destination_pairs` should be a list or tuple") 125 for pair in value_destination_pairs: 126 if not isinstance(pair, tuple): 127 raise ValueError( 128 "Each element of `value_destination_pairs` should be a tuple.") 129 if len(pair) != 2: 130 raise ValueError("Each element of `value_destination_pairs` should be a " 131 "tuple of size 2.") 132 133 per_replica = _make_tensor_into_per_replica(pair[0]) 134 result.append((per_replica, pair[1])) 135 return result 136 137 138def _validate_value_destination_pairs(value_destination_pairs): 139 # TODO(yuefengz): raise exceptions instead of returning False. 140 # pylint: disable=g-missing-docstring 141 if not value_destination_pairs: return False 142 if not isinstance(value_destination_pairs, (list, tuple)): return False 143 if not all(isinstance(pair, tuple) for pair in value_destination_pairs): 144 return False 145 if not all(isinstance(v[0], value_lib.PerReplica) 146 for v in value_destination_pairs): 147 return False 148 return True 149 150 151# TODO(yuefengz): consider calling this function in the caller of 152# CrossDeviceOps. 153def get_devices_from(destinations): 154 if isinstance(destinations, value_lib.DistributedValues): 155 return destinations.devices 156 elif isinstance(destinations, value_lib.LogicalDeviceSpec): 157 return destinations.device_map.logical_to_actual_devices( 158 destinations.logical_device) 159 elif isinstance(destinations, six.string_types): 160 return (device_util.resolve(destinations),) 161 return (destinations.device,) 162 163 164def get_device_map_from(destinations): 165 if isinstance(destinations, (value_lib.DistributedValues, 166 value_lib.LogicalDeviceSpec)): 167 return destinations.device_map, destinations.logical_device 168 if isinstance(destinations, six.string_types): 169 device = device_util.resolve(destinations) 170 else: 171 device = destinations.device 172 return value_lib.SingleDeviceMap(device), 0 173 174 175def _devices_match(left, right): 176 return set(get_devices_from(left)) == set(get_devices_from(right)) 177 178 179def _all_devices_match(value_destination_pairs): 180 if not all(_devices_match(v, d) for v, d in value_destination_pairs): 181 return False 182 if not all(_devices_match(v, value_destination_pairs[0][0]) 183 for v, _ in value_destination_pairs[1:]): 184 return False 185 return True 186 187 188def simple_broadcast(value, destinations, always_mirrored=False): 189 """Broadcast `value` to `destinations` using simple copies.""" 190 device_map, logical_device = get_device_map_from(destinations) 191 devices = device_map.logical_to_actual_devices(logical_device) 192 if len(devices) == 1 and not always_mirrored: 193 return cross_device_utils.copy_tensor_or_indexed_slices_to_device( 194 value, devices[0]) 195 else: 196 value_updates = [] 197 for d in devices: 198 value_updates.append( 199 cross_device_utils.copy_tensor_or_indexed_slices_to_device( 200 value, d)) 201 return value_lib.Mirrored(device_map, value_updates, logical_device) 202 203 204def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn, 205 reduce_op): 206 # pylint: disable=g-missing-docstring 207 all_values = per_replica_value.values 208 if not all_values: 209 raise ValueError("`per_replica_value` must be non-empty") 210 count = len(all_values) 211 212 with ops.device(reduce_to_device): 213 with context.device_policy(context.DEVICE_PLACEMENT_SILENT): 214 reduced = cross_device_utils.aggregate_tensors_or_indexed_slices( 215 all_values, accumulation_fn) 216 if reduce_op == reduce_util.ReduceOp.MEAN: 217 reduced = cross_device_utils.divide_by_n_tensors_or_indexed_slices( 218 reduced, count) 219 elif reduce_op != reduce_util.ReduceOp.SUM: 220 raise ValueError("`reduce_op` must be Reduce.SUM or Reduce.MEAN.") 221 return reduced 222 223 224@tf_export("distribute.CrossDeviceOps") 225class CrossDeviceOps(object): 226 """Base class for cross-device reduction and broadcasting algorithms.""" 227 228 def __init__(self): 229 pass 230 231 def reduce(self, reduce_op, per_replica_value, destinations): 232 """Reduce `per_replica_value` to `destinations`. 233 234 It runs the reduction operation defined by `reduce_op` and put the 235 result on `destinations`. 236 237 Args: 238 reduce_op: Indicates how per_replica_value will be reduced. Accepted 239 values are `tf.distribute.ReduceOp.SUM`, `tf.distribute.ReduceOp.MEAN`. 240 per_replica_value: a PerReplica object or a tensor with device set. 241 destinations: the reduction destinations. 242 243 Returns: 244 a Mirrored object. 245 246 Raises: 247 ValueError: if per_replica_value can't be converted to a PerReplica 248 object. 249 """ 250 if not isinstance(per_replica_value, value_lib.PerReplica): 251 per_replica_value = _make_tensor_into_per_replica(per_replica_value) 252 253 validate_destinations(destinations) 254 return self.reduce_implementation(reduce_op, per_replica_value, 255 destinations) 256 257 def batch_reduce(self, reduce_op, value_destination_pairs): 258 """Reduce PerReplica objects in a batch. 259 260 Reduce each first element in `value_destination_pairs` to each second 261 element which indicates the destinations. 262 263 Args: 264 reduce_op: Indicates how per_replica_value will be reduced. Accepted 265 values are `tf.distribute.ReduceOp.SUM`, `tf.distribute.ReduceOp.MEAN`. 266 value_destination_pairs: a list or a tuple of tuples of PerReplica objects 267 (or tensors with device set if there is one device) and destinations. 268 269 Returns: 270 a list of Mirrored objects. 271 272 Raises: 273 ValueError: if `value_destination_pairs` is not a list or a tuple of 274 tuples of PerReplica objects and destinations 275 """ 276 # TODO(yuefengz): if destinations are different, split into several 277 # `_batch_reduce` invocations. 278 if not _validate_value_destination_pairs(value_destination_pairs): 279 # If the first element of each pair is a tensor, we try to turn it into a 280 # PerReplica object. 281 value_destination_pairs = _normalize_value_destination_pairs( 282 value_destination_pairs) 283 284 for _, d in value_destination_pairs: 285 validate_destinations(d) 286 287 return self.batch_reduce_implementation(reduce_op, value_destination_pairs) 288 289 def broadcast(self, tensor, destinations): 290 """Broadcast the `tensor` to destinations. 291 292 Args: 293 tensor: the tensor to broadcast. 294 destinations: the broadcast destinations. 295 296 Returns: 297 a Mirrored object. 298 """ 299 validate_destinations(destinations) 300 return self.broadcast_implementation(tensor, destinations) 301 302 @doc_controls.for_subclass_implementers 303 def reduce_implementation(self, reduce_op, per_replica_value, destinations): 304 """The implementation of reduce of `per_replica_value` to `destinations`. 305 306 It runs the reduction operation defined by `reduce_op` and put the 307 result on `destinations`. 308 309 Args: 310 reduce_op: Indicates how per_replica_value will be reduced. Accepted 311 values are `tf.distribute.ReduceOp.SUM`, `tf.distribute.ReduceOp.MEAN`. 312 per_replica_value: a PerReplica object or a tensor with device set. 313 destinations: the reduction destinations. 314 315 Returns: 316 a Mirrored object. 317 318 Raises: 319 ValueError: if per_replica_value can't be converted to a PerReplica 320 object. 321 """ 322 raise NotImplementedError( 323 "_reduce method must be implemented in descendants.") 324 325 @doc_controls.for_subclass_implementers 326 def batch_reduce_implementation(self, reduce_op, value_destination_pairs): 327 """Implementation of reduce PerReplica objects in a batch. 328 329 Reduce each first element in `value_destination_pairs` to each second 330 element which indicates the destinations. 331 332 Args: 333 reduce_op: Indicates how per_replica_value will be reduced. Accepted 334 values are `tf.distribute.ReduceOp.SUM`, `tf.distribute.ReduceOp.MEAN`. 335 value_destination_pairs: a list or a tuple of tuples of PerReplica objects 336 (or tensors with device set if there is one device) and destinations. 337 338 Returns: 339 a list of Mirrored objects. 340 341 Raises: 342 ValueError: if `value_destination_pairs` is not a list or a tuple of 343 tuples of PerReplica objects and destinations 344 """ 345 raise NotImplementedError( 346 "_batch_reduce method must be implemented in descendants.") 347 348 @doc_controls.for_subclass_implementers 349 def broadcast_implementation(self, tensor, destinations): 350 """Implementation of broadcast the `tensor` to destinations. 351 352 Args: 353 tensor: the tensor to broadcast. 354 destinations: the broadcast destinations. 355 356 Returns: 357 a Mirrored object. 358 """ 359 return simple_broadcast(tensor, destinations, always_mirrored=True) 360 361 362@tf_export("distribute.ReductionToOneDevice") 363class ReductionToOneDevice(CrossDeviceOps): 364 """Always do reduction to one device first and then do broadcasting. 365 366 Batch reduction is done by reduction on each element one by one. 367 """ 368 369 def __init__(self, reduce_to_device=None, accumulation_fn=None): 370 """Constructor. 371 372 Args: 373 reduce_to_device: the intermediate device to reduce to. If None, reduce 374 to the first device in `destinations` of the reduce() method. 375 accumulation_fn: a function that does accumulation. If None, then 376 `tf.math.add_n` is used. 377 """ 378 self.reduce_to_device = reduce_to_device 379 self.accumulation_fn = accumulation_fn or math_ops.add_n 380 super(ReductionToOneDevice, self).__init__() 381 382 def reduce_implementation(self, reduce_op, per_replica_value, destinations): 383 if check_destinations(destinations): 384 devices = get_devices_from(destinations) 385 else: 386 devices = get_devices_from(per_replica_value) 387 reduce_to_device = self.reduce_to_device or devices[0] 388 logging.log_first_n( 389 logging.INFO, 390 "Reduce to %s then broadcast to %r." % (reduce_to_device, devices), 10) 391 reduced = _simple_reduce(per_replica_value, reduce_to_device, 392 self.accumulation_fn, reduce_op) 393 return self.broadcast(reduced, destinations) 394 395 def batch_reduce_implementation(self, reduce_op, value_destination_pairs): 396 return [ 397 self.reduce_implementation(reduce_op, t, destinations=v) 398 for t, v in value_destination_pairs 399 ] 400 401 402def _group_value_by_device(per_replica_values): 403 """Group values into sublists by their devices. 404 405 This grouping is needed to call the all-reduce library because it expects a 406 list of the following form: 407 [[(grad0_gpu0, v0_gpu0), (grad1_gpu0, v1_gpu0), (grad2_gpu0, v2_gpu0) ...], 408 [(grad0_gpu1, v0_gpu1), (grad1_gpu1, v1_gpu1), (grad2_gpu1, v2_gpu1) ...], 409 [(grad0_gpu2, v0_gpu2), (grad1_gpu0, v1_gpu2), (grad2_gpu0, v2_gpu2) ...], 410 ... 411 ] 412 413 Args: 414 per_replica_values: a list of PerReplica obejcts. 415 416 Returns: 417 a list of lists, each sublist has components for its corresponding device of 418 PerReplica objects, paired with a None. 419 """ 420 destinations = per_replica_values[0].devices 421 grouped = [[] for _ in range(len(destinations))] 422 for per_replica_value in per_replica_values: 423 # pylint: disable=protected-access 424 for i, v in enumerate(per_replica_value.values): 425 assert per_replica_value.devices == destinations 426 grouped[i].append((v, None)) 427 return grouped 428 429 430def _ungroup_and_make_mirrored(grouped_reduced, 431 destinations, 432 reduce_op, 433 num_between_graph_workers=1): 434 """Ungroup results from all-reduce and make Mirrored objects. 435 436 Each all-reduce result will be divided by the number of destinations before 437 Mirrored objects are created if reduce_op is "mean". 438 439 Args: 440 grouped_reduced: a list of lists, each sublist has components for each 441 device, paired with a None. It is the result from 442 cross_device_utils.aggregate_gradients_using*. 443 destinations: a value to colocate the result with. 444 reduce_op: Indicates how values will be aggregated. Accepted values 445 are `tf.distribute.ReduceOp.SUM`, `tf.distribute.ReduceOp.MEAN`. 446 num_between_graph_workers: number of workers in the between-graph 447 replication. 448 449 Returns: 450 a list of Mirrored objects. 451 """ 452 device_map, logical_device = get_device_map_from(destinations) 453 num_replicas = device_map.num_replicas_in_graph * num_between_graph_workers 454 index = [[] for _ in range(len(grouped_reduced[0]))] 455 for per_replica_reduced in grouped_reduced: 456 for i, (v, _) in enumerate(per_replica_reduced): 457 if reduce_op == reduce_util.ReduceOp.MEAN: 458 index[i].append(v / num_replicas) 459 else: 460 index[i].append(v) 461 return [value_lib.Mirrored(device_map, v, logical_device) for v in index] 462 463 464class _ConcatAndSplitPacker(object): 465 """Concatenate and split tensors for reduction.""" 466 467 def __init__(self, num_packs=1): 468 """Initialize the _ConcatAndSplitPacker object. 469 470 Args: 471 num_packs: specifies the number of split packs that will be 472 formed. 473 474 Raises: 475 ValueError: if num_packs is not greater than 0. 476 """ 477 if num_packs <= 0: 478 raise ValueError("num_packs must be greater than zero.") 479 self.num_packs = num_packs 480 481 def pack(self, grouped_grads_and_vars): 482 """Pack tensors.""" 483 self.grouped_grads_and_vars = grouped_grads_and_vars 484 self.all_device_shapes = [] 485 self.all_device_sizes = [] 486 487 device_grad_packs = [] 488 for device_grads_and_vars in grouped_grads_and_vars: 489 with ops.colocate_with(device_grads_and_vars[0][0]): 490 # Flatten all the grads. 491 flat_grads = [ 492 array_ops.reshape(g, [-1]) for g, _ in device_grads_and_vars 493 ] 494 # Remember the original shape of all the grads. 495 device_shapes = [array_ops.shape(g) for g, _ in device_grads_and_vars] 496 # Remember the original sizes of all the grads. 497 device_sizes = [array_ops.size(g) for g, _ in device_grads_and_vars] 498 # Concat all the flat grads into a big flat tensor. 499 concat_grads = array_ops.concat(flat_grads, 0) 500 501 # Split the big tensor into num_splits packs. In cases where the 502 # total size is not divisible num_splits, the last pack gets 503 # more elements. 504 # TODO(zhengxq): it is also possible to optimize away all the concat 505 # as well. 506 num_splits = self.num_packs 507 508 # The array_ops.size function will sometimes remove static shapes. So if 509 # all gradient shapes are defined, we use another method to get the 510 # total size. 511 # TODO(yuefengz): move this logic to array_ops.size. 512 if all(g.shape.is_fully_defined() for g, _ in device_grads_and_vars): 513 total_grad_size = sum( 514 [g.shape.num_elements() for g, _ in device_grads_and_vars]) 515 else: 516 total_grad_size = array_ops.size(concat_grads) 517 518 split_size = total_grad_size // num_splits 519 split_size_last = total_grad_size - split_size * (num_splits - 1) 520 split_sizes = [split_size] * (num_splits - 1) + [split_size_last] 521 grad_packs = array_ops.split(concat_grads, split_sizes) 522 523 # Ready to aggregate the repacked gradients, with fake variables. 524 # TODO(zhengxq): It is hacky to have to use fake variables. 525 # We should remove the need for variables in 526 # aggregate_gradients_using*. 527 device_grad_packs.append(zip(grad_packs, [None] * num_splits)) 528 self.all_device_shapes.append(device_shapes) 529 self.all_device_sizes.append(device_sizes) 530 531 return device_grad_packs 532 533 def unpack(self, summed_device_grad_packs): 534 """Reverse the pack.""" 535 aggregated_device_grads = [] 536 for (summed_device_grad_packs, 537 device_grads_and_vars, device_shapes, device_sizes) in zip( 538 summed_device_grad_packs, self.grouped_grads_and_vars, 539 self.all_device_shapes, self.all_device_sizes): 540 # pylint: enable=line-too-long 541 # Reverse the packing operations in the previous steps. Form the 542 # summed gradients back into their original shapes. 543 with ops.colocate_with(summed_device_grad_packs[0][0]): 544 # Form a list of the summed grad packs. 545 device_grad_packs = [g for g, _ in summed_device_grad_packs] 546 547 # Concat them back into a big flat tensor. 548 device_grads_concat = array_ops.concat(device_grad_packs, 0) 549 550 # Split the tensors back into their original sizes. 551 grads_with_sizes = array_ops.split(device_grads_concat, device_sizes) 552 553 # Reshape the tensors back into their original shapes. 554 grads_with_shapes = [ 555 array_ops.reshape(grad, shape) 556 for shape, grad in zip(device_shapes, grads_with_sizes) 557 ] 558 559 # Form the list with the original list of variables. 560 summed_device_grads = [ 561 (g, v) for g, (_, v) in zip(grads_with_shapes, 562 device_grads_and_vars) 563 ] 564 aggregated_device_grads.append(summed_device_grads) 565 return aggregated_device_grads 566 567 568class _AggregateSmallTensorPacker(object): 569 """Concatenate small gradient tensors together for reduction.""" 570 571 def __init__(self, 572 agg_small_grads_max_bytes=1048576, 573 agg_small_grads_max_group=16): 574 """Initialize the _AggregateSmallTensorPacker object. 575 576 Args: 577 agg_small_grads_max_bytes: largest tensor eligible for aggregation, 578 in number of bytes. 579 agg_small_grads_max_group: largest permitted aggregation of small 580 tensors. 581 582 Raises: 583 ValueError: if `agg_small_grads_max_bytes` or `agg_small_grads_max_group` 584 is not greater than 0. 585 """ 586 if agg_small_grads_max_bytes <= 0 or agg_small_grads_max_group <= 0: 587 raise ValueError("agg_small_grads_max_bytes and agg_small_grads_max_group" 588 " should both be greater than zero.") 589 self.agg_small_grads_max_bytes = agg_small_grads_max_bytes 590 self.agg_small_grads_max_group = agg_small_grads_max_group 591 592 def pack(self, grouped_grads_and_vars): 593 """Aggregate small tensors.""" 594 if (self.agg_small_grads_max_bytes > 0 and 595 self.agg_small_grads_max_group > 0): 596 device_grads, self.packing = cross_device_utils.pack_small_tensors( 597 grouped_grads_and_vars, 598 max_bytes=self.agg_small_grads_max_bytes, 599 max_group=self.agg_small_grads_max_group) 600 return device_grads 601 602 def unpack(self, summed_device_grad_packs): 603 """Reverse the aggregation process.""" 604 return cross_device_utils.unpack_small_tensors(summed_device_grad_packs, 605 self.packing) 606 607 608def _pack_tensors(device_grads, 609 num_packs=0, 610 agg_small_grads_max_bytes=0, 611 agg_small_grads_max_group=0): 612 """Pack tensors if specified.""" 613 if num_packs > 0: 614 tensor_packer = _ConcatAndSplitPacker(num_packs) 615 device_grad_packs = tensor_packer.pack(device_grads) 616 elif agg_small_grads_max_bytes > 0 and agg_small_grads_max_group > 0: 617 tensor_packer = _AggregateSmallTensorPacker(agg_small_grads_max_bytes, 618 agg_small_grads_max_group) 619 device_grad_packs = tensor_packer.pack(device_grads) 620 else: 621 tensor_packer = None 622 device_grad_packs = device_grads 623 return device_grad_packs, tensor_packer 624 625 626def _unpack_tensors(reduced, tensor_packer=None): 627 """Unpack tensors if they are packed before all-reduce.""" 628 if tensor_packer: 629 return tensor_packer.unpack(reduced) 630 return reduced 631 632 633class AllReduceCrossDeviceOps(CrossDeviceOps): 634 """Reduction using all-reduce.""" 635 636 def __init__(self, 637 all_reduce_alg="nccl", 638 num_packs=1, 639 agg_small_grads_max_bytes=0, 640 agg_small_grads_max_group=10): 641 """All-reduce implementation of CrossDeviceOps. 642 643 Before performing all-reduce, tensors will be repacked or aggregated for 644 more efficient cross-device transportation: 645 1) If `num_packs` is non-zero, pack values into 646 `num_packs` splits. 647 2) Otherwise, if `agg_small_grads_max_bytes` > 0 and 648 `agg_small_grads_max_group` > 0, aggregate values smaller than 649 `agg_small_grads_max_bytes` into groups with at most 650 `agg_small_grads_max_group` values. 651 3) Otherwise, no repacking or grouping will happen. 652 653 Args: 654 all_reduce_alg: the all-reduce algorithm to use, currently only "nccl" or 655 "hierarchical_copy" are supported. 656 num_packs: see above. 657 agg_small_grads_max_bytes: see above. 658 agg_small_grads_max_group: see above. 659 """ 660 self._all_reduce_alg = all_reduce_alg 661 self._num_packs = num_packs 662 self._agg_small_grads_max_bytes = agg_small_grads_max_bytes 663 self._agg_small_grads_max_group = agg_small_grads_max_group 664 self._simple_cross_replica_ops = ReductionToOneDevice() 665 super(AllReduceCrossDeviceOps, self).__init__() 666 667 def reduce_implementation(self, reduce_op, per_replica_value, destinations): 668 if _devices_match(per_replica_value, destinations): 669 return self._batch_all_reduce(reduce_op, [per_replica_value])[0] 670 else: 671 return self._simple_cross_replica_ops.reduce(reduce_op, per_replica_value, 672 destinations) 673 674 def batch_reduce_implementation(self, reduce_op, value_destination_pairs): 675 all_devices_match = _all_devices_match(value_destination_pairs) 676 contains_indexed_slices = cross_device_utils.contains_indexed_slices( 677 value_destination_pairs) 678 if (all_devices_match and not context.executing_eagerly() 679 and not contains_indexed_slices): 680 return self._batch_all_reduce(reduce_op, 681 [v[0] for v in value_destination_pairs]) 682 else: 683 if not all_devices_match: 684 logging.log_first_n(logging.WARN, 685 "Efficient batch_reduce is not supported if " 686 "destinations are different.", 687 10) 688 689 return [ 690 self.reduce_implementation(reduce_op, t, destinations=v) 691 for t, v in value_destination_pairs 692 ] 693 694 def _batch_all_reduce(self, reduce_op, per_replica_values): 695 """All-reduce algorithm in a batch.""" 696 dense_values, dense_indices, sparse_values, sparse_indices = ( 697 cross_device_utils.split_by_sparsity(per_replica_values)) 698 if dense_values: 699 dense_results = self._do_batch_all_reduce(reduce_op, dense_values) 700 else: 701 dense_results = [] 702 if sparse_values: 703 sparse_results = self._do_batch_all_reduce_sparse(reduce_op, 704 sparse_values) 705 else: 706 sparse_results = [] 707 return cross_device_utils.stitch_values(((dense_results, dense_indices), 708 (sparse_results, sparse_indices))) 709 710 def _do_batch_all_reduce(self, reduce_op, dense_values): 711 """Run batch all-reduces.""" 712 logging.log_first_n( 713 logging.INFO, "batch_all_reduce invoked for batches size = %d with " 714 "algorithm = %s, num_packs = %d, agg_small_grads_max_bytes = %d and " 715 "agg_small_grads_max_group = %d" % 716 (len(dense_values), self._all_reduce_alg, self._num_packs, 717 self._agg_small_grads_max_bytes, self._agg_small_grads_max_group), 10) 718 719 destinations = dense_values[0].devices 720 grouped = _group_value_by_device(dense_values) 721 722 device_grad_packs, tensor_packer = _pack_tensors( 723 grouped, self._num_packs, self._agg_small_grads_max_bytes, 724 self._agg_small_grads_max_group) 725 726 # The actual aggregation of the repacked gradients. Note that they are 727 # sharded among different aggregation trees. So it is important to strike 728 # the balance on num_splits. 729 if self._all_reduce_alg == "nccl": 730 # TODO(yuefengz): merge this into the all-reduce library. 731 reduced = cross_device_utils.aggregate_gradients_using_nccl( 732 device_grad_packs) 733 else: 734 # TODO(yuefengz): check that gpu ids in `destinations` are in ascending 735 # order. 736 reduced = ( 737 cross_device_utils.aggregate_gradients_using_hierarchical_copy( 738 destinations, device_grad_packs)) 739 740 reduced = _unpack_tensors(reduced, tensor_packer) 741 return _ungroup_and_make_mirrored(reduced, dense_values[0], reduce_op) 742 743 def _do_batch_all_reduce_sparse(self, reduce_op, sparse_values): 744 """Run batch all-reduce for sparse values.""" 745 logging.log_first_n( 746 logging.WARN, 747 "Efficient allreduce is not supported for %d IndexedSlices" % 748 len(sparse_values), 10) 749 # Use `sparse_values` as destinations to do all-reduces. It is effectively 750 # an allgather under the hood but not an efficient one. 751 return self._simple_cross_replica_ops.batch_reduce( 752 reduce_op, zip(sparse_values, sparse_values)) 753 754 755# For compatibility with code using the old name of `AllReduceCrossDeviceOps`. 756AllReduceCrossTowerOps = AllReduceCrossDeviceOps 757 758 759AllReduceSpecTuple = collections.namedtuple("AllReduceSpecTuple", 760 "alg shards limit") 761 762 763@tf_export("distribute.NcclAllReduce") 764class NcclAllReduce(AllReduceCrossDeviceOps): 765 """Reduction using NCCL all-reduce.""" 766 767 def __init__(self, num_packs=1): 768 """NCCL all-reduce implementation of CrossDeviceOps. 769 770 Before performing all-reduce, tensors will be repacked or aggregated for 771 more efficient cross-device transportation. 772 773 Args: 774 num_packs: values will be packed in this many splits. `num_packs` should 775 be greater than 0. 776 """ 777 assert num_packs > 0, ( 778 "NCLL all-reduce requires num_packs > 0, but {} is specified".format( 779 num_packs)) 780 super(NcclAllReduce, self).__init__( 781 all_reduce_alg="nccl", num_packs=num_packs) 782 783 784@tf_export("distribute.HierarchicalCopyAllReduce") 785class HierarchicalCopyAllReduce(AllReduceCrossDeviceOps): 786 """Reduction using hierarchical copy all-reduce. 787 788 This is a good reduction for configurations like Nvidia DGX-1. 789 """ 790 791 def __init__(self, num_packs=1): 792 """Hierarchical copy all-reduce implementation of CrossDeviceOps. 793 794 Before performing all-reduce, tensors will be repacked or aggregated for 795 more efficient cross-device transportation. 796 797 Args: 798 num_packs: values will be packed in this many splits. `num_packs` should 799 be greater than 0. 800 """ 801 super(HierarchicalCopyAllReduce, self).__init__( 802 all_reduce_alg="hierarchical_copy", 803 num_packs=num_packs) 804 805 806class MultiWorkerAllReduce(AllReduceCrossDeviceOps): 807 """All-reduce algorithms for distributed TensorFlow.""" 808 809 def __init__(self, 810 worker_devices, 811 num_gpus_per_worker, 812 all_reduce_spec=("pscpu/pscpu", 2, -1), 813 num_packs=0, 814 agg_small_grads_max_bytes=0, 815 agg_small_grads_max_group=10): 816 """Initialize the all-reduce algorithm. 817 818 Args: 819 worker_devices: a list of device strings for workers participating in 820 all-reduce. 821 num_gpus_per_worker: number of GPU devices per worker. 822 all_reduce_spec: a tuple or a named tuple or a list of tuples specifying 823 the all-reduce algorithm. 824 1. The first element of a tuple is the name of the all-reduce algorithm. 825 Valid algorithm names are: "nccl", "nccl/xring", "nccl/rechd", 826 "nccl/pscpu", "xring", "pscpu", "psgpu", "pscpu/pscpu". Algorithms with 827 a "/" are hierarchical, so two all-reduces are executed, the first one 828 aggregates tensors within a worker and the second aggregates across 829 workers. 830 2. The second element of a tuple is the number of shards when doing 831 all-reduce. Let's say its values is M, each tensor after packing will be 832 split into M shards and then M parallel all-reduces would be performed 833 before finally they are concatenated backed into a complete tensor. 834 3. The third element is the maximum size of tensors that will be 835 applicable for the algorithm specified by the first element. For 836 example, if all_reduce_spec=[("nccl", 2, 1024), ("pscpu/pscpu", 2, -1)], 837 tensors with size not larger than 1024 bytes will be applied a 2-shard 838 "nccl" all-reduce and other tensors will be applied a 2-shard 839 "pscpu/pscpu" algorithm. The third elements should be in increasing 840 order across tuples and end with -1 which indicates infinity. 841 num_packs: see AllReduceCrossDeviceOps. 842 agg_small_grads_max_bytes: see AllReduceCrossDeviceOps. 843 agg_small_grads_max_group: see AllReduceCrossDeviceOps. 844 """ 845 self._worker_devices = worker_devices 846 self._num_gpus_per_worker = num_gpus_per_worker 847 super(MultiWorkerAllReduce, self).__init__( 848 num_packs=num_packs, 849 agg_small_grads_max_bytes=agg_small_grads_max_bytes, 850 agg_small_grads_max_group=agg_small_grads_max_group) 851 852 def validate_and_complete_spec(spec): 853 """Validate and complete the all-reduce spec.""" 854 # TODO(yuefengz): support namedtuple. 855 if not isinstance(spec, tuple): 856 raise ValueError( 857 "A tuple is expected for all-reduce spec: %r" % all_reduce_spec) 858 if not spec or len(spec) > 3: 859 raise ValueError( 860 "Too many elements in the all-reduce spec tuple: %r" % spec) 861 if len(spec) == 1: 862 return AllReduceSpecTuple(spec[0], 1, -1) 863 elif len(spec) == 2: 864 return AllReduceSpecTuple(spec[0], spec[1], -1) 865 else: 866 return AllReduceSpecTuple(*spec) 867 868 self._all_reduce_spec = [] 869 if isinstance(all_reduce_spec, six.string_types): 870 self._all_reduce_spec.append(AllReduceSpecTuple(all_reduce_spec, 1, -1)) 871 elif isinstance(all_reduce_spec, tuple): 872 self._all_reduce_spec.append(validate_and_complete_spec(all_reduce_spec)) 873 elif isinstance(all_reduce_spec, list): 874 self._all_reduce_spec = [ 875 validate_and_complete_spec(spec) for spec in all_reduce_spec 876 ] 877 878 def _batch_all_reduce(self, reduce_op, per_replica_values): 879 """All-reduce algorithm in a batch.""" 880 logging.log_first_n( 881 logging.INFO, 882 "distributed batch_all_reduce invoked for batches size = %d with " 883 "allreduce_spec = %r, num_packs = %d, agg_small_grads_max_bytes = %d " 884 "and agg_small_grads_max_group = %d" % 885 (len(per_replica_values), self._all_reduce_spec, self._num_packs, 886 self._agg_small_grads_max_bytes, self._agg_small_grads_max_group), 10) 887 888 device_grads = _group_value_by_device(per_replica_values) 889 890 # The all-reduce library requires fully defined shapes. 891 # TODO(yuefengz): when tensor sharding is not needed, static shapes are not 892 # required as well. 893 for device_grad in device_grads: 894 for grad, _ in device_grad: 895 if not grad.shape.is_fully_defined(): 896 raise ValueError("Shape is unknown for node %r" % grad) 897 898 remaining_grads = device_grads 899 aggregated_grads = [] 900 for spec_tuple in self._all_reduce_spec: 901 if spec_tuple.limit < 0: 902 this_grads = remaining_grads 903 remaining_grads = [] 904 else: 905 (this_grads, remaining_grads) = cross_device_utils.split_grads_by_size( 906 spec_tuple.limit, remaining_grads) 907 if this_grads: 908 device_grad_packs, tensor_packer = _pack_tensors( 909 this_grads, self._num_packs, self._agg_small_grads_max_bytes, 910 self._agg_small_grads_max_group) 911 range_agg_grads = cross_device_utils.sum_gradients_all_reduce( 912 self._worker_devices, device_grad_packs, len(self._worker_devices), 913 spec_tuple.alg, spec_tuple.shards, range(self._num_gpus_per_worker)) 914 range_agg_grads = _unpack_tensors(range_agg_grads, tensor_packer) 915 916 if not aggregated_grads: 917 aggregated_grads = range_agg_grads 918 else: 919 assert len(aggregated_grads) == len(range_agg_grads) 920 for i in range(len(aggregated_grads)): 921 aggregated_grads[i] += range_agg_grads[i] 922 assert not remaining_grads 923 924 return _ungroup_and_make_mirrored(aggregated_grads, per_replica_values[0], 925 reduce_op) 926 927 928@tf_export("distribute.experimental.CollectiveCommunication") 929class CollectiveCommunication(enum.Enum): 930 """Communication choices for CollectiveOps. 931 932 * `AUTO`: Default to runtime's automatic choices. 933 * `RING`: TensorFlow's ring algorithms for all-reduce and 934 all-gather. 935 * `NCCL`: Use ncclAllReduce for all-reduce, and ring algorithms for 936 all-gather. TODO(ayushd): add ncclAllGather implementation. 937 """ 938 AUTO = "AUTO" 939 RING = "RING" 940 NCCL = "NCCL" 941 942 943# TODO(yuefengz): support in-graph collective all-reduce. 944class CollectiveAllReduce(CrossDeviceOps): 945 """All-reduce cross device ops using collective ops. 946 947 In the between-graph replicated training, it will still do all-reduces across 948 all workers and then put results on the right destinations. 949 """ 950 951 def __init__(self, 952 num_workers=1, 953 num_gpus_per_worker=0, 954 all_reduce_merge_scope=32, 955 collective_keys=None): 956 """Initializes the object. 957 958 Args: 959 num_workers: number of workers in the between-graph replicated training. 960 num_gpus_per_worker: number of GPUs per worker. 961 all_reduce_merge_scope: size of groups into which to partition consecutive 962 gradients grouped under a common 'allreduce' name scope. This is useful 963 for some optimization of collective ops. 964 collective_keys: an optional CollectiveKey object. 965 """ 966 self._num_workers = num_workers 967 self._num_gpus_per_worker = num_gpus_per_worker 968 self._all_reduce_merge_scope = all_reduce_merge_scope 969 self._collective_keys = (collective_keys or 970 cross_device_utils.CollectiveKeys()) 971 super(CollectiveAllReduce, self).__init__() 972 973 # TODO(yuefengz, tucker): is indexed slices supported by collective ops? 974 def reduce_implementation(self, reduce_op, per_replica_value, destinations): 975 if cross_device_utils.contains_indexed_slices(per_replica_value): 976 raise ValueError( 977 "`IndexSlices` is not supported for Collective All-Reduce.") 978 979 all_reduced = self._batch_all_reduce(reduce_op, [per_replica_value])[0] 980 device_map, logical_device = get_device_map_from(destinations) 981 if (all_reduced.device_map is device_map and 982 all_reduced.logical_device == logical_device): 983 return all_reduced 984 devices = device_map.logical_to_actual_devices(logical_device) 985 index = [] 986 for d in devices: 987 if d in all_reduced.devices: 988 index.append(all_reduced.get(d)) 989 else: 990 # TODO(josh11b): Once we add support for model parallelism, get the 991 # copy from the corresponding replica instead of the primary. 992 with ops.control_dependencies(all_reduced.values), ops.device(d): 993 index.append(array_ops.identity(all_reduced.primary)) 994 995 return value_lib.Mirrored(device_map, index, logical_device) 996 997 def batch_reduce_implementation(self, reduce_op, value_destination_pairs): 998 if cross_device_utils.contains_indexed_slices(value_destination_pairs): 999 raise ValueError( 1000 "`IndexSlices` is not supported for Collective All-Reduce.") 1001 1002 all_devices_match = _all_devices_match(value_destination_pairs) 1003 if all_devices_match: 1004 return self._batch_all_reduce(reduce_op, 1005 [v[0] for v in value_destination_pairs]) 1006 else: 1007 if not all_devices_match: 1008 logging.log_first_n( 1009 logging.WARN, "Efficient batch_reduce is not supported if " 1010 "destinations are different.", 10) 1011 1012 return [ 1013 self.reduce_implementation(reduce_op, t, destinations=v) 1014 for t, v in value_destination_pairs 1015 ] 1016 1017 def _batch_all_reduce(self, reduce_op, per_replica_values): 1018 """All-reduce across all workers in a batch.""" 1019 1020 logging.log_first_n( 1021 logging.INFO, "Collective All-reduce invoked with batches size = %d, " 1022 "num_workers = %d" % (len(per_replica_values), self._num_workers), 10) 1023 1024 grouped_by_device = _group_value_by_device(per_replica_values) 1025 1026 grouped_by_var = list(zip(*grouped_by_device)) 1027 # grouped_by_var is grouped by variables and takes the following format: 1028 # [((grad0_gpu0, v0_gpu0), (grad0_gpu1, v0_gpu1), (grad0_gpu2, v0_gpu2) ..), 1029 # ((grad1_gpu0, v1_gpu0), (grad1_gpu1, v1_gpu1), (grad1_gpu0, v1_gpu2) ..), 1030 # ((grad2_gpu0, v2_gpu0), (grad2_gpu1, v2_gpu1), (grad2_gpu0, v2_gpu2) ..), 1031 # ... 1032 # ] 1033 chunked_gv = [ 1034 grouped_by_var[x:x + self._all_reduce_merge_scope] 1035 for x in range(0, len(grouped_by_var), self._all_reduce_merge_scope) 1036 ] 1037 1038 reduced_gv_list = [] 1039 for chunk in chunked_gv: 1040 with ops.name_scope("allreduce"): 1041 for grad_and_vars in chunk: 1042 scaled_grads = [g for g, _ in grad_and_vars] 1043 collective_reduced = cross_device_utils.build_collective_reduce( 1044 scaled_grads, self._num_workers, self._collective_keys, "Add", 1045 "Id") 1046 result = [] 1047 for (_, v), g in zip(grad_and_vars, collective_reduced): 1048 result.append([g, v]) 1049 reduced_gv_list.append(result) 1050 1051 new_device_grads = [list(x) for x in zip(*reduced_gv_list)] 1052 return _ungroup_and_make_mirrored( 1053 new_device_grads, 1054 per_replica_values[0], 1055 reduce_op, 1056 num_between_graph_workers=self._num_workers) 1057 1058 1059_dgx1_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], 1060 [0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]] 1061 1062 1063def _has_dgx1_like_links(gpu_links): 1064 if not gpu_links: 1065 return False 1066 # TODO(yuefengz): figure out the right topology for hierarchical copy if 1067 # number of gpus are less than 8. 1068 if len(gpu_links) < 8: 1069 return False 1070 for i, (gpu_link, dgx1_link) in enumerate(zip(gpu_links, _dgx1_links)): 1071 if (set(gpu_link) != set(dgx1_link) and 1072 set(gpu_link) != set(dgx1_link + [i])): 1073 return False 1074 return True 1075 1076 1077def _choose_all_reduce_algorithm(device_links): 1078 if _has_dgx1_like_links(device_links): 1079 return HierarchicalCopyAllReduce(num_packs=len(device_links)) 1080 else: 1081 return NcclAllReduce(num_packs=1) 1082 1083 1084def choose_the_best(devices, session_config=None): 1085 """Find the best subclass of CrossDeviceOps given a session config. 1086 1087 Args: 1088 devices: a list of devices passed to `tf.distribute.Strategy`. 1089 session_config: a `tf.ConfigProto` or `None`. If `None`, it will make 1090 decision based on all local devices. 1091 1092 Returns: 1093 A subclass of `CrossDeviceOps`. 1094 """ 1095 requested_devices = set([device_util.canonicalize(d) for d in devices]) 1096 machine_devices = device_lib.list_local_devices(session_config=session_config) 1097 using_devices = [] 1098 for d in machine_devices: 1099 if device_util.canonicalize(d.name) in requested_devices: 1100 using_devices.append(d) 1101 else: 1102 logging.info( 1103 "Device is available but not used by distribute strategy: %s", d.name) 1104 1105 if len(using_devices) != len(requested_devices): 1106 logging.warning("Not all devices in `tf.distribute.Strategy` are visible " 1107 "to TensorFlow.") 1108 return ReductionToOneDevice() 1109 1110 if any(d.device_type.lower() != "gpu" for d in using_devices): 1111 logging.warning("Not all devices in `tf.distribute.Strategy` are visible " 1112 "to TensorFlow.") 1113 return ReductionToOneDevice() 1114 1115 device_links = [[] for _ in range(len(using_devices))] 1116 for i, device in enumerate(using_devices): 1117 for link in device.locality.links.link: 1118 device_links[i].append(link.device_id) 1119 1120 return _choose_all_reduce_algorithm(device_links) 1121