1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Class CollectiveAllReduceStrategy implementing DistributionStrategy.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import copy 22import threading 23import time 24import weakref 25 26from tensorflow.core.protobuf import rewriter_config_pb2 27from tensorflow.core.protobuf import tensorflow_server_pb2 28from tensorflow.python.distribute import collective_util 29from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib 30from tensorflow.python.distribute import cross_device_utils 31from tensorflow.python.distribute import device_util 32from tensorflow.python.distribute import distribute_lib 33from tensorflow.python.distribute import distribute_utils 34from tensorflow.python.distribute import distribution_strategy_context as ds_context 35from tensorflow.python.distribute import input_lib 36from tensorflow.python.distribute import mirrored_strategy 37from tensorflow.python.distribute import multi_worker_util 38from tensorflow.python.distribute import numpy_dataset 39from tensorflow.python.distribute import reduce_util 40from tensorflow.python.distribute import values 41from tensorflow.python.distribute.cluster_resolver import ClusterResolver 42from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver 43from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver 44from tensorflow.python.eager import context 45from tensorflow.python.framework import errors 46from tensorflow.python.framework import ops 47from tensorflow.python.ops import array_ops 48from tensorflow.python.ops import collective_ops 49from tensorflow.python.platform import tf_logging as logging 50from tensorflow.python.training.tracking import base 51from tensorflow.python.util import deprecation 52from tensorflow.python.util.tf_export import tf_export 53 54 55# pylint: disable=line-too-long 56@tf_export("distribute.MultiWorkerMirroredStrategy", v1=[]) 57class CollectiveAllReduceStrategy(distribute_lib.Strategy): 58 """A distribution strategy for synchronous training on multiple workers. 59 60 This strategy implements synchronous distributed training across multiple 61 workers, each with potentially multiple GPUs. Similar to 62 `tf.distribute.MirroredStrategy`, it replicates all variables and computations 63 to each local device. The difference is that it uses a distributed collective 64 implementation (e.g. all-reduce), so that multiple workers can work together. 65 66 You need to launch your program on each worker and configure 67 `cluster_resolver` correctly. For example, if you are using 68 `tf.distribute.cluster_resolver.TFConfigClusterResolver`, each worker needs to 69 have its corresponding `task_type` and `task_id` set in the `TF_CONFIG` 70 environment variable. An example TF_CONFIG on worker-0 of a two worker cluster 71 is: 72 73 ``` 74 TF_CONFIG = '{"cluster": {"worker": ["localhost:12345", "localhost:23456"]}, "task": {"type": "worker", "index": 0} }' 75 ``` 76 77 Your program runs on each worker as-is. Note that collectives require each 78 worker to participate. All `tf.distribute` and non `tf.distribute` API may use 79 collectives internally, e.g. checkpointing and saving since reading a 80 `tf.Variable` with `tf.VariableSynchronization.ON_READ` all-reduces the value. 81 Therefore it's recommended to run exactly the same program on each worker. 82 Dispatching based on `task_type` or `task_id` of the worker is error-prone. 83 84 `cluster_resolver.num_accelerators()` determines the number of GPUs the 85 strategy uses. If it's zero, the strategy uses the CPU. All workers need to 86 use the same number of devices, otherwise the behavior is undefined. 87 88 This strategy is not intended for TPU. Use `tf.distribute.TPUStrategy` 89 instead. 90 91 After setting up TF_CONFIG, using this strategy is similar to using 92 `tf.distribute.MirroredStrategy` and `tf.distribute.TPUStrategy`. 93 94 ``` 95 strategy = tf.distribute.MultiWorkerMirroredStrategy() 96 97 with strategy.scope(): 98 model = tf.keras.Sequential([ 99 tf.keras.layers.Dense(2, input_shape=(5,)), 100 ]) 101 optimizer = tf.keras.optimizers.SGD(learning_rate=0.1) 102 103 def dataset_fn(ctx): 104 x = np.random.random((2, 5)).astype(np.float32) 105 y = np.random.randint(2, size=(2, 1)) 106 dataset = tf.data.Dataset.from_tensor_slices((x, y)) 107 return dataset.repeat().batch(1, drop_remainder=True) 108 dist_dataset = strategy.distribute_datasets_from_function(dataset_fn) 109 110 model.compile() 111 model.fit(dist_dataset) 112 ``` 113 114 You can also write your own training loop: 115 116 ``` 117 @tf.function 118 def train_step(iterator): 119 120 def step_fn(inputs): 121 features, labels = inputs 122 with tf.GradientTape() as tape: 123 logits = model(features, training=True) 124 loss = tf.keras.losses.sparse_categorical_crossentropy( 125 labels, logits) 126 127 grads = tape.gradient(loss, model.trainable_variables) 128 optimizer.apply_gradients(zip(grads, model.trainable_variables)) 129 130 strategy.run(step_fn, args=(next(iterator),)) 131 132 for _ in range(NUM_STEP): 133 train_step(iterator) 134 ``` 135 136 See 137 [Multi-worker training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras) 138 for a detailed tutorial. 139 140 __Saving__ 141 142 You need to save and checkpoint on all workers instead of just one. This is 143 because variables whose synchronization=ON_READ triggers aggregation during 144 saving. It's recommended to save to a different path on each worker to avoid 145 race conditions. Each worker saves the same thing. See 146 [Multi-worker training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras#model_saving_and_loading) 147 tutorial for examples. 148 149 __Known Issues__ 150 151 * `tf.distribute.cluster_resolver.TFConfigClusterResolver` does not return the 152 correct number of accelerators. The strategy uses all available GPUs if 153 `cluster_resolver` is `tf.distribute.cluster_resolver.TFConfigClusterResolver` 154 or `None`. 155 * In eager mode, the strategy needs to be created before calling any other 156 Tensorflow API. 157 158 """ 159 # pylint: enable=line-too-long 160 161 # TODO(anjalisridhar): Update our guides with examples showing how we can use 162 # the cluster_resolver argument. 163 164 # The starting number for collective keys. This should only be set in tests. 165 _collective_key_base = 0 166 167 def __init__(self, 168 cluster_resolver=None, 169 communication_options=None): 170 """Creates the strategy. 171 172 Args: 173 cluster_resolver: optional 174 `tf.distribute.cluster_resolver.ClusterResolver`. If `None`, 175 `tf.distribute.cluster_resolver.TFConfigClusterResolver` is used. 176 communication_options: optional 177 `tf.distribute.experimental.CommunicationOptions`. This configures the 178 default options for cross device communications. It can be overridden by 179 options provided to the communication APIs like 180 `tf.distribute.ReplicaContext.all_reduce`. See 181 `tf.distribute.experimental.CommunicationOptions` for details. 182 """ 183 if communication_options is None: 184 communication_options = collective_util.Options() 185 super(CollectiveAllReduceStrategy, self).__init__( 186 CollectiveAllReduceExtended( 187 self, 188 cluster_resolver=cluster_resolver, 189 communication_options=communication_options)) 190 191 distribute_lib.distribution_strategy_gauge.get_cell("V2").set( 192 "MultiWorkerMirroredStrategy") 193 # pylint: disable=protected-access 194 distribute_lib.distribution_strategy_replica_gauge.get_cell( 195 "num_workers").set(self.extended._num_workers) 196 distribute_lib.distribution_strategy_replica_gauge.get_cell( 197 "num_replicas_per_worker").set(self.extended._num_gpus_per_worker) 198 199 @classmethod 200 def _from_local_devices(cls, devices, communication_options=None): 201 """A convenience method to create an object with a list of devices.""" 202 obj = cls(communication_options=communication_options) 203 obj.extended._initialize_local(TFConfigClusterResolver(), devices=devices) # pylint: disable=protected-access 204 return obj 205 206 @property 207 def cluster_resolver(self): 208 """Returns the cluster resolver associated with this strategy. 209 210 As a multi-worker strategy, 211 `tf.distribute.experimental.MultiWorkerMirroredStrategy` provides the 212 associated `tf.distribute.cluster_resolver.ClusterResolver`. If the user 213 provides one in `__init__`, that instance is returned; if the user does 214 not, a default `TFConfigClusterResolver` is provided. 215 """ 216 return self.extended._cluster_resolver # pylint: disable=protected-access 217 218 219class _CollectiveAllReduceStrategyExperimentalMeta(type): 220 221 @classmethod 222 def __instancecheck__(cls, instance): 223 # This is to make isinstance(tf.distribute.MultiWorkerMirroredStrategy(), 224 # tf.distribute.experimental.MultiWorkerMirroredStrategy). Some libraries is 225 # performing such check. 226 return isinstance(instance, CollectiveAllReduceStrategy) 227 228 229@tf_export("distribute.experimental.MultiWorkerMirroredStrategy", v1=[]) 230class _CollectiveAllReduceStrategyExperimental( 231 CollectiveAllReduceStrategy, 232 metaclass=_CollectiveAllReduceStrategyExperimentalMeta): 233 234 __doc__ = CollectiveAllReduceStrategy.__doc__ 235 236 @deprecation.deprecated( 237 None, "use distribute.MultiWorkerMirroredStrategy instead") 238 def __init__(self, 239 communication=collective_util.CommunicationImplementation.AUTO, 240 cluster_resolver=None): 241 """Creates the strategy. 242 243 Args: 244 communication: optional 245 `tf.distribute.experimental.CommunicationImplementation`. This is a hint 246 on the preferred collective communication implementation. Possible 247 values include `AUTO`, `RING`, and `NCCL`. 248 cluster_resolver: optional 249 `tf.distribute.cluster_resolver.ClusterResolver`. If `None`, 250 `tf.distribute.cluster_resolver.TFConfigClusterResolver` is used. 251 """ 252 communication_options = collective_util.Options( 253 implementation=communication) 254 super(_CollectiveAllReduceStrategyExperimental, 255 self).__init__(cluster_resolver, communication_options) 256 257 @classmethod 258 def _from_local_devices( 259 cls, 260 devices, 261 communication=collective_util.CommunicationImplementation.AUTO): 262 """A convenience method to create an object with a list of devices.""" 263 obj = cls(communication) 264 obj.extended._initialize_local(TFConfigClusterResolver(), devices=devices) # pylint: disable=protected-access 265 return obj 266 267 268_CollectiveAllReduceStrategyExperimental.__name__ = CollectiveAllReduceStrategy.__name__ 269 270 271@tf_export(v1=["distribute.experimental.MultiWorkerMirroredStrategy"]) # pylint: disable=missing-docstring 272class CollectiveAllReduceStrategyV1(distribute_lib.StrategyV1): 273 274 __doc__ = CollectiveAllReduceStrategy.__doc__ 275 276 # The starting number for collective keys. This should only be set in tests. 277 _collective_key_base = 0 278 279 def __init__(self, 280 communication=collective_util.CommunicationImplementation.AUTO, 281 cluster_resolver=None): 282 """Initializes the object.""" 283 communication_options = collective_util.Options( 284 implementation=communication) 285 super(CollectiveAllReduceStrategyV1, self).__init__( 286 CollectiveAllReduceExtended( 287 self, 288 cluster_resolver=cluster_resolver, 289 communication_options=communication_options)) 290 distribute_lib.distribution_strategy_gauge.get_cell("V1").set( 291 "MultiWorkerMirroredStrategy") 292 # pylint: disable=protected-access 293 distribute_lib.distribution_strategy_replica_gauge.get_cell( 294 "num_workers").set(self.extended._num_workers) 295 distribute_lib.distribution_strategy_replica_gauge.get_cell( 296 "num_gpu_per_worker").set(self.extended._num_gpus_per_worker) 297 298 299class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): 300 """Implementation of CollectiveAllReduceStrategy.""" 301 302 # Whether to perdically check the health of the cluster. If any worker is not 303 # reachable, collectives are aborted and the user program should get a 304 # tf.errors.UnavailableError. It's required to restart in order to recover. 305 _enable_check_health = True 306 # Check health interval in seconds. 307 _check_health_interval = 30 308 # Timeout in seconds for the first check health. The first check health needs 309 # to wait for cluster, which may make a longer time. 310 _check_health_initial_timeout = 0 311 # Times to retry before considering the peer is down. 312 _check_health_retry_limit = 3 313 # Timeout in seconds the each check health. 314 _check_health_timeout = 10 315 316 def __init__(self, container_strategy, cluster_resolver, 317 communication_options): 318 if not isinstance(communication_options, collective_util.Options): 319 raise ValueError("communication_options must be an instance of " 320 "tf.distribute.experimental.CommunicationOptions") 321 self._cluster_resolver = cluster_resolver or TFConfigClusterResolver() 322 if not isinstance(self._cluster_resolver, ClusterResolver): 323 raise ValueError("cluster_resolver must be an instance of " 324 "tf.distribute.cluster_resolver.ClusterResolver") 325 distribute_lib.StrategyExtendedV1.__init__(self, container_strategy) 326 self._communication_options = communication_options 327 self._collective_key_base = container_strategy._collective_key_base # pylint: disable=protected-access 328 self._initialize_strategy(self._cluster_resolver) 329 self._cfer_fn_cache = weakref.WeakKeyDictionary() 330 self.experimental_enable_get_next_as_optional = True 331 assert isinstance(self._cross_device_ops, 332 cross_device_ops_lib.CollectiveAllReduce) 333 334 def _initialize_strategy(self, cluster_resolver): 335 if cluster_resolver.cluster_spec().as_dict(): 336 self._initialize_multi_worker(cluster_resolver) 337 else: 338 self._initialize_local(cluster_resolver) 339 340 def _initialize_local(self, cluster_resolver, devices=None): 341 """Initializes the object for local training.""" 342 self._is_chief = True 343 self._num_workers = 1 344 345 if ops.executing_eagerly_outside_functions(): 346 try: 347 context.context().configure_collective_ops( 348 scoped_allocator_enabled_ops=("CollectiveReduce",)) 349 except RuntimeError: 350 logging.warning("Collective ops is not configured at program startup. " 351 "Some performance features may not be enabled.") 352 self._collective_ops_configured = True 353 354 # TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in 355 # some cases. 356 if isinstance(cluster_resolver, TFConfigClusterResolver): 357 num_gpus = context.num_gpus() 358 else: 359 num_gpus = cluster_resolver.num_accelerators().get("GPU", 0) 360 361 if devices: 362 local_devices = devices 363 else: 364 if num_gpus: 365 local_devices = tuple("/device:GPU:%d" % i for i in range(num_gpus)) 366 else: 367 local_devices = ("/device:CPU:0",) 368 369 self._worker_device = device_util.canonicalize("/device:CPU:0") 370 self._host_input_device = numpy_dataset.SingleDevice(self._worker_device) 371 372 self._collective_keys = cross_device_utils.CollectiveKeys( 373 group_key_start=1 + self._collective_key_base) 374 self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( 375 devices=local_devices, 376 group_size=len(local_devices), 377 collective_keys=self._collective_keys) 378 # CrossDeviceOps for per host tensors. 379 self._host_cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( 380 devices=[self._worker_device], 381 group_size=self._num_workers, 382 collective_keys=self._collective_keys) 383 super(CollectiveAllReduceExtended, self)._initialize_single_worker( 384 local_devices) 385 386 self._cluster_spec = None 387 self._task_type = None 388 self._task_id = None 389 self._id_in_cluster = 0 390 391 # This is a mark to tell whether we are running with standalone client or 392 # independent worker. Right now with standalone client, strategy object is 393 # created as local strategy and then turn into multi-worker strategy via 394 # configure call. 395 self._local_or_standalone_client_mode = True 396 397 # Save the num_gpus_per_worker and rpc_layer for configure method. 398 self._num_gpus_per_worker = num_gpus 399 self._rpc_layer = cluster_resolver.rpc_layer 400 self._warn_nccl_no_gpu() 401 402 logging.info( 403 "Single-worker MultiWorkerMirroredStrategy with local_devices " 404 "= %r, communication = %s", local_devices, 405 self._communication_options.implementation) 406 407 def _initialize_multi_worker(self, cluster_resolver): 408 """Initializes the object for multi-worker training.""" 409 cluster_spec = multi_worker_util.normalize_cluster_spec( 410 cluster_resolver.cluster_spec()) 411 task_type = cluster_resolver.task_type 412 task_id = cluster_resolver.task_id 413 if task_type is None or task_id is None: 414 raise ValueError("When `cluster_spec` is given, you must also specify " 415 "`task_type` and `task_id`.") 416 self._cluster_spec = cluster_spec 417 self._task_type = task_type 418 self._task_id = task_id 419 self._id_in_cluster = multi_worker_util.id_in_cluster( 420 self._cluster_spec, self._task_type, self._task_id) 421 422 self._num_workers = multi_worker_util.worker_count(cluster_spec, task_type) 423 if not self._num_workers: 424 raise ValueError("No `worker`, `chief` or `evaluator` tasks can be found " 425 "in `cluster_spec`.") 426 427 self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type, 428 task_id) 429 430 self._worker_device = "/job:%s/task:%d" % (task_type, task_id) 431 self._host_input_device = numpy_dataset.SingleDevice(self._worker_device) 432 433 if (ops.executing_eagerly_outside_functions() and 434 not getattr(self, "_local_or_standalone_client_mode", False)): 435 context.context().configure_collective_ops( 436 collective_leader=multi_worker_util.collective_leader( 437 cluster_spec, task_type, task_id), 438 scoped_allocator_enabled_ops=("CollectiveReduce",), 439 device_filters=("/job:%s/task:%d" % (task_type, task_id),)) 440 self._collective_ops_configured = True 441 442 # Starting a std server in eager mode and in independent worker mode. 443 if (context.executing_eagerly() and 444 not getattr(self, "_std_server_started", False) and 445 not getattr(self, "_local_or_standalone_client_mode", False)): 446 # Checking _local_or_standalone_client_mode as well because we should not 447 # create the std server in standalone client mode. 448 config_proto = copy.deepcopy(context.context().config) 449 config_proto = self._update_config_proto(config_proto) 450 451 if hasattr(cluster_resolver, "port"): 452 port = cluster_resolver.port 453 else: 454 port = 0 455 server_def = tensorflow_server_pb2.ServerDef( 456 cluster=cluster_spec.as_cluster_def(), 457 default_session_config=config_proto, 458 job_name=task_type, 459 task_index=task_id, 460 protocol=cluster_resolver.rpc_layer or "grpc", 461 port=port) 462 context.context().enable_collective_ops(server_def) 463 self._std_server_started = True 464 # The `ensure_initialized` is needed before calling 465 # `context.context().devices()`. 466 context.context().ensure_initialized() 467 logging.info( 468 "Enabled multi-worker collective ops with available devices: %r", 469 context.context().devices()) 470 471 # TODO(yuefengz): The `num_gpus` is only for this particular task. It 472 # assumes all workers have the same number of GPUs. We should remove this 473 # assumption by querying all tasks for their numbers of GPUs. 474 # TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in 475 # some cases. 476 if isinstance(cluster_resolver, TFConfigClusterResolver): 477 num_gpus = context.num_gpus() 478 else: 479 num_gpus = cluster_resolver.num_accelerators().get("GPU", 0) 480 481 if num_gpus: 482 local_devices = tuple("%s/device:GPU:%d" % (self._worker_device, i) 483 for i in range(num_gpus)) 484 else: 485 local_devices = (self._worker_device,) 486 487 self._collective_keys = cross_device_utils.CollectiveKeys( 488 group_key_start=1 + self._collective_key_base) 489 self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( 490 devices=local_devices, 491 group_size=len(local_devices) * self._num_workers, 492 collective_keys=self._collective_keys) 493 # CrossDeviceOps for per host tensors. 494 self._host_cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( 495 devices=[self._worker_device], 496 group_size=self._num_workers, 497 collective_keys=self._collective_keys) 498 super(CollectiveAllReduceExtended, self)._initialize_single_worker( 499 local_devices) 500 501 # Add a default device so that ops without specified devices will not end up 502 # on other workers. 503 self._default_device = "/job:%s/task:%d" % (task_type, task_id) 504 505 # Save the num_gpus_per_worker and rpc_layer for configure method. 506 self._num_gpus_per_worker = num_gpus 507 self._rpc_layer = cluster_resolver.rpc_layer 508 self._warn_nccl_no_gpu() 509 510 if self._enable_check_health: 511 self._start_check_health_thread() 512 513 logging.info( 514 "MultiWorkerMirroredStrategy with cluster_spec = %r, task_type = %r, " 515 "task_id = %r, num_workers = %r, local_devices = %r, " 516 "communication = %s", cluster_spec.as_dict(), task_type, task_id, 517 self._num_workers, local_devices, 518 self._communication_options.implementation) 519 520 def __del__(self): 521 self._stop_check_health_thread() 522 523 def _input_workers_with_options(self, options=None): 524 host_device = device_util.get_host_for_device(self._worker_device) 525 if not options or options.experimental_prefetch_to_device: 526 return input_lib.InputWorkers([(host_device, self.worker_devices)]) 527 else: 528 return input_lib.InputWorkers([( 529 host_device, 530 [device_util.get_host_for_device(worker) for worker in 531 self.worker_devices])]) 532 533 @property 534 def _input_workers(self): 535 return self._input_workers_with_options() 536 537 def _get_variable_creator_initial_value(self, 538 replica_id, 539 device, 540 primary_var, 541 **kwargs): 542 if replica_id == 0: # First replica on each worker. 543 assert device is not None 544 assert primary_var is None 545 546 def initial_value_fn(): # pylint: disable=g-missing-docstring 547 # Only the first device participates in the broadcast of initial values. 548 group_key = self._collective_keys.get_group_key([device]) 549 group_size = self._num_workers 550 collective_instance_key = ( 551 self._collective_keys.get_instance_key(group_key, device)) 552 553 with ops.device(device): 554 initial_value = kwargs["initial_value"] 555 if callable(initial_value): 556 initial_value = initial_value() 557 if isinstance(initial_value, base.CheckpointInitialValue): 558 initial_value = initial_value.wrapped_value 559 assert not callable(initial_value) 560 initial_value = ops.convert_to_tensor( 561 initial_value, dtype=kwargs.get("dtype", None)) 562 563 if self._num_workers > 1: 564 if self._is_chief: 565 bcast_send = collective_ops.broadcast_send( 566 initial_value, initial_value.shape, initial_value.dtype, 567 group_size, group_key, collective_instance_key) 568 with ops.control_dependencies([bcast_send]): 569 return array_ops.identity(initial_value) 570 else: 571 return collective_ops.broadcast_recv(initial_value.shape, 572 initial_value.dtype, 573 group_size, group_key, 574 collective_instance_key) 575 return initial_value 576 577 return initial_value_fn 578 else: 579 return super(CollectiveAllReduceExtended, 580 self)._get_variable_creator_initial_value( 581 replica_id=replica_id, 582 device=device, 583 primary_var=primary_var, 584 **kwargs) 585 586 def _make_input_context(self): 587 input_context = distribute_lib.InputContext( 588 num_input_pipelines=self._num_workers, 589 input_pipeline_id=self._id_in_cluster, 590 num_replicas_in_sync=self._num_replicas_in_sync) 591 return input_context 592 593 def _experimental_distribute_dataset(self, dataset, options): 594 if (options and options.experimental_replication_mode == 595 distribute_lib.InputReplicationMode.PER_REPLICA): 596 raise NotImplementedError( 597 "InputReplicationMode.PER_REPLICA " 598 "is only supported in " 599 "`experimental_distribute_datasets_from_function`." 600 ) 601 input_context = self._make_input_context() 602 return input_lib.get_distributed_dataset( 603 dataset, 604 self._input_workers_with_options(options), 605 self._container_strategy(), 606 num_replicas_in_sync=self._num_replicas_in_sync, 607 input_context=input_context) 608 609 def _distribute_datasets_from_function(self, dataset_fn, options): 610 if (options and options.experimental_replication_mode == 611 distribute_lib.InputReplicationMode.PER_REPLICA): 612 raise NotImplementedError( 613 "InputReplicationMode.PER_REPLICA " 614 "is only supported in " 615 " `experimental_distribute_datasets_from_function` " 616 "of tf.distribute.MirroredStrategy") 617 input_context = self._make_input_context() 618 return input_lib.get_distributed_datasets_from_function( 619 dataset_fn=dataset_fn, 620 input_workers=self._input_workers_with_options(options), 621 input_contexts=[input_context], 622 strategy=self._container_strategy()) 623 624 def _experimental_distribute_values_from_function(self, value_fn): 625 per_replica_values = [] 626 num_local_replicas = len(self.worker_devices) 627 for local_replica_id in range(num_local_replicas): 628 replica_id = (self._id_in_cluster * num_local_replicas + 629 local_replica_id) 630 value_context = distribute_lib.ValueContext( 631 replica_id, self._num_replicas_in_sync) 632 per_replica_values.append(value_fn(value_context)) 633 return distribute_utils.regroup(per_replica_values, always_wrap=True) 634 635 def _make_dataset_iterator(self, dataset): 636 """Distributes the dataset to each local GPU.""" 637 input_context = self._make_input_context() 638 return input_lib.DatasetIterator( 639 dataset, 640 self._input_workers, 641 self._container_strategy(), 642 num_replicas_in_sync=self._num_replicas_in_sync, 643 input_context=input_context) 644 645 def _make_input_fn_iterator( 646 self, 647 input_fn, 648 replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): 649 """Distributes the input function to each local GPU.""" 650 input_context = self._make_input_context() 651 return input_lib.InputFunctionIterator(input_fn, self._input_workers, 652 [input_context], 653 self._container_strategy()) 654 655 def _configure(self, 656 session_config=None, 657 cluster_spec=None, 658 task_type=None, 659 task_id=None): 660 """Configures the object. 661 662 Args: 663 session_config: a `tf.compat.v1.ConfigProto` 664 cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the 665 cluster configurations. 666 task_type: the current task type, such as "worker". 667 task_id: the current task id. 668 669 Raises: 670 ValueError: if `task_type` is not in the `cluster_spec`. 671 """ 672 if cluster_spec: 673 # Use the num_gpus_per_worker recorded in constructor since _configure 674 # doesn't take num_gpus. 675 cluster_resolver = SimpleClusterResolver( 676 cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec), 677 task_type=task_type, 678 task_id=task_id, 679 num_accelerators={"GPU": self._num_gpus_per_worker}, 680 rpc_layer=self._rpc_layer) 681 self._initialize_multi_worker(cluster_resolver) 682 assert isinstance(self._cross_device_ops, 683 cross_device_ops_lib.CollectiveAllReduce) 684 685 if session_config: 686 session_config.CopyFrom(self._update_config_proto(session_config)) 687 688 def _update_config_proto(self, config_proto): 689 updated_config = copy.deepcopy(config_proto) 690 # Enable the scoped allocator optimization for CollectiveOps. This 691 # optimization converts many small all-reduces into fewer larger 692 # all-reduces. 693 rewrite_options = updated_config.graph_options.rewrite_options 694 rewrite_options.scoped_allocator_optimization = ( 695 rewriter_config_pb2.RewriterConfig.ON) 696 # We turn on ScopedAllocator only for CollectiveReduce op, i.e. enable_op = 697 # ["CollectiveReduce"]. Since we can't assign to a repeated proto field, we 698 # clear and then append. 699 del rewrite_options.scoped_allocator_opts.enable_op[:] 700 rewrite_options.scoped_allocator_opts.enable_op.append("CollectiveReduce") 701 702 if (not ops.executing_eagerly_outside_functions() and 703 self._communication_options.implementation == 704 collective_util.CommunicationImplementation.NCCL): 705 updated_config.experimental.collective_nccl = True 706 707 if not self._cluster_spec: 708 return updated_config 709 710 assert self._task_type 711 assert self._task_id is not None 712 713 # Collective group leader is needed for collective ops to coordinate 714 # workers. 715 updated_config.experimental.collective_group_leader = ( 716 multi_worker_util.collective_leader(self._cluster_spec, self._task_type, 717 self._task_id)) 718 719 # The device filters prevent communication between workers. 720 del updated_config.device_filters[:] 721 updated_config.device_filters.append( 722 "/job:%s/task:%d" % (self._task_type, self._task_id)) 723 724 return updated_config 725 726 def _get_cross_device_ops(self, value): 727 # CollectiveAllReduce works on a predefined set of devices. In most cases 728 # they should be the compute devices, but certain use cases may reduce host 729 # tensors as well (e.g. early stopping). We infer the cross_device_ops to 730 # use based on the number of devices, since inputs don't always have device 731 # annotations. The compute devices one is preferred since we can potentially 732 # leverage NCCL. 733 if isinstance(value, values.DistributedValues): 734 num_devices = len(value._values) # pylint: disable=protected-access 735 else: 736 num_devices = 1 737 if num_devices == len(self.worker_devices): 738 return self._cross_device_ops 739 else: 740 return self._host_cross_device_ops 741 742 def _gather_to_implementation(self, value, destinations, axis, options): 743 return self._get_cross_device_ops(value)._gather( # pylint: disable=protected-access 744 value, 745 destinations=destinations, 746 axis=axis, 747 options=options) 748 749 def _reduce_to(self, reduce_op, value, destinations, options): 750 if (isinstance(value, values.Mirrored) and 751 reduce_op == reduce_util.ReduceOp.MEAN): 752 return value 753 assert not isinstance(value, values.Mirrored) 754 755 if (isinstance(value, values.DistributedValues) and 756 len(self.worker_devices) == 1): 757 value = value.values[0] 758 759 # When there are multiple workers, we need to reduce across workers using 760 # collective ops. 761 if (not isinstance(value, values.DistributedValues) and 762 self._num_workers == 1): 763 # This function handles reducing values that are not PerReplica or 764 # Mirrored values. For example, the same value could be present on all 765 # replicas in which case `value` would be a single value or value could 766 # be 0. 767 return cross_device_ops_lib.reduce_non_distributed_value( 768 reduce_op, value, destinations, len(self.worker_devices)) 769 return self._get_cross_device_ops(value).reduce( 770 reduce_op, 771 value, 772 destinations=destinations, 773 options=self._communication_options.merge(options)) 774 775 def _replica_ctx_all_reduce(self, reduce_op, value, options=None): 776 """Implements `StrategyExtendedV2._replica_ctx_all_reduce`.""" 777 # This implementation avoids using `merge_call` and just launches collective 778 # ops in one replica. 779 if options is None: 780 options = collective_util.Options() 781 782 if context.executing_eagerly(): 783 # In eager mode, falls back to the default implemenation that uses 784 # `merge_call`. Replica functions are running sequentially in eager mode, 785 # and due to the blocking nature of collective ops, execution will hang if 786 # collective ops are to be launched sequentially. 787 return super()._replica_ctx_all_reduce(reduce_op, value, options) 788 789 replica_context = ds_context.get_replica_context() 790 assert replica_context, ( 791 "`StrategyExtended._replica_ctx_all_reduce` must be called in a " 792 "replica context") 793 return self._cross_device_ops._all_reduce( # pylint: disable=protected-access 794 reduce_op, 795 value, 796 replica_context._replica_id, # pylint: disable=protected-access 797 options) 798 799 def _check_health(self): 800 while True: 801 if self._check_health_thread_should_stop.is_set(): 802 return 803 for job in self._cluster_spec.jobs: 804 for task_id in range(self._cluster_spec.num_tasks(job)): 805 peer = "/job:{}/replica:0/task:{}".format(job, task_id) 806 attempts = 0 807 while True: 808 attempts += 1 809 try: 810 context.context().check_collective_ops_peer_health( 811 peer, timeout_in_ms=self._check_health_timeout * 1000) 812 # If check_collective_ops_peer_health doesn't raise an Exception, 813 # the peer is healthy. 814 break 815 except (errors.UnavailableError, errors.FailedPreconditionError, 816 errors.DeadlineExceededError) as e: 817 # TODO(b/151232436): Always raise UnavailableError when a peer 818 # fails. Now there could be many kinds of errors: 819 # - Unavailable: when the peer is not reachable, e.g. it's down. 820 # - FailedPrecondition: when the peer has restarted. 821 if attempts < self._check_health_retry_limit: 822 logging.warning("%s seems down, retrying %d/%d", peer, attempts, 823 self._check_health_retry_limit) 824 continue 825 logging.error( 826 "Cluster check alive failed, %s is down, " 827 "aborting collectives: %s", peer, e) 828 context.context().abort_collective_ops( 829 errors.UNAVAILABLE, 830 "cluster check alive failed, {} is down".format(peer)) 831 return 832 except Exception as e: # pylint: disable=broad-except 833 logging.error("Unexpected exception in check alive: %s", e) 834 context.context().abort_collective_ops( 835 errors.INTERNAL, 836 "unexecpted exception in check alive: %s" % e) 837 return 838 time.sleep(self._check_health_interval) 839 840 def _start_check_health_thread(self): 841 if not context.executing_eagerly(): 842 logging.info("Check health is only supported in eager.") 843 return 844 # Use a dummy all-reduce as a barrier to wait for all workers to be up, 845 # otherwise the check health may fail immediately. 846 847 # Use array_ops.identity to create the dummy tensor so that we have a new 848 # Tensor. If we use constant it may be a cached from on a /job:localhost 849 # device, which will cause some code that relies on tensor.device to error. 850 # 851 # TODO(b/151232436): change to an explicit barrier if we have it. 852 dummy_value = array_ops.identity([]) 853 logging.info("Waiting for the cluster, timeout = %s", 854 self._check_health_initial_timeout or "inf") 855 try: 856 self._host_cross_device_ops.reduce( 857 reduce_util.ReduceOp.SUM, 858 dummy_value, 859 dummy_value, 860 options=collective_util.Options( 861 timeout_seconds=self._check_health_initial_timeout, 862 implementation=collective_util.CommunicationImplementation.RING)) 863 if context.is_async(): 864 context.async_wait() 865 except errors.DeadlineExceededError: 866 raise RuntimeError( 867 "Timeout waiting for the cluster, timeout is %d seconds" % 868 self._check_health_initial_timeout) 869 logging.info("Cluster is ready.") 870 self._check_health_thread_should_stop = threading.Event() 871 # Start the thread as daemon to avoid it blocking the program from exiting. 872 # We try best to shutdown the thread but __del__ is not guaranteed to be 873 # called when program exists. 874 self._check_health_thread = threading.Thread( 875 target=self._check_health, 876 daemon=True) 877 self._check_health_thread.start() 878 879 def _stop_check_health_thread(self): 880 if getattr(self, "_check_health_thread", None): 881 logging.info("stopping check health thread") 882 self._check_health_thread_should_stop.set() 883 self._check_health_thread.join() 884 self._check_health_thread = None 885 logging.info("check health thread stopped") 886 887 def _warn_nccl_no_gpu(self): 888 if ((self._communication_options.implementation == 889 collective_util.CommunicationImplementation.NCCL) and 890 self._num_gpus_per_worker == 0): 891 logging.warning("Enabled NCCL communication but no GPUs detected/" 892 "specified.") 893 894 def _in_multi_worker_mode(self): 895 """Whether this strategy indicates working in multi-worker settings.""" 896 return self._num_workers > 1 897 898 @property 899 def experimental_between_graph(self): 900 return True 901 902 @property 903 def experimental_should_init(self): 904 return True 905 906 @property 907 def should_checkpoint(self): 908 return self._is_chief 909 910 @property 911 def should_save_summary(self): 912 return self._is_chief 913 914 @property 915 def _num_replicas_in_sync(self): 916 return len(self.worker_devices) * self._num_workers 917 918 # TODO(priyag): Delete this once all strategies use global batch size. 919 @property 920 def _global_batch_size(self): 921 """`make_dataset_iterator` and `make_numpy_iterator` use global batch size. 922 923 `make_input_fn_iterator` assumes per-replica batching. 924 925 Returns: 926 Boolean. 927 """ 928 return True 929 930 def _get_replica_id_in_sync_group(self, replica_id): 931 return self._id_in_cluster * len(self.worker_devices) + replica_id 932 933 def _get_local_replica_id(self, replica_id_in_sync_group): 934 return (replica_id_in_sync_group - 935 self._id_in_cluster * len(self.worker_devices)) 936