1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Class CollectiveAllReduceStrategy implementing DistributionStrategy.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import copy 22import threading 23import time 24import weakref 25 26from tensorflow.core.protobuf import rewriter_config_pb2 27from tensorflow.core.protobuf import tensorflow_server_pb2 28from tensorflow.python.distribute import collective_util 29from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib 30from tensorflow.python.distribute import cross_device_utils 31from tensorflow.python.distribute import device_util 32from tensorflow.python.distribute import distribute_lib 33from tensorflow.python.distribute import distribute_utils 34from tensorflow.python.distribute import distribution_strategy_context as ds_context 35from tensorflow.python.distribute import input_lib 36from tensorflow.python.distribute import mirrored_strategy 37from tensorflow.python.distribute import multi_worker_util 38from tensorflow.python.distribute import numpy_dataset 39from tensorflow.python.distribute import reduce_util 40from tensorflow.python.distribute import values 41from tensorflow.python.distribute.cluster_resolver import ClusterResolver 42from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver 43from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver 44from tensorflow.python.eager import context 45from tensorflow.python.framework import device as pydev 46from tensorflow.python.framework import errors 47from tensorflow.python.framework import ops 48from tensorflow.python.ops import array_ops 49from tensorflow.python.ops import collective_ops 50from tensorflow.python.platform import tf_logging as logging 51from tensorflow.python.training.tracking import base 52from tensorflow.python.util import deprecation 53from tensorflow.python.util.tf_export import tf_export 54 55 56# pylint: disable=line-too-long 57@tf_export("distribute.MultiWorkerMirroredStrategy", v1=[]) 58class CollectiveAllReduceStrategy(distribute_lib.Strategy): 59 """A distribution strategy for synchronous training on multiple workers. 60 61 This strategy implements synchronous distributed training across multiple 62 workers, each with potentially multiple GPUs. Similar to 63 `tf.distribute.MirroredStrategy`, it replicates all variables and computations 64 to each local device. The difference is that it uses a distributed collective 65 implementation (e.g. all-reduce), so that multiple workers can work together. 66 67 You need to launch your program on each worker and configure 68 `cluster_resolver` correctly. For example, if you are using 69 `tf.distribute.cluster_resolver.TFConfigClusterResolver`, each worker needs to 70 have its corresponding `task_type` and `task_id` set in the `TF_CONFIG` 71 environment variable. An example TF_CONFIG on worker-0 of a two worker cluster 72 is: 73 74 ``` 75 TF_CONFIG = '{"cluster": {"worker": ["localhost:12345", "localhost:23456"]}, "task": {"type": "worker", "index": 0} }' 76 ``` 77 78 Your program runs on each worker as-is. Note that collectives require each 79 worker to participate. All `tf.distribute` and non `tf.distribute` API may use 80 collectives internally, e.g. checkpointing and saving since reading a 81 `tf.Variable` with `tf.VariableSynchronization.ON_READ` all-reduces the value. 82 Therefore it's recommended to run exactly the same program on each worker. 83 Dispatching based on `task_type` or `task_id` of the worker is error-prone. 84 85 `cluster_resolver.num_accelerators()` determines the number of GPUs the 86 strategy uses. If it's zero, the strategy uses the CPU. All workers need to 87 use the same number of devices, otherwise the behavior is undefined. 88 89 This strategy is not intended for TPU. Use `tf.distribute.TPUStrategy` 90 instead. 91 92 After setting up TF_CONFIG, using this strategy is similar to using 93 `tf.distribute.MirroredStrategy` and `tf.distribute.TPUStrategy`. 94 95 ``` 96 strategy = tf.distribute.MultiWorkerMirroredStrategy() 97 98 with strategy.scope(): 99 model = tf.keras.Sequential([ 100 tf.keras.layers.Dense(2, input_shape=(5,)), 101 ]) 102 optimizer = tf.keras.optimizers.SGD(learning_rate=0.1) 103 104 def dataset_fn(ctx): 105 x = np.random.random((2, 5)).astype(np.float32) 106 y = np.random.randint(2, size=(2, 1)) 107 dataset = tf.data.Dataset.from_tensor_slices((x, y)) 108 return dataset.repeat().batch(1, drop_remainder=True) 109 dist_dataset = strategy.distribute_datasets_from_function(dataset_fn) 110 111 model.compile() 112 model.fit(dist_dataset) 113 ``` 114 115 You can also write your own training loop: 116 117 ``` 118 @tf.function 119 def train_step(iterator): 120 121 def step_fn(inputs): 122 features, labels = inputs 123 with tf.GradientTape() as tape: 124 logits = model(features, training=True) 125 loss = tf.keras.losses.sparse_categorical_crossentropy( 126 labels, logits) 127 128 grads = tape.gradient(loss, model.trainable_variables) 129 optimizer.apply_gradients(zip(grads, model.trainable_variables)) 130 131 strategy.run(step_fn, args=(next(iterator),)) 132 133 for _ in range(NUM_STEP): 134 train_step(iterator) 135 ``` 136 137 See 138 [Multi-worker training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras) 139 for a detailed tutorial. 140 141 __Saving__ 142 143 You need to save and checkpoint on all workers instead of just one. This is 144 because variables whose synchronization=ON_READ triggers aggregation during 145 saving. It's recommended to save to a different path on each worker to avoid 146 race conditions. Each worker saves the same thing. See 147 [Multi-worker training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras#model_saving_and_loading) 148 tutorial for examples. 149 150 __Known Issues__ 151 152 * `tf.distribute.cluster_resolver.TFConfigClusterResolver` does not return the 153 correct number of accelerators. The strategy uses all available GPUs if 154 `cluster_resolver` is `tf.distribute.cluster_resolver.TFConfigClusterResolver` 155 or `None`. 156 * In eager mode, the strategy needs to be created before calling any other 157 Tensorflow API. 158 159 """ 160 # pylint: enable=line-too-long 161 162 # TODO(anjalisridhar): Update our guides with examples showing how we can use 163 # the cluster_resolver argument. 164 165 # The starting number for collective keys. This should only be set in tests. 166 _collective_key_base = 0 167 168 def __init__(self, 169 cluster_resolver=None, 170 communication_options=None): 171 """Creates the strategy. 172 173 Args: 174 cluster_resolver: optional 175 `tf.distribute.cluster_resolver.ClusterResolver`. If `None`, 176 `tf.distribute.cluster_resolver.TFConfigClusterResolver` is used. 177 communication_options: optional 178 `tf.distribute.experimental.CommunicationOptions`. This configures the 179 default options for cross device communications. It can be overridden by 180 options provided to the communication APIs like 181 `tf.distribute.ReplicaContext.all_reduce`. See 182 `tf.distribute.experimental.CommunicationOptions` for details. 183 """ 184 if communication_options is None: 185 communication_options = collective_util.Options() 186 super(CollectiveAllReduceStrategy, self).__init__( 187 CollectiveAllReduceExtended( 188 self, 189 cluster_resolver=cluster_resolver, 190 communication_options=communication_options)) 191 192 distribute_lib.distribution_strategy_gauge.get_cell("V2").set( 193 "MultiWorkerMirroredStrategy") 194 # pylint: disable=protected-access 195 distribute_lib.distribution_strategy_replica_gauge.get_cell( 196 "num_workers").set(self.extended._num_workers) 197 distribute_lib.distribution_strategy_replica_gauge.get_cell( 198 "num_replicas_per_worker").set(self.extended._num_gpus_per_worker) 199 200 @classmethod 201 def _from_local_devices(cls, devices, communication_options=None): 202 """A convenience method to create an object with a list of devices.""" 203 obj = cls(communication_options=communication_options) 204 obj.extended._initialize_local(TFConfigClusterResolver(), devices=devices) # pylint: disable=protected-access 205 return obj 206 207 @property 208 def cluster_resolver(self): 209 """Returns the cluster resolver associated with this strategy. 210 211 As a multi-worker strategy, `tf.distribute.MultiWorkerMirroredStrategy` 212 provides the associated `tf.distribute.cluster_resolver.ClusterResolver`. If 213 the user provides one in `__init__`, that instance is returned; if the user 214 does not, a default `TFConfigClusterResolver` is provided. 215 """ 216 return self.extended._cluster_resolver # pylint: disable=protected-access 217 218 219class _CollectiveAllReduceStrategyExperimentalMeta(type): 220 221 @classmethod 222 def __instancecheck__(cls, instance): 223 # This is to make isinstance(tf.distribute.MultiWorkerMirroredStrategy(), 224 # tf.distribute.experimental.MultiWorkerMirroredStrategy). Some libraries is 225 # performing such check. 226 return isinstance(instance, CollectiveAllReduceStrategy) 227 228 229@tf_export("distribute.experimental.MultiWorkerMirroredStrategy", v1=[]) 230class _CollectiveAllReduceStrategyExperimental( 231 CollectiveAllReduceStrategy, 232 metaclass=_CollectiveAllReduceStrategyExperimentalMeta): 233 234 __doc__ = CollectiveAllReduceStrategy.__doc__ 235 236 @deprecation.deprecated( 237 None, "use distribute.MultiWorkerMirroredStrategy instead") 238 def __init__(self, 239 communication=collective_util.CommunicationImplementation.AUTO, 240 cluster_resolver=None): 241 """Creates the strategy. 242 243 Args: 244 communication: optional 245 `tf.distribute.experimental.CommunicationImplementation`. This is a hint 246 on the preferred collective communication implementation. Possible 247 values include `AUTO`, `RING`, and `NCCL`. 248 cluster_resolver: optional 249 `tf.distribute.cluster_resolver.ClusterResolver`. If `None`, 250 `tf.distribute.cluster_resolver.TFConfigClusterResolver` is used. 251 """ 252 communication_options = collective_util.Options( 253 implementation=communication) 254 super(_CollectiveAllReduceStrategyExperimental, 255 self).__init__(cluster_resolver, communication_options) 256 257 @classmethod 258 def _from_local_devices( 259 cls, 260 devices, 261 communication=collective_util.CommunicationImplementation.AUTO): 262 """A convenience method to create an object with a list of devices.""" 263 obj = cls(communication) 264 obj.extended._initialize_local(TFConfigClusterResolver(), devices=devices) # pylint: disable=protected-access 265 return obj 266 267 268_CollectiveAllReduceStrategyExperimental.__name__ = CollectiveAllReduceStrategy.__name__ 269 270 271@tf_export(v1=["distribute.experimental.MultiWorkerMirroredStrategy"]) # pylint: disable=missing-docstring 272class CollectiveAllReduceStrategyV1(distribute_lib.StrategyV1): 273 274 __doc__ = CollectiveAllReduceStrategy.__doc__ 275 276 # The starting number for collective keys. This should only be set in tests. 277 _collective_key_base = 0 278 279 def __init__(self, 280 communication=collective_util.CommunicationImplementation.AUTO, 281 cluster_resolver=None): 282 """Initializes the object.""" 283 communication_options = collective_util.Options( 284 implementation=communication) 285 super(CollectiveAllReduceStrategyV1, self).__init__( 286 CollectiveAllReduceExtended( 287 self, 288 cluster_resolver=cluster_resolver, 289 communication_options=communication_options)) 290 distribute_lib.distribution_strategy_gauge.get_cell("V1").set( 291 "MultiWorkerMirroredStrategy") 292 # pylint: disable=protected-access 293 distribute_lib.distribution_strategy_replica_gauge.get_cell( 294 "num_workers").set(self.extended._num_workers) 295 distribute_lib.distribution_strategy_replica_gauge.get_cell( 296 "num_gpu_per_worker").set(self.extended._num_gpus_per_worker) 297 298 299class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): 300 """Implementation of CollectiveAllReduceStrategy.""" 301 302 # Whether to perdically check the health of the cluster. If any worker is not 303 # reachable, collectives are aborted and the user program should get a 304 # tf.errors.UnavailableError. It's required to restart in order to recover. 305 _enable_check_health = True 306 # Check health interval in seconds. 307 _check_health_interval = 30 308 # Timeout in seconds for the first check health. The first check health needs 309 # to wait for cluster, which may make a longer time. 310 _check_health_initial_timeout = 0 311 # Times to retry before considering the peer is down. 312 _check_health_retry_limit = 3 313 # Timeout in seconds the each check health. 314 _check_health_timeout = 10 315 316 def __init__(self, container_strategy, cluster_resolver, 317 communication_options): 318 if not isinstance(communication_options, collective_util.Options): 319 raise ValueError("communication_options must be an instance of " 320 "tf.distribute.experimental.CommunicationOptions") 321 self._cluster_resolver = cluster_resolver or TFConfigClusterResolver() 322 if not isinstance(self._cluster_resolver, ClusterResolver): 323 raise ValueError("cluster_resolver must be an instance of " 324 "tf.distribute.cluster_resolver.ClusterResolver") 325 distribute_lib.StrategyExtendedV1.__init__(self, container_strategy) 326 self._communication_options = communication_options 327 self._collective_key_base = container_strategy._collective_key_base # pylint: disable=protected-access 328 self._initialize_strategy(self._cluster_resolver) 329 self._cfer_fn_cache = weakref.WeakKeyDictionary() 330 self.experimental_enable_get_next_as_optional = True 331 assert isinstance(self._cross_device_ops, 332 cross_device_ops_lib.CollectiveAllReduce) 333 334 def _use_merge_call(self): 335 """XLA is not supported for multi-worker strategy.""" 336 return True 337 338 def _initialize_strategy(self, cluster_resolver): 339 if cluster_resolver.cluster_spec().as_dict(): 340 self._initialize_multi_worker(cluster_resolver) 341 else: 342 self._initialize_local(cluster_resolver) 343 344 def _initialize_local(self, cluster_resolver, devices=None): 345 """Initializes the object for local training.""" 346 self._is_chief = True 347 self._num_workers = 1 348 349 if ops.executing_eagerly_outside_functions(): 350 try: 351 context.context().configure_collective_ops( 352 scoped_allocator_enabled_ops=("CollectiveReduce",)) 353 except RuntimeError: 354 logging.warning("Collective ops is not configured at program startup. " 355 "Some performance features may not be enabled.") 356 self._collective_ops_configured = True 357 358 # TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in 359 # some cases. 360 if isinstance(cluster_resolver, TFConfigClusterResolver): 361 num_gpus = context.num_gpus() 362 else: 363 num_gpus = cluster_resolver.num_accelerators().get("GPU", 0) 364 365 if devices: 366 local_devices = devices 367 else: 368 if num_gpus: 369 local_devices = tuple("/device:GPU:%d" % i for i in range(num_gpus)) 370 else: 371 local_devices = ("/device:CPU:0",) 372 373 self._worker_device = device_util.canonicalize("/device:CPU:0") 374 self._host_input_device = numpy_dataset.SingleDevice(self._worker_device) 375 376 self._collective_keys = cross_device_utils.CollectiveKeys( 377 group_key_start=1 + self._collective_key_base) 378 self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( 379 devices=local_devices, 380 group_size=len(local_devices), 381 collective_keys=self._collective_keys) 382 # CrossDeviceOps for per host tensors. 383 self._host_cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( 384 devices=[self._worker_device], 385 group_size=self._num_workers, 386 collective_keys=self._collective_keys) 387 super(CollectiveAllReduceExtended, self)._initialize_single_worker( 388 local_devices) 389 390 self._cluster_spec = None 391 self._task_type = None 392 self._task_id = None 393 self._id_in_cluster = 0 394 395 # This is a mark to tell whether we are running with standalone client or 396 # independent worker. Right now with standalone client, strategy object is 397 # created as local strategy and then turn into multi-worker strategy via 398 # configure call. 399 self._local_or_standalone_client_mode = True 400 401 # Save the num_gpus_per_worker and rpc_layer for configure method. 402 self._num_gpus_per_worker = num_gpus 403 self._rpc_layer = cluster_resolver.rpc_layer 404 self._warn_nccl_no_gpu() 405 406 logging.info( 407 "Single-worker MultiWorkerMirroredStrategy with local_devices " 408 "= %r, communication = %s", local_devices, 409 self._communication_options.implementation) 410 411 def _initialize_multi_worker(self, cluster_resolver): 412 """Initializes the object for multi-worker training.""" 413 cluster_spec = multi_worker_util.normalize_cluster_spec( 414 cluster_resolver.cluster_spec()) 415 task_type = cluster_resolver.task_type 416 task_id = cluster_resolver.task_id 417 if task_type is None or task_id is None: 418 raise ValueError("When `cluster_spec` is given, you must also specify " 419 "`task_type` and `task_id`.") 420 self._cluster_spec = cluster_spec 421 self._task_type = task_type 422 self._task_id = task_id 423 self._id_in_cluster = multi_worker_util.id_in_cluster( 424 self._cluster_spec, self._task_type, self._task_id) 425 426 self._num_workers = multi_worker_util.worker_count(cluster_spec, task_type) 427 if not self._num_workers: 428 raise ValueError("No `worker`, `chief` or `evaluator` tasks can be found " 429 "in `cluster_spec`.") 430 431 self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type, 432 task_id) 433 434 self._worker_device = "/job:%s/task:%d" % (task_type, task_id) 435 self._host_input_device = numpy_dataset.SingleDevice(self._worker_device) 436 437 if (ops.executing_eagerly_outside_functions() and 438 not getattr(self, "_local_or_standalone_client_mode", False)): 439 context.context().configure_collective_ops( 440 collective_leader=multi_worker_util.collective_leader( 441 cluster_spec, task_type, task_id), 442 scoped_allocator_enabled_ops=("CollectiveReduce",), 443 device_filters=("/job:%s/task:%d" % (task_type, task_id),)) 444 self._collective_ops_configured = True 445 446 # Starting a std server in eager mode and in independent worker mode. 447 if (context.executing_eagerly() and 448 not getattr(self, "_std_server_started", False) and 449 not getattr(self, "_local_or_standalone_client_mode", False)): 450 # Checking _local_or_standalone_client_mode as well because we should not 451 # create the std server in standalone client mode. 452 config_proto = copy.deepcopy(context.context().config) 453 config_proto = self._update_config_proto(config_proto) 454 455 # If coordination service is enabled, use its internal heartbeat to detect 456 # peer failures instead of the Python-level health check. 457 if config_proto.experimental.coordination_service: 458 self._enable_check_health = False 459 460 if hasattr(cluster_resolver, "port"): 461 port = cluster_resolver.port 462 else: 463 port = 0 464 server_def = tensorflow_server_pb2.ServerDef( 465 cluster=cluster_spec.as_cluster_def(), 466 default_session_config=config_proto, 467 job_name=task_type, 468 task_index=task_id, 469 protocol=cluster_resolver.rpc_layer or "grpc", 470 port=port) 471 context.context().enable_collective_ops(server_def) 472 self._std_server_started = True 473 # The `ensure_initialized` is needed before calling 474 # `context.context().devices()`. 475 context.context().ensure_initialized() 476 logging.info( 477 "Enabled multi-worker collective ops with available devices: %r", 478 context.context().devices()) 479 480 # TODO(yuefengz): The `num_gpus` is only for this particular task. It 481 # assumes all workers have the same number of GPUs. We should remove this 482 # assumption by querying all tasks for their numbers of GPUs. 483 # TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in 484 # some cases. 485 if isinstance(cluster_resolver, TFConfigClusterResolver): 486 num_gpus = 0 487 devices = context.context().devices() 488 for d in devices: 489 device_spec = pydev.DeviceSpec.from_string(d) 490 if (device_spec.job == task_type and device_spec.task == task_id and 491 device_spec.device_type == "GPU"): 492 num_gpus += 1 493 else: 494 num_gpus = cluster_resolver.num_accelerators().get("GPU", 0) 495 496 if num_gpus: 497 local_devices = tuple("%s/device:GPU:%d" % (self._worker_device, i) 498 for i in range(num_gpus)) 499 else: 500 local_devices = (self._worker_device,) 501 502 self._collective_keys = cross_device_utils.CollectiveKeys( 503 group_key_start=1 + self._collective_key_base) 504 self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( 505 devices=local_devices, 506 group_size=len(local_devices) * self._num_workers, 507 collective_keys=self._collective_keys) 508 # CrossDeviceOps for per host tensors. 509 self._host_cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( 510 devices=[self._worker_device], 511 group_size=self._num_workers, 512 collective_keys=self._collective_keys) 513 super(CollectiveAllReduceExtended, self)._initialize_single_worker( 514 local_devices) 515 516 # Add a default device so that ops without specified devices will not end up 517 # on other workers. 518 self._default_device = "/job:%s/task:%d" % (task_type, task_id) 519 520 # Save the num_gpus_per_worker and rpc_layer for configure method. 521 self._num_gpus_per_worker = num_gpus 522 self._rpc_layer = cluster_resolver.rpc_layer 523 self._warn_nccl_no_gpu() 524 525 if self._enable_check_health and context.executing_eagerly(): 526 self._start_check_health_thread() 527 else: 528 logging.info("Check health not enabled.") 529 530 logging.info( 531 "MultiWorkerMirroredStrategy with cluster_spec = %r, task_type = %r, " 532 "task_id = %r, num_workers = %r, local_devices = %r, " 533 "communication = %s", cluster_spec.as_dict(), task_type, task_id, 534 self._num_workers, local_devices, 535 self._communication_options.implementation) 536 537 def __del__(self): 538 self._stop_check_health_thread() 539 540 def _input_workers_with_options(self, options=None): 541 host_device = device_util.get_host_for_device(self._worker_device) 542 if not options or options.experimental_fetch_to_device: 543 return input_lib.InputWorkers([(host_device, self.worker_devices)]) 544 else: 545 return input_lib.InputWorkers([( 546 host_device, 547 [device_util.get_host_for_device(worker) for worker in 548 self.worker_devices])]) 549 550 @property 551 def _input_workers(self): 552 return self._input_workers_with_options() 553 554 def _get_variable_creator_initial_value(self, 555 replica_id, 556 device, 557 primary_var, 558 **kwargs): 559 if replica_id == 0: # First replica on each worker. 560 assert device is not None 561 assert primary_var is None 562 563 def initial_value_fn(): # pylint: disable=g-missing-docstring 564 # Only the first device participates in the broadcast of initial values. 565 group_key = self._collective_keys.get_group_key([device]) 566 group_size = self._num_workers 567 collective_instance_key = ( 568 self._collective_keys.get_instance_key(group_key, device)) 569 570 with ops.device(device): 571 initial_value = kwargs["initial_value"] 572 if callable(initial_value): 573 initial_value = initial_value() 574 if isinstance(initial_value, base.CheckpointInitialValue): 575 initial_value = initial_value.wrapped_value 576 assert not callable(initial_value) 577 initial_value = ops.convert_to_tensor( 578 initial_value, dtype=kwargs.get("dtype", None)) 579 580 if self._num_workers > 1: 581 if self._is_chief: 582 bcast_send = collective_ops.broadcast_send( 583 initial_value, initial_value.shape, initial_value.dtype, 584 group_size, group_key, collective_instance_key) 585 with ops.control_dependencies([bcast_send]): 586 return array_ops.identity(initial_value) 587 else: 588 return collective_ops.broadcast_recv(initial_value.shape, 589 initial_value.dtype, 590 group_size, group_key, 591 collective_instance_key) 592 return initial_value 593 594 return initial_value_fn 595 else: 596 return super(CollectiveAllReduceExtended, 597 self)._get_variable_creator_initial_value( 598 replica_id=replica_id, 599 device=device, 600 primary_var=primary_var, 601 **kwargs) 602 603 def _make_input_context(self): 604 input_context = distribute_lib.InputContext( 605 num_input_pipelines=self._num_workers, 606 input_pipeline_id=self._id_in_cluster, 607 num_replicas_in_sync=self._num_replicas_in_sync) 608 return input_context 609 610 def _experimental_distribute_dataset(self, dataset, options): 611 if (options and options.experimental_replication_mode == 612 distribute_lib.InputReplicationMode.PER_REPLICA): 613 raise NotImplementedError( 614 "InputReplicationMode.PER_REPLICA " 615 "is only supported in " 616 "`distribute_datasets_from_function` " 617 "of tf.distribute.MirroredStrategy" 618 ) 619 input_context = self._make_input_context() 620 return input_lib.get_distributed_dataset( 621 dataset, 622 self._input_workers_with_options(options), 623 self._container_strategy(), 624 num_replicas_in_sync=self._num_replicas_in_sync, 625 input_context=input_context, 626 options=options) 627 628 def _distribute_datasets_from_function(self, dataset_fn, options): 629 if (options and options.experimental_replication_mode == 630 distribute_lib.InputReplicationMode.PER_REPLICA): 631 raise NotImplementedError( 632 "InputReplicationMode.PER_REPLICA " 633 "is only supported in " 634 "`distribute_datasets_from_function` " 635 "of tf.distribute.MirroredStrategy") 636 input_context = self._make_input_context() 637 return input_lib.get_distributed_datasets_from_function( 638 dataset_fn=dataset_fn, 639 input_workers=self._input_workers_with_options(options), 640 input_contexts=[input_context], 641 strategy=self._container_strategy(), 642 options=options) 643 644 def _experimental_distribute_values_from_function(self, value_fn): 645 per_replica_values = [] 646 num_local_replicas = len(self.worker_devices) 647 for local_replica_id in range(num_local_replicas): 648 replica_id = (self._id_in_cluster * num_local_replicas + 649 local_replica_id) 650 value_context = distribute_lib.ValueContext( 651 replica_id, self._num_replicas_in_sync) 652 per_replica_values.append(value_fn(value_context)) 653 return distribute_utils.regroup(per_replica_values, always_wrap=True) 654 655 def _make_dataset_iterator(self, dataset): 656 """Distributes the dataset to each local GPU.""" 657 input_context = self._make_input_context() 658 return input_lib.DatasetIterator( 659 dataset, 660 self._input_workers, 661 self._container_strategy(), 662 num_replicas_in_sync=self._num_replicas_in_sync, 663 input_context=input_context) 664 665 def _make_input_fn_iterator( 666 self, 667 input_fn, 668 replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): 669 """Distributes the input function to each local GPU.""" 670 input_context = self._make_input_context() 671 return input_lib.InputFunctionIterator(input_fn, self._input_workers, 672 [input_context], 673 self._container_strategy()) 674 675 def _configure(self, 676 session_config=None, 677 cluster_spec=None, 678 task_type=None, 679 task_id=None): 680 """Configures the object. 681 682 Args: 683 session_config: a `tf.compat.v1.ConfigProto` 684 cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the 685 cluster configurations. 686 task_type: the current task type, such as "worker". 687 task_id: the current task id. 688 689 Raises: 690 ValueError: if `task_type` is not in the `cluster_spec`. 691 """ 692 if cluster_spec: 693 # Use the num_gpus_per_worker recorded in constructor since _configure 694 # doesn't take num_gpus. 695 cluster_resolver = SimpleClusterResolver( 696 cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec), 697 task_type=task_type, 698 task_id=task_id, 699 num_accelerators={"GPU": self._num_gpus_per_worker}, 700 rpc_layer=self._rpc_layer) 701 self._initialize_multi_worker(cluster_resolver) 702 assert isinstance(self._cross_device_ops, 703 cross_device_ops_lib.CollectiveAllReduce) 704 705 if session_config: 706 session_config.CopyFrom(self._update_config_proto(session_config)) 707 708 def _update_config_proto(self, config_proto): 709 updated_config = copy.deepcopy(config_proto) 710 # Enable the scoped allocator optimization for CollectiveOps. This 711 # optimization converts many small all-reduces into fewer larger 712 # all-reduces. 713 rewrite_options = updated_config.graph_options.rewrite_options 714 rewrite_options.scoped_allocator_optimization = ( 715 rewriter_config_pb2.RewriterConfig.ON) 716 # We turn on ScopedAllocator only for CollectiveReduce op, i.e. enable_op = 717 # ["CollectiveReduce"]. Since we can't assign to a repeated proto field, we 718 # clear and then append. 719 del rewrite_options.scoped_allocator_opts.enable_op[:] 720 rewrite_options.scoped_allocator_opts.enable_op.append("CollectiveReduce") 721 722 if (not ops.executing_eagerly_outside_functions() and 723 self._communication_options.implementation == 724 collective_util.CommunicationImplementation.NCCL): 725 updated_config.experimental.collective_nccl = True 726 727 if not self._cluster_spec: 728 return updated_config 729 730 assert self._task_type 731 assert self._task_id is not None 732 733 # Collective group leader is needed for collective ops to coordinate 734 # workers. 735 updated_config.experimental.collective_group_leader = ( 736 multi_worker_util.collective_leader(self._cluster_spec, self._task_type, 737 self._task_id)) 738 739 # The device filters prevent communication between workers. 740 del updated_config.device_filters[:] 741 updated_config.device_filters.append( 742 "/job:%s/task:%d" % (self._task_type, self._task_id)) 743 744 return updated_config 745 746 def _get_cross_device_ops(self, value): 747 # CollectiveAllReduce works on a predefined set of devices. In most cases 748 # they should be the compute devices, but certain use cases may reduce host 749 # tensors as well (e.g. early stopping). We infer the cross_device_ops to 750 # use based on the number of devices, since inputs don't always have device 751 # annotations. The compute devices one is preferred since we can potentially 752 # leverage NCCL. 753 if isinstance(value, values.DistributedValues): 754 num_devices = len(value._values) # pylint: disable=protected-access 755 else: 756 num_devices = 1 757 if num_devices == len(self.worker_devices): 758 return self._cross_device_ops 759 else: 760 return self._host_cross_device_ops 761 762 def _gather_to_implementation(self, value, destinations, axis, options): 763 return self._get_cross_device_ops(value)._gather( # pylint: disable=protected-access 764 value, 765 destinations=destinations, 766 axis=axis, 767 options=options) 768 769 def _reduce_to(self, reduce_op, value, destinations, options): 770 if (isinstance(value, values.Mirrored) and 771 reduce_op == reduce_util.ReduceOp.MEAN): 772 return value 773 assert not isinstance(value, values.Mirrored) 774 775 if (isinstance(value, values.DistributedValues) and 776 len(self.worker_devices) == 1): 777 value = value.values[0] 778 779 # When there are multiple workers, we need to reduce across workers using 780 # collective ops. 781 if (not isinstance(value, values.DistributedValues) and 782 self._num_workers == 1): 783 # This function handles reducing values that are not PerReplica or 784 # Mirrored values. For example, the same value could be present on all 785 # replicas in which case `value` would be a single value or value could 786 # be 0. 787 return cross_device_ops_lib.reduce_non_distributed_value( 788 reduce_op, value, destinations, len(self.worker_devices)) 789 return self._get_cross_device_ops(value).reduce( 790 reduce_op, 791 value, 792 destinations=destinations, 793 options=self._communication_options.merge(options)) 794 795 def _replica_ctx_all_reduce(self, reduce_op, value, options=None): 796 """Implements `StrategyExtendedV2._replica_ctx_all_reduce`.""" 797 # This implementation avoids using `merge_call` and just launches collective 798 # ops in one replica. 799 if options is None: 800 options = collective_util.Options() 801 802 if context.executing_eagerly(): 803 # In eager mode, falls back to the default implemenation that uses 804 # `merge_call`. Replica functions are running sequentially in eager mode, 805 # and due to the blocking nature of collective ops, execution will hang if 806 # collective ops are to be launched sequentially. 807 return super()._replica_ctx_all_reduce(reduce_op, value, options) 808 809 replica_context = ds_context.get_replica_context() 810 assert replica_context, ( 811 "`StrategyExtended._replica_ctx_all_reduce` must be called in a " 812 "replica context") 813 return self._cross_device_ops._all_reduce( # pylint: disable=protected-access 814 reduce_op, 815 value, 816 replica_context._replica_id, # pylint: disable=protected-access 817 options) 818 819 def _check_health(self): 820 while True: 821 if self._check_health_thread_should_stop.is_set(): 822 return 823 for job in self._cluster_spec.jobs: 824 for task_id in range(self._cluster_spec.num_tasks(job)): 825 peer = "/job:{}/replica:0/task:{}".format(job, task_id) 826 attempts = 0 827 while True: 828 attempts += 1 829 try: 830 context.context().check_collective_ops_peer_health( 831 peer, timeout_in_ms=self._check_health_timeout * 1000) 832 # If check_collective_ops_peer_health doesn't raise an Exception, 833 # the peer is healthy. 834 break 835 except (errors.UnavailableError, errors.FailedPreconditionError, 836 errors.DeadlineExceededError) as e: 837 # TODO(b/151232436): Always raise UnavailableError when a peer 838 # fails. Now there could be many kinds of errors: 839 # - Unavailable: when the peer is not reachable, e.g. it's down. 840 # - FailedPrecondition: when the peer has restarted. 841 if attempts < self._check_health_retry_limit: 842 logging.warning("%s seems down, retrying %d/%d", peer, attempts, 843 self._check_health_retry_limit) 844 continue 845 logging.error( 846 "Cluster check alive failed, %s is down, " 847 "aborting collectives: %s", peer, e) 848 context.context().abort_collective_ops( 849 errors.UNAVAILABLE, 850 "cluster check alive failed, {} is down".format(peer)) 851 return 852 except Exception as e: # pylint: disable=broad-except 853 logging.error("Unexpected exception in check alive: %s", e) 854 context.context().abort_collective_ops( 855 errors.INTERNAL, 856 "unexecpted exception in check alive: %s" % e) 857 return 858 time.sleep(self._check_health_interval) 859 860 def _start_check_health_thread(self): 861 # Use a dummy all-reduce as a barrier to wait for all workers to be up, 862 # otherwise the check health may fail immediately. 863 864 # Use array_ops.identity to create the dummy tensor so that we have a new 865 # Tensor. If we use constant it may be a cached from on a /job:localhost 866 # device, which will cause some code that relies on tensor.device to error. 867 # 868 # TODO(b/151232436): change to an explicit barrier if we have it. 869 dummy_value = array_ops.identity([]) 870 logging.info("Waiting for the cluster, timeout = %s", 871 self._check_health_initial_timeout or "inf") 872 try: 873 self._host_cross_device_ops.reduce( 874 reduce_util.ReduceOp.SUM, 875 dummy_value, 876 dummy_value, 877 options=collective_util.Options( 878 timeout_seconds=self._check_health_initial_timeout, 879 implementation=collective_util.CommunicationImplementation.RING)) 880 if context.is_async(): 881 context.async_wait() 882 except errors.DeadlineExceededError: 883 raise RuntimeError( 884 "Timeout waiting for the cluster, timeout is %d seconds" % 885 self._check_health_initial_timeout) 886 logging.info("Cluster is ready.") 887 self._check_health_thread_should_stop = threading.Event() 888 # Start the thread as daemon to avoid it blocking the program from exiting. 889 # We try best to shutdown the thread but __del__ is not guaranteed to be 890 # called when program exists. 891 self._check_health_thread = threading.Thread( 892 target=self._check_health, 893 daemon=True) 894 self._check_health_thread.start() 895 896 def _stop_check_health_thread(self): 897 if getattr(self, "_check_health_thread", None): 898 logging.info("stopping check health thread") 899 self._check_health_thread_should_stop.set() 900 self._check_health_thread.join() 901 self._check_health_thread = None 902 logging.info("check health thread stopped") 903 904 def _warn_nccl_no_gpu(self): 905 if ((self._communication_options.implementation == 906 collective_util.CommunicationImplementation.NCCL) and 907 self._num_gpus_per_worker == 0): 908 logging.warning("Enabled NCCL communication but no GPUs detected/" 909 "specified.") 910 911 def _in_multi_worker_mode(self): 912 """Whether this strategy indicates working in multi-worker settings.""" 913 return self._num_workers > 1 914 915 @property 916 def experimental_between_graph(self): 917 return True 918 919 @property 920 def experimental_should_init(self): 921 return True 922 923 @property 924 def should_checkpoint(self): 925 return self._is_chief 926 927 @property 928 def should_save_summary(self): 929 return self._is_chief 930 931 @property 932 def _num_replicas_in_sync(self): 933 return len(self.worker_devices) * self._num_workers 934 935 # TODO(priyag): Delete this once all strategies use global batch size. 936 @property 937 def _global_batch_size(self): 938 """`make_dataset_iterator` and `make_numpy_iterator` use global batch size. 939 940 `make_input_fn_iterator` assumes per-replica batching. 941 942 Returns: 943 Boolean. 944 """ 945 return True 946 947 def _get_replica_id_in_sync_group(self, replica_id): 948 return self._id_in_cluster * len(self.worker_devices) + replica_id 949 950 def _get_local_replica_id(self, replica_id_in_sync_group): 951 return (replica_id_in_sync_group - 952 self._id_in_cluster * len(self.worker_devices)) 953 954 def __deepcopy__(self, memo): 955 # We check the check health thread instead of whether we are in eager mode 956 # to limit the backward incompatibility. 957 if hasattr(self, "_check_health_thread"): 958 raise ValueError( 959 "MultiWorkerMirroredStrategy cannot be deep copied in eager mode. " 960 "If you're using Estimator and see this error message, call " 961 "tf.compat.v1.disable_eager_execution() at the beginning of your " 962 "program") 963 # Otherwise, do a regular deepcopy. 964 cls = self.__class__ 965 result = cls.__new__(cls) 966 memo[id(self)] = result 967 for k, v in self.__dict__.items(): 968 setattr(result, k, copy.deepcopy(v, memo)) 969 return result 970