# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Class CollectiveAllReduceStrategy implementing DistributionStrategy.""" import copy import threading import time import weakref from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.core.protobuf import tensorflow_server_pb2 from tensorflow.python.distribute import collective_util from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import cross_device_utils from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import distribute_utils from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import input_lib from tensorflow.python.distribute import input_util from tensorflow.python.distribute import mirrored_strategy from tensorflow.python.distribute import multi_worker_util from tensorflow.python.distribute import numpy_dataset from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import values from tensorflow.python.distribute.cluster_resolver import ClusterResolver from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver from tensorflow.python.distribute.v1 import input_lib as input_lib_v1 from tensorflow.python.eager import context from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import collective_ops from tensorflow.python.ops import control_flow_util from tensorflow.python.platform import tf_logging as logging from tensorflow.python.tpu import tpu_strategy_util from tensorflow.python.trackable import base from tensorflow.python.util import deprecation from tensorflow.python.util.tf_export import tf_export # pylint: disable=line-too-long @tf_export("distribute.MultiWorkerMirroredStrategy", v1=[]) class CollectiveAllReduceStrategy(distribute_lib.Strategy): """A distribution strategy for synchronous training on multiple workers. This strategy implements synchronous distributed training across multiple workers, each with potentially multiple GPUs. Similar to `tf.distribute.MirroredStrategy`, it replicates all variables and computations to each local device. The difference is that it uses a distributed collective implementation (e.g. all-reduce), so that multiple workers can work together. You need to launch your program on each worker and configure `cluster_resolver` correctly. For example, if you are using `tf.distribute.cluster_resolver.TFConfigClusterResolver`, each worker needs to have its corresponding `task_type` and `task_id` set in the `TF_CONFIG` environment variable. An example TF_CONFIG on worker-0 of a two worker cluster is: ``` TF_CONFIG = '{"cluster": {"worker": ["localhost:12345", "localhost:23456"]}, "task": {"type": "worker", "index": 0} }' ``` Your program runs on each worker as-is. Note that collectives require each worker to participate. All `tf.distribute` and non `tf.distribute` API may use collectives internally, e.g. checkpointing and saving since reading a `tf.Variable` with `tf.VariableSynchronization.ON_READ` all-reduces the value. Therefore it's recommended to run exactly the same program on each worker. Dispatching based on `task_type` or `task_id` of the worker is error-prone. `cluster_resolver.num_accelerators()` determines the number of GPUs the strategy uses. If it's zero, the strategy uses the CPU. All workers need to use the same number of devices, otherwise the behavior is undefined. This strategy is not intended for TPU. Use `tf.distribute.TPUStrategy` instead. After setting up TF_CONFIG, using this strategy is similar to using `tf.distribute.MirroredStrategy` and `tf.distribute.TPUStrategy`. ``` strategy = tf.distribute.MultiWorkerMirroredStrategy() with strategy.scope(): model = tf.keras.Sequential([ tf.keras.layers.Dense(2, input_shape=(5,)), ]) optimizer = tf.keras.optimizers.SGD(learning_rate=0.1) def dataset_fn(ctx): x = np.random.random((2, 5)).astype(np.float32) y = np.random.randint(2, size=(2, 1)) dataset = tf.data.Dataset.from_tensor_slices((x, y)) return dataset.repeat().batch(1, drop_remainder=True) dist_dataset = strategy.distribute_datasets_from_function(dataset_fn) model.compile() model.fit(dist_dataset) ``` You can also write your own training loop: ``` @tf.function def train_step(iterator): def step_fn(inputs): features, labels = inputs with tf.GradientTape() as tape: logits = model(features, training=True) loss = tf.keras.losses.sparse_categorical_crossentropy( labels, logits) grads = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) strategy.run(step_fn, args=(next(iterator),)) for _ in range(NUM_STEP): train_step(iterator) ``` See [Multi-worker training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras) for a detailed tutorial. __Saving__ You need to save and checkpoint on all workers instead of just one. This is because variables whose synchronization=ON_READ triggers aggregation during saving. It's recommended to save to a different path on each worker to avoid race conditions. Each worker saves the same thing. See [Multi-worker training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras#model_saving_and_loading) tutorial for examples. __Known Issues__ * `tf.distribute.cluster_resolver.TFConfigClusterResolver` does not return the correct number of accelerators. The strategy uses all available GPUs if `cluster_resolver` is `tf.distribute.cluster_resolver.TFConfigClusterResolver` or `None`. * In eager mode, the strategy needs to be created before calling any other Tensorflow API. """ # pylint: enable=line-too-long # TODO(anjalisridhar): Update our guides with examples showing how we can use # the cluster_resolver argument. # The starting number for collective keys. This should only be set in tests. _collective_key_base = 0 def __init__(self, cluster_resolver=None, communication_options=None): """Creates the strategy. Args: cluster_resolver: optional `tf.distribute.cluster_resolver.ClusterResolver`. If `None`, `tf.distribute.cluster_resolver.TFConfigClusterResolver` is used. communication_options: optional `tf.distribute.experimental.CommunicationOptions`. This configures the default options for cross device communications. It can be overridden by options provided to the communication APIs like `tf.distribute.ReplicaContext.all_reduce`. See `tf.distribute.experimental.CommunicationOptions` for details. """ if communication_options is None: communication_options = collective_util.Options() super(CollectiveAllReduceStrategy, self).__init__( CollectiveAllReduceExtended( self, cluster_resolver=cluster_resolver, communication_options=communication_options)) distribute_lib.distribution_strategy_gauge.get_cell("V2").set( "MultiWorkerMirroredStrategy") # pylint: disable=protected-access distribute_lib.distribution_strategy_replica_gauge.get_cell( "num_workers").set(self.extended._num_workers) distribute_lib.distribution_strategy_replica_gauge.get_cell( "num_replicas_per_worker").set(self.extended._num_devices_per_worker) @classmethod def _from_local_devices(cls, devices, communication_options=None): """A convenience method to create an object with a list of devices.""" obj = cls(communication_options=communication_options) obj.extended._initialize_local(TFConfigClusterResolver(), devices=devices) # pylint: disable=protected-access return obj @property def cluster_resolver(self): """Returns the cluster resolver associated with this strategy. As a multi-worker strategy, `tf.distribute.MultiWorkerMirroredStrategy` provides the associated `tf.distribute.cluster_resolver.ClusterResolver`. If the user provides one in `__init__`, that instance is returned; if the user does not, a default `TFConfigClusterResolver` is provided. """ return self.extended._cluster_resolver # pylint: disable=protected-access class _CollectiveAllReduceStrategyExperimentalMeta(type): @classmethod def __instancecheck__(cls, instance): # This is to make isinstance(tf.distribute.MultiWorkerMirroredStrategy(), # tf.distribute.experimental.MultiWorkerMirroredStrategy). Some libraries is # performing such check. return isinstance(instance, CollectiveAllReduceStrategy) @tf_export("distribute.experimental.MultiWorkerMirroredStrategy", v1=[]) class _CollectiveAllReduceStrategyExperimental( CollectiveAllReduceStrategy, metaclass=_CollectiveAllReduceStrategyExperimentalMeta): __doc__ = CollectiveAllReduceStrategy.__doc__ @deprecation.deprecated( None, "use distribute.MultiWorkerMirroredStrategy instead") def __init__(self, communication=collective_util.CommunicationImplementation.AUTO, cluster_resolver=None): """Creates the strategy. Args: communication: optional `tf.distribute.experimental.CommunicationImplementation`. This is a hint on the preferred collective communication implementation. Possible values include `AUTO`, `RING`, and `NCCL`. cluster_resolver: optional `tf.distribute.cluster_resolver.ClusterResolver`. If `None`, `tf.distribute.cluster_resolver.TFConfigClusterResolver` is used. """ communication_options = collective_util.Options( implementation=communication) super(_CollectiveAllReduceStrategyExperimental, self).__init__(cluster_resolver, communication_options) @classmethod def _from_local_devices( cls, devices, communication=collective_util.CommunicationImplementation.AUTO): """A convenience method to create an object with a list of devices.""" obj = cls(communication) obj.extended._initialize_local(TFConfigClusterResolver(), devices=devices) # pylint: disable=protected-access return obj _CollectiveAllReduceStrategyExperimental.__name__ = CollectiveAllReduceStrategy.__name__ @tf_export(v1=["distribute.experimental.MultiWorkerMirroredStrategy"]) # pylint: disable=missing-docstring class CollectiveAllReduceStrategyV1(distribute_lib.StrategyV1): __doc__ = CollectiveAllReduceStrategy.__doc__ # The starting number for collective keys. This should only be set in tests. _collective_key_base = 0 def __init__(self, communication=collective_util.CommunicationImplementation.AUTO, cluster_resolver=None): """Initializes the object.""" communication_options = collective_util.Options( implementation=communication) super(CollectiveAllReduceStrategyV1, self).__init__( CollectiveAllReduceExtended( self, cluster_resolver=cluster_resolver, communication_options=communication_options)) distribute_lib.distribution_strategy_gauge.get_cell("V1").set( "MultiWorkerMirroredStrategy") # pylint: disable=protected-access distribute_lib.distribution_strategy_replica_gauge.get_cell( "num_workers").set(self.extended._num_workers) distribute_lib.distribution_strategy_replica_gauge.get_cell( "num_gpu_per_worker").set( self.extended._num_devices_per_worker if self.extended._local_device_type == "GPU" else 0) def _is_gpu_device(device): return tf_device.DeviceSpec.from_string(device).device_type == "GPU" class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): """Implementation of CollectiveAllReduceStrategy.""" # Whether to perdically check the health of the cluster. If any worker is not # reachable, collectives are aborted and the user program should get a # tf.errors.UnavailableError. It's required to restart in order to recover. _enable_check_health = True # Check health interval in seconds. _check_health_interval = 30 # Timeout in seconds for the first check health. The first check health needs # to wait for cluster, which may make a longer time. _check_health_initial_timeout = 0 # Times to retry before considering the peer is down. _check_health_retry_limit = 3 # Timeout in seconds the each check health. _check_health_timeout = 10 def __init__(self, container_strategy, cluster_resolver, communication_options, devices=None): if not isinstance(communication_options, collective_util.Options): raise ValueError("communication_options must be an instance of " "tf.distribute.experimental.CommunicationOptions") if cluster_resolver and devices: raise ValueError( "cluster_resolver and devices cannot be set at the same time") self._cluster_resolver = cluster_resolver or TFConfigClusterResolver() if not isinstance(self._cluster_resolver, ClusterResolver): raise ValueError("cluster_resolver must be an instance of " "tf.distribute.cluster_resolver.ClusterResolver") distribute_lib.StrategyExtendedV1.__init__(self, container_strategy) self._communication_options = communication_options self._collective_key_base = container_strategy._collective_key_base # pylint: disable=protected-access self._initialize_strategy(self._cluster_resolver, devices=devices) self._cfer_fn_cache = weakref.WeakKeyDictionary() self.experimental_enable_get_next_as_optional = True assert isinstance(self._cross_device_ops, cross_device_ops_lib.CollectiveAllReduce) def _use_merge_call(self): # We currently only disable merge_call when XLA is used to compile the `fn` # passed to `strategy.run` and all devices are GPU. return not control_flow_util.GraphOrParentsInXlaContext( ops.get_default_graph()) or not all( [_is_gpu_device(d) for d in self._devices]) def _initialize_strategy(self, cluster_resolver, devices): # If devices are provided or cluster_spec is not specified, initialize # single worker. Otherwise initialize multi workers. if devices or not cluster_resolver.cluster_spec().as_dict(): self._initialize_local(cluster_resolver, devices=devices) else: self._initialize_multi_worker(cluster_resolver) def _initialize_local_devices(self, cluster_resolver, worker_device): # TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in # some cases. if isinstance(cluster_resolver, TFConfigClusterResolver): num_gpus = context.num_gpus() num_tpus = 0 else: num_gpus = cluster_resolver.num_accelerators().get("GPU", 0) num_tpus = cluster_resolver.num_accelerators().get("TPU", 0) if num_gpus: local_device_type = "GPU" num_local_devices = num_gpus elif num_tpus: local_device_type = "TPU" num_local_devices = num_tpus else: local_device_type = "CPU" num_local_devices = 1 local_devices = tuple( f"{worker_device}/device:{local_device_type}:{i}" for i in range(num_local_devices)) return local_devices, local_device_type def _initialize_local(self, cluster_resolver, devices=None): """Initializes the object for local training.""" self._is_chief = True self._num_workers = 1 if ops.executing_eagerly_outside_functions(): try: context.context().configure_collective_ops( scoped_allocator_enabled_ops=("CollectiveReduce",)) except RuntimeError: logging.warning("Collective ops is not configured at program startup. " "Some performance features may not be enabled.") self._collective_ops_configured = True if devices: local_devices = devices if "GPU" in devices[0]: local_device_type = "GPU" elif "TPU" in devices[0]: local_device_type = "TPU" else: local_device_type = "CPU" else: local_devices, local_device_type = self._initialize_local_devices( cluster_resolver, worker_device="") self._worker_device = device_util.canonicalize("/device:CPU:0") self._host_input_device = numpy_dataset.SingleDevice(self._worker_device) self._collective_keys = cross_device_utils.CollectiveKeys( group_key_start=1 + self._collective_key_base) self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( devices=local_devices, group_size=len(local_devices), options=self._communication_options, collective_keys=self._collective_keys) # CrossDeviceOps for per host tensors. self._host_cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( devices=[self._worker_device], group_size=self._num_workers, options=self._communication_options, collective_keys=self._collective_keys) super(CollectiveAllReduceExtended, self)._initialize_single_worker( local_devices) self._cluster_spec = None self._task_type = None self._task_id = None self._id_in_cluster = 0 # This is a mark to tell whether we are running with standalone client or # independent worker. Right now with standalone client, strategy object is # created as local strategy and then turn into multi-worker strategy via # configure call. self._local_or_standalone_client_mode = True # Save the num_devices_per_worker and rpc_layer for configure method. self._num_devices_per_worker = len(local_devices) self._local_device_type = local_device_type self._rpc_layer = cluster_resolver.rpc_layer self._warn_nccl_no_gpu() logging.info( "Single-worker MultiWorkerMirroredStrategy with local_devices " "= %r, communication = %s", local_devices, self._communication_options.implementation) def _initialize_multi_worker(self, cluster_resolver): """Initializes the object for multi-worker training.""" cluster_spec = multi_worker_util.normalize_cluster_spec( cluster_resolver.cluster_spec()) task_type = cluster_resolver.task_type task_id = cluster_resolver.task_id if task_type is None or task_id is None: raise ValueError("When `cluster_spec` is given, you must also specify " "`task_type` and `task_id`.") self._cluster_spec = cluster_spec self._task_type = task_type self._task_id = task_id self._id_in_cluster = multi_worker_util.id_in_cluster( self._cluster_spec, self._task_type, self._task_id) self._num_workers = multi_worker_util.worker_count(cluster_spec, task_type) if not self._num_workers: raise ValueError("No `worker`, `chief` or `evaluator` tasks can be found " "in `cluster_spec`.") self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type, task_id) self._worker_device = "/job:%s/task:%d" % (task_type, task_id) self._host_input_device = numpy_dataset.SingleDevice(self._worker_device) if (ops.executing_eagerly_outside_functions() and not getattr(self, "_local_or_standalone_client_mode", False)): context.context().configure_collective_ops( collective_leader=multi_worker_util.collective_leader( cluster_spec, task_type, task_id), scoped_allocator_enabled_ops=("CollectiveReduce",), device_filters=("/job:%s/task:%d" % (task_type, task_id),)) self._collective_ops_configured = True if context.context().coordination_service is None: coordinated_jobs = ["chief", "worker"] if task_type in coordinated_jobs: context.context().configure_coordination_service( service_type="standalone", service_leader=multi_worker_util.coordination_leader( cluster_spec), coordinated_jobs=coordinated_jobs) # Starting a std server in eager mode and in independent worker mode. if (context.executing_eagerly() and not getattr(self, "_std_server_started", False) and not getattr(self, "_local_or_standalone_client_mode", False)): # Checking _local_or_standalone_client_mode as well because we should not # create the std server in standalone client mode. config_proto = copy.deepcopy(context.context().config) config_proto = self._update_config_proto(config_proto) # If coordination service is enabled, use its internal heartbeat to detect # peer failures instead of the Python-level health check. if config_proto.experimental.coordination_config.service_type: self._enable_check_health = False if hasattr(cluster_resolver, "port"): port = cluster_resolver.port else: port = 0 server_def = tensorflow_server_pb2.ServerDef( cluster=cluster_spec.as_cluster_def(), default_session_config=config_proto, job_name=task_type, task_index=task_id, protocol=cluster_resolver.rpc_layer or "grpc", port=port) context.context().enable_collective_ops(server_def) self._std_server_started = True # The `ensure_initialized` is needed before calling # `context.context().devices()`. context.context().ensure_initialized() logging.info( "Enabled multi-worker collective ops with available devices: %r", context.context().devices()) # TODO(yuefengz): The `num_gpus` is only for this particular task. It # assumes all workers have the same number of GPUs. We should remove this # assumption by querying all tasks for their numbers of GPUs. # TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in # some cases. local_devices, local_device_type = self._initialize_local_devices( cluster_resolver, self._worker_device) if local_device_type == "TPU": tpu_strategy_util.initialize_tpu_system() self._collective_keys = cross_device_utils.CollectiveKeys( group_key_start=1 + self._collective_key_base) self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( devices=local_devices, group_size=len(local_devices) * self._num_workers, options=self._communication_options, collective_keys=self._collective_keys) # CrossDeviceOps for per host tensors. self._host_cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( devices=[self._worker_device], group_size=self._num_workers, options=self._communication_options, collective_keys=self._collective_keys) super(CollectiveAllReduceExtended, self)._initialize_single_worker( local_devices) # Add a default device so that ops without specified devices will not end up # on other workers. self._default_device = "/job:%s/task:%d" % (task_type, task_id) # Save the num_devices_per_worker and rpc_layer for configure method. self._num_devices_per_worker = len(local_devices) self._local_device_type = local_device_type self._rpc_layer = cluster_resolver.rpc_layer self._warn_nccl_no_gpu() if self._enable_check_health and context.executing_eagerly(): self._start_check_health_thread() else: logging.info("Check health not enabled.") logging.info( "MultiWorkerMirroredStrategy with cluster_spec = %r, task_type = %r, " "task_id = %r, num_workers = %r, local_devices = %r, " "communication = %s", cluster_spec.as_dict(), task_type, task_id, self._num_workers, local_devices, self._communication_options.implementation) def __del__(self): self._stop_check_health_thread() def _input_workers_with_options(self, options=None): host_device = device_util.get_host_for_device(self._worker_device) if not options or options.experimental_fetch_to_device: return input_lib.InputWorkers([(host_device, self.worker_devices)]) else: return input_lib.InputWorkers([( host_device, [device_util.get_host_for_device(worker) for worker in self.worker_devices])]) @property def _input_workers(self): return self._input_workers_with_options() def _get_variable_creator_initial_value(self, replica_id, device, primary_var, **kwargs): if replica_id == 0: # First replica on each worker. assert device is not None assert primary_var is None def initial_value_fn(): # pylint: disable=g-missing-docstring # Only the first device participates in the broadcast of initial values. group_key = self._collective_keys.get_group_key([device]) group_size = self._num_workers collective_instance_key = ( self._collective_keys.get_instance_key(group_key, device)) with ops.device(device): initial_value = kwargs["initial_value"] if callable(initial_value): initial_value = initial_value() if isinstance(initial_value, base.CheckpointInitialValue): initial_value = initial_value.wrapped_value assert not callable(initial_value) initial_value = ops.convert_to_tensor( initial_value, dtype=kwargs.get("dtype", None)) if self._num_workers > 1: if self._is_chief: bcast_send = collective_ops.broadcast_send( initial_value, initial_value.shape, initial_value.dtype, group_size, group_key, collective_instance_key) with ops.control_dependencies([bcast_send]): return array_ops.identity(initial_value) else: return collective_ops.broadcast_recv(initial_value.shape, initial_value.dtype, group_size, group_key, collective_instance_key) return initial_value return initial_value_fn else: return super(CollectiveAllReduceExtended, self)._get_variable_creator_initial_value( replica_id=replica_id, device=device, primary_var=primary_var, **kwargs) def _make_input_context(self): input_context = distribute_lib.InputContext( num_input_pipelines=self._num_workers, input_pipeline_id=self._id_in_cluster, num_replicas_in_sync=self._num_replicas_in_sync) return input_context def _experimental_distribute_dataset(self, dataset, options): if (options and options.experimental_replication_mode == distribute_lib.InputReplicationMode.PER_REPLICA): raise NotImplementedError( "InputReplicationMode.PER_REPLICA " "is only supported in " "`distribute_datasets_from_function` " "of tf.distribute.MirroredStrategy" ) input_context = self._make_input_context() return input_util.get_distributed_dataset( dataset, self._input_workers_with_options(options), self._container_strategy(), num_replicas_in_sync=self._num_replicas_in_sync, input_context=input_context, options=options) def _distribute_datasets_from_function(self, dataset_fn, options): if (options and options.experimental_replication_mode == distribute_lib.InputReplicationMode.PER_REPLICA): raise NotImplementedError( "InputReplicationMode.PER_REPLICA " "is only supported in " "`distribute_datasets_from_function` " "of tf.distribute.MirroredStrategy") input_context = self._make_input_context() return input_util.get_distributed_datasets_from_function( dataset_fn=dataset_fn, input_workers=self._input_workers_with_options(options), input_contexts=[input_context], strategy=self._container_strategy(), options=options) def _experimental_distribute_values_from_function(self, value_fn): per_replica_values = [] num_local_replicas = len(self.worker_devices) for local_replica_id in range(num_local_replicas): replica_id = (self._id_in_cluster * num_local_replicas + local_replica_id) value_context = distribute_lib.ValueContext( replica_id, self._num_replicas_in_sync) per_replica_values.append(value_fn(value_context)) return distribute_utils.regroup(per_replica_values, always_wrap=True) def _make_dataset_iterator(self, dataset): """Distributes the dataset to each local GPU.""" input_context = self._make_input_context() return input_lib_v1.DatasetIterator( dataset, self._input_workers, self._container_strategy(), num_replicas_in_sync=self._num_replicas_in_sync, input_context=input_context) def _make_input_fn_iterator( self, input_fn, replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): """Distributes the input function to each local GPU.""" input_context = self._make_input_context() return input_lib_v1.InputFunctionIterator(input_fn, self._input_workers, [input_context], self._container_strategy()) def _configure(self, session_config=None, cluster_spec=None, task_type=None, task_id=None): """Configures the object. Args: session_config: a `tf.compat.v1.ConfigProto` cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the cluster configurations. task_type: the current task type, such as "worker". task_id: the current task id. Raises: ValueError: if `task_type` is not in the `cluster_spec`. """ if cluster_spec: cluster_resolver = SimpleClusterResolver( cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec), task_type=task_type, task_id=task_id, num_accelerators={ self._local_device_type: self._num_devices_per_worker}, rpc_layer=self._rpc_layer) self._initialize_multi_worker(cluster_resolver) assert isinstance(self._cross_device_ops, cross_device_ops_lib.CollectiveAllReduce) if session_config: session_config.CopyFrom(self._update_config_proto(session_config)) def _update_config_proto(self, config_proto): updated_config = copy.deepcopy(config_proto) # Enable the scoped allocator optimization for CollectiveOps. This # optimization converts many small all-reduces into fewer larger # all-reduces. rewrite_options = updated_config.graph_options.rewrite_options rewrite_options.scoped_allocator_optimization = ( rewriter_config_pb2.RewriterConfig.ON) # We turn on ScopedAllocator only for CollectiveReduce op, i.e. enable_op = # ["CollectiveReduce"]. Since we can't assign to a repeated proto field, we # clear and then append. del rewrite_options.scoped_allocator_opts.enable_op[:] rewrite_options.scoped_allocator_opts.enable_op.append("CollectiveReduce") if (not ops.executing_eagerly_outside_functions() and self._communication_options.implementation == collective_util.CommunicationImplementation.NCCL): updated_config.experimental.collective_nccl = True if not self._cluster_spec: return updated_config assert self._task_type assert self._task_id is not None # Collective group leader is needed for collective ops to coordinate # workers. updated_config.experimental.collective_group_leader = ( multi_worker_util.collective_leader(self._cluster_spec, self._task_type, self._task_id)) # The device filters prevent communication between workers. del updated_config.device_filters[:] updated_config.device_filters.append( "/job:%s/task:%d" % (self._task_type, self._task_id)) return updated_config def _get_cross_device_ops(self, value): # CollectiveAllReduce works on a predefined set of devices. In most cases # they should be the compute devices, but certain use cases may reduce host # tensors as well (e.g. early stopping). We infer the cross_device_ops to # use based on the number of devices, since inputs don't always have device # annotations. The compute devices one is preferred since we can potentially # leverage NCCL. if isinstance(value, values.DistributedValues): num_devices = len(value._values) # pylint: disable=protected-access else: num_devices = 1 if num_devices == len(self.worker_devices): return self._cross_device_ops else: return self._host_cross_device_ops def _gather_to_implementation(self, value, destinations, axis, options): return self._get_cross_device_ops(value)._gather( # pylint: disable=protected-access value, destinations=destinations, axis=axis, options=options) def _reduce_to(self, reduce_op, value, destinations, options): if (isinstance(value, values.Mirrored) and reduce_op == reduce_util.ReduceOp.MEAN): return value assert not isinstance(value, values.Mirrored) if (isinstance(value, values.DistributedValues) and len(self.worker_devices) == 1): value = value.values[0] # When there are multiple workers, we need to reduce across workers using # collective ops. if (not isinstance(value, values.DistributedValues) and self._num_workers == 1): # This function handles reducing values that are not PerReplica or # Mirrored values. For example, the same value could be present on all # replicas in which case `value` would be a single value or value could # be 0. return cross_device_ops_lib.reduce_non_distributed_value( reduce_op, value, destinations, len(self.worker_devices)) return self._get_cross_device_ops(value).reduce( reduce_op, value, destinations=destinations, options=self._communication_options.merge(options)) def _replica_ctx_all_reduce(self, reduce_op, value, options=None): """Implements `StrategyExtendedV2._replica_ctx_all_reduce`.""" # This implementation avoids using `merge_call` and just launches collective # ops in one replica. if options is None: options = collective_util.Options() if context.executing_eagerly(): # In eager mode, falls back to the default implemenation that uses # `merge_call`. Replica functions are running sequentially in eager mode, # and due to the blocking nature of collective ops, execution will hang if # collective ops are to be launched sequentially. return super()._replica_ctx_all_reduce(reduce_op, value, options) replica_context = ds_context.get_replica_context() assert replica_context, ( "`StrategyExtended._replica_ctx_all_reduce` must be called in a " "replica context") return self._cross_device_ops._all_reduce( # pylint: disable=protected-access reduce_op, value, replica_context._replica_id, # pylint: disable=protected-access options) def _check_health(self): while True: if self._check_health_thread_should_stop.is_set(): return for job in self._cluster_spec.jobs: for task_id in range(self._cluster_spec.num_tasks(job)): peer = "/job:{}/replica:0/task:{}".format(job, task_id) attempts = 0 while True: attempts += 1 try: context.context().check_collective_ops_peer_health( peer, timeout_in_ms=self._check_health_timeout * 1000) # If check_collective_ops_peer_health doesn't raise an Exception, # the peer is healthy. break except (errors.UnavailableError, errors.FailedPreconditionError, errors.DeadlineExceededError) as e: # TODO(b/151232436): Always raise UnavailableError when a peer # fails. Now there could be many kinds of errors: # - Unavailable: when the peer is not reachable, e.g. it's down. # - FailedPrecondition: when the peer has restarted. if attempts < self._check_health_retry_limit: logging.warning("%s seems down, retrying %d/%d", peer, attempts, self._check_health_retry_limit) continue logging.error( "Cluster check alive failed, %s is down, " "aborting collectives: %s", peer, e) context.context().abort_collective_ops( errors.UNAVAILABLE, "cluster check alive failed, {} is down".format(peer)) return except Exception as e: # pylint: disable=broad-except logging.error("Unexpected exception in check alive: %s", e) context.context().abort_collective_ops( errors.INTERNAL, "unexecpted exception in check alive: %s" % e) return time.sleep(self._check_health_interval) def _start_check_health_thread(self): # Use a dummy all-reduce as a barrier to wait for all workers to be up, # otherwise the check health may fail immediately. # Use array_ops.identity to create the dummy tensor so that we have a new # Tensor. If we use constant it may be a cached from on a /job:localhost # device, which will cause some code that relies on tensor.device to error. # # TODO(b/151232436): change to an explicit barrier if we have it. dummy_value = array_ops.identity([]) logging.info("Waiting for the cluster, timeout = %s", self._check_health_initial_timeout or "inf") try: self._host_cross_device_ops.reduce( reduce_util.ReduceOp.SUM, dummy_value, dummy_value, options=collective_util.Options( timeout_seconds=self._check_health_initial_timeout, implementation=collective_util.CommunicationImplementation.RING)) if context.is_async(): context.async_wait() except errors.DeadlineExceededError: raise RuntimeError( "Timeout waiting for the cluster, timeout is %d seconds" % self._check_health_initial_timeout) logging.info("Cluster is ready.") self._check_health_thread_should_stop = threading.Event() # Start the thread as daemon to avoid it blocking the program from exiting. # We try best to shutdown the thread but __del__ is not guaranteed to be # called when program exists. self._check_health_thread = threading.Thread( target=self._check_health, daemon=True) self._check_health_thread.start() def _stop_check_health_thread(self): if getattr(self, "_check_health_thread", None): logging.info("stopping check health thread") self._check_health_thread_should_stop.set() self._check_health_thread.join() self._check_health_thread = None logging.info("check health thread stopped") def _warn_nccl_no_gpu(self): if ((self._communication_options.implementation == collective_util.CommunicationImplementation.NCCL) and self._local_device_type != "GPU"): logging.warning("Enabled NCCL communication but no GPUs detected/" "specified.") def _in_multi_worker_mode(self): """Whether this strategy indicates working in multi-worker settings.""" return self._num_workers > 1 @property def experimental_between_graph(self): return True @property def experimental_should_init(self): return True @property def should_checkpoint(self): return self._is_chief @property def should_save_summary(self): return self._is_chief @property def _num_replicas_in_sync(self): return len(self.worker_devices) * self._num_workers # TODO(priyag): Delete this once all strategies use global batch size. @property def _global_batch_size(self): """`make_dataset_iterator` and `make_numpy_iterator` use global batch size. `make_input_fn_iterator` assumes per-replica batching. Returns: Boolean. """ return True def _get_replica_id_in_sync_group(self, replica_id): return self._id_in_cluster * len(self.worker_devices) + replica_id def _get_local_replica_id(self, replica_id_in_sync_group): return (replica_id_in_sync_group - self._id_in_cluster * len(self.worker_devices)) def __deepcopy__(self, memo): # We check the check health thread instead of whether we are in eager mode # to limit the backward incompatibility. if hasattr(self, "_check_health_thread"): raise ValueError( "MultiWorkerMirroredStrategy cannot be deep copied in eager mode. " "If you're using Estimator and see this error message, call " "tf.compat.v1.disable_eager_execution() at the beginning of your " "program") # Otherwise, do a regular deepcopy. cls = self.__class__ result = cls.__new__(cls) memo[id(self)] = result for k, v in self.__dict__.items(): setattr(result, k, copy.deepcopy(v, memo)) return result