1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Classes for different algorithms of reduction and broadcasting.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import copy 23import multiprocessing.dummy 24import multiprocessing.pool 25import threading 26 27import six 28 29from tensorflow.python.client import device_lib 30from tensorflow.python.distribute import collective_util 31from tensorflow.python.distribute import cross_device_utils 32from tensorflow.python.distribute import device_util 33from tensorflow.python.distribute import distribute_utils 34from tensorflow.python.distribute import ps_values 35from tensorflow.python.distribute import reduce_util 36from tensorflow.python.distribute import tpu_values 37from tensorflow.python.distribute import values as value_lib 38from tensorflow.python.distribute import values_util 39from tensorflow.python.eager import context 40from tensorflow.python.eager import def_function 41from tensorflow.python.framework import kernels 42from tensorflow.python.framework import ops 43from tensorflow.python.framework import tensor_util 44from tensorflow.python.ops import array_ops 45from tensorflow.python.ops import math_ops 46from tensorflow.python.ops import resource_variable_ops 47from tensorflow.python.platform import tf_logging as logging 48from tensorflow.python.util import nest 49from tensorflow.python.util.tf_export import tf_export 50from tensorflow.tools.docs import doc_controls 51 52 53def check_destinations(destinations): 54 """Checks whether `destinations` is not empty. 55 56 Args: 57 destinations: a `DistributedValues`, variable, or string object. 58 59 Returns: 60 Boolean which is True if `destinations` is not empty. 61 """ 62 # Calling bool() on a ResourceVariable is not allowed. 63 if isinstance(destinations, 64 (resource_variable_ops.BaseResourceVariable, ops.Tensor)): 65 return bool(destinations.device) 66 return bool(destinations) 67 68 69def validate_destinations(destinations): 70 """Validates the `destination` is one of expected types.""" 71 if not isinstance( 72 destinations, 73 (value_lib.DistributedValues, ops.Tensor, ps_values.AggregatingVariable, 74 six.string_types, tpu_values.TPUMirroredVariable 75 )) and not resource_variable_ops.is_resource_variable(destinations): 76 raise ValueError("destinations must be one of a `DistributedValues` object," 77 " a tf.Variable object, or a device string.") 78 79 if not check_destinations(destinations): 80 raise ValueError("destinations can not be empty") 81 82 83def reduce_non_distributed_value( 84 reduce_op, value, destinations, num_replicas_in_graph): 85 """Reduce a non-DistributedValue `value` to `destinations`.""" 86 if isinstance(value, value_lib.DistributedValues): 87 raise ValueError("You are passing a `DistributedValues` to " 88 "`reduce_non_distributed_value`, which is not allowed.") 89 90 # If the same value is present on all replicas then the PerReplica value will 91 # be a single value. We also handle the case when `value` is a single value 92 # and equal to 0. 93 # TODO:(b/138823479): handle the tensor value properly. 94 if not tensor_util.is_tf_type(value) and value == 0: 95 return 0 96 # If there is only a single value and the reduce op is MEAN, 97 # that value should be on all destinations. 98 if reduce_op == reduce_util.ReduceOp.MEAN: 99 return value 100 elif num_replicas_in_graph != 1: 101 # We do not support a reduce op of SUM if the value is the same across 102 # all replicas. We call this as part of assign functions for 103 # MirroredVariables and summing up identical values across replicas is not 104 # clearly defined. 105 raise ValueError("A non-DistributedValues value %s cannot be reduced with " 106 "the given reduce op %s." % (value, reduce_op)) 107 else: 108 validate_destinations(destinations) 109 return simple_broadcast(value, destinations) 110 111 112def _make_tensor_into_per_replica(input_tensor): 113 """Converts a single tensor into a PerReplica object.""" 114 if isinstance(input_tensor, (tuple, list)): 115 raise ValueError("Cannot convert `input_tensor` to a `PerReplica` object, " 116 "got %r but expected a object that is not a tuple or list." 117 % (input_tensor,)) 118 if isinstance(input_tensor, value_lib.PerReplica): 119 return input_tensor 120 elif hasattr(input_tensor, "device"): 121 return value_lib.PerReplica((input_tensor,)) 122 else: 123 raise ValueError("Cannot convert `input_tensor` to a `PerReplica` object " 124 "because it doesn't have device set.") 125 126 127def _normalize_value_destination_pairs(value_destination_pairs): 128 """Converts each tensor into a PerReplica object in the input list.""" 129 result = [] 130 131 value_destination_pairs = list(value_destination_pairs) 132 133 if not isinstance(value_destination_pairs, (list, tuple)): 134 raise ValueError("`value_destination_pairs` should be a list or tuple") 135 for pair in value_destination_pairs: 136 if not isinstance(pair, tuple): 137 raise ValueError( 138 "Each element of `value_destination_pairs` should be a tuple.") 139 if len(pair) != 2: 140 raise ValueError("Each element of `value_destination_pairs` should be a " 141 "tuple of size 2.") 142 143 per_replica = _make_tensor_into_per_replica(pair[0]) 144 result.append((per_replica, pair[1])) 145 return result 146 147 148def _validate_value_destination_pairs(value_destination_pairs): 149 """Validates value_destination_pairs are valid.""" 150 # TODO(yuefengz): raise exceptions instead of returning False. 151 if not value_destination_pairs: return False 152 if not isinstance(value_destination_pairs, (list, tuple)): return False 153 if not all(isinstance(pair, tuple) for pair in value_destination_pairs): 154 return False 155 if not all(isinstance(v[0], value_lib.PerReplica) 156 for v in value_destination_pairs): 157 return False 158 return True 159 160 161# TODO(yuefengz): consider calling this function in the caller of 162# CrossDeviceOps. 163def get_devices_from(destinations): 164 if isinstance(destinations, value_lib.DistributedValues): 165 return destinations._devices # pylint: disable=protected-access 166 elif isinstance(destinations, six.string_types): 167 return (device_util.resolve(destinations),) 168 return (device_util.resolve(destinations.device),) 169 170 171def _devices_match(left, right): 172 return left is right or set(get_devices_from(left)) == set( 173 get_devices_from(right)) 174 175 176def _all_devices_match(value_destination_pairs): 177 if not all(_devices_match(v, d) for v, d in value_destination_pairs): 178 return False 179 if not all(_devices_match(v, value_destination_pairs[0][0]) 180 for v, _ in value_destination_pairs[1:]): 181 return False 182 return True 183 184 185def simple_broadcast(value, destinations, always_mirrored=False): 186 """Broadcast `value` to `destinations` using simple copies.""" 187 devices = get_devices_from(destinations) 188 if len(devices) == 1 and not always_mirrored: 189 return cross_device_utils.copy_tensor_or_indexed_slices_to_device( 190 value, devices[0]) 191 else: 192 value_updates = [] 193 for d in devices: 194 value_updates.append( 195 cross_device_utils.copy_tensor_or_indexed_slices_to_device(value, d)) 196 return distribute_utils.regroup(value_updates, 197 wrap_class=value_lib.Mirrored) 198 199 200def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn, 201 reduce_op): 202 """Reduces the value by accumulation_fn and reduce_op.""" 203 all_values = per_replica_value.values 204 if not all_values: 205 raise ValueError("`per_replica_value` must be non-empty") 206 count = len(all_values) 207 208 with ops.device(reduce_to_device): 209 with context.device_policy(context.DEVICE_PLACEMENT_SILENT): 210 reduced = cross_device_utils.aggregate_tensors_or_indexed_slices( 211 all_values, accumulation_fn) 212 if reduce_op == reduce_util.ReduceOp.MEAN: 213 reduced = cross_device_utils.divide_by_n_tensors_or_indexed_slices( 214 reduced, count) 215 elif reduce_op != reduce_util.ReduceOp.SUM: 216 raise ValueError("`reduce_op` must be Reduce.SUM or Reduce.MEAN.") 217 return reduced 218 219 220def _simple_gather(per_replica_value, reduce_to_device, axis): 221 """Concatenate all values in the DistributedValues input and return.""" 222 all_values = per_replica_value.values 223 if not all_values: 224 raise ValueError("`per_replica_value` must be non-empty") 225 226 with ops.device(reduce_to_device): 227 with context.device_policy(context.DEVICE_PLACEMENT_SILENT): 228 gathered = array_ops.concat(all_values, axis) 229 return gathered 230 231 232@tf_export("distribute.CrossDeviceOps") 233class CrossDeviceOps(object): 234 """Base class for cross-device reduction and broadcasting algorithms. 235 236 The main purpose of this class is to be passed to 237 `tf.distribute.MirroredStrategy` in order to choose among different cross 238 device communication implementations. Prefer using the methods of 239 `tf.distribute.Strategy` instead of the ones of this class. 240 241 Implementations: 242 * `tf.distribute.ReductionToOneDevice` 243 * `tf.distribute.NcclAllReduce` 244 * `tf.distribute.HierarchicalCopyAllReduce` 245 """ 246 247 def __init__(self): 248 pass 249 250 @property 251 def _num_between_graph_workers(self): 252 # Returns 1 by default, the value may be overridden by sub classes. 253 return 1 254 255 def reduce(self, reduce_op, per_replica_value, destinations, options=None): 256 """Reduce `per_replica_value` to `destinations`. 257 258 See `tf.distribute.StrategyExtended.reduce_to`. This can only be called in 259 the cross-replica context. 260 261 Args: 262 reduce_op: a `tf.distribute.ReduceOp` specifying how values should be 263 combined. 264 per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` 265 like object. 266 destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a 267 `tf.Tensor` alike object, or a device string. It specifies the devices 268 to reduce to. To perform an all-reduce, pass the same to `value` and 269 `destinations`. Note that if it's a `tf.Variable`, the value is reduced 270 to the devices of that variable, and this method doesn't update the 271 variable. 272 options: a `tf.distribute.experimental.CommunicationOptions`. See 273 `tf.distribute.experimental.CommunicationOptions` for details. 274 275 Returns: 276 A `tf.Tensor` or `tf.distribute.DistributedValues`. 277 278 Raises: 279 ValueError: if per_replica_value can't be converted to a 280 `tf.distribute.DistributedValues` or if destinations is not a string, 281 `tf.Variable` or `tf.distribute.DistributedValues`. 282 """ 283 if options is None: 284 options = collective_util.Options() 285 if not isinstance(per_replica_value, value_lib.DistributedValues): 286 per_replica_value = _make_tensor_into_per_replica(per_replica_value) 287 288 validate_destinations(destinations) 289 290 # Shortcut if `per_replica_value` only contains one value. 291 if self._num_between_graph_workers == 1 and len( 292 per_replica_value.values) == 1 and _devices_match( 293 per_replica_value, destinations): 294 with ops.device(per_replica_value.values[0].device): 295 v = array_ops.identity(per_replica_value.values[0]) 296 return distribute_utils.regroup((v,), wrap_class=value_lib.Mirrored) 297 298 if options is None: 299 options = collective_util.Options() 300 return self.reduce_implementation(reduce_op, per_replica_value, 301 destinations, options) 302 303 def _gather(self, per_replica_value, destinations, axis, options=None): 304 """Gather `per_replica_value` to `destinations`. 305 306 Args: 307 per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` 308 like object. 309 destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a 310 `tf.Tensor` alike object, or a device string. It specifies the devices 311 to gather to. To perform an all-gather, pass the same to `value` and 312 `destinations`. Note that if it's a `tf.Variable`, the value is gathered 313 to the devices of that variable, and this method doesn't update the 314 variable. 315 axis: specifies the dimension to gather along within each replica's 316 tensor. 317 options: a `tf.distribute.experimental.CommunicationOptions`. See 318 `tf.distribute.experimental.CommunicationOptions` for details. 319 320 Returns: 321 A `tf.Tensor` or `tf.distribute.DistributedValues` 322 323 Raises: 324 ValueError: if per_replica_value can't be converted to a 325 `tf.distribute.DistributedValues` or if destinations is not a string, 326 `tf.Variable` or `tf.distribute.DistributedValues`. 327 """ 328 if isinstance(per_replica_value, ops.IndexedSlices): 329 raise NotImplementedError("gather/all_gather does not support " 330 "IndexedSlices") 331 if options is None: 332 options = collective_util.Options() 333 334 if not isinstance(per_replica_value, value_lib.DistributedValues): 335 per_replica_value = _make_tensor_into_per_replica(per_replica_value) 336 337 validate_destinations(destinations) 338 339 # Shortcut if `per_replica_value` only contains one value. 340 if self._num_between_graph_workers == 1 and len( 341 per_replica_value.values) == 1 and _devices_match( 342 per_replica_value, destinations): 343 with ops.device(per_replica_value.values[0].device): 344 v = array_ops.identity(per_replica_value.values[0]) 345 return distribute_utils.regroup((v,), wrap_class=value_lib.Mirrored) 346 347 return self._gather_implementation(per_replica_value, destinations, axis, 348 options) 349 350 def _gather_implementation(self, per_replica_value, destinations, axis, 351 options): 352 """Implementation of `gather` method of `tf.distribute.CrossDeviceOps`. 353 354 Overriding this method is useful for subclass implementers. 355 356 Args: 357 per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` 358 like object. 359 destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a 360 `tf.Tensor` alike object, or a device string. It specifies the devices 361 to gather to. To perform an all-gather, pass the same to `value` and 362 `destinations`. Note that if it's a `tf.Variable`, the value is gathered 363 to the devices of that variable, this method doesn't update the 364 variable. 365 axis: specifies the dimension to gather along within each replica's 366 tensor. 367 options: a `tf.distribute.experimental.CommunicationOptions`. See 368 `tf.distribute.experimental.CommunicationOptions` for details. 369 370 Returns: 371 A `tf.Tensor` or `tf.distribute.DistributedValues`. 372 373 Raises: 374 ValueError: if per_replica_value can't be converted to a 375 `tf.distribute.DistributedValues` or if destinations is not a string, 376 `tf.Variable` or `tf.distribute.DistributedValues`. 377 """ 378 raise NotImplementedError( 379 "_gather method must be implemented in descendants.") 380 381 def batch_reduce(self, reduce_op, value_destination_pairs, options=None): 382 """Reduce values to destinations in batches. 383 384 See `tf.distribute.StrategyExtended.batch_reduce_to`. This can only be 385 called in the cross-replica context. 386 387 Args: 388 reduce_op: a `tf.distribute.ReduceOp` specifying how values should be 389 combined. 390 value_destination_pairs: a sequence of (value, destinations) pairs. See 391 `tf.distribute.CrossDeviceOps.reduce` for descriptions. 392 options: a `tf.distribute.experimental.CommunicationOptions`. See 393 `tf.distribute.experimental.CommunicationOptions` for details. 394 395 Returns: 396 A list of `tf.Tensor` or `tf.distribute.DistributedValues`, one per pair 397 in `value_destination_pairs`. 398 399 Raises: 400 ValueError: if `value_destination_pairs` is not an iterable of 401 tuples of `tf.distribute.DistributedValues` and destinations. 402 """ 403 if options is None: 404 options = collective_util.Options() 405 # TODO(yuefengz): if destinations are different, split into several 406 # `_batch_reduce` invocations. 407 if not _validate_value_destination_pairs(value_destination_pairs): 408 # If the first element of each pair is a tensor, we try to turn it into a 409 # PerReplica object. 410 value_destination_pairs = _normalize_value_destination_pairs( 411 value_destination_pairs) 412 413 for _, d in value_destination_pairs: 414 validate_destinations(d) 415 416 # Shortcut all PerReplica objects only contain one value. 417 if self._num_between_graph_workers == 1 and _all_devices_match( 418 value_destination_pairs) and len( 419 value_destination_pairs[0][0].values) == 1: 420 return [ 421 distribute_utils.regroup(v.values, wrap_class=value_lib.Mirrored) 422 for v, _ in value_destination_pairs 423 ] 424 425 if options is None: 426 options = collective_util.Options() 427 return self.batch_reduce_implementation(reduce_op, value_destination_pairs, 428 options) 429 430 def broadcast(self, tensor, destinations): 431 """Broadcast `tensor` to `destinations`. 432 433 This can only be called in the cross-replica context. 434 435 Args: 436 tensor: a `tf.Tensor` like object. The value to broadcast. 437 destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a 438 `tf.Tensor` alike object, or a device string. It specifies the devices 439 to broadcast to. Note that if it's a `tf.Variable`, the value is 440 broadcasted to the devices of that variable, this method doesn't update 441 the variable. 442 443 Returns: 444 A `tf.Tensor` or `tf.distribute.DistributedValues`. 445 """ 446 validate_destinations(destinations) 447 return self.broadcast_implementation(tensor, destinations) 448 449 @doc_controls.for_subclass_implementers 450 def reduce_implementation(self, reduce_op, per_replica_value, destinations, 451 options): 452 """Implementation of `reduce`. 453 454 Overriding this method is useful for subclass implementers. 455 456 Args: 457 reduce_op: a `tf.distribute.ReduceOp` specifying how values should be 458 combined. 459 per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` 460 like object. 461 destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a 462 `tf.Tensor` alike object, or a device string. It specifies the devices 463 to reduce to. To perform an all-reduce, pass the same to `value` and 464 `destinations`. Note that if it's a `tf.Variable`, the value is reduced 465 to the devices of that variable, this method doesn't update the 466 variable. 467 options: a `tf.distribute.experimental.CommunicationOptions`. See 468 `tf.distribute.experimental.CommunicationOptions` for details. 469 470 Returns: 471 A `tf.Tensor` or `tf.distribute.DistributedValues`. 472 473 Raises: 474 ValueError: if per_replica_value can't be converted to a 475 `tf.distribute.DistributedValues` or if destinations is not a string, 476 `tf.Variable` or `tf.distribute.DistributedValues`. 477 """ 478 raise NotImplementedError( 479 "_reduce method must be implemented in descendants.") 480 481 @doc_controls.for_subclass_implementers 482 def batch_reduce_implementation(self, reduce_op, value_destination_pairs, 483 options): 484 """Implementation of `batch_reduce`. 485 486 Overriding this method is useful for subclass implementers. 487 488 Args: 489 reduce_op: a `tf.distribute.ReduceOp` specifying how values should be 490 combined. 491 value_destination_pairs: a sequence of (value, destinations) pairs. See 492 `reduce` for descriptions. 493 options: a `tf.distribute.experimental.CommunicationOptions`. See 494 `tf.distribute.experimental.CommunicationOptions` for details. 495 496 Returns: 497 A list of `tf.Tensor` or `tf.distribute.DistributedValues`, one per pair 498 in `value_destination_pairs`. 499 500 Raises: 501 ValueError: if `value_destination_pairs` is not an iterable of 502 tuples of `tf.distribute.DistributedValues` and destinations. 503 """ 504 raise NotImplementedError( 505 "batch_reduce_implementation method must be implemented in descendants." 506 ) 507 508 @doc_controls.for_subclass_implementers 509 def broadcast_implementation(self, tensor, destinations): 510 """Implementation of `broadcast`. 511 512 Args: 513 tensor: a `tf.Tensor` like object. The value to broadcast. 514 destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a 515 `tf.Tensor` alike object, or a device string. It specifies the devices 516 to broadcast to. 517 `destinations`. Note that if it's a `tf.Variable`, the value is 518 broadcasted to the devices of that variable, this method doesn't update 519 the variable. 520 521 Returns: 522 A `tf.Tensor` or `tf.distribute.DistributedValues`. 523 """ 524 return simple_broadcast(tensor, destinations, always_mirrored=True) 525 526 # ========================== Collective APIs ================================ 527 # 528 # Different than `reduce`, `batch_reduce` and `broadcast` which must be called 529 # in cross-replcia context, collective APIs are to be called in replica 530 # context. 531 532 def _all_reduce(self, reduce_op, value, replica_id, options): 533 """All-reduce the `value` across all replicas so that all get the result. 534 535 `value` can be a nested structure of tensors. The implementation should 536 generally batch the all-reduces when possible. `options` can be set to 537 hint the batching behavior. 538 539 This API must be called in a replica context. 540 541 Args: 542 reduce_op: A `tf.distribute.ReduceOp` value specifying how values should 543 be combined. Allows using string representation of the enum such as 544 "SUM", "MEAN". 545 value: Value to be reduced. A tensor or a nested structure of tensors. 546 replica_id: An interger indicating the id of the replica where this 547 all_reduce is called under. This is the local replica id that ranges 548 from 0 to len(local_devices) - 1. 549 options: A `tf.distribute.experimental.CommunicationOptions`. 550 551 Returns: 552 A tensor or a nested strucutre of tensors with the reduced values. The 553 structure is the same as `value`. 554 """ 555 raise NotImplementedError("_all_reduce must be implemented in descendants.") 556 557 558@tf_export("distribute.ReductionToOneDevice") 559class ReductionToOneDevice(CrossDeviceOps): 560 """A CrossDeviceOps implementation that copies values to one device to reduce. 561 562 This implementation always copies values to one device to reduce them, then 563 broadcast reduced values to the destinations. It doesn't support efficient 564 batching. 565 566 Here is how you can use `ReductionToOneDevice` in 567 `tf.distribute.MirroredStrategy`: 568 569 ``` 570 strategy = tf.distribute.MirroredStrategy( 571 cross_device_ops=tf.distribute.ReductionToOneDevice()) 572 ``` 573 """ 574 575 def __init__(self, reduce_to_device=None, accumulation_fn=None): 576 """Initializes with a device to reduce to and a way to accumulate. 577 578 Args: 579 reduce_to_device: the intermediate device to reduce to. If None, reduce 580 to the first device in `destinations` of the `reduce` method. 581 accumulation_fn: a function that does accumulation. If None, 582 `tf.math.add_n` is used. 583 """ 584 self.reduce_to_device = reduce_to_device 585 self.accumulation_fn = accumulation_fn or math_ops.add_n 586 super(ReductionToOneDevice, self).__init__() 587 588 def reduce_implementation(self, reduce_op, per_replica_value, destinations, 589 options): 590 del options # Unused. 591 if check_destinations(destinations): 592 devices = get_devices_from(destinations) 593 else: 594 devices = get_devices_from(per_replica_value) 595 reduce_to_device = self.reduce_to_device or devices[0] 596 logging.log_first_n( 597 logging.INFO, 598 "Reduce to %s then broadcast to %r." % (reduce_to_device, devices), 10) 599 reduced = _simple_reduce(per_replica_value, reduce_to_device, 600 self.accumulation_fn, reduce_op) 601 return self.broadcast(reduced, destinations) 602 603 def _gather_implementation(self, per_replica_value, destinations, axis, 604 options): 605 del options # Unused. 606 if check_destinations(destinations): 607 devices = get_devices_from(destinations) 608 else: 609 devices = get_devices_from(per_replica_value) 610 reduce_to_device = self.reduce_to_device or devices[0] 611 logging.log_first_n( 612 logging.INFO, 613 "Gather to %s then broadcast to %r." % (reduce_to_device, devices), 10) 614 gathered = _simple_gather(per_replica_value, reduce_to_device, axis) 615 return self.broadcast(gathered, destinations) 616 617 def batch_reduce_implementation(self, reduce_op, value_destination_pairs, 618 options): 619 return [ 620 self.reduce_implementation( 621 reduce_op, t, destinations=v, options=options) 622 for t, v in value_destination_pairs 623 ] 624 625 626def _group_value_by_device(per_replica_values): 627 """Group values into sublists by their devices. 628 629 This grouping is needed to call the all-reduce library because it expects a 630 list of the following form: 631 [[(grad0_gpu0, v0_gpu0), (grad1_gpu0, v1_gpu0), (grad2_gpu0, v2_gpu0) ...], 632 [(grad0_gpu1, v0_gpu1), (grad1_gpu1, v1_gpu1), (grad2_gpu1, v2_gpu1) ...], 633 [(grad0_gpu2, v0_gpu2), (grad1_gpu0, v1_gpu2), (grad2_gpu0, v2_gpu2) ...], 634 ... 635 ] 636 637 Args: 638 per_replica_values: a list of PerReplica objects. 639 640 Returns: 641 a list of lists, each sublist has components for its corresponding device of 642 PerReplica objects, paired with a None. 643 """ 644 destinations = per_replica_values[0]._devices # pylint: disable=protected-access 645 grouped = [[] for _ in range(len(destinations))] 646 for per_replica_value in per_replica_values: 647 # pylint: disable=protected-access 648 for i, v in enumerate(per_replica_value.values): 649 assert per_replica_value._devices == destinations 650 grouped[i].append((v, None)) 651 return grouped 652 653 654def _ungroup_and_make_mirrored(grouped_reduced, 655 destinations, 656 reduce_op, 657 num_between_graph_workers=1): 658 """Ungroup results from all-reduce and make Mirrored objects. 659 660 Each all-reduce result will be divided by the number of destinations before 661 Mirrored objects are created if reduce_op is "mean". 662 663 Args: 664 grouped_reduced: a list of lists, each sublist has components for each 665 device, paired with a None. It is the result from 666 cross_device_utils.aggregate_gradients_using*. 667 destinations: a value to colocate the result with. 668 reduce_op: Indicates how values will be aggregated. Accepted values 669 are `tf.distribute.ReduceOp.SUM`, `tf.distribute.ReduceOp.MEAN`. 670 num_between_graph_workers: number of workers in the between-graph 671 replication. 672 673 Returns: 674 a list of Mirrored objects. 675 """ 676 num_replicas = len(get_devices_from(destinations)) * num_between_graph_workers 677 index = [[] for _ in range(len(grouped_reduced[0]))] 678 for per_replica_reduced in grouped_reduced: 679 for i, (v, _) in enumerate(per_replica_reduced): 680 if reduce_op == reduce_util.ReduceOp.MEAN: 681 with ops.device(v.device): 682 index[i].append(v / num_replicas) 683 else: 684 index[i].append(v) 685 return [distribute_utils.regroup( 686 v, wrap_class=value_lib.Mirrored) for v in index] 687 688 689class _ConcatAndSplitPacker(object): 690 """Concatenate and split tensors for reduction.""" 691 692 def __init__(self, num_packs=1): 693 """Initialize the _ConcatAndSplitPacker object. 694 695 Args: 696 num_packs: specifies the number of split packs that will be 697 formed. 698 699 Raises: 700 ValueError: if num_packs is not greater than 0. 701 """ 702 if num_packs <= 0: 703 raise ValueError("num_packs must be greater than zero.") 704 self.num_packs = num_packs 705 706 def pack(self, grouped_grads_and_vars): 707 """Pack tensors.""" 708 self.grouped_grads_and_vars = grouped_grads_and_vars 709 self.all_device_shapes = [] 710 self.all_device_sizes = [] 711 712 device_grad_packs = [] 713 for device_grads_and_vars in grouped_grads_and_vars: 714 with ops.colocate_with(device_grads_and_vars[0][0]): 715 # Flatten all the grads. 716 flat_grads = [ 717 array_ops.reshape(g, [-1]) for g, _ in device_grads_and_vars 718 ] 719 # Remember the original shape of all the grads. 720 device_shapes = [array_ops.shape(g) for g, _ in device_grads_and_vars] 721 # Remember the original sizes of all the grads. 722 device_sizes = [array_ops.size(g) for g, _ in device_grads_and_vars] 723 # Concat all the flat grads into a big flat tensor. 724 concat_grads = array_ops.concat(flat_grads, 0) 725 726 # Split the big tensor into num_splits packs. In cases where the 727 # total size is not divisible num_splits, the last pack gets 728 # more elements. 729 # TODO(zhengxq): it is also possible to optimize away all the concat 730 # as well. 731 num_splits = self.num_packs 732 733 # The array_ops.size function will sometimes remove static shapes. So if 734 # all gradient shapes are defined, we use another method to get the 735 # total size. 736 # TODO(yuefengz): move this logic to array_ops.size. 737 if all(g.shape.is_fully_defined() for g, _ in device_grads_and_vars): 738 total_grad_size = sum( 739 [g.shape.num_elements() for g, _ in device_grads_and_vars]) 740 else: 741 total_grad_size = array_ops.size(concat_grads) 742 743 split_size = total_grad_size // num_splits 744 split_size_last = total_grad_size - split_size * (num_splits - 1) 745 split_sizes = [split_size] * (num_splits - 1) + [split_size_last] 746 grad_packs = array_ops.split(concat_grads, split_sizes) 747 748 # Ready to aggregate the repacked gradients, with fake variables. 749 # TODO(zhengxq): It is hacky to have to use fake variables. 750 # We should remove the need for variables in 751 # aggregate_gradients_using*. 752 device_grad_packs.append(zip(grad_packs, [None] * num_splits)) 753 self.all_device_shapes.append(device_shapes) 754 self.all_device_sizes.append(device_sizes) 755 756 return device_grad_packs 757 758 def unpack(self, summed_device_grad_packs): 759 """Reverse the pack.""" 760 aggregated_device_grads = [] 761 for (summed_device_grad_packs, 762 device_grads_and_vars, device_shapes, device_sizes) in zip( 763 summed_device_grad_packs, self.grouped_grads_and_vars, 764 self.all_device_shapes, self.all_device_sizes): 765 # pylint: enable=line-too-long 766 # Reverse the packing operations in the previous steps. Form the 767 # summed gradients back into their original shapes. 768 with ops.colocate_with(summed_device_grad_packs[0][0]): 769 # Form a list of the summed grad packs. 770 device_grad_packs = [g for g, _ in summed_device_grad_packs] 771 772 # Concat them back into a big flat tensor. 773 device_grads_concat = array_ops.concat(device_grad_packs, 0) 774 775 # Split the tensors back into their original sizes. 776 grads_with_sizes = array_ops.split(device_grads_concat, device_sizes) 777 778 # Reshape the tensors back into their original shapes. 779 grads_with_shapes = [ 780 array_ops.reshape(grad, shape) 781 for shape, grad in zip(device_shapes, grads_with_sizes) 782 ] 783 784 # Form the list with the original list of variables. 785 summed_device_grads = [ 786 (g, v) for g, (_, v) in zip(grads_with_shapes, 787 device_grads_and_vars) 788 ] 789 aggregated_device_grads.append(summed_device_grads) 790 return aggregated_device_grads 791 792 793def _pack_tensors(device_grads, num_packs=0): 794 """Pack tensors if specified.""" 795 if num_packs > 0: 796 tensor_packer = _ConcatAndSplitPacker(num_packs) 797 device_grad_packs = tensor_packer.pack(device_grads) 798 else: 799 tensor_packer = None 800 device_grad_packs = device_grads 801 return device_grad_packs, tensor_packer 802 803 804def _unpack_tensors(reduced, tensor_packer=None): 805 """Unpack tensors if they are packed before all-reduce.""" 806 if tensor_packer: 807 return tensor_packer.unpack(reduced) 808 return reduced 809 810 811class AllReduceCrossDeviceOps(CrossDeviceOps): 812 """All-reduce implementation of CrossDeviceOps. 813 814 It performs all-reduce when applicable using NCCL or hierarchical copy. For 815 the batch API, tensors will be repacked or aggregated for more efficient 816 cross-device transportation. 817 818 For reduces that are not all-reduce, it falls back to 819 `tf.distribute.ReductionToOneDevice`. 820 """ 821 822 def __init__(self, all_reduce_alg="nccl", num_packs=1): 823 """Initializes the object. 824 825 Args: 826 all_reduce_alg: the all-reduce algorithm to use, currently only "nccl" or 827 "hierarchical_copy" are supported. 828 num_packs: a non-negative integer. The number of packs to split values 829 into. If zero, no packing will be done. 830 """ 831 self._all_reduce_alg = all_reduce_alg 832 self._num_packs = num_packs 833 self._simple_cross_replica_ops = ReductionToOneDevice() 834 super(AllReduceCrossDeviceOps, self).__init__() 835 836 def reduce_implementation(self, reduce_op, per_replica_value, destinations, 837 options): 838 del options # Unused. 839 # To use NCCL or all-reduce, source and destination devices should match, 840 # and none of the devices should be CPU. 841 if (_devices_match(per_replica_value, destinations) and 842 not any("cpu" in d.lower() for d in get_devices_from(destinations))): 843 return self._batch_all_reduce(reduce_op, [per_replica_value])[0] 844 else: 845 return self._simple_cross_replica_ops.reduce(reduce_op, per_replica_value, 846 destinations) 847 848 def batch_reduce_implementation(self, reduce_op, value_destination_pairs, 849 options): 850 if _all_devices_match(value_destination_pairs): 851 return self._batch_all_reduce(reduce_op, 852 [v[0] for v in value_destination_pairs]) 853 else: 854 return [ 855 self.reduce_implementation(reduce_op, value, dest, options) 856 for value, dest in value_destination_pairs 857 ] 858 859 def _batch_all_reduce(self, reduce_op, per_replica_values): 860 """All-reduce algorithm in a batch.""" 861 dense_values, dense_indices, sparse_values, sparse_indices = ( 862 cross_device_utils.split_by_sparsity(per_replica_values)) 863 if dense_values: 864 dense_results = self._do_batch_all_reduce(reduce_op, dense_values) 865 else: 866 dense_results = [] 867 if sparse_values: 868 sparse_results = self._do_batch_all_reduce_sparse(reduce_op, 869 sparse_values) 870 else: 871 sparse_results = [] 872 return cross_device_utils.stitch_values(((dense_results, dense_indices), 873 (sparse_results, sparse_indices))) 874 875 def _do_batch_all_reduce(self, reduce_op, dense_values): 876 """Run batch all-reduces.""" 877 logging.log_first_n( 878 logging.INFO, 879 "batch_all_reduce: %d all-reduces with algorithm = %s, num_packs = %d" % 880 (len(dense_values), self._all_reduce_alg, self._num_packs), 10) 881 882 destinations = dense_values[0]._devices # pylint: disable=protected-access 883 grouped = _group_value_by_device(dense_values) 884 885 # device_grad_packs: 886 # [[(t0_gpu0, None), (t1_gpu0, None)], [(t0_gpu1, None), (t1_gpu1, None)]] 887 device_grad_packs, tensor_packer = _pack_tensors(grouped, self._num_packs) 888 889 # The actual aggregation of the repacked gradients. Note that they are 890 # sharded among different aggregation trees. So it is important to strike 891 # the balance on num_splits. 892 if self._all_reduce_alg == "nccl": 893 # TODO(yuefengz): merge this into the all-reduce library. 894 reduced = cross_device_utils.aggregate_gradients_using_nccl( 895 device_grad_packs) 896 else: 897 # TODO(yuefengz): check that gpu ids in `destinations` are in ascending 898 # order. 899 reduced = ( 900 cross_device_utils.aggregate_gradients_using_hierarchical_copy( 901 destinations, device_grad_packs)) 902 903 reduced = _unpack_tensors(reduced, tensor_packer) 904 return _ungroup_and_make_mirrored(reduced, dense_values[0], reduce_op) 905 906 def _do_batch_all_reduce_sparse(self, reduce_op, sparse_values): 907 """Run batch all-reduce for sparse values.""" 908 logging.log_first_n( 909 logging.WARN, 910 "Efficient allreduce is not supported for %d IndexedSlices" % 911 len(sparse_values), 10) 912 # Use `sparse_values` as destinations to do all-reduces. It is effectively 913 # an allgather under the hood but not an efficient one. 914 return self._simple_cross_replica_ops.batch_reduce( 915 reduce_op, zip(sparse_values, sparse_values)) 916 917 def _gather_implementation(self, per_replica_value, destinations, axis, 918 options): 919 logging.warning("gather/all_gather with NCCL or HierarchicalCopy is not " 920 "supported. Falling back to gather on one device and " 921 "then broadcast. We're working on a more efficient " 922 "implementation.") 923 return ReductionToOneDevice()._gather(per_replica_value, destinations, axis, # pylint: disable=protected-access 924 options) 925 926 927# For compatibility with code using the old name of `AllReduceCrossDeviceOps`. 928AllReduceCrossTowerOps = AllReduceCrossDeviceOps 929 930 931AllReduceSpecTuple = collections.namedtuple("AllReduceSpecTuple", 932 "alg shards limit") 933 934 935@tf_export("distribute.NcclAllReduce") 936class NcclAllReduce(AllReduceCrossDeviceOps): 937 """NCCL all-reduce implementation of CrossDeviceOps. 938 939 It uses Nvidia NCCL for all-reduce. For the batch API, tensors will be 940 repacked or aggregated for more efficient cross-device transportation. 941 942 For reduces that are not all-reduce, it falls back to 943 `tf.distribute.ReductionToOneDevice`. 944 945 Here is how you can use `NcclAllReduce` in `tf.distribute.MirroredStrategy`: 946 947 948 ``` 949 strategy = tf.distribute.MirroredStrategy( 950 cross_device_ops=tf.distribute.NcclAllReduce()) 951 ``` 952 """ 953 954 def __init__(self, num_packs=1): 955 """Initializes the object. 956 957 Args: 958 num_packs: a non-negative integer. The number of packs to split values 959 into. If zero, no packing will be done. 960 961 Raises: 962 ValueError: if `num_packs` is negative. 963 """ 964 if num_packs < 0: 965 raise ValueError( 966 "NCCL all-reduce requires num_packs >= 0, but {} is specified".format( 967 num_packs)) 968 super(NcclAllReduce, self).__init__( 969 all_reduce_alg="nccl", num_packs=num_packs) 970 971 972@tf_export("distribute.HierarchicalCopyAllReduce") 973class HierarchicalCopyAllReduce(AllReduceCrossDeviceOps): 974 """Hierarchical copy all-reduce implementation of CrossDeviceOps. 975 976 It reduces to one GPU along edges in some hierarchy and broadcasts back to 977 each GPU along the same path. For the batch API, tensors will be repacked or 978 aggregated for more efficient cross-device transportation. 979 980 This is a reduction created for Nvidia DGX-1 which assumes GPUs connects like 981 that on DGX-1 machine. If you have different GPU inter-connections, it is 982 likely that it would be slower than `tf.distribute.ReductionToOneDevice`. 983 984 For reduces that are not all-reduce, it falls back to 985 `tf.distribute.ReductionToOneDevice`. 986 987 Here is how you can use `HierarchicalCopyAllReduce` in 988 `tf.distribute.MirroredStrategy`: 989 990 ``` 991 strategy = tf.distribute.MirroredStrategy( 992 cross_device_ops=tf.distribute.HierarchicalCopyAllReduce()) 993 ``` 994 """ 995 996 def __init__(self, num_packs=1): 997 """Initializes the object. 998 999 Args: 1000 num_packs: a non-negative integer. The number of packs to split values 1001 into. If zero, no packing will be done. 1002 1003 Raises: 1004 ValueError if `num_packs` is negative. 1005 """ 1006 if num_packs < 0: 1007 raise ValueError( 1008 "HierarchicalCopy requires num_packs >= 0, but {} is specified" 1009 .format(num_packs)) 1010 super(HierarchicalCopyAllReduce, self).__init__( 1011 all_reduce_alg="hierarchical_copy", 1012 num_packs=num_packs) 1013 1014 1015# TODO(crccw): remove after migrating all callers. 1016CollectiveCommunication = collective_util.CommunicationImplementation 1017CommunicationImplementation = collective_util.CommunicationImplementation 1018 1019 1020# TODO(yuefengz): support in-graph collective all-reduce. 1021class CollectiveAllReduce(CrossDeviceOps): 1022 """All-reduce cross device ops using collective ops. 1023 1024 In the between-graph replicated training, it will still do all-reduces across 1025 all workers and then put results on the right destinations. 1026 """ 1027 1028 def __init__(self, devices, group_size, collective_keys=None): 1029 """Initializes the object. 1030 1031 Args: 1032 devices: a list of device strings to run collectives on. 1033 group_size: the global group size. For between-graph replicated training 1034 it's the total number of devices across all workers. 1035 collective_keys: an optional CollectiveKey object. 1036 """ 1037 if group_size % len(devices) > 0: 1038 raise ValueError("group_size must be divisible by the number of devices.") 1039 1040 self._group_size = group_size 1041 self._collective_keys = (collective_keys or 1042 cross_device_utils.CollectiveKeys()) 1043 # This lock guards all collective launches, i.e. calls to 1044 # cross_device_utils.build_collectve_*. 1045 # 1046 # In a multi threaded eager program we need to ensure different groups of 1047 # collectives don't interleave each other, otherwise there could be 1048 # deadlocks. E.g. if two user threads both are launching collectives: 1049 # user-thread-0 device0 device1 1050 # user-thread-1 device0 device1 1051 # In eager mode, we use one thread per device to launch collective ops, so 1052 # the above launch sequences end up with the following queues: 1053 # device-0 collective-0 collective-1 1054 # device-1 collective-1 collective-0 1055 # This deadlocks since neither collective is able to finish. 1056 self._lock = threading.Lock() 1057 1058 self._devices = tuple(device_util.canonicalize(d) for d in devices) 1059 group_key = self._collective_keys.get_group_key(self._devices) 1060 self._launchers = [] 1061 # Whether to only use NCCL for batched all-reduce when NCCL is requested. 1062 # This is because of the lack of mechanism to order NCCL operations 1063 # deterministically. 1064 self._limited_nccl = False 1065 for device in self._devices: 1066 launcher = cross_device_utils.CollectiveReplicaLauncher( 1067 group_key, group_size, self._collective_keys, device) 1068 self._launchers.append(launcher) 1069 if not launcher.can_order_nccl(): 1070 self._limited_nccl = True 1071 1072 self._pool = multiprocessing.pool.ThreadPool(len(self._devices)) 1073 1074 super(CollectiveAllReduce, self).__init__() 1075 1076 @property 1077 def _num_between_graph_workers(self): 1078 # Currently we only support equal number of devices on each worker. 1079 return self._group_size / len(self._devices) 1080 1081 def _all_reduce(self, reduce_op, value, replica_id, options): 1082 """Implements CrossDeviceOps.all_reduce.""" 1083 # TODO(b/122840926): reuse this method in _batch_all_reduce. 1084 flat_values = nest.flatten(value) 1085 1086 if isinstance(flat_values[0], ops.IndexedSlices): 1087 raise NotImplementedError("all_reduce doesn't support IndexedSlices.") 1088 1089 batch_size = len(flat_values) 1090 1091 implementation = options.implementation.value 1092 # If NCCL launches can't be ordered (self._limited_nccl == True), we only 1093 # use NCCL only when batch_size > 1, hoping that there's only one batched 1094 # all-reduce, which is the gradients. 1095 if (self._limited_nccl and 1096 options.implementation == CommunicationImplementation.NCCL and 1097 batch_size == 1): 1098 implementation = CommunicationImplementation.AUTO.value 1099 1100 # Reverse the lists so that there's better chance that values follows 1101 # the order in which they are calculated (e.g. when they're gradients), so 1102 # as to overlap calculation with communication. However, this may not be 1103 # optimal for cases like gradients of complicated non-sequential models. 1104 # 1105 # Note that we reverse the list before packing so that the first pack won't 1106 # be too small, since it's more likely for first few packs to have long 1107 # queuing time due to concurrent intense computation. 1108 # 1109 # TODO(b/147393503): explore solutions for optimal ordering. 1110 flat_values.reverse() 1111 packs = cross_device_utils.group_by_size(flat_values, 1112 options.bytes_per_pack) 1113 1114 launcher = self._launchers[replica_id] 1115 if not context.executing_eagerly() and replica_id == 0: 1116 logging.info( 1117 "Collective all_reduce: %d all-reduces, num_devices = %d, " 1118 "group_size = %d, implementation = %s, num_packs = %d", batch_size, 1119 len(self._launchers), self._group_size, implementation, len(packs)) 1120 flat_results = launcher.batch_all_reduce(packs, implementation, 1121 options.timeout_seconds) 1122 1123 if reduce_op == reduce_util.ReduceOp.MEAN: 1124 for i, v in enumerate(flat_results): 1125 flat_results[i] = v / self._group_size 1126 flat_results.reverse() 1127 1128 return nest.pack_sequence_as(value, flat_results) 1129 1130 def reduce_implementation(self, reduce_op, per_replica_value, destinations, 1131 options): 1132 values_util.mark_as_unsaveable() 1133 all_reduced = self._batch_all_reduce(reduce_op, [per_replica_value], 1134 options)[0] 1135 devices = get_devices_from(destinations) 1136 1137 if _devices_match(per_replica_value, destinations): 1138 return all_reduced 1139 1140 # Convert `all_reduced` to a `Mirrored` object, as a simple and uniform 1141 # utility to access component for a particular device. 1142 if not isinstance(all_reduced, value_lib.Mirrored): 1143 all_reduced = value_lib.Mirrored([all_reduced]) 1144 1145 # If we got this far, the destination devices do not match the all-reduce 1146 # devices, so we must map from one to the other. 1147 index = [] 1148 # We must add these control dependencies, otherwise we can get deadlock. 1149 with ops.control_dependencies(all_reduced.values): 1150 for d in devices: 1151 with ops.device(d): 1152 for v in all_reduced.values: 1153 if v.device == d: 1154 index.append(array_ops.identity(v)) 1155 break 1156 else: 1157 # TODO(josh11b): Once we add support for model parallelism, get the 1158 # copy from the corresponding replica instead of the primary. 1159 index.append(array_ops.identity(all_reduced._primary)) # pylint: disable=protected-access 1160 return distribute_utils.regroup(index, wrap_class=value_lib.Mirrored) 1161 1162 def batch_reduce_implementation(self, reduce_op, value_destination_pairs, 1163 options): 1164 values_util.mark_as_unsaveable() 1165 all_devices_match = _all_devices_match(value_destination_pairs) 1166 if all_devices_match: 1167 return self._batch_all_reduce(reduce_op, 1168 [v[0] for v in value_destination_pairs], 1169 options) 1170 else: 1171 if not all_devices_match: 1172 logging.log_first_n( 1173 logging.WARN, "Efficient batch_reduce is not supported if " 1174 "destinations are different.", 10) 1175 1176 return [ 1177 self.reduce_implementation(reduce_op, value, dest, options) 1178 for value, dest in value_destination_pairs 1179 ] 1180 1181 def _batch_all_reduce(self, reduce_op, per_replica_values, options): 1182 """All reduce algorithm in a batch.""" 1183 dense_values, dense_indices, sparse_values, sparse_indices = ( 1184 cross_device_utils.split_by_sparsity(per_replica_values)) 1185 if dense_values: 1186 dense_results = self._do_batch_all_reduce_dense(reduce_op, dense_values, 1187 options) 1188 else: 1189 dense_results = [] 1190 if sparse_values: 1191 sparse_results = self._do_batch_all_reduce_sparse(reduce_op, 1192 sparse_values, options) 1193 else: 1194 sparse_results = [] 1195 return cross_device_utils.stitch_values( 1196 ((dense_results, dense_indices), (sparse_results, sparse_indices))) 1197 1198 def _do_batch_all_reduce_dense(self, reduce_op, per_replica_values, options): 1199 """All-reduce across all workers in a batch.""" 1200 1201 batch_size = len(per_replica_values) 1202 implementation = options.implementation.value 1203 # For now, we use NCCL only when batch_size > 1 since we don't have a way to 1204 # order NCCL launches. We're hoping that there's only one batched 1205 # all-reduce, which is the gradients. 1206 # TODO(b/132575814): switch to NCCL for all collectives when communication 1207 # is NCCL if and only if we can order collectives deterministically. 1208 if (self._limited_nccl and 1209 options.implementation == CommunicationImplementation.NCCL and 1210 batch_size == 1): 1211 implementation = CommunicationImplementation.AUTO.value 1212 1213 # Reverse the lists so that there's better chance that values follows 1214 # the order in which they are calculated (e.g. when they're gradients), so 1215 # as to overlap calculation with communication. However, this may not be 1216 # optimal for cases like gradients of complicated non-sequential models. 1217 # 1218 # Note that we reverse the list before packing so that the first pack won't 1219 # be too small, since it's more likely for first few packs to have long 1220 # queuing time due to concurrent intense computation. 1221 # 1222 # TODO(b/147393503): explore solutions for optimal ordering. 1223 values_by_device = [[] for _ in range(len(self._devices))] 1224 for per_replica in reversed(per_replica_values): 1225 for i in range(len(self._devices)): 1226 values_by_device[i].append(per_replica.values[i]) 1227 1228 if context.executing_eagerly(): 1229 def thread_fn(device_id): 1230 with context.eager_mode(): 1231 packs = cross_device_utils.group_by_size(values_by_device[device_id], 1232 options.bytes_per_pack) 1233 return self._launchers[device_id].batch_all_reduce( 1234 packs, implementation, options.timeout_seconds) 1235 1236 num_devices = len(self._devices) 1237 with self._lock: 1238 outputs_by_device = self._pool.map(thread_fn, list(range(num_devices))) 1239 else: 1240 outputs_by_device = [] 1241 with self._lock: 1242 for i in range(len(self._devices)): 1243 packs = cross_device_utils.group_by_size( 1244 values_by_device[i], options.bytes_per_pack) 1245 if i == 0: 1246 logging.info( 1247 "Collective batch_all_reduce: %d all-reduces, num_devices = %d," 1248 " group_size = %d, implementation = %s, num_packs = %d", 1249 batch_size, len(self._launchers), self._group_size, 1250 implementation, len(packs)) 1251 outputs_by_device.append(self._launchers[i].batch_all_reduce( 1252 packs, implementation, options.timeout_seconds)) 1253 1254 mirrored = [] 1255 for values in zip(*outputs_by_device): 1256 if reduce_op == reduce_util.ReduceOp.MEAN: 1257 values = list(values) 1258 for i, v in enumerate(values): 1259 with ops.device(v.device): 1260 values[i] = v / self._group_size 1261 mirrored.append( 1262 distribute_utils.regroup(values, wrap_class=value_lib.Mirrored)) 1263 # Reverse the order of reduced value to recover the order in the input. 1264 return list(reversed(mirrored)) 1265 1266 def _do_batch_all_reduce_sparse(self, reduce_op, per_replica_values, options): 1267 """All-reduce IndexedSlices across all workers in a batch.""" 1268 1269 logging.log_first_n( 1270 logging.INFO, "Collective batch_all_reduce for IndexedSlices: " 1271 "%d all-reduces, group_size = %d" % 1272 (len(per_replica_values), self._group_size), 10) 1273 1274 implementation = options.implementation.value 1275 # For now, we use NCCL only when batch_size > 1. 1276 # TODO(b/132575814): switch to NCCL for all collectives when implementation 1277 # is NCCL. 1278 if (self._limited_nccl and 1279 options.implementation == CommunicationImplementation.NCCL and 1280 len(per_replica_values) == 1): 1281 implementation = CommunicationImplementation.AUTO.value 1282 1283 gathered_values = [] 1284 with self._lock: 1285 for per_replica in per_replica_values: 1286 outputs = [] 1287 for i in range(len(self._devices)): 1288 outputs.append(self._launchers[i].all_reduce_indexed_slices( 1289 per_replica.values[i], implementation, options.timeout_seconds)) 1290 gathered_values.append(outputs) 1291 1292 mirrored = [] 1293 for value in gathered_values: 1294 if reduce_op == reduce_util.ReduceOp.MEAN: 1295 # Assume each worker has the same number of replicas. 1296 for i, v in enumerate(value): 1297 with ops.device(v.device): 1298 value[i].values = value[i].values / self._group_size 1299 mirrored.append( 1300 distribute_utils.regroup(value, wrap_class=value_lib.Mirrored)) 1301 return mirrored 1302 1303 def _gather_implementation(self, per_replica_value, destinations, axis, 1304 options): 1305 all_gathered = self._batch_all_gather([per_replica_value], axis, options)[0] 1306 values_util.mark_as_unsaveable() 1307 devices = get_devices_from(destinations) 1308 1309 if _devices_match(per_replica_value, destinations): 1310 return all_gathered 1311 1312 # Convert `all_gathered` to a `Mirrored` object, as a simple and uniform 1313 # utility to access component for a particular device. 1314 if not isinstance(all_gathered, value_lib.Mirrored): 1315 all_gathered = value_lib.Mirrored([all_gathered]) 1316 1317 # If we got this far, the destination devices do not match the all-gather 1318 # devices, so we must map from one to the other. 1319 index = [] 1320 # We must add these control dependencies, otherwise we can get deadlock. 1321 with ops.control_dependencies(all_gathered.values): 1322 for d in devices: 1323 with ops.device(d): 1324 for v in all_gathered.values: 1325 if v.device == d: 1326 index.append(array_ops.identity(v)) 1327 break 1328 else: 1329 index.append(array_ops.identity(all_gathered._primary)) # pylint: disable=protected-access 1330 return distribute_utils.regroup(index, wrap_class=value_lib.Mirrored) 1331 1332 def _batch_all_gather(self, per_replica_values, axis, options): 1333 """all gather multiple per-replica-values.""" 1334 batch_size = len(per_replica_values) 1335 # Pass options.implementation to the runtime as a communication 1336 # implementation hint. 1337 implementation = options.implementation.value 1338 # For now, we use NCCL only when batch_size > 1. 1339 # TODO(b/132575814): switch to NCCL for all collectives when implementation 1340 # is NCCL. 1341 if (options.implementation == CommunicationImplementation.NCCL and 1342 batch_size == 1): 1343 implementation = CommunicationImplementation.AUTO.value 1344 1345 logging.log_first_n( 1346 logging.INFO, "Collective batch_all_gather: %d all-gathers, " 1347 "num_devices = %d, group_size = %d, implementation = %s, " % 1348 (batch_size, len(self._devices), self._group_size, implementation), 10) 1349 1350 def compute_gathered_values(): 1351 gathered_values = [] 1352 with self._lock, ops.name_scope("allgather"): 1353 for per_replica in per_replica_values: 1354 outputs = [] 1355 for i in range(len(self._devices)): 1356 outputs.append(self._launchers[i].all_gather( 1357 per_replica.values[i], axis, implementation, 1358 options.timeout_seconds)) 1359 gathered_values.append(outputs) 1360 return gathered_values 1361 1362 if context.executing_eagerly(): 1363 gathered_values = def_function.function(compute_gathered_values)() 1364 else: 1365 gathered_values = compute_gathered_values() 1366 1367 mirrored = [] 1368 for value in gathered_values: 1369 mirrored.append( 1370 distribute_utils.regroup(value, wrap_class=value_lib.Mirrored)) 1371 return mirrored 1372 1373 def __deepcopy__(self, memo): 1374 # distribute_coordinator deep-copies the strategy object, so 1375 # CollectiveAllReduce needs to support deep copy as well. 1376 collective_keys = copy.deepcopy(self._collective_keys, memo) 1377 return CollectiveAllReduce(self._devices, self._group_size, collective_keys) 1378 1379 1380def select_cross_device_ops(devices, session_config=None): 1381 """Find the best `CrossDeviceOps` locally given a `tf.compat.v1.ConfigProto`. 1382 1383 Args: 1384 devices: a list of devices passed to `tf.distribute.Strategy`. 1385 session_config: a `tf.compat.v1.ConfigProto` or `None`. If `None`, it will 1386 make decision based on all logical devices. 1387 1388 Returns: 1389 A subclass of `CrossDeviceOps`. 1390 """ 1391 requested_devices = set(device_util.canonicalize(d) for d in devices) 1392 if ops.executing_eagerly_outside_functions(): 1393 logical_gpus = context.context().list_logical_devices(device_type="GPU") 1394 physical_gpus = context.context().list_physical_devices(device_type="GPU") 1395 if len(logical_gpus) != len(physical_gpus): 1396 logging.warning("NCCL is not supported when using virtual GPUs, falling" 1397 "back to reduction to one device") 1398 return ReductionToOneDevice() 1399 1400 machine_devices = context.context().list_logical_devices() 1401 else: 1402 machine_devices = device_lib.list_local_devices( 1403 session_config=session_config) 1404 using_devices = set() 1405 for d in machine_devices: 1406 if device_util.canonicalize(d.name) in requested_devices: 1407 using_devices.add(d.name) 1408 1409 if len(using_devices) != len(requested_devices): 1410 logging.warning( 1411 "Some requested devices in `tf.distribute.Strategy` are not visible " 1412 "to TensorFlow: %s", ",".join(list(requested_devices - using_devices))) 1413 1414 if any("gpu" not in d.lower() for d in requested_devices): 1415 logging.warning("There are non-GPU devices in `tf.distribute.Strategy`, " 1416 "not using nccl allreduce.") 1417 return ReductionToOneDevice() 1418 1419 if kernels.get_registered_kernels_for_op("NcclAllReduce"): 1420 return NcclAllReduce(num_packs=1) 1421 else: 1422 logging.warning("Nccl kernel is not found, not using nccl allreduce.") 1423 return ReductionToOneDevice() 1424