• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Class CollectiveAllReduceStrategy implementing DistributionStrategy."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import copy
22import threading
23import time
24import weakref
25
26from tensorflow.core.protobuf import rewriter_config_pb2
27from tensorflow.core.protobuf import tensorflow_server_pb2
28from tensorflow.python.distribute import collective_util
29from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
30from tensorflow.python.distribute import cross_device_utils
31from tensorflow.python.distribute import device_util
32from tensorflow.python.distribute import distribute_lib
33from tensorflow.python.distribute import distribute_utils
34from tensorflow.python.distribute import distribution_strategy_context as ds_context
35from tensorflow.python.distribute import input_lib
36from tensorflow.python.distribute import mirrored_strategy
37from tensorflow.python.distribute import multi_worker_util
38from tensorflow.python.distribute import numpy_dataset
39from tensorflow.python.distribute import reduce_util
40from tensorflow.python.distribute import values
41from tensorflow.python.distribute.cluster_resolver import ClusterResolver
42from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
43from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
44from tensorflow.python.eager import context
45from tensorflow.python.framework import device as pydev
46from tensorflow.python.framework import errors
47from tensorflow.python.framework import ops
48from tensorflow.python.ops import array_ops
49from tensorflow.python.ops import collective_ops
50from tensorflow.python.platform import tf_logging as logging
51from tensorflow.python.training.tracking import base
52from tensorflow.python.util import deprecation
53from tensorflow.python.util.tf_export import tf_export
54
55
56# pylint: disable=line-too-long
57@tf_export("distribute.MultiWorkerMirroredStrategy", v1=[])
58class CollectiveAllReduceStrategy(distribute_lib.Strategy):
59  """A distribution strategy for synchronous training on multiple workers.
60
61  This strategy implements synchronous distributed training across multiple
62  workers, each with potentially multiple GPUs. Similar to
63  `tf.distribute.MirroredStrategy`, it replicates all variables and computations
64  to each local device. The difference is that it uses a distributed collective
65  implementation (e.g. all-reduce), so that multiple workers can work together.
66
67  You need to launch your program on each worker and configure
68  `cluster_resolver` correctly. For example, if you are using
69  `tf.distribute.cluster_resolver.TFConfigClusterResolver`, each worker needs to
70  have its corresponding `task_type` and `task_id` set in the `TF_CONFIG`
71  environment variable. An example TF_CONFIG on worker-0 of a two worker cluster
72  is:
73
74  ```
75  TF_CONFIG = '{"cluster": {"worker": ["localhost:12345", "localhost:23456"]}, "task": {"type": "worker", "index": 0} }'
76  ```
77
78  Your program runs on each worker as-is. Note that collectives require each
79  worker to participate. All `tf.distribute` and non `tf.distribute` API may use
80  collectives internally, e.g. checkpointing and saving since reading a
81  `tf.Variable` with `tf.VariableSynchronization.ON_READ` all-reduces the value.
82  Therefore it's recommended to run exactly the same program on each worker.
83  Dispatching based on `task_type` or `task_id` of the worker is error-prone.
84
85  `cluster_resolver.num_accelerators()` determines the number of GPUs the
86  strategy uses. If it's zero, the strategy uses the CPU. All workers need to
87  use the same number of devices, otherwise the behavior is undefined.
88
89  This strategy is not intended for TPU. Use `tf.distribute.TPUStrategy`
90  instead.
91
92  After setting up TF_CONFIG, using this strategy is similar to using
93  `tf.distribute.MirroredStrategy` and `tf.distribute.TPUStrategy`.
94
95  ```
96  strategy = tf.distribute.MultiWorkerMirroredStrategy()
97
98  with strategy.scope():
99    model = tf.keras.Sequential([
100      tf.keras.layers.Dense(2, input_shape=(5,)),
101    ])
102    optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)
103
104  def dataset_fn(ctx):
105    x = np.random.random((2, 5)).astype(np.float32)
106    y = np.random.randint(2, size=(2, 1))
107    dataset = tf.data.Dataset.from_tensor_slices((x, y))
108    return dataset.repeat().batch(1, drop_remainder=True)
109  dist_dataset = strategy.distribute_datasets_from_function(dataset_fn)
110
111  model.compile()
112  model.fit(dist_dataset)
113  ```
114
115  You can also write your own training loop:
116
117  ```
118  @tf.function
119  def train_step(iterator):
120
121    def step_fn(inputs):
122      features, labels = inputs
123      with tf.GradientTape() as tape:
124        logits = model(features, training=True)
125        loss = tf.keras.losses.sparse_categorical_crossentropy(
126            labels, logits)
127
128      grads = tape.gradient(loss, model.trainable_variables)
129      optimizer.apply_gradients(zip(grads, model.trainable_variables))
130
131    strategy.run(step_fn, args=(next(iterator),))
132
133  for _ in range(NUM_STEP):
134    train_step(iterator)
135  ```
136
137  See
138  [Multi-worker training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras)
139  for a detailed tutorial.
140
141  __Saving__
142
143  You need to save and checkpoint on all workers instead of just one. This is
144  because variables whose synchronization=ON_READ triggers aggregation during
145  saving. It's recommended to save to a different path on each worker to avoid
146  race conditions. Each worker saves the same thing. See
147  [Multi-worker training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras#model_saving_and_loading)
148  tutorial for examples.
149
150  __Known Issues__
151
152  * `tf.distribute.cluster_resolver.TFConfigClusterResolver` does not return the
153  correct number of accelerators. The strategy uses all available GPUs if
154  `cluster_resolver` is `tf.distribute.cluster_resolver.TFConfigClusterResolver`
155  or `None`.
156  * In eager mode, the strategy needs to be created before calling any other
157  Tensorflow API.
158
159  """
160  # pylint: enable=line-too-long
161
162  # TODO(anjalisridhar): Update our guides with examples showing how we can use
163  # the cluster_resolver argument.
164
165  # The starting number for collective keys. This should only be set in tests.
166  _collective_key_base = 0
167
168  def __init__(self,
169               cluster_resolver=None,
170               communication_options=None):
171    """Creates the strategy.
172
173    Args:
174      cluster_resolver: optional
175        `tf.distribute.cluster_resolver.ClusterResolver`. If `None`,
176        `tf.distribute.cluster_resolver.TFConfigClusterResolver` is used.
177      communication_options: optional
178        `tf.distribute.experimental.CommunicationOptions`. This configures the
179        default options for cross device communications. It can be overridden by
180        options provided to the communication APIs like
181        `tf.distribute.ReplicaContext.all_reduce`. See
182        `tf.distribute.experimental.CommunicationOptions` for details.
183    """
184    if communication_options is None:
185      communication_options = collective_util.Options()
186    super(CollectiveAllReduceStrategy, self).__init__(
187        CollectiveAllReduceExtended(
188            self,
189            cluster_resolver=cluster_resolver,
190            communication_options=communication_options))
191
192    distribute_lib.distribution_strategy_gauge.get_cell("V2").set(
193        "MultiWorkerMirroredStrategy")
194    # pylint: disable=protected-access
195    distribute_lib.distribution_strategy_replica_gauge.get_cell(
196        "num_workers").set(self.extended._num_workers)
197    distribute_lib.distribution_strategy_replica_gauge.get_cell(
198        "num_replicas_per_worker").set(self.extended._num_gpus_per_worker)
199
200  @classmethod
201  def _from_local_devices(cls, devices, communication_options=None):
202    """A convenience method to create an object with a list of devices."""
203    obj = cls(communication_options=communication_options)
204    obj.extended._initialize_local(TFConfigClusterResolver(), devices=devices)  # pylint: disable=protected-access
205    return obj
206
207  @property
208  def cluster_resolver(self):
209    """Returns the cluster resolver associated with this strategy.
210
211    As a multi-worker strategy, `tf.distribute.MultiWorkerMirroredStrategy`
212    provides the associated `tf.distribute.cluster_resolver.ClusterResolver`. If
213    the user provides one in `__init__`, that instance is returned; if the user
214    does not, a default `TFConfigClusterResolver` is provided.
215    """
216    return self.extended._cluster_resolver  # pylint: disable=protected-access
217
218
219class _CollectiveAllReduceStrategyExperimentalMeta(type):
220
221  @classmethod
222  def __instancecheck__(cls, instance):
223    # This is to make isinstance(tf.distribute.MultiWorkerMirroredStrategy(),
224    # tf.distribute.experimental.MultiWorkerMirroredStrategy). Some libraries is
225    # performing such check.
226    return isinstance(instance, CollectiveAllReduceStrategy)
227
228
229@tf_export("distribute.experimental.MultiWorkerMirroredStrategy", v1=[])
230class _CollectiveAllReduceStrategyExperimental(
231    CollectiveAllReduceStrategy,
232    metaclass=_CollectiveAllReduceStrategyExperimentalMeta):
233
234  __doc__ = CollectiveAllReduceStrategy.__doc__
235
236  @deprecation.deprecated(
237      None, "use distribute.MultiWorkerMirroredStrategy instead")
238  def __init__(self,
239               communication=collective_util.CommunicationImplementation.AUTO,
240               cluster_resolver=None):
241    """Creates the strategy.
242
243    Args:
244      communication: optional
245        `tf.distribute.experimental.CommunicationImplementation`. This is a hint
246        on the preferred collective communication implementation. Possible
247        values include `AUTO`, `RING`, and `NCCL`.
248      cluster_resolver: optional
249        `tf.distribute.cluster_resolver.ClusterResolver`. If `None`,
250        `tf.distribute.cluster_resolver.TFConfigClusterResolver` is used.
251    """
252    communication_options = collective_util.Options(
253        implementation=communication)
254    super(_CollectiveAllReduceStrategyExperimental,
255          self).__init__(cluster_resolver, communication_options)
256
257  @classmethod
258  def _from_local_devices(
259      cls,
260      devices,
261      communication=collective_util.CommunicationImplementation.AUTO):
262    """A convenience method to create an object with a list of devices."""
263    obj = cls(communication)
264    obj.extended._initialize_local(TFConfigClusterResolver(), devices=devices)  # pylint: disable=protected-access
265    return obj
266
267
268_CollectiveAllReduceStrategyExperimental.__name__ = CollectiveAllReduceStrategy.__name__
269
270
271@tf_export(v1=["distribute.experimental.MultiWorkerMirroredStrategy"])  # pylint: disable=missing-docstring
272class CollectiveAllReduceStrategyV1(distribute_lib.StrategyV1):
273
274  __doc__ = CollectiveAllReduceStrategy.__doc__
275
276  # The starting number for collective keys. This should only be set in tests.
277  _collective_key_base = 0
278
279  def __init__(self,
280               communication=collective_util.CommunicationImplementation.AUTO,
281               cluster_resolver=None):
282    """Initializes the object."""
283    communication_options = collective_util.Options(
284        implementation=communication)
285    super(CollectiveAllReduceStrategyV1, self).__init__(
286        CollectiveAllReduceExtended(
287            self,
288            cluster_resolver=cluster_resolver,
289            communication_options=communication_options))
290    distribute_lib.distribution_strategy_gauge.get_cell("V1").set(
291        "MultiWorkerMirroredStrategy")
292    # pylint: disable=protected-access
293    distribute_lib.distribution_strategy_replica_gauge.get_cell(
294        "num_workers").set(self.extended._num_workers)
295    distribute_lib.distribution_strategy_replica_gauge.get_cell(
296        "num_gpu_per_worker").set(self.extended._num_gpus_per_worker)
297
298
299class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
300  """Implementation of CollectiveAllReduceStrategy."""
301
302  # Whether to perdically check the health of the cluster. If any worker is not
303  # reachable, collectives are aborted and the user program should get a
304  # tf.errors.UnavailableError. It's required to restart in order to recover.
305  _enable_check_health = True
306  # Check health interval in seconds.
307  _check_health_interval = 30
308  # Timeout in seconds for the first check health. The first check health needs
309  # to wait for cluster, which may make a longer time.
310  _check_health_initial_timeout = 0
311  # Times to retry before considering the peer is down.
312  _check_health_retry_limit = 3
313  # Timeout in seconds the each check health.
314  _check_health_timeout = 10
315
316  def __init__(self, container_strategy, cluster_resolver,
317               communication_options):
318    if not isinstance(communication_options, collective_util.Options):
319      raise ValueError("communication_options must be an instance of "
320                       "tf.distribute.experimental.CommunicationOptions")
321    self._cluster_resolver = cluster_resolver or TFConfigClusterResolver()
322    if not isinstance(self._cluster_resolver, ClusterResolver):
323      raise ValueError("cluster_resolver must be an instance of "
324                       "tf.distribute.cluster_resolver.ClusterResolver")
325    distribute_lib.StrategyExtendedV1.__init__(self, container_strategy)
326    self._communication_options = communication_options
327    self._collective_key_base = container_strategy._collective_key_base  # pylint: disable=protected-access
328    self._initialize_strategy(self._cluster_resolver)
329    self._cfer_fn_cache = weakref.WeakKeyDictionary()
330    self.experimental_enable_get_next_as_optional = True
331    assert isinstance(self._cross_device_ops,
332                      cross_device_ops_lib.CollectiveAllReduce)
333
334  def _use_merge_call(self):
335    """XLA is not supported for multi-worker strategy."""
336    return True
337
338  def _initialize_strategy(self, cluster_resolver):
339    if cluster_resolver.cluster_spec().as_dict():
340      self._initialize_multi_worker(cluster_resolver)
341    else:
342      self._initialize_local(cluster_resolver)
343
344  def _initialize_local(self, cluster_resolver, devices=None):
345    """Initializes the object for local training."""
346    self._is_chief = True
347    self._num_workers = 1
348
349    if ops.executing_eagerly_outside_functions():
350      try:
351        context.context().configure_collective_ops(
352            scoped_allocator_enabled_ops=("CollectiveReduce",))
353      except RuntimeError:
354        logging.warning("Collective ops is not configured at program startup. "
355                        "Some performance features may not be enabled.")
356      self._collective_ops_configured = True
357
358    # TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in
359    # some cases.
360    if isinstance(cluster_resolver, TFConfigClusterResolver):
361      num_gpus = context.num_gpus()
362    else:
363      num_gpus = cluster_resolver.num_accelerators().get("GPU", 0)
364
365    if devices:
366      local_devices = devices
367    else:
368      if num_gpus:
369        local_devices = tuple("/device:GPU:%d" % i for i in range(num_gpus))
370      else:
371        local_devices = ("/device:CPU:0",)
372
373    self._worker_device = device_util.canonicalize("/device:CPU:0")
374    self._host_input_device = numpy_dataset.SingleDevice(self._worker_device)
375
376    self._collective_keys = cross_device_utils.CollectiveKeys(
377        group_key_start=1 + self._collective_key_base)
378    self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
379        devices=local_devices,
380        group_size=len(local_devices),
381        collective_keys=self._collective_keys)
382    # CrossDeviceOps for per host tensors.
383    self._host_cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
384        devices=[self._worker_device],
385        group_size=self._num_workers,
386        collective_keys=self._collective_keys)
387    super(CollectiveAllReduceExtended, self)._initialize_single_worker(
388        local_devices)
389
390    self._cluster_spec = None
391    self._task_type = None
392    self._task_id = None
393    self._id_in_cluster = 0
394
395    # This is a mark to tell whether we are running with standalone client or
396    # independent worker. Right now with standalone client, strategy object is
397    # created as local strategy and then turn into multi-worker strategy via
398    # configure call.
399    self._local_or_standalone_client_mode = True
400
401    # Save the num_gpus_per_worker and rpc_layer for configure method.
402    self._num_gpus_per_worker = num_gpus
403    self._rpc_layer = cluster_resolver.rpc_layer
404    self._warn_nccl_no_gpu()
405
406    logging.info(
407        "Single-worker MultiWorkerMirroredStrategy with local_devices "
408        "= %r, communication = %s", local_devices,
409        self._communication_options.implementation)
410
411  def _initialize_multi_worker(self, cluster_resolver):
412    """Initializes the object for multi-worker training."""
413    cluster_spec = multi_worker_util.normalize_cluster_spec(
414        cluster_resolver.cluster_spec())
415    task_type = cluster_resolver.task_type
416    task_id = cluster_resolver.task_id
417    if task_type is None or task_id is None:
418      raise ValueError("When `cluster_spec` is given, you must also specify "
419                       "`task_type` and `task_id`.")
420    self._cluster_spec = cluster_spec
421    self._task_type = task_type
422    self._task_id = task_id
423    self._id_in_cluster = multi_worker_util.id_in_cluster(
424        self._cluster_spec, self._task_type, self._task_id)
425
426    self._num_workers = multi_worker_util.worker_count(cluster_spec, task_type)
427    if not self._num_workers:
428      raise ValueError("No `worker`, `chief` or `evaluator` tasks can be found "
429                       "in `cluster_spec`.")
430
431    self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type,
432                                                task_id)
433
434    self._worker_device = "/job:%s/task:%d" % (task_type, task_id)
435    self._host_input_device = numpy_dataset.SingleDevice(self._worker_device)
436
437    if (ops.executing_eagerly_outside_functions() and
438        not getattr(self, "_local_or_standalone_client_mode", False)):
439      context.context().configure_collective_ops(
440          collective_leader=multi_worker_util.collective_leader(
441              cluster_spec, task_type, task_id),
442          scoped_allocator_enabled_ops=("CollectiveReduce",),
443          device_filters=("/job:%s/task:%d" % (task_type, task_id),))
444      self._collective_ops_configured = True
445
446    # Starting a std server in eager mode and in independent worker mode.
447    if (context.executing_eagerly() and
448        not getattr(self, "_std_server_started", False) and
449        not getattr(self, "_local_or_standalone_client_mode", False)):
450      # Checking _local_or_standalone_client_mode as well because we should not
451      # create the std server in standalone client mode.
452      config_proto = copy.deepcopy(context.context().config)
453      config_proto = self._update_config_proto(config_proto)
454
455      # If coordination service is enabled, use its internal heartbeat to detect
456      # peer failures instead of the Python-level health check.
457      if config_proto.experimental.coordination_service:
458        self._enable_check_health = False
459
460      if hasattr(cluster_resolver, "port"):
461        port = cluster_resolver.port
462      else:
463        port = 0
464      server_def = tensorflow_server_pb2.ServerDef(
465          cluster=cluster_spec.as_cluster_def(),
466          default_session_config=config_proto,
467          job_name=task_type,
468          task_index=task_id,
469          protocol=cluster_resolver.rpc_layer or "grpc",
470          port=port)
471      context.context().enable_collective_ops(server_def)
472      self._std_server_started = True
473      # The `ensure_initialized` is needed before calling
474      # `context.context().devices()`.
475      context.context().ensure_initialized()
476      logging.info(
477          "Enabled multi-worker collective ops with available devices: %r",
478          context.context().devices())
479
480    # TODO(yuefengz): The `num_gpus` is only for this particular task. It
481    # assumes all workers have the same number of GPUs. We should remove this
482    # assumption by querying all tasks for their numbers of GPUs.
483    # TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in
484    # some cases.
485    if isinstance(cluster_resolver, TFConfigClusterResolver):
486      num_gpus = 0
487      devices = context.context().devices()
488      for d in devices:
489        device_spec = pydev.DeviceSpec.from_string(d)
490        if (device_spec.job == task_type and device_spec.task == task_id and
491            device_spec.device_type == "GPU"):
492          num_gpus += 1
493    else:
494      num_gpus = cluster_resolver.num_accelerators().get("GPU", 0)
495
496    if num_gpus:
497      local_devices = tuple("%s/device:GPU:%d" % (self._worker_device, i)
498                            for i in range(num_gpus))
499    else:
500      local_devices = (self._worker_device,)
501
502    self._collective_keys = cross_device_utils.CollectiveKeys(
503        group_key_start=1 + self._collective_key_base)
504    self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
505        devices=local_devices,
506        group_size=len(local_devices) * self._num_workers,
507        collective_keys=self._collective_keys)
508    # CrossDeviceOps for per host tensors.
509    self._host_cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
510        devices=[self._worker_device],
511        group_size=self._num_workers,
512        collective_keys=self._collective_keys)
513    super(CollectiveAllReduceExtended, self)._initialize_single_worker(
514        local_devices)
515
516    # Add a default device so that ops without specified devices will not end up
517    # on other workers.
518    self._default_device = "/job:%s/task:%d" % (task_type, task_id)
519
520    # Save the num_gpus_per_worker and rpc_layer for configure method.
521    self._num_gpus_per_worker = num_gpus
522    self._rpc_layer = cluster_resolver.rpc_layer
523    self._warn_nccl_no_gpu()
524
525    if self._enable_check_health and context.executing_eagerly():
526      self._start_check_health_thread()
527    else:
528      logging.info("Check health not enabled.")
529
530    logging.info(
531        "MultiWorkerMirroredStrategy with cluster_spec = %r, task_type = %r, "
532        "task_id = %r, num_workers = %r, local_devices = %r, "
533        "communication = %s", cluster_spec.as_dict(), task_type, task_id,
534        self._num_workers, local_devices,
535        self._communication_options.implementation)
536
537  def __del__(self):
538    self._stop_check_health_thread()
539
540  def _input_workers_with_options(self, options=None):
541    host_device = device_util.get_host_for_device(self._worker_device)
542    if not options or options.experimental_fetch_to_device:
543      return input_lib.InputWorkers([(host_device, self.worker_devices)])
544    else:
545      return input_lib.InputWorkers([(
546          host_device,
547          [device_util.get_host_for_device(worker) for worker in
548           self.worker_devices])])
549
550  @property
551  def _input_workers(self):
552    return self._input_workers_with_options()
553
554  def _get_variable_creator_initial_value(self,
555                                          replica_id,
556                                          device,
557                                          primary_var,
558                                          **kwargs):
559    if replica_id == 0:  # First replica on each worker.
560      assert device is not None
561      assert primary_var is None
562
563      def initial_value_fn():  # pylint: disable=g-missing-docstring
564        # Only the first device participates in the broadcast of initial values.
565        group_key = self._collective_keys.get_group_key([device])
566        group_size = self._num_workers
567        collective_instance_key = (
568            self._collective_keys.get_instance_key(group_key, device))
569
570        with ops.device(device):
571          initial_value = kwargs["initial_value"]
572          if callable(initial_value):
573            initial_value = initial_value()
574          if isinstance(initial_value, base.CheckpointInitialValue):
575            initial_value = initial_value.wrapped_value
576          assert not callable(initial_value)
577          initial_value = ops.convert_to_tensor(
578              initial_value, dtype=kwargs.get("dtype", None))
579
580          if self._num_workers > 1:
581            if self._is_chief:
582              bcast_send = collective_ops.broadcast_send(
583                  initial_value, initial_value.shape, initial_value.dtype,
584                  group_size, group_key, collective_instance_key)
585              with ops.control_dependencies([bcast_send]):
586                return array_ops.identity(initial_value)
587            else:
588              return collective_ops.broadcast_recv(initial_value.shape,
589                                                   initial_value.dtype,
590                                                   group_size, group_key,
591                                                   collective_instance_key)
592          return initial_value
593
594      return initial_value_fn
595    else:
596      return super(CollectiveAllReduceExtended,
597                   self)._get_variable_creator_initial_value(
598                       replica_id=replica_id,
599                       device=device,
600                       primary_var=primary_var,
601                       **kwargs)
602
603  def _make_input_context(self):
604    input_context = distribute_lib.InputContext(
605        num_input_pipelines=self._num_workers,
606        input_pipeline_id=self._id_in_cluster,
607        num_replicas_in_sync=self._num_replicas_in_sync)
608    return input_context
609
610  def _experimental_distribute_dataset(self, dataset, options):
611    if (options and options.experimental_replication_mode ==
612        distribute_lib.InputReplicationMode.PER_REPLICA):
613      raise NotImplementedError(
614          "InputReplicationMode.PER_REPLICA "
615          "is only supported in "
616          "`distribute_datasets_from_function` "
617          "of tf.distribute.MirroredStrategy"
618      )
619    input_context = self._make_input_context()
620    return input_lib.get_distributed_dataset(
621        dataset,
622        self._input_workers_with_options(options),
623        self._container_strategy(),
624        num_replicas_in_sync=self._num_replicas_in_sync,
625        input_context=input_context,
626        options=options)
627
628  def _distribute_datasets_from_function(self, dataset_fn, options):
629    if (options and options.experimental_replication_mode ==
630        distribute_lib.InputReplicationMode.PER_REPLICA):
631      raise NotImplementedError(
632          "InputReplicationMode.PER_REPLICA "
633          "is only supported in "
634          "`distribute_datasets_from_function` "
635          "of tf.distribute.MirroredStrategy")
636    input_context = self._make_input_context()
637    return input_lib.get_distributed_datasets_from_function(
638        dataset_fn=dataset_fn,
639        input_workers=self._input_workers_with_options(options),
640        input_contexts=[input_context],
641        strategy=self._container_strategy(),
642        options=options)
643
644  def _experimental_distribute_values_from_function(self, value_fn):
645    per_replica_values = []
646    num_local_replicas = len(self.worker_devices)
647    for local_replica_id in range(num_local_replicas):
648      replica_id = (self._id_in_cluster * num_local_replicas +
649                    local_replica_id)
650      value_context = distribute_lib.ValueContext(
651          replica_id, self._num_replicas_in_sync)
652      per_replica_values.append(value_fn(value_context))
653    return distribute_utils.regroup(per_replica_values, always_wrap=True)
654
655  def _make_dataset_iterator(self, dataset):
656    """Distributes the dataset to each local GPU."""
657    input_context = self._make_input_context()
658    return input_lib.DatasetIterator(
659        dataset,
660        self._input_workers,
661        self._container_strategy(),
662        num_replicas_in_sync=self._num_replicas_in_sync,
663        input_context=input_context)
664
665  def _make_input_fn_iterator(
666      self,
667      input_fn,
668      replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
669    """Distributes the input function to each local GPU."""
670    input_context = self._make_input_context()
671    return input_lib.InputFunctionIterator(input_fn, self._input_workers,
672                                           [input_context],
673                                           self._container_strategy())
674
675  def _configure(self,
676                 session_config=None,
677                 cluster_spec=None,
678                 task_type=None,
679                 task_id=None):
680    """Configures the object.
681
682    Args:
683      session_config: a `tf.compat.v1.ConfigProto`
684      cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
685        cluster configurations.
686      task_type: the current task type, such as "worker".
687      task_id: the current task id.
688
689    Raises:
690      ValueError: if `task_type` is not in the `cluster_spec`.
691    """
692    if cluster_spec:
693      # Use the num_gpus_per_worker recorded in constructor since _configure
694      # doesn't take num_gpus.
695      cluster_resolver = SimpleClusterResolver(
696          cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec),
697          task_type=task_type,
698          task_id=task_id,
699          num_accelerators={"GPU": self._num_gpus_per_worker},
700          rpc_layer=self._rpc_layer)
701      self._initialize_multi_worker(cluster_resolver)
702      assert isinstance(self._cross_device_ops,
703                        cross_device_ops_lib.CollectiveAllReduce)
704
705    if session_config:
706      session_config.CopyFrom(self._update_config_proto(session_config))
707
708  def _update_config_proto(self, config_proto):
709    updated_config = copy.deepcopy(config_proto)
710    # Enable the scoped allocator optimization for CollectiveOps.  This
711    # optimization converts many small all-reduces into fewer larger
712    # all-reduces.
713    rewrite_options = updated_config.graph_options.rewrite_options
714    rewrite_options.scoped_allocator_optimization = (
715        rewriter_config_pb2.RewriterConfig.ON)
716    # We turn on ScopedAllocator only for CollectiveReduce op, i.e. enable_op =
717    # ["CollectiveReduce"].  Since we can't assign to a repeated proto field, we
718    # clear and then append.
719    del rewrite_options.scoped_allocator_opts.enable_op[:]
720    rewrite_options.scoped_allocator_opts.enable_op.append("CollectiveReduce")
721
722    if (not ops.executing_eagerly_outside_functions() and
723        self._communication_options.implementation ==
724        collective_util.CommunicationImplementation.NCCL):
725      updated_config.experimental.collective_nccl = True
726
727    if not self._cluster_spec:
728      return updated_config
729
730    assert self._task_type
731    assert self._task_id is not None
732
733    # Collective group leader is needed for collective ops to coordinate
734    # workers.
735    updated_config.experimental.collective_group_leader = (
736        multi_worker_util.collective_leader(self._cluster_spec, self._task_type,
737                                            self._task_id))
738
739    # The device filters prevent communication between workers.
740    del updated_config.device_filters[:]
741    updated_config.device_filters.append(
742        "/job:%s/task:%d" % (self._task_type, self._task_id))
743
744    return updated_config
745
746  def _get_cross_device_ops(self, value):
747    # CollectiveAllReduce works on a predefined set of devices. In most cases
748    # they should be the compute devices, but certain use cases may reduce host
749    # tensors as well (e.g. early stopping). We infer the cross_device_ops to
750    # use based on the number of devices, since inputs don't always have device
751    # annotations. The compute devices one is preferred since we can potentially
752    # leverage NCCL.
753    if isinstance(value, values.DistributedValues):
754      num_devices = len(value._values)  # pylint: disable=protected-access
755    else:
756      num_devices = 1
757    if num_devices == len(self.worker_devices):
758      return self._cross_device_ops
759    else:
760      return self._host_cross_device_ops
761
762  def _gather_to_implementation(self, value, destinations, axis, options):
763    return self._get_cross_device_ops(value)._gather(  # pylint: disable=protected-access
764        value,
765        destinations=destinations,
766        axis=axis,
767        options=options)
768
769  def _reduce_to(self, reduce_op, value, destinations, options):
770    if (isinstance(value, values.Mirrored) and
771        reduce_op == reduce_util.ReduceOp.MEAN):
772      return value
773    assert not isinstance(value, values.Mirrored)
774
775    if (isinstance(value, values.DistributedValues) and
776        len(self.worker_devices) == 1):
777      value = value.values[0]
778
779    # When there are multiple workers, we need to reduce across workers using
780    # collective ops.
781    if (not isinstance(value, values.DistributedValues) and
782        self._num_workers == 1):
783      # This function handles reducing values that are not PerReplica or
784      # Mirrored values. For example, the same value could be present on all
785      # replicas in which case `value` would be a single value or value could
786      # be 0.
787      return cross_device_ops_lib.reduce_non_distributed_value(
788          reduce_op, value, destinations, len(self.worker_devices))
789    return self._get_cross_device_ops(value).reduce(
790        reduce_op,
791        value,
792        destinations=destinations,
793        options=self._communication_options.merge(options))
794
795  def _replica_ctx_all_reduce(self, reduce_op, value, options=None):
796    """Implements `StrategyExtendedV2._replica_ctx_all_reduce`."""
797    # This implementation avoids using `merge_call` and just launches collective
798    # ops in one replica.
799    if options is None:
800      options = collective_util.Options()
801
802    if context.executing_eagerly():
803      # In eager mode, falls back to the default implemenation that uses
804      # `merge_call`. Replica functions are running sequentially in eager mode,
805      # and due to the blocking nature of collective ops, execution will hang if
806      # collective ops are to be launched sequentially.
807      return super()._replica_ctx_all_reduce(reduce_op, value, options)
808
809    replica_context = ds_context.get_replica_context()
810    assert replica_context, (
811        "`StrategyExtended._replica_ctx_all_reduce` must be called in a "
812        "replica context")
813    return self._cross_device_ops._all_reduce(  # pylint: disable=protected-access
814        reduce_op,
815        value,
816        replica_context._replica_id,  # pylint: disable=protected-access
817        options)
818
819  def _check_health(self):
820    while True:
821      if self._check_health_thread_should_stop.is_set():
822        return
823      for job in self._cluster_spec.jobs:
824        for task_id in range(self._cluster_spec.num_tasks(job)):
825          peer = "/job:{}/replica:0/task:{}".format(job, task_id)
826          attempts = 0
827          while True:
828            attempts += 1
829            try:
830              context.context().check_collective_ops_peer_health(
831                  peer, timeout_in_ms=self._check_health_timeout * 1000)
832              # If check_collective_ops_peer_health doesn't raise an Exception,
833              # the peer is healthy.
834              break
835            except (errors.UnavailableError, errors.FailedPreconditionError,
836                    errors.DeadlineExceededError) as e:
837              # TODO(b/151232436): Always raise UnavailableError when a peer
838              # fails. Now there could be many kinds of errors:
839              # - Unavailable: when the peer is not reachable, e.g. it's down.
840              # - FailedPrecondition: when the peer has restarted.
841              if attempts < self._check_health_retry_limit:
842                logging.warning("%s seems down, retrying %d/%d", peer, attempts,
843                                self._check_health_retry_limit)
844                continue
845              logging.error(
846                  "Cluster check alive failed, %s is down, "
847                  "aborting collectives: %s", peer, e)
848              context.context().abort_collective_ops(
849                  errors.UNAVAILABLE,
850                  "cluster check alive failed, {} is down".format(peer))
851              return
852            except Exception as e:  # pylint: disable=broad-except
853              logging.error("Unexpected exception in check alive: %s", e)
854              context.context().abort_collective_ops(
855                  errors.INTERNAL,
856                  "unexecpted exception in check alive: %s" % e)
857              return
858      time.sleep(self._check_health_interval)
859
860  def _start_check_health_thread(self):
861    # Use a dummy all-reduce as a barrier to wait for all workers to be up,
862    # otherwise the check health may fail immediately.
863
864    # Use array_ops.identity to create the dummy tensor so that we have a new
865    # Tensor. If we use constant it may be a cached from on a /job:localhost
866    # device, which will cause some code that relies on tensor.device to error.
867    #
868    # TODO(b/151232436): change to an explicit barrier if we have it.
869    dummy_value = array_ops.identity([])
870    logging.info("Waiting for the cluster, timeout = %s",
871                 self._check_health_initial_timeout or "inf")
872    try:
873      self._host_cross_device_ops.reduce(
874          reduce_util.ReduceOp.SUM,
875          dummy_value,
876          dummy_value,
877          options=collective_util.Options(
878              timeout_seconds=self._check_health_initial_timeout,
879              implementation=collective_util.CommunicationImplementation.RING))
880      if context.is_async():
881        context.async_wait()
882    except errors.DeadlineExceededError:
883      raise RuntimeError(
884          "Timeout waiting for the cluster, timeout is %d seconds" %
885          self._check_health_initial_timeout)
886    logging.info("Cluster is ready.")
887    self._check_health_thread_should_stop = threading.Event()
888    # Start the thread as daemon to avoid it blocking the program from exiting.
889    # We try best to shutdown the thread but __del__ is not guaranteed to be
890    # called when program exists.
891    self._check_health_thread = threading.Thread(
892        target=self._check_health,
893        daemon=True)
894    self._check_health_thread.start()
895
896  def _stop_check_health_thread(self):
897    if getattr(self, "_check_health_thread", None):
898      logging.info("stopping check health thread")
899      self._check_health_thread_should_stop.set()
900      self._check_health_thread.join()
901      self._check_health_thread = None
902      logging.info("check health thread stopped")
903
904  def _warn_nccl_no_gpu(self):
905    if ((self._communication_options.implementation ==
906         collective_util.CommunicationImplementation.NCCL) and
907        self._num_gpus_per_worker == 0):
908      logging.warning("Enabled NCCL communication but no GPUs detected/"
909                      "specified.")
910
911  def _in_multi_worker_mode(self):
912    """Whether this strategy indicates working in multi-worker settings."""
913    return self._num_workers > 1
914
915  @property
916  def experimental_between_graph(self):
917    return True
918
919  @property
920  def experimental_should_init(self):
921    return True
922
923  @property
924  def should_checkpoint(self):
925    return self._is_chief
926
927  @property
928  def should_save_summary(self):
929    return self._is_chief
930
931  @property
932  def _num_replicas_in_sync(self):
933    return len(self.worker_devices) * self._num_workers
934
935  # TODO(priyag): Delete this once all strategies use global batch size.
936  @property
937  def _global_batch_size(self):
938    """`make_dataset_iterator` and `make_numpy_iterator` use global batch size.
939
940    `make_input_fn_iterator` assumes per-replica batching.
941
942    Returns:
943      Boolean.
944    """
945    return True
946
947  def _get_replica_id_in_sync_group(self, replica_id):
948    return self._id_in_cluster * len(self.worker_devices) + replica_id
949
950  def _get_local_replica_id(self, replica_id_in_sync_group):
951    return (replica_id_in_sync_group -
952            self._id_in_cluster * len(self.worker_devices))
953
954  def __deepcopy__(self, memo):
955    # We check the check health thread instead of whether we are in eager mode
956    # to limit the backward incompatibility.
957    if hasattr(self, "_check_health_thread"):
958      raise ValueError(
959          "MultiWorkerMirroredStrategy cannot be deep copied in eager mode. "
960          "If you're using Estimator and see this error message, call "
961          "tf.compat.v1.disable_eager_execution() at the beginning of your "
962          "program")
963    # Otherwise, do a regular deepcopy.
964    cls = self.__class__
965    result = cls.__new__(cls)
966    memo[id(self)] = result
967    for k, v in self.__dict__.items():
968      setattr(result, k, copy.deepcopy(v, memo))
969    return result
970