1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Classes for different algorithms of reduction and broadcasting.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import copy 23import multiprocessing.dummy 24import multiprocessing.pool 25import threading 26 27import six 28 29from tensorflow.python.client import device_lib 30from tensorflow.python.distribute import collective_util 31from tensorflow.python.distribute import cross_device_utils 32from tensorflow.python.distribute import device_util 33from tensorflow.python.distribute import distribute_utils 34from tensorflow.python.distribute import ps_values 35from tensorflow.python.distribute import reduce_util 36from tensorflow.python.distribute import tpu_values 37from tensorflow.python.distribute import values as value_lib 38from tensorflow.python.distribute import values_util 39from tensorflow.python.eager import context 40from tensorflow.python.eager import def_function 41from tensorflow.python.framework import kernels 42from tensorflow.python.framework import ops 43from tensorflow.python.framework import tensor_util 44from tensorflow.python.ops import array_ops 45from tensorflow.python.ops import math_ops 46from tensorflow.python.ops import resource_variable_ops 47from tensorflow.python.platform import tf_logging as logging 48from tensorflow.python.util import nest 49from tensorflow.python.util.tf_export import tf_export 50from tensorflow.tools.docs import doc_controls 51 52 53def check_destinations(destinations): 54 """Checks whether `destinations` is not empty. 55 56 Args: 57 destinations: a `DistributedValues`, variable, or string object. 58 59 Returns: 60 Boolean which is True if `destinations` is not empty. 61 """ 62 # Calling bool() on a ResourceVariable is not allowed. 63 if isinstance(destinations, 64 (resource_variable_ops.BaseResourceVariable, ops.Tensor)): 65 return bool(destinations.device) 66 return bool(destinations) 67 68 69def validate_destinations(destinations): 70 """Validates the `destination` is one of expected types.""" 71 if not isinstance( 72 destinations, 73 (value_lib.DistributedValues, ops.Tensor, ops.IndexedSlices, 74 ps_values.AggregatingVariable, six.string_types, 75 tpu_values.TPUMirroredVariable 76 )) and not resource_variable_ops.is_resource_variable(destinations): 77 raise ValueError("destinations must be one of a `DistributedValues` object," 78 " a tf.Variable object, or a device string.") 79 80 if not check_destinations(destinations): 81 raise ValueError("destinations can not be empty") 82 83 84def reduce_non_distributed_value(reduce_op, 85 value, 86 destinations, 87 num_replicas_in_graph, 88 canonicalize_devices=True): 89 """Reduce a non-DistributedValue `value` to `destinations`.""" 90 if isinstance(value, value_lib.DistributedValues): 91 raise ValueError("You are passing a `DistributedValues` to " 92 "`reduce_non_distributed_value`, which is not allowed.") 93 94 # If the same value is present on all replicas then the PerReplica value will 95 # be a single value. We also handle the case when `value` is a single value 96 # and equal to 0. 97 # TODO:(b/138823479): handle the tensor value properly. 98 if not tensor_util.is_tf_type(value) and value == 0: 99 return 0 100 # If there is only a single value and the reduce op is MEAN, 101 # that value should be on all destinations. 102 if reduce_op == reduce_util.ReduceOp.MEAN: 103 return value 104 elif num_replicas_in_graph != 1: 105 # We do not support a reduce op of SUM if the value is the same across 106 # all replicas. We call this as part of assign functions for 107 # MirroredVariables and summing up identical values across replicas is not 108 # clearly defined. 109 raise ValueError("A non-DistributedValues value %s cannot be reduced with " 110 "the given reduce op %s." % (value, reduce_op)) 111 else: 112 validate_destinations(destinations) 113 return simple_broadcast( 114 value, destinations, canonicalize_devices=canonicalize_devices) 115 116 117def _make_tensor_into_per_replica(input_tensor): 118 """Converts a single tensor into a PerReplica object.""" 119 if isinstance(input_tensor, (tuple, list)): 120 raise ValueError("Cannot convert `input_tensor` to a `PerReplica` object, " 121 "got %r but expected a object that is not a tuple or list." 122 % (input_tensor,)) 123 if isinstance(input_tensor, value_lib.PerReplica): 124 return input_tensor 125 elif hasattr(input_tensor, "device"): 126 return value_lib.PerReplica((input_tensor,)) 127 else: 128 raise ValueError("Cannot convert `input_tensor` to a `PerReplica` object " 129 "because it doesn't have device set.") 130 131 132def _normalize_value_destination_pairs(value_destination_pairs): 133 """Converts each tensor into a PerReplica object in the input list.""" 134 result = [] 135 136 value_destination_pairs = list(value_destination_pairs) 137 138 if not isinstance(value_destination_pairs, (list, tuple)): 139 raise ValueError("`value_destination_pairs` should be a list or tuple") 140 for pair in value_destination_pairs: 141 if not isinstance(pair, tuple): 142 raise ValueError( 143 "Each element of `value_destination_pairs` should be a tuple.") 144 if len(pair) != 2: 145 raise ValueError("Each element of `value_destination_pairs` should be a " 146 "tuple of size 2.") 147 148 per_replica = _make_tensor_into_per_replica(pair[0]) 149 result.append((per_replica, pair[1])) 150 return result 151 152 153def _validate_value_destination_pairs(value_destination_pairs): 154 """Validates value_destination_pairs are valid.""" 155 # TODO(yuefengz): raise exceptions instead of returning False. 156 if not value_destination_pairs: return False 157 if not isinstance(value_destination_pairs, (list, tuple)): return False 158 if not all(isinstance(pair, tuple) for pair in value_destination_pairs): 159 return False 160 if not all(isinstance(v[0], value_lib.PerReplica) 161 for v in value_destination_pairs): 162 return False 163 return True 164 165 166# TODO(yuefengz): consider calling this function in the caller of 167# CrossDeviceOps. 168def get_devices_from(destinations, canonicalize_devices=True): 169 if isinstance(destinations, value_lib.DistributedValues): 170 return destinations._devices # pylint: disable=protected-access 171 if canonicalize_devices: 172 if isinstance(destinations, six.string_types): 173 return (device_util.resolve(destinations),) 174 return (device_util.resolve(destinations.device),) 175 176 # Let placer canonicalize and resolve destination devices. 177 if isinstance(destinations, six.string_types): 178 return (device_util.canonicalize_without_job_and_task(destinations),) 179 return (device_util.canonicalize_without_job_and_task(destinations.device),) 180 181 182def _devices_match(left, right, canonicalize_devices=True): 183 return left is right or set(get_devices_from( 184 left, canonicalize_devices)) == set( 185 get_devices_from(right, canonicalize_devices)) 186 187 188def _all_devices_match(value_destination_pairs, canonicalize_devices=True): 189 if not all( 190 _devices_match(v, d, canonicalize_devices) 191 for v, d in value_destination_pairs): 192 return False 193 if not all( 194 _devices_match(v, value_destination_pairs[0][0], canonicalize_devices) 195 for v, _ in value_destination_pairs[1:]): 196 return False 197 return True 198 199 200def simple_broadcast(value, 201 destinations, 202 always_mirrored=False, 203 canonicalize_devices=True): 204 """Broadcast `value` to `destinations` using simple copies.""" 205 devices = get_devices_from(destinations, canonicalize_devices) 206 if len(devices) == 1 and not always_mirrored: 207 return cross_device_utils.copy_tensor_or_indexed_slices_to_device( 208 value, devices[0]) 209 else: 210 value_updates = [] 211 for d in devices: 212 value_updates.append( 213 cross_device_utils.copy_tensor_or_indexed_slices_to_device(value, d)) 214 return distribute_utils.regroup(value_updates, 215 wrap_class=value_lib.Mirrored) 216 217 218def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn, 219 reduce_op): 220 """Reduces the value by accumulation_fn and reduce_op.""" 221 all_values = per_replica_value.values 222 if not all_values: 223 raise ValueError("`per_replica_value` must be non-empty") 224 count = len(all_values) 225 226 with ops.device(reduce_to_device): 227 with context.device_policy(context.DEVICE_PLACEMENT_SILENT): 228 reduced = cross_device_utils.aggregate_tensors_or_indexed_slices( 229 all_values, accumulation_fn) 230 if reduce_op == reduce_util.ReduceOp.MEAN: 231 reduced = cross_device_utils.divide_by_n_tensors_or_indexed_slices( 232 reduced, count) 233 elif reduce_op != reduce_util.ReduceOp.SUM: 234 raise ValueError("`reduce_op` must be Reduce.SUM or Reduce.MEAN.") 235 return reduced 236 237 238def _simple_gather(per_replica_value, reduce_to_device, axis): 239 """Concatenate all values in the DistributedValues input and return.""" 240 all_values = per_replica_value.values 241 if not all_values: 242 raise ValueError("`per_replica_value` must be non-empty") 243 244 with ops.device(reduce_to_device): 245 with context.device_policy(context.DEVICE_PLACEMENT_SILENT): 246 gathered = array_ops.concat(all_values, axis) 247 return gathered 248 249 250@tf_export("distribute.CrossDeviceOps") 251class CrossDeviceOps(object): 252 """Base class for cross-device reduction and broadcasting algorithms. 253 254 The main purpose of this class is to be passed to 255 `tf.distribute.MirroredStrategy` in order to choose among different cross 256 device communication implementations. Prefer using the methods of 257 `tf.distribute.Strategy` instead of the ones of this class. 258 259 Implementations: 260 * `tf.distribute.ReductionToOneDevice` 261 * `tf.distribute.NcclAllReduce` 262 * `tf.distribute.HierarchicalCopyAllReduce` 263 """ 264 265 def __init__(self): 266 self._canonicalize_devices = True 267 pass 268 269 @property 270 def _num_between_graph_workers(self): 271 # Returns 1 by default, the value may be overridden by sub classes. 272 return 1 273 274 def reduce(self, reduce_op, per_replica_value, destinations, options=None): 275 """Reduce `per_replica_value` to `destinations`. 276 277 See `tf.distribute.StrategyExtended.reduce_to`. This can only be called in 278 the cross-replica context. 279 280 Args: 281 reduce_op: a `tf.distribute.ReduceOp` specifying how values should be 282 combined. 283 per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` 284 like object. 285 destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a 286 `tf.Tensor` alike object, or a device string. It specifies the devices 287 to reduce to. To perform an all-reduce, pass the same to `value` and 288 `destinations`. Note that if it's a `tf.Variable`, the value is reduced 289 to the devices of that variable, and this method doesn't update the 290 variable. 291 options: a `tf.distribute.experimental.CommunicationOptions`. See 292 `tf.distribute.experimental.CommunicationOptions` for details. 293 294 Returns: 295 A `tf.Tensor` or `tf.distribute.DistributedValues`. 296 297 Raises: 298 ValueError: if per_replica_value can't be converted to a 299 `tf.distribute.DistributedValues` or if destinations is not a string, 300 `tf.Variable` or `tf.distribute.DistributedValues`. 301 """ 302 if options is None: 303 options = collective_util.Options() 304 if not isinstance(per_replica_value, value_lib.DistributedValues): 305 per_replica_value = _make_tensor_into_per_replica(per_replica_value) 306 307 validate_destinations(destinations) 308 309 # Shortcut if `per_replica_value` only contains one value. 310 if self._num_between_graph_workers == 1 and len( 311 per_replica_value.values) == 1 and _devices_match( 312 per_replica_value, destinations, self._canonicalize_devices): 313 with ops.device(per_replica_value.values[0].device): 314 v = array_ops.identity(per_replica_value.values[0]) 315 return distribute_utils.regroup((v,), wrap_class=value_lib.Mirrored) 316 317 if options is None: 318 options = collective_util.Options() 319 return self.reduce_implementation(reduce_op, per_replica_value, 320 destinations, options) 321 322 def _gather(self, per_replica_value, destinations, axis, options=None): 323 """Gather `per_replica_value` to `destinations`. 324 325 Args: 326 per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` 327 like object. 328 destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a 329 `tf.Tensor` alike object, or a device string. It specifies the devices 330 to gather to. To perform an all-gather, pass the same to `value` and 331 `destinations`. Note that if it's a `tf.Variable`, the value is gathered 332 to the devices of that variable, and this method doesn't update the 333 variable. 334 axis: specifies the dimension to gather along within each replica's 335 tensor. 336 options: a `tf.distribute.experimental.CommunicationOptions`. See 337 `tf.distribute.experimental.CommunicationOptions` for details. 338 339 Returns: 340 A `tf.Tensor` or `tf.distribute.DistributedValues` 341 342 Raises: 343 ValueError: if per_replica_value can't be converted to a 344 `tf.distribute.DistributedValues` or if destinations is not a string, 345 `tf.Variable` or `tf.distribute.DistributedValues`. 346 """ 347 if isinstance(per_replica_value, ops.IndexedSlices): 348 raise NotImplementedError("gather/all_gather does not support " 349 "IndexedSlices") 350 if options is None: 351 options = collective_util.Options() 352 353 if not isinstance(per_replica_value, value_lib.DistributedValues): 354 per_replica_value = _make_tensor_into_per_replica(per_replica_value) 355 356 validate_destinations(destinations) 357 358 # Shortcut if `per_replica_value` only contains one value. 359 if self._num_between_graph_workers == 1 and len( 360 per_replica_value.values) == 1 and _devices_match( 361 per_replica_value, destinations, self._canonicalize_devices): 362 with ops.device(per_replica_value.values[0].device): 363 v = array_ops.identity(per_replica_value.values[0]) 364 return distribute_utils.regroup((v,), wrap_class=value_lib.Mirrored) 365 366 return self._gather_implementation(per_replica_value, destinations, axis, 367 options) 368 369 def _gather_implementation(self, per_replica_value, destinations, axis, 370 options): 371 """Implementation of `gather` method of `tf.distribute.CrossDeviceOps`. 372 373 Overriding this method is useful for subclass implementers. 374 375 Args: 376 per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` 377 like object. 378 destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a 379 `tf.Tensor` alike object, or a device string. It specifies the devices 380 to gather to. To perform an all-gather, pass the same to `value` and 381 `destinations`. Note that if it's a `tf.Variable`, the value is gathered 382 to the devices of that variable, this method doesn't update the 383 variable. 384 axis: specifies the dimension to gather along within each replica's 385 tensor. 386 options: a `tf.distribute.experimental.CommunicationOptions`. See 387 `tf.distribute.experimental.CommunicationOptions` for details. 388 389 Returns: 390 A `tf.Tensor` or `tf.distribute.DistributedValues`. 391 392 Raises: 393 ValueError: if per_replica_value can't be converted to a 394 `tf.distribute.DistributedValues` or if destinations is not a string, 395 `tf.Variable` or `tf.distribute.DistributedValues`. 396 """ 397 raise NotImplementedError( 398 "_gather method must be implemented in descendants.") 399 400 def batch_reduce(self, reduce_op, value_destination_pairs, options=None): 401 """Reduce values to destinations in batches. 402 403 See `tf.distribute.StrategyExtended.batch_reduce_to`. This can only be 404 called in the cross-replica context. 405 406 Args: 407 reduce_op: a `tf.distribute.ReduceOp` specifying how values should be 408 combined. 409 value_destination_pairs: a sequence of (value, destinations) pairs. See 410 `tf.distribute.CrossDeviceOps.reduce` for descriptions. 411 options: a `tf.distribute.experimental.CommunicationOptions`. See 412 `tf.distribute.experimental.CommunicationOptions` for details. 413 414 Returns: 415 A list of `tf.Tensor` or `tf.distribute.DistributedValues`, one per pair 416 in `value_destination_pairs`. 417 418 Raises: 419 ValueError: if `value_destination_pairs` is not an iterable of 420 tuples of `tf.distribute.DistributedValues` and destinations. 421 """ 422 if options is None: 423 options = collective_util.Options() 424 # TODO(yuefengz): if destinations are different, split into several 425 # `_batch_reduce` invocations. 426 if not _validate_value_destination_pairs(value_destination_pairs): 427 # If the first element of each pair is a tensor, we try to turn it into a 428 # PerReplica object. 429 value_destination_pairs = _normalize_value_destination_pairs( 430 value_destination_pairs) 431 432 for _, d in value_destination_pairs: 433 validate_destinations(d) 434 435 # Shortcut all PerReplica objects only contain one value. 436 if self._num_between_graph_workers == 1 and _all_devices_match( 437 value_destination_pairs, self._canonicalize_devices) and len( 438 value_destination_pairs[0][0].values) == 1: 439 return [ 440 distribute_utils.regroup(v.values, wrap_class=value_lib.Mirrored) 441 for v, _ in value_destination_pairs 442 ] 443 444 if options is None: 445 options = collective_util.Options() 446 return self.batch_reduce_implementation(reduce_op, value_destination_pairs, 447 options) 448 449 def broadcast(self, tensor, destinations): 450 """Broadcast `tensor` to `destinations`. 451 452 This can only be called in the cross-replica context. 453 454 Args: 455 tensor: a `tf.Tensor` like object. The value to broadcast. 456 destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a 457 `tf.Tensor` alike object, or a device string. It specifies the devices 458 to broadcast to. Note that if it's a `tf.Variable`, the value is 459 broadcasted to the devices of that variable, this method doesn't update 460 the variable. 461 462 Returns: 463 A `tf.Tensor` or `tf.distribute.DistributedValues`. 464 """ 465 validate_destinations(destinations) 466 return self.broadcast_implementation(tensor, destinations) 467 468 @doc_controls.for_subclass_implementers 469 def reduce_implementation(self, reduce_op, per_replica_value, destinations, 470 options): 471 """Implementation of `reduce`. 472 473 Overriding this method is useful for subclass implementers. 474 475 Args: 476 reduce_op: a `tf.distribute.ReduceOp` specifying how values should be 477 combined. 478 per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` 479 like object. 480 destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a 481 `tf.Tensor` alike object, or a device string. It specifies the devices 482 to reduce to. To perform an all-reduce, pass the same to `value` and 483 `destinations`. Note that if it's a `tf.Variable`, the value is reduced 484 to the devices of that variable, this method doesn't update the 485 variable. 486 options: a `tf.distribute.experimental.CommunicationOptions`. See 487 `tf.distribute.experimental.CommunicationOptions` for details. 488 489 Returns: 490 A `tf.Tensor` or `tf.distribute.DistributedValues`. 491 492 Raises: 493 ValueError: if per_replica_value can't be converted to a 494 `tf.distribute.DistributedValues` or if destinations is not a string, 495 `tf.Variable` or `tf.distribute.DistributedValues`. 496 """ 497 raise NotImplementedError( 498 "_reduce method must be implemented in descendants.") 499 500 @doc_controls.for_subclass_implementers 501 def batch_reduce_implementation(self, reduce_op, value_destination_pairs, 502 options): 503 """Implementation of `batch_reduce`. 504 505 Overriding this method is useful for subclass implementers. 506 507 Args: 508 reduce_op: a `tf.distribute.ReduceOp` specifying how values should be 509 combined. 510 value_destination_pairs: a sequence of (value, destinations) pairs. See 511 `reduce` for descriptions. 512 options: a `tf.distribute.experimental.CommunicationOptions`. See 513 `tf.distribute.experimental.CommunicationOptions` for details. 514 515 Returns: 516 A list of `tf.Tensor` or `tf.distribute.DistributedValues`, one per pair 517 in `value_destination_pairs`. 518 519 Raises: 520 ValueError: if `value_destination_pairs` is not an iterable of 521 tuples of `tf.distribute.DistributedValues` and destinations. 522 """ 523 raise NotImplementedError( 524 "batch_reduce_implementation method must be implemented in descendants." 525 ) 526 527 @doc_controls.for_subclass_implementers 528 def broadcast_implementation(self, tensor, destinations): 529 """Implementation of `broadcast`. 530 531 Args: 532 tensor: a `tf.Tensor` like object. The value to broadcast. 533 destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a 534 `tf.Tensor` alike object, or a device string. It specifies the devices 535 to broadcast to. 536 `destinations`. Note that if it's a `tf.Variable`, the value is 537 broadcasted to the devices of that variable, this method doesn't update 538 the variable. 539 540 Returns: 541 A `tf.Tensor` or `tf.distribute.DistributedValues`. 542 """ 543 return simple_broadcast( 544 tensor, 545 destinations, 546 always_mirrored=True, 547 canonicalize_devices=self._canonicalize_devices) 548 549 # ========================== Collective APIs ================================ 550 # 551 # Different than `reduce`, `batch_reduce` and `broadcast` which must be called 552 # in cross-replcia context, collective APIs are to be called in replica 553 # context. 554 555 def _all_reduce(self, reduce_op, value, replica_id, options): 556 """All-reduce the `value` across all replicas so that all get the result. 557 558 `value` can be a nested structure of tensors or `IndexedSlices`. The 559 implementation should generally batch the all-reduces when possible. 560 `options` can be set to hint the batching behavior. 561 562 This API must be called in a replica context. 563 564 Args: 565 reduce_op: A `tf.distribute.ReduceOp` value specifying how values should 566 be combined. 567 value: Value to be reduced. A tensor or a nested structure of tensors or 568 `IndexedSlices`. 569 replica_id: An interger indicating the id of the replica where this 570 all_reduce is called under. This is the local replica id that ranges 571 from 0 to len(local_devices) - 1. 572 options: A `tf.distribute.experimental.CommunicationOptions`. 573 574 Returns: 575 A tensor/IndexedSlices or a nested strucutre of tensors/IndexedSlices with 576 the reduced values. The structure is the same as `value`. 577 """ 578 raise NotImplementedError("_all_reduce must be implemented in descendants.") 579 580 581@tf_export("distribute.ReductionToOneDevice") 582class ReductionToOneDevice(CrossDeviceOps): 583 """A CrossDeviceOps implementation that copies values to one device to reduce. 584 585 This implementation always copies values to one device to reduce them, then 586 broadcast reduced values to the destinations. It doesn't support efficient 587 batching. 588 589 Here is how you can use `ReductionToOneDevice` in 590 `tf.distribute.MirroredStrategy`: 591 592 ``` 593 strategy = tf.distribute.MirroredStrategy( 594 cross_device_ops=tf.distribute.ReductionToOneDevice()) 595 ``` 596 """ 597 598 def __init__(self, reduce_to_device=None, accumulation_fn=None): 599 """Initializes with a device to reduce to and a way to accumulate. 600 601 Args: 602 reduce_to_device: the intermediate device to reduce to. If None, reduce 603 to the first device in `destinations` of the `reduce` method. 604 accumulation_fn: a function that does accumulation. If None, 605 `tf.math.add_n` is used. 606 """ 607 self.reduce_to_device = reduce_to_device 608 self.accumulation_fn = accumulation_fn or math_ops.add_n 609 super(ReductionToOneDevice, self).__init__() 610 611 def reduce_implementation(self, reduce_op, per_replica_value, destinations, 612 options): 613 del options # Unused. 614 if check_destinations(destinations): 615 devices = get_devices_from(destinations, self._canonicalize_devices) 616 else: 617 devices = get_devices_from(per_replica_value, self._canonicalize_devices) 618 reduce_to_device = self.reduce_to_device or devices[0] 619 logging.log_first_n( 620 logging.INFO, 621 "Reduce to %s then broadcast to %r." % (reduce_to_device, devices), 10) 622 reduced = _simple_reduce(per_replica_value, reduce_to_device, 623 self.accumulation_fn, reduce_op) 624 return self.broadcast(reduced, destinations) 625 626 def _gather_implementation(self, per_replica_value, destinations, axis, 627 options): 628 del options # Unused. 629 if check_destinations(destinations): 630 devices = get_devices_from(destinations, self._canonicalize_devices) 631 else: 632 devices = get_devices_from(per_replica_value, self._canonicalize_devices) 633 reduce_to_device = self.reduce_to_device or devices[0] 634 logging.log_first_n( 635 logging.INFO, 636 "Gather to %s then broadcast to %r." % (reduce_to_device, devices), 10) 637 gathered = _simple_gather(per_replica_value, reduce_to_device, axis) 638 return self.broadcast(gathered, destinations) 639 640 def batch_reduce_implementation(self, reduce_op, value_destination_pairs, 641 options): 642 return [ 643 self.reduce_implementation( 644 reduce_op, t, destinations=v, options=options) 645 for t, v in value_destination_pairs 646 ] 647 648 649def _group_value_by_device(per_replica_values): 650 """Group values into sublists by their devices. 651 652 This grouping is needed to call the all-reduce library because it expects a 653 list of the following form: 654 [[(grad0_gpu0, v0_gpu0), (grad1_gpu0, v1_gpu0), (grad2_gpu0, v2_gpu0) ...], 655 [(grad0_gpu1, v0_gpu1), (grad1_gpu1, v1_gpu1), (grad2_gpu1, v2_gpu1) ...], 656 [(grad0_gpu2, v0_gpu2), (grad1_gpu0, v1_gpu2), (grad2_gpu0, v2_gpu2) ...], 657 ... 658 ] 659 660 Args: 661 per_replica_values: a list of PerReplica objects. 662 663 Returns: 664 a list of lists, each sublist has components for its corresponding device of 665 PerReplica objects, paired with a None. 666 """ 667 destinations = per_replica_values[0]._devices # pylint: disable=protected-access 668 grouped = [[] for _ in range(len(destinations))] 669 for per_replica_value in per_replica_values: 670 # pylint: disable=protected-access 671 for i, v in enumerate(per_replica_value.values): 672 assert per_replica_value._devices == destinations 673 grouped[i].append((v, None)) 674 return grouped 675 676 677def _ungroup_and_make_mirrored(grouped_reduced, 678 destinations, 679 reduce_op, 680 num_between_graph_workers=1): 681 """Ungroup results from all-reduce and make Mirrored objects. 682 683 Each all-reduce result will be divided by the number of destinations before 684 Mirrored objects are created if reduce_op is "mean". 685 686 Args: 687 grouped_reduced: a list of lists, each sublist has components for each 688 device, paired with a None. It is the result from 689 cross_device_utils.aggregate_gradients_using*. 690 destinations: a value to colocate the result with. 691 reduce_op: Indicates how values will be aggregated. Accepted values 692 are `tf.distribute.ReduceOp.SUM`, `tf.distribute.ReduceOp.MEAN`. 693 num_between_graph_workers: number of workers in the between-graph 694 replication. 695 696 Returns: 697 a list of Mirrored objects. 698 """ 699 num_replicas = len(get_devices_from(destinations)) * num_between_graph_workers 700 index = [[] for _ in range(len(grouped_reduced[0]))] 701 for per_replica_reduced in grouped_reduced: 702 for i, (v, _) in enumerate(per_replica_reduced): 703 if reduce_op == reduce_util.ReduceOp.MEAN: 704 with ops.device(v.device): 705 index[i].append(v / num_replicas) 706 else: 707 index[i].append(v) 708 return [distribute_utils.regroup( 709 v, wrap_class=value_lib.Mirrored) for v in index] 710 711 712class _ConcatAndSplitPacker(object): 713 """Concatenate and split tensors for reduction.""" 714 715 def __init__(self, num_packs=1): 716 """Initialize the _ConcatAndSplitPacker object. 717 718 Args: 719 num_packs: specifies the number of split packs that will be 720 formed. 721 722 Raises: 723 ValueError: if num_packs is not greater than 0. 724 """ 725 if num_packs <= 0: 726 raise ValueError("num_packs must be greater than zero.") 727 self.num_packs = num_packs 728 729 def pack(self, grouped_grads_and_vars): 730 """Pack tensors.""" 731 self.grouped_grads_and_vars = grouped_grads_and_vars 732 self.all_device_shapes = [] 733 self.all_device_sizes = [] 734 735 device_grad_packs = [] 736 for device_grads_and_vars in grouped_grads_and_vars: 737 with ops.colocate_with(device_grads_and_vars[0][0]): 738 # Flatten all the grads. 739 flat_grads = [ 740 array_ops.reshape(g, [-1]) for g, _ in device_grads_and_vars 741 ] 742 # Remember the original shape of all the grads. 743 device_shapes = [array_ops.shape(g) for g, _ in device_grads_and_vars] 744 # Remember the original sizes of all the grads. 745 device_sizes = [array_ops.size(g) for g, _ in device_grads_and_vars] 746 # Concat all the flat grads into a big flat tensor. 747 concat_grads = array_ops.concat(flat_grads, 0) 748 749 # Split the big tensor into num_splits packs. In cases where the 750 # total size is not divisible num_splits, the last pack gets 751 # more elements. 752 # TODO(zhengxq): it is also possible to optimize away all the concat 753 # as well. 754 num_splits = self.num_packs 755 756 # The array_ops.size function will sometimes remove static shapes. So if 757 # all gradient shapes are defined, we use another method to get the 758 # total size. 759 # TODO(yuefengz): move this logic to array_ops.size. 760 if all(g.shape.is_fully_defined() for g, _ in device_grads_and_vars): 761 total_grad_size = sum( 762 [g.shape.num_elements() for g, _ in device_grads_and_vars]) 763 else: 764 total_grad_size = array_ops.size(concat_grads) 765 766 split_size = total_grad_size // num_splits 767 split_size_last = total_grad_size - split_size * (num_splits - 1) 768 split_sizes = [split_size] * (num_splits - 1) + [split_size_last] 769 grad_packs = array_ops.split(concat_grads, split_sizes) 770 771 # Ready to aggregate the repacked gradients, with fake variables. 772 # TODO(zhengxq): It is hacky to have to use fake variables. 773 # We should remove the need for variables in 774 # aggregate_gradients_using*. 775 device_grad_packs.append(zip(grad_packs, [None] * num_splits)) 776 self.all_device_shapes.append(device_shapes) 777 self.all_device_sizes.append(device_sizes) 778 779 return device_grad_packs 780 781 def unpack(self, summed_device_grad_packs): 782 """Reverse the pack.""" 783 aggregated_device_grads = [] 784 for (summed_device_grad_packs, 785 device_grads_and_vars, device_shapes, device_sizes) in zip( 786 summed_device_grad_packs, self.grouped_grads_and_vars, 787 self.all_device_shapes, self.all_device_sizes): 788 # pylint: enable=line-too-long 789 # Reverse the packing operations in the previous steps. Form the 790 # summed gradients back into their original shapes. 791 with ops.colocate_with(summed_device_grad_packs[0][0]): 792 # Form a list of the summed grad packs. 793 device_grad_packs = [g for g, _ in summed_device_grad_packs] 794 795 # Concat them back into a big flat tensor. 796 device_grads_concat = array_ops.concat(device_grad_packs, 0) 797 798 # Split the tensors back into their original sizes. 799 grads_with_sizes = array_ops.split(device_grads_concat, device_sizes) 800 801 # Reshape the tensors back into their original shapes. 802 grads_with_shapes = [ 803 array_ops.reshape(grad, shape) 804 for shape, grad in zip(device_shapes, grads_with_sizes) 805 ] 806 807 # Form the list with the original list of variables. 808 summed_device_grads = [ 809 (g, v) for g, (_, v) in zip(grads_with_shapes, 810 device_grads_and_vars) 811 ] 812 aggregated_device_grads.append(summed_device_grads) 813 return aggregated_device_grads 814 815 816def _pack_tensors(device_grads, num_packs=0): 817 """Pack tensors if specified.""" 818 if num_packs > 0: 819 tensor_packer = _ConcatAndSplitPacker(num_packs) 820 device_grad_packs = tensor_packer.pack(device_grads) 821 else: 822 tensor_packer = None 823 device_grad_packs = device_grads 824 return device_grad_packs, tensor_packer 825 826 827def _unpack_tensors(reduced, tensor_packer=None): 828 """Unpack tensors if they are packed before all-reduce.""" 829 if tensor_packer: 830 return tensor_packer.unpack(reduced) 831 return reduced 832 833 834class AllReduceCrossDeviceOps(CrossDeviceOps): 835 """All-reduce implementation of CrossDeviceOps. 836 837 It performs all-reduce when applicable using NCCL or hierarchical copy. For 838 the batch API, tensors will be repacked or aggregated for more efficient 839 cross-device transportation. 840 841 For reduces that are not all-reduce, it falls back to 842 `tf.distribute.ReductionToOneDevice`. 843 """ 844 845 def __init__(self, all_reduce_alg="nccl", num_packs=1): 846 """Initializes the object. 847 848 Args: 849 all_reduce_alg: the all-reduce algorithm to use, currently only "nccl" or 850 "hierarchical_copy" are supported. 851 num_packs: a non-negative integer. The number of packs to split values 852 into. If zero, no packing will be done. 853 """ 854 self._all_reduce_alg = all_reduce_alg 855 self._num_packs = num_packs 856 self._simple_cross_replica_ops = ReductionToOneDevice() 857 super(AllReduceCrossDeviceOps, self).__init__() 858 859 def reduce_implementation(self, reduce_op, per_replica_value, destinations, 860 options): 861 del options # Unused. 862 # To use NCCL or all-reduce, source and destination devices should match, 863 # and none of the devices should be CPU. 864 if (_devices_match(per_replica_value, destinations) and 865 not any("cpu" in d.lower() for d in get_devices_from(destinations))): 866 return self._batch_all_reduce(reduce_op, [per_replica_value])[0] 867 else: 868 return self._simple_cross_replica_ops.reduce(reduce_op, per_replica_value, 869 destinations) 870 871 def batch_reduce_implementation(self, reduce_op, value_destination_pairs, 872 options): 873 if _all_devices_match(value_destination_pairs): 874 return self._batch_all_reduce(reduce_op, 875 [v[0] for v in value_destination_pairs]) 876 else: 877 return [ 878 self.reduce_implementation(reduce_op, value, dest, options) 879 for value, dest in value_destination_pairs 880 ] 881 882 def _batch_all_reduce(self, reduce_op, per_replica_values): 883 """All-reduce algorithm in a batch.""" 884 dense_values, dense_indices, sparse_values, sparse_indices = ( 885 cross_device_utils.split_by_sparsity(per_replica_values)) 886 if dense_values: 887 dense_results = self._do_batch_all_reduce(reduce_op, dense_values) 888 else: 889 dense_results = [] 890 if sparse_values: 891 sparse_results = self._do_batch_all_reduce_sparse(reduce_op, 892 sparse_values) 893 else: 894 sparse_results = [] 895 return cross_device_utils.stitch_values(((dense_results, dense_indices), 896 (sparse_results, sparse_indices))) 897 898 def _do_batch_all_reduce(self, reduce_op, dense_values): 899 """Run batch all-reduces.""" 900 logging.log_first_n( 901 logging.INFO, 902 "batch_all_reduce: %d all-reduces with algorithm = %s, num_packs = %d" % 903 (len(dense_values), self._all_reduce_alg, self._num_packs), 10) 904 905 destinations = dense_values[0]._devices # pylint: disable=protected-access 906 grouped = _group_value_by_device(dense_values) 907 908 # device_grad_packs: 909 # [[(t0_gpu0, None), (t1_gpu0, None)], [(t0_gpu1, None), (t1_gpu1, None)]] 910 device_grad_packs, tensor_packer = _pack_tensors(grouped, self._num_packs) 911 912 # The actual aggregation of the repacked gradients. Note that they are 913 # sharded among different aggregation trees. So it is important to strike 914 # the balance on num_splits. 915 if self._all_reduce_alg == "nccl": 916 # TODO(yuefengz): merge this into the all-reduce library. 917 reduced = cross_device_utils.aggregate_gradients_using_nccl( 918 device_grad_packs) 919 else: 920 # TODO(yuefengz): check that gpu ids in `destinations` are in ascending 921 # order. 922 reduced = ( 923 cross_device_utils.aggregate_gradients_using_hierarchical_copy( 924 destinations, device_grad_packs)) 925 926 reduced = _unpack_tensors(reduced, tensor_packer) 927 return _ungroup_and_make_mirrored(reduced, dense_values[0], reduce_op) 928 929 def _do_batch_all_reduce_sparse(self, reduce_op, sparse_values): 930 """Run batch all-reduce for sparse values.""" 931 logging.log_first_n( 932 logging.WARN, 933 "Efficient allreduce is not supported for %d IndexedSlices" % 934 len(sparse_values), 10) 935 # Use `sparse_values` as destinations to do all-reduces. It is effectively 936 # an allgather under the hood but not an efficient one. 937 return self._simple_cross_replica_ops.batch_reduce( 938 reduce_op, zip(sparse_values, sparse_values)) 939 940 def _gather_implementation(self, per_replica_value, destinations, axis, 941 options): 942 logging.warning("gather/all_gather with NCCL or HierarchicalCopy is not " 943 "supported. Falling back to gather on one device and " 944 "then broadcast. We're working on a more efficient " 945 "implementation.") 946 return ReductionToOneDevice()._gather(per_replica_value, destinations, axis, # pylint: disable=protected-access 947 options) 948 949 950# For compatibility with code using the old name of `AllReduceCrossDeviceOps`. 951AllReduceCrossTowerOps = AllReduceCrossDeviceOps 952 953 954AllReduceSpecTuple = collections.namedtuple("AllReduceSpecTuple", 955 "alg shards limit") 956 957 958@tf_export("distribute.NcclAllReduce") 959class NcclAllReduce(AllReduceCrossDeviceOps): 960 """NCCL all-reduce implementation of CrossDeviceOps. 961 962 It uses Nvidia NCCL for all-reduce. For the batch API, tensors will be 963 repacked or aggregated for more efficient cross-device transportation. 964 965 For reduces that are not all-reduce, it falls back to 966 `tf.distribute.ReductionToOneDevice`. 967 968 Here is how you can use `NcclAllReduce` in `tf.distribute.MirroredStrategy`: 969 970 971 ``` 972 strategy = tf.distribute.MirroredStrategy( 973 cross_device_ops=tf.distribute.NcclAllReduce()) 974 ``` 975 """ 976 977 def __init__(self, num_packs=1): 978 """Initializes the object. 979 980 Args: 981 num_packs: a non-negative integer. The number of packs to split values 982 into. If zero, no packing will be done. 983 984 Raises: 985 ValueError: if `num_packs` is negative. 986 """ 987 if num_packs < 0: 988 raise ValueError( 989 "NCCL all-reduce requires num_packs >= 0, but {} is specified".format( 990 num_packs)) 991 super(NcclAllReduce, self).__init__( 992 all_reduce_alg="nccl", num_packs=num_packs) 993 994 995@tf_export("distribute.HierarchicalCopyAllReduce") 996class HierarchicalCopyAllReduce(AllReduceCrossDeviceOps): 997 """Hierarchical copy all-reduce implementation of CrossDeviceOps. 998 999 It reduces to one GPU along edges in some hierarchy and broadcasts back to 1000 each GPU along the same path. For the batch API, tensors will be repacked or 1001 aggregated for more efficient cross-device transportation. 1002 1003 This is a reduction created for Nvidia DGX-1 which assumes GPUs connects like 1004 that on DGX-1 machine. If you have different GPU inter-connections, it is 1005 likely that it would be slower than `tf.distribute.ReductionToOneDevice`. 1006 1007 For reduces that are not all-reduce, it falls back to 1008 `tf.distribute.ReductionToOneDevice`. 1009 1010 Here is how you can use `HierarchicalCopyAllReduce` in 1011 `tf.distribute.MirroredStrategy`: 1012 1013 ``` 1014 strategy = tf.distribute.MirroredStrategy( 1015 cross_device_ops=tf.distribute.HierarchicalCopyAllReduce()) 1016 ``` 1017 """ 1018 1019 def __init__(self, num_packs=1): 1020 """Initializes the object. 1021 1022 Args: 1023 num_packs: a non-negative integer. The number of packs to split values 1024 into. If zero, no packing will be done. 1025 1026 Raises: 1027 ValueError if `num_packs` is negative. 1028 """ 1029 if num_packs < 0: 1030 raise ValueError( 1031 "HierarchicalCopy requires num_packs >= 0, but {} is specified" 1032 .format(num_packs)) 1033 super(HierarchicalCopyAllReduce, self).__init__( 1034 all_reduce_alg="hierarchical_copy", 1035 num_packs=num_packs) 1036 1037 1038# TODO(crccw): remove after migrating all callers. 1039CollectiveCommunication = collective_util.CommunicationImplementation 1040CommunicationImplementation = collective_util.CommunicationImplementation 1041 1042 1043# TODO(yuefengz): support in-graph collective all-reduce. 1044class CollectiveAllReduce(CrossDeviceOps): 1045 """All-reduce cross device ops using collective ops. 1046 1047 In the between-graph replicated training, it will still do all-reduces across 1048 all workers and then put results on the right destinations. 1049 """ 1050 1051 def __init__(self, 1052 devices, 1053 group_size, 1054 collective_keys=None, 1055 canonicalize_devices=True): 1056 """Initializes the object. 1057 1058 Args: 1059 devices: a list of device strings to run collectives on. 1060 group_size: the global group size. For between-graph replicated training 1061 it's the total number of devices across all workers. 1062 collective_keys: an optional CollectiveKey object. 1063 canonicalize_devices: Whether to canonicalize devices for workers or not. 1064 """ 1065 if group_size % len(devices) > 0: 1066 raise ValueError("group_size must be divisible by the number of devices.") 1067 1068 self._group_size = group_size 1069 self._collective_keys = (collective_keys or 1070 cross_device_utils.CollectiveKeys()) 1071 # This lock guards all collective launches, i.e. calls to 1072 # cross_device_utils.build_collectve_*. 1073 # 1074 # In a multi threaded eager program we need to ensure different groups of 1075 # collectives don't interleave each other, otherwise there could be 1076 # deadlocks. E.g. if two user threads both are launching collectives: 1077 # user-thread-0 device0 device1 1078 # user-thread-1 device0 device1 1079 # In eager mode, we use one thread per device to launch collective ops, so 1080 # the above launch sequences end up with the following queues: 1081 # device-0 collective-0 collective-1 1082 # device-1 collective-1 collective-0 1083 # This deadlocks since neither collective is able to finish. 1084 self._lock = threading.Lock() 1085 1086 if canonicalize_devices: 1087 self._devices = tuple(device_util.canonicalize(d) for d in devices) 1088 else: 1089 self._devices = tuple( 1090 device_util.canonicalize_without_job_and_task(d) for d in devices) 1091 group_key = self._collective_keys.get_group_key(self._devices) 1092 self._launchers = [] 1093 # Whether to only use NCCL for batched all-reduce when NCCL is requested. 1094 # This is because of the lack of mechanism to order NCCL operations 1095 # deterministically. 1096 self._limited_nccl = False 1097 for device in self._devices: 1098 launcher = cross_device_utils.CollectiveReplicaLauncher( 1099 group_key, group_size, self._collective_keys, device) 1100 self._launchers.append(launcher) 1101 if not launcher.can_order_nccl(): 1102 self._limited_nccl = True 1103 1104 self._pool = multiprocessing.pool.ThreadPool(len(self._devices)) 1105 1106 super(CollectiveAllReduce, self).__init__() 1107 self._canonicalize_devices = canonicalize_devices 1108 1109 @property 1110 def _num_between_graph_workers(self): 1111 # Currently we only support equal number of devices on each worker. 1112 return self._group_size / len(self._devices) 1113 1114 def _all_reduce(self, reduce_op, value, replica_id, options): 1115 """Implements CrossDeviceOps.all_reduce.""" 1116 # TODO(b/122840926): reuse this method in _batch_all_reduce. 1117 flat_values = nest.flatten(value) 1118 1119 implementation = options.implementation.value 1120 # If NCCL launches can't be ordered (self._limited_nccl == True), we only 1121 # use NCCL when batch_size > 1, hoping that there's only one batched 1122 # all-reduce, which is the gradient aggregation in optimizer. For TF 2.x, 1123 # NCCL launches are always ordered. 1124 if (self._limited_nccl and 1125 options.implementation == CommunicationImplementation.NCCL and 1126 len(flat_values) == 1): 1127 implementation = CommunicationImplementation.AUTO.value 1128 1129 launcher = self._launchers[replica_id] 1130 dense_values, dense_indices, sparse_values, sparse_indices = ( 1131 cross_device_utils.split_by_sparsity(flat_values)) 1132 dense_results = [] 1133 sparse_results = [] 1134 1135 if dense_values: 1136 # Reverse the lists so that there's better chance that values follows 1137 # the order in which they are calculated (e.g. when they're gradients), so 1138 # as to overlap calculation with communication. However, this may not be 1139 # optimal for cases like gradients of complicated non-sequential models. 1140 # 1141 # Note that we reverse the list before packing so that the first pack 1142 # won't be too small, since it's more likely for first few packs to have 1143 # long queuing time due to concurrent intense computation. 1144 # 1145 # TODO(b/147393503): explore solutions for optimal ordering. 1146 dense_values.reverse() 1147 packs = cross_device_utils.group_by_size(dense_values, 1148 options.bytes_per_pack) 1149 1150 if not context.executing_eagerly() and replica_id == 0: 1151 logging.info( 1152 "Collective all_reduce tensors: %d all_reduces, num_devices = %d, " 1153 "group_size = %d, implementation = %s, num_packs = %d", 1154 len(dense_values), len(self._launchers), self._group_size, 1155 implementation, len(packs)) 1156 1157 dense_results = launcher.batch_all_reduce(packs, implementation, 1158 options.timeout_seconds) 1159 if reduce_op == reduce_util.ReduceOp.MEAN: 1160 for i, v in enumerate(dense_results): 1161 with ops.device(self._devices[replica_id]): 1162 dense_results[i] = v / self._group_size 1163 dense_results.reverse() 1164 1165 if sparse_values: 1166 if not context.executing_eagerly() and replica_id == 0: 1167 logging.info( 1168 "Collective all_reduce IndexedSlices: %d all_reduces, num_devices =" 1169 "%d, group_size = %d, implementation = %s", len(sparse_values), 1170 len(self._launchers), self._group_size, implementation) 1171 1172 for indexed_slice in sparse_values: 1173 sparse_results.append( 1174 launcher.all_reduce_indexed_slices(indexed_slice, implementation, 1175 options.timeout_seconds)) 1176 1177 if reduce_op == reduce_util.ReduceOp.MEAN: 1178 for i, v in enumerate(sparse_results): 1179 with ops.device(self._devices[replica_id]): 1180 sparse_results[i] = ops.IndexedSlices( 1181 values=sparse_results[i].values / self._group_size, 1182 indices=sparse_results[i].indices, 1183 dense_shape=sparse_results[i].dense_shape) 1184 1185 flat_results = cross_device_utils.stitch_values( 1186 ((dense_results, dense_indices), (sparse_results, sparse_indices))) 1187 return nest.pack_sequence_as(value, flat_results) 1188 1189 def _all_reduce_per_replica_values(self, reduce_op, per_replica_values, 1190 options): 1191 """All reduce a list of per_replica_value.""" 1192 values_by_device = [[] for _ in self._devices] 1193 num_devices = len(self._devices) 1194 for per_replica in per_replica_values: 1195 for i in range(num_devices): 1196 values_by_device[i].append(per_replica.values[i]) 1197 1198 if context.executing_eagerly(): 1199 1200 def thread_fn(device_id): 1201 with context.eager_mode(): 1202 return self._all_reduce(reduce_op, values_by_device[device_id], 1203 device_id, options) 1204 1205 with self._lock: 1206 outputs_by_device = self._pool.map(thread_fn, list(range(num_devices))) 1207 else: 1208 outputs_by_device = [] 1209 with self._lock: 1210 for i in range(num_devices): 1211 outputs_by_device.append( 1212 self._all_reduce(reduce_op, values_by_device[i], i, options)) 1213 1214 result = [] 1215 for values in zip(*outputs_by_device): 1216 result.append( 1217 distribute_utils.regroup(values, wrap_class=value_lib.Mirrored)) 1218 return result 1219 1220 def reduce_implementation(self, reduce_op, per_replica_value, destinations, 1221 options): 1222 values_util.mark_as_unsaveable() 1223 all_reduced = self._all_reduce_per_replica_values(reduce_op, 1224 [per_replica_value], 1225 options)[0] 1226 devices = get_devices_from(destinations, self._canonicalize_devices) 1227 1228 if _devices_match(per_replica_value, destinations, 1229 self._canonicalize_devices): 1230 return all_reduced 1231 1232 # Convert `all_reduced` to a `Mirrored` object, as a simple and uniform 1233 # utility to access component for a particular device. 1234 if not isinstance(all_reduced, value_lib.Mirrored): 1235 all_reduced = value_lib.Mirrored([all_reduced]) 1236 1237 # If we got this far, the destination devices do not match the all-reduce 1238 # devices, so we must map from one to the other. 1239 index = [] 1240 # We must add these control dependencies, otherwise we can get deadlock. 1241 with ops.control_dependencies(all_reduced.values): 1242 for d in devices: 1243 with ops.device(d): 1244 for v in all_reduced.values: 1245 if v.device == d: 1246 index.append(array_ops.identity(v)) 1247 break 1248 else: 1249 # TODO(josh11b): Once we add support for model parallelism, get the 1250 # copy from the corresponding replica instead of the primary. 1251 index.append(array_ops.identity(all_reduced._primary)) # pylint: disable=protected-access 1252 return distribute_utils.regroup(index, wrap_class=value_lib.Mirrored) 1253 1254 def batch_reduce_implementation(self, reduce_op, value_destination_pairs, 1255 options): 1256 values_util.mark_as_unsaveable() 1257 all_devices_match = _all_devices_match(value_destination_pairs, 1258 self._canonicalize_devices) 1259 if all_devices_match: 1260 return self._all_reduce_per_replica_values( 1261 reduce_op, [v[0] for v in value_destination_pairs], options) 1262 else: 1263 if not all_devices_match: 1264 logging.log_first_n( 1265 logging.WARN, "Efficient batch_reduce is not supported if " 1266 "destinations are different.", 10) 1267 1268 return [ 1269 self.reduce_implementation(reduce_op, value, dest, options) 1270 for value, dest in value_destination_pairs 1271 ] 1272 1273 def _gather_implementation(self, per_replica_value, destinations, axis, 1274 options): 1275 all_gathered = self._batch_all_gather([per_replica_value], axis, options)[0] 1276 values_util.mark_as_unsaveable() 1277 devices = get_devices_from(destinations, self._canonicalize_devices) 1278 1279 if _devices_match(per_replica_value, destinations, 1280 self._canonicalize_devices): 1281 return all_gathered 1282 1283 # Convert `all_gathered` to a `Mirrored` object, as a simple and uniform 1284 # utility to access component for a particular device. 1285 if not isinstance(all_gathered, value_lib.Mirrored): 1286 all_gathered = value_lib.Mirrored([all_gathered]) 1287 1288 # If we got this far, the destination devices do not match the all-gather 1289 # devices, so we must map from one to the other. 1290 index = [] 1291 # We must add these control dependencies, otherwise we can get deadlock. 1292 with ops.control_dependencies(all_gathered.values): 1293 for d in devices: 1294 with ops.device(d): 1295 for v in all_gathered.values: 1296 if v.device == d: 1297 index.append(array_ops.identity(v)) 1298 break 1299 else: 1300 index.append(array_ops.identity(all_gathered._primary)) # pylint: disable=protected-access 1301 return distribute_utils.regroup(index, wrap_class=value_lib.Mirrored) 1302 1303 def _batch_all_gather(self, per_replica_values, axis, options): 1304 """all gather multiple per-replica-values.""" 1305 batch_size = len(per_replica_values) 1306 # Pass options.implementation to the runtime as a communication 1307 # implementation hint. 1308 implementation = options.implementation.value 1309 # For now, we use NCCL only when batch_size > 1. 1310 # TODO(b/132575814): switch to NCCL for all collectives when implementation 1311 # is NCCL. 1312 if (options.implementation == CommunicationImplementation.NCCL and 1313 batch_size == 1): 1314 implementation = CommunicationImplementation.AUTO.value 1315 1316 logging.log_first_n( 1317 logging.INFO, "Collective batch_all_gather: %d all-gathers, " 1318 "num_devices = %d, group_size = %d, implementation = %s, " % 1319 (batch_size, len(self._devices), self._group_size, implementation), 10) 1320 1321 def compute_gathered_values(): 1322 gathered_values = [] 1323 with self._lock, ops.name_scope("allgather"): 1324 for per_replica in per_replica_values: 1325 outputs = [] 1326 for i in range(len(self._devices)): 1327 outputs.append(self._launchers[i].all_gather( 1328 per_replica.values[i], axis, implementation, 1329 options.timeout_seconds)) 1330 gathered_values.append(outputs) 1331 return gathered_values 1332 1333 if context.executing_eagerly(): 1334 gathered_values = def_function.function(compute_gathered_values)() 1335 else: 1336 gathered_values = compute_gathered_values() 1337 1338 mirrored = [] 1339 for value in gathered_values: 1340 mirrored.append( 1341 distribute_utils.regroup(value, wrap_class=value_lib.Mirrored)) 1342 return mirrored 1343 1344 def __deepcopy__(self, memo): 1345 # distribute_coordinator deep-copies the strategy object, so 1346 # CollectiveAllReduce needs to support deep copy as well. 1347 collective_keys = copy.deepcopy(self._collective_keys, memo) 1348 return CollectiveAllReduce(self._devices, self._group_size, collective_keys, 1349 self._canonicalize_devices) 1350 1351 1352def select_cross_device_ops(devices, session_config=None): 1353 """Find the best `CrossDeviceOps` locally given a `tf.compat.v1.ConfigProto`. 1354 1355 Args: 1356 devices: a list of devices passed to `tf.distribute.Strategy`. 1357 session_config: a `tf.compat.v1.ConfigProto` or `None`. If `None`, it will 1358 make decision based on all logical devices. 1359 1360 Returns: 1361 A subclass of `CrossDeviceOps`. 1362 """ 1363 requested_devices = set(device_util.canonicalize(d) for d in devices) 1364 if ops.executing_eagerly_outside_functions(): 1365 logical_gpus = context.context().list_logical_devices(device_type="GPU") 1366 physical_gpus = context.context().list_physical_devices(device_type="GPU") 1367 if len(logical_gpus) != len(physical_gpus): 1368 logging.warning("NCCL is not supported when using virtual GPUs, falling" 1369 "back to reduction to one device") 1370 return ReductionToOneDevice() 1371 1372 machine_devices = context.context().list_logical_devices() 1373 else: 1374 machine_devices = device_lib.list_local_devices( 1375 session_config=session_config) 1376 using_devices = set() 1377 for d in machine_devices: 1378 if device_util.canonicalize(d.name) in requested_devices: 1379 using_devices.add(d.name) 1380 1381 if len(using_devices) != len(requested_devices): 1382 logging.warning( 1383 "Some requested devices in `tf.distribute.Strategy` are not visible " 1384 "to TensorFlow: %s", ",".join(list(requested_devices - using_devices))) 1385 1386 if any("gpu" not in d.lower() for d in requested_devices): 1387 logging.warning("There are non-GPU devices in `tf.distribute.Strategy`, " 1388 "not using nccl allreduce.") 1389 return ReductionToOneDevice() 1390 1391 if kernels.get_registered_kernels_for_op("NcclAllReduce"): 1392 return NcclAllReduce(num_packs=1) 1393 else: 1394 logging.warning("Nccl kernel is not found, not using nccl allreduce.") 1395 return ReductionToOneDevice() 1396