• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Class MirroredStrategy implementing tf.distribute.Strategy."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import copy
22
23from tensorflow.python import tf2
24from tensorflow.python.distribute import collective_util
25from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
26from tensorflow.python.distribute import cross_device_utils
27from tensorflow.python.distribute import device_util
28from tensorflow.python.distribute import distribute_lib
29from tensorflow.python.distribute import distribute_utils
30from tensorflow.python.distribute import distribution_strategy_context
31from tensorflow.python.distribute import input_lib
32from tensorflow.python.distribute import mirrored_run
33from tensorflow.python.distribute import multi_worker_util
34from tensorflow.python.distribute import numpy_dataset
35from tensorflow.python.distribute import reduce_util
36from tensorflow.python.distribute import values
37from tensorflow.python.distribute import values_util
38from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
39from tensorflow.python.eager import context
40from tensorflow.python.eager import tape
41from tensorflow.python.framework import config
42from tensorflow.python.framework import constant_op
43from tensorflow.python.framework import device as tf_device
44from tensorflow.python.framework import dtypes
45from tensorflow.python.framework import ops
46from tensorflow.python.ops import array_ops
47from tensorflow.python.ops import control_flow_ops
48from tensorflow.python.ops import control_flow_util
49from tensorflow.python.platform import tf_logging as logging
50from tensorflow.python.util import nest
51from tensorflow.python.util.tf_export import tf_export
52
53# TODO(josh11b): Replace asserts in this file with if ...: raise ...
54
55
56def _is_device_list_single_worker(devices):
57  """Checks whether the devices list is for single or multi-worker.
58
59  Args:
60    devices: a list of device strings or tf.config.LogicalDevice objects, for
61      either local or for remote devices.
62
63  Returns:
64    a boolean indicating whether these device strings are for local or for
65    remote.
66
67  Raises:
68    ValueError: if device strings are not consistent.
69  """
70  specs = []
71  for d in devices:
72    name = d.name if isinstance(d, context.LogicalDevice) else d
73    specs.append(tf_device.DeviceSpec.from_string(name))
74  num_workers = len({(d.job, d.task, d.replica) for d in specs})
75  all_local = all(d.job in (None, "localhost") for d in specs)
76  any_local = any(d.job in (None, "localhost") for d in specs)
77
78  if any_local and not all_local:
79    raise ValueError("Local device string cannot have job specified other "
80                     "than 'localhost'")
81
82  if num_workers == 1 and not all_local:
83    if any(d.task is None for d in specs):
84      raise ValueError("Remote device string must have task specified.")
85
86  return num_workers == 1
87
88
89def _cluster_spec_to_device_list(cluster_spec, num_gpus_per_worker):
90  """Returns a device list given a cluster spec."""
91  cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
92  devices = []
93  for task_type in ("chief", "worker"):
94    for task_id in range(len(cluster_spec.as_dict().get(task_type, []))):
95      if num_gpus_per_worker == 0:
96        devices.append("/job:%s/task:%d/device:CPU:0" % (task_type, task_id))
97      else:
98        devices.extend([
99            "/job:%s/task:%d/device:GPU:%i" % (task_type, task_id, gpu_id)
100            for gpu_id in range(num_gpus_per_worker)
101        ])
102  return devices
103
104
105def _group_device_list(devices):
106  """Groups the devices list by task_type and task_id.
107
108  Args:
109    devices: a list of device strings for remote devices.
110
111  Returns:
112    a dict of list of device strings mapping from task_type to a list of devices
113    for the task_type in the ascending order of task_id.
114  """
115  assert not _is_device_list_single_worker(devices)
116  device_dict = {}
117
118  for d in devices:
119    d_spec = tf_device.DeviceSpec.from_string(d)
120
121    # Create an entry for the task_type.
122    if d_spec.job not in device_dict:
123      device_dict[d_spec.job] = []
124
125    # Fill the device list for task_type until it covers the task_id.
126    while len(device_dict[d_spec.job]) <= d_spec.task:
127      device_dict[d_spec.job].append([])
128
129    device_dict[d_spec.job][d_spec.task].append(d)
130
131  return device_dict
132
133
134def _is_gpu_device(device):
135  return tf_device.DeviceSpec.from_string(device).device_type == "GPU"
136
137
138def _infer_num_gpus_per_worker(devices):
139  """Infers the number of GPUs on each worker.
140
141  Currently to make multi-worker cross device ops work, we need all workers to
142  have the same number of GPUs.
143
144  Args:
145    devices: a list of device strings, can be either local devices or remote
146      devices.
147
148  Returns:
149    number of GPUs per worker.
150
151  Raises:
152    ValueError if workers have different number of GPUs or GPU indices are not
153    consecutive and starting from 0.
154  """
155  if _is_device_list_single_worker(devices):
156    return sum(1 for d in devices if _is_gpu_device(d))
157  else:
158    device_dict = _group_device_list(devices)
159    num_gpus = None
160    for _, devices_in_task in device_dict.items():
161      for device_in_task in devices_in_task:
162        if num_gpus is None:
163          num_gpus = sum(1 for d in device_in_task if _is_gpu_device(d))
164
165        # Verify other workers have the same number of GPUs.
166        elif num_gpus != sum(1 for d in device_in_task if _is_gpu_device(d)):
167          raise ValueError("All workers should have the same number of GPUs.")
168
169        for d in device_in_task:
170          d_spec = tf_device.DeviceSpec.from_string(d)
171          if (d_spec.device_type == "GPU" and
172              d_spec.device_index >= num_gpus):
173            raise ValueError("GPU `device_index` on a worker should be "
174                             "consecutive and start from 0.")
175    return num_gpus
176
177
178def all_local_devices(num_gpus=None):
179  devices = config.list_logical_devices("GPU")
180  if num_gpus is not None:
181    devices = devices[:num_gpus]
182  return devices or config.list_logical_devices("CPU")
183
184
185def all_devices():
186  devices = []
187  tfconfig = TFConfigClusterResolver()
188  if tfconfig.cluster_spec().as_dict():
189    devices = _cluster_spec_to_device_list(tfconfig.cluster_spec(),
190                                           context.num_gpus())
191  return devices if devices else all_local_devices()
192
193
194@tf_export("distribute.MirroredStrategy", v1=[])  # pylint: disable=g-classes-have-attributes
195class MirroredStrategy(distribute_lib.Strategy):
196  """Synchronous training across multiple replicas on one machine.
197
198  This strategy is typically used for training on one
199  machine with multiple GPUs. For TPUs, use
200  `tf.distribute.TPUStrategy`. To use `MirroredStrategy` with multiple workers,
201  please refer to `tf.distribute.experimental.MultiWorkerMirroredStrategy`.
202
203  For example, a variable created under a `MirroredStrategy` is a
204  `MirroredVariable`. If no devices are specified in the constructor argument of
205  the strategy then it will use all the available GPUs. If no GPUs are found, it
206  will use the available CPUs. Note that TensorFlow treats all CPUs on a
207  machine as a single device, and uses threads internally for parallelism.
208
209  >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
210  >>> with strategy.scope():
211  ...   x = tf.Variable(1.)
212  >>> x
213  MirroredVariable:{
214    0: <tf.Variable ... shape=() dtype=float32, numpy=1.0>,
215    1: <tf.Variable ... shape=() dtype=float32, numpy=1.0>
216  }
217
218  While using distribution strategies, all the variable creation should be done
219  within the strategy's scope. This will replicate the variables across all the
220  replicas and keep them in sync using an all-reduce algorithm.
221
222  Variables created inside a `MirroredStrategy` which is wrapped with a
223  `tf.function` are still `MirroredVariables`.
224
225  >>> x = []
226  >>> @tf.function  # Wrap the function with tf.function.
227  ... def create_variable():
228  ...   if not x:
229  ...     x.append(tf.Variable(1.))
230  ...   return x[0]
231  >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
232  >>> with strategy.scope():
233  ...   _ = create_variable()
234  ...   print(x[0])
235  MirroredVariable:{
236    0: <tf.Variable ... shape=() dtype=float32, numpy=1.0>,
237    1: <tf.Variable ... shape=() dtype=float32, numpy=1.0>
238  }
239
240  `experimental_distribute_dataset` can be used to distribute the dataset across
241  the replicas when writing your own training loop. If you are using `.fit` and
242  `.compile` methods available in `tf.keras`, then `tf.keras` will handle the
243  distribution for you.
244
245  For example:
246
247  ```python
248  my_strategy = tf.distribute.MirroredStrategy()
249  with my_strategy.scope():
250    @tf.function
251    def distribute_train_epoch(dataset):
252      def replica_fn(input):
253        # process input and return result
254        return result
255
256      total_result = 0
257      for x in dataset:
258        per_replica_result = my_strategy.run(replica_fn, args=(x,))
259        total_result += my_strategy.reduce(tf.distribute.ReduceOp.SUM,
260                                           per_replica_result, axis=None)
261      return total_result
262
263    dist_dataset = my_strategy.experimental_distribute_dataset(dataset)
264    for _ in range(EPOCHS):
265      train_result = distribute_train_epoch(dist_dataset)
266  ```
267
268  Args:
269    devices: a list of device strings such as `['/gpu:0', '/gpu:1']`.  If
270      `None`, all available GPUs are used. If no GPUs are found, CPU is used.
271    cross_device_ops: optional, a descedant of `CrossDeviceOps`. If this is not
272      set, `NcclAllReduce()` will be used by default.  One would customize this
273      if NCCL isn't available or if a special implementation that exploits
274      the particular hardware is available.
275  """
276
277  # Only set this in tests.
278  _collective_key_base = 0
279
280  def __init__(self, devices=None, cross_device_ops=None):
281    extended = MirroredExtended(
282        self, devices=devices, cross_device_ops=cross_device_ops)
283    super(MirroredStrategy, self).__init__(extended)
284    distribute_lib.distribution_strategy_gauge.get_cell("V2").set(
285        "MirroredStrategy")
286
287
288@tf_export(v1=["distribute.MirroredStrategy"])
289class MirroredStrategyV1(distribute_lib.StrategyV1):  # pylint: disable=g-missing-docstring
290
291  __doc__ = MirroredStrategy.__doc__
292
293  # Only set this in tests.
294  _collective_key_base = 0
295
296  def __init__(self, devices=None, cross_device_ops=None):
297    extended = MirroredExtended(
298        self, devices=devices, cross_device_ops=cross_device_ops)
299    super(MirroredStrategyV1, self).__init__(extended)
300    distribute_lib.distribution_strategy_gauge.get_cell("V1").set(
301        "MirroredStrategy")
302
303
304# TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1.
305class MirroredExtended(distribute_lib.StrategyExtendedV1):
306  """Implementation of MirroredStrategy."""
307
308  # If this is set to True, use NCCL collective ops instead of NCCL cross device
309  # ops.
310  _prefer_collective_ops = False
311
312  def __init__(self, container_strategy, devices=None, cross_device_ops=None):
313    super(MirroredExtended, self).__init__(container_strategy)
314    if context.executing_eagerly():
315      if devices and not _is_device_list_single_worker(devices):
316        raise RuntimeError("In-graph multi-worker training with "
317                           "`MirroredStrategy` is not supported in eager mode.")
318      else:
319        if TFConfigClusterResolver().cluster_spec().as_dict():
320          # if you are executing in eager mode, only the single machine code
321          # path is supported.
322          logging.info("Initializing local devices since in-graph multi-worker "
323                       "training with `MirroredStrategy` is not supported in "
324                       "eager mode. TF_CONFIG will be ignored when "
325                       "when initializing `MirroredStrategy`.")
326        devices = devices or all_local_devices()
327    else:
328      devices = devices or all_devices()
329
330    assert devices, ("Got an empty `devices` list and unable to recognize "
331                     "any local devices.")
332    self._cross_device_ops = cross_device_ops
333    self._collective_ops_in_use = False
334    self._collective_key_base = container_strategy._collective_key_base
335    self._initialize_strategy(devices)
336    self._communication_options = collective_util.Options(
337        implementation=collective_util.CommunicationImplementation.NCCL)
338
339    # TODO(b/128995245): Enable last partial batch support in graph mode.
340    if ops.executing_eagerly_outside_functions():
341      self.experimental_enable_get_next_as_optional = True
342
343    # Flag to turn on VariablePolicy.
344    self._use_var_policy = False
345
346  def _use_merge_call(self):
347    # We currently only disable merge_call when XLA is used to compile the `fn`
348    # passed to `strategy.run` and all devices are GPU.
349    return not control_flow_util.GraphOrParentsInXlaContext(
350        ops.get_default_graph()) or not all(
351            [_is_gpu_device(d) for d in self._devices])
352
353  def _initialize_strategy(self, devices):
354    # The _initialize_strategy method is intended to be used by distribute
355    # coordinator as well.
356    assert devices, "Must specify at least one device."
357    devices = tuple(device_util.resolve(d) for d in devices)
358    assert len(set(devices)) == len(devices), (
359        "No duplicates allowed in `devices` argument: %s" % (devices,))
360    if _is_device_list_single_worker(devices):
361      self._initialize_single_worker(devices)
362      self._collective_ops = self._make_collective_ops(devices)
363      if self._prefer_collective_ops and (
364          isinstance(self._cross_device_ops, cross_device_ops_lib.NcclAllReduce)
365          or isinstance(self._inferred_cross_device_ops,
366                        cross_device_ops_lib.NcclAllReduce)):
367        self._collective_ops_in_use = True
368        self._inferred_cross_device_ops = None
369      logging.info("Using MirroredStrategy with devices %r", devices)
370    else:
371      self._initialize_multi_worker(devices)
372
373  def _make_collective_ops(self, devices):
374    self._collective_keys = cross_device_utils.CollectiveKeys(
375        group_key_start=1 + self._collective_key_base)  # pylint: disable=protected-access
376    return cross_device_ops_lib.CollectiveAllReduce(
377        devices=self._devices,
378        group_size=len(self._devices),
379        collective_keys=self._collective_keys)
380
381  def _initialize_single_worker(self, devices):
382    """Initializes the object for single-worker training."""
383    self._devices = tuple(device_util.canonicalize(d) for d in devices)
384    self._input_workers_devices = (
385        (device_util.canonicalize("/device:CPU:0", devices[0]), devices),)
386
387    self._inferred_cross_device_ops = None if self._cross_device_ops else (
388        cross_device_ops_lib.select_cross_device_ops(devices))
389    self._host_input_device = numpy_dataset.SingleDevice(
390        self._input_workers_devices[0][0])
391    self._is_multi_worker_training = False
392    device_spec = tf_device.DeviceSpec.from_string(
393        self._input_workers_devices[0][0])
394    # Ensures when we enter strategy.scope() we use the correct default device
395    if device_spec.job is not None and device_spec.job != "localhost":
396      self._default_device = "/job:%s/replica:%d/task:%d" % (
397          device_spec.job, device_spec.replica, device_spec.task)
398
399  def _initialize_multi_worker(self, devices):
400    """Initializes the object for multi-worker training."""
401    device_dict = _group_device_list(devices)
402    workers = []
403    worker_devices = []
404    for job in ("chief", "worker"):
405      for task in range(len(device_dict.get(job, []))):
406        worker = "/job:%s/task:%d" % (job, task)
407        workers.append(worker)
408        worker_devices.append((worker, device_dict[job][task]))
409
410    # Setting `_default_device` will add a device scope in the
411    # distribution.scope. We set the default device to the first worker. When
412    # users specify device under distribution.scope by
413    #   with tf.device("/cpu:0"):
414    #     ...
415    # their ops will end up on the cpu device of its first worker, e.g.
416    # "/job:worker/task:0/device:CPU:0". Note this is not used in replica mode.
417    self._default_device = workers[0]
418    self._host_input_device = numpy_dataset.SingleDevice(workers[0])
419
420    self._devices = tuple(devices)
421    self._input_workers_devices = worker_devices
422    self._is_multi_worker_training = True
423
424    if len(workers) > 1:
425      # Grandfather usage in the legacy tests if they're configured properly.
426      if (not isinstance(self._cross_device_ops,
427                         cross_device_ops_lib.ReductionToOneDevice) or
428          self._cross_device_ops._num_between_graph_workers > 1):  # pylint: disable=protected-access
429        raise ValueError(
430            "In-graph multi-worker training with `MirroredStrategy` is not "
431            "supported.")
432      self._inferred_cross_device_ops = self._cross_device_ops
433    else:
434      # TODO(yuefengz): make `select_cross_device_ops` work with device strings
435      # containing job names.
436      self._inferred_cross_device_ops = cross_device_ops_lib.NcclAllReduce()
437
438    logging.info("Using MirroredStrategy with remote devices %r", devices)
439
440  def _input_workers_with_options(self, options=None):
441    if not options:
442      return input_lib.InputWorkers(self._input_workers_devices)
443    if (options.experimental_replication_mode ==
444        distribute_lib.InputReplicationMode.PER_REPLICA):
445      if options.experimental_place_dataset_on_device:
446        self._input_workers_devices = (
447            tuple(
448                (device_util.canonicalize(d, d), (d,)) for d in self._devices))
449      else:
450        self._input_workers_devices = (
451            tuple((device_util.canonicalize("/device:CPU:0", d), (d,))
452                  for d in self._devices))
453      return input_lib.InputWorkers(self._input_workers_devices)
454    else:
455      if not options.experimental_fetch_to_device:
456        return input_lib.InputWorkers([
457            (host_device, (host_device,) * len(compute_devices))
458            for host_device, compute_devices in self._input_workers_devices
459        ])
460      else:
461        return input_lib.InputWorkers(self._input_workers_devices)
462
463  @property
464  def _input_workers(self):
465    return self._input_workers_with_options()
466
467  def _get_variable_creator_initial_value(self,
468                                          replica_id,
469                                          device,
470                                          primary_var,
471                                          **kwargs):
472    """Return the initial value for variables on a replica."""
473    if replica_id == 0:
474      return kwargs["initial_value"]
475    else:
476      assert primary_var is not None
477      assert device is not None
478      assert kwargs is not None
479
480      def initial_value_fn():
481        if context.executing_eagerly() or ops.inside_function():
482          init_value = primary_var.value()
483          return array_ops.identity(init_value)
484        else:
485          with ops.device(device):
486            init_value = primary_var.initial_value
487            return array_ops.identity(init_value)
488
489      return initial_value_fn
490
491  def _create_variable(self, next_creator, **kwargs):
492    """Create a mirrored variable. See `DistributionStrategy.scope`."""
493    colocate_with = kwargs.pop("colocate_with", None)
494    if colocate_with is None:
495      devices = self._devices
496    elif isinstance(colocate_with, numpy_dataset.SingleDevice):
497      with ops.device(colocate_with.device):
498        return next_creator(**kwargs)
499    else:
500      devices = colocate_with._devices  # pylint: disable=protected-access
501
502    def _real_mirrored_creator(**kwargs):  # pylint: disable=g-missing-docstring
503      value_list = []
504      for i, d in enumerate(devices):
505        with ops.device(d):
506          kwargs["initial_value"] = self._get_variable_creator_initial_value(
507              replica_id=i,
508              device=d,
509              primary_var=value_list[0] if value_list else None,
510              **kwargs)
511          if i > 0:
512            # Give replicas meaningful distinct names:
513            var0name = value_list[0].name.split(":")[0]
514            # We append a / to variable names created on replicas with id > 0 to
515            # ensure that we ignore the name scope and instead use the given
516            # name as the absolute name of the variable.
517            kwargs["name"] = "%s/replica_%d/" % (var0name, i)
518          with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
519            # Don't record operations (e.g. other variable reads) during
520            # variable creation.
521            with tape.stop_recording():
522              v = next_creator(**kwargs)
523          assert not isinstance(v, values.DistributedVariable)
524          value_list.append(v)
525      return value_list
526
527    return distribute_utils.create_mirrored_variable(
528        self._container_strategy(), _real_mirrored_creator,
529        distribute_utils.VARIABLE_CLASS_MAPPING,
530        distribute_utils.VARIABLE_POLICY_MAPPING, **kwargs)
531
532  def _validate_colocate_with_variable(self, colocate_with_variable):
533    distribute_utils.validate_colocate_distributed_variable(
534        colocate_with_variable, self)
535
536  def _make_dataset_iterator(self, dataset):
537    return input_lib.DatasetIterator(
538        dataset,
539        self._input_workers,
540        self._container_strategy(),
541        num_replicas_in_sync=self._num_replicas_in_sync)
542
543  def _make_input_fn_iterator(
544      self,
545      input_fn,
546      replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
547    input_contexts = []
548    num_workers = self._input_workers.num_workers
549    for i in range(num_workers):
550      input_contexts.append(distribute_lib.InputContext(
551          num_input_pipelines=num_workers,
552          input_pipeline_id=i,
553          num_replicas_in_sync=self._num_replicas_in_sync))
554    return input_lib.InputFunctionIterator(input_fn, self._input_workers,
555                                           input_contexts,
556                                           self._container_strategy())
557
558  def _experimental_distribute_dataset(self, dataset, options):
559    if (options and options.experimental_replication_mode ==
560        distribute_lib.InputReplicationMode.PER_REPLICA):
561      raise NotImplementedError(
562          "InputReplicationMode.PER_REPLICA "
563          "is only supported in "
564          "`experimental_distribute_datasets_from_function`."
565      )
566    return input_lib.get_distributed_dataset(
567        dataset,
568        self._input_workers_with_options(options),
569        self._container_strategy(),
570        num_replicas_in_sync=self._num_replicas_in_sync,
571        options=options)
572
573  def _experimental_make_numpy_dataset(self, numpy_input, session):
574    return numpy_dataset.one_host_numpy_dataset(
575        numpy_input, self._host_input_device, session)
576
577  def _distribute_datasets_from_function(self, dataset_fn, options):
578    input_workers = self._input_workers_with_options(options)
579    input_contexts = []
580    num_workers = input_workers.num_workers
581    for i in range(num_workers):
582      input_contexts.append(distribute_lib.InputContext(
583          num_input_pipelines=num_workers,
584          input_pipeline_id=i,
585          num_replicas_in_sync=self._num_replicas_in_sync))
586
587    return input_lib.get_distributed_datasets_from_function(
588        dataset_fn, input_workers, input_contexts, self._container_strategy(),
589        options)
590
591  def _experimental_distribute_values_from_function(self, value_fn):
592    per_replica_values = []
593    for replica_id in range(self._num_replicas_in_sync):
594      per_replica_values.append(value_fn(
595          distribute_lib.ValueContext(replica_id,
596                                      self._num_replicas_in_sync)))
597    return distribute_utils.regroup(per_replica_values, always_wrap=True)
598
599  # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
600  def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
601                                          initial_loop_values=None):
602    if initial_loop_values is None:
603      initial_loop_values = {}
604    initial_loop_values = nest.flatten(initial_loop_values)
605
606    ctx = input_lib.MultiStepContext()
607    def body(i, *args):
608      """A wrapper around `fn` to create the while loop body."""
609      del args
610      fn_result = fn(ctx, iterator.get_next())
611      for (name, output) in ctx.last_step_outputs.items():
612        # Convert all outputs to tensors, potentially from `DistributedValues`.
613        ctx.last_step_outputs[name] = self._local_results(output)
614      flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
615      with ops.control_dependencies([fn_result]):
616        return [i + 1] + flat_last_step_outputs
617
618    # We capture the control_flow_context at this point, before we run `fn`
619    # inside a while_loop. This is useful in cases where we might need to exit
620    # these contexts and get back to the outer context to do some things, for
621    # e.g. create an op which should be evaluated only once at the end of the
622    # loop on the host. One such usage is in creating metrics' value op.
623    self._outer_control_flow_context = (
624        ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access
625
626    cond = lambda i, *args: i < iterations
627    i = constant_op.constant(0)
628    loop_result = control_flow_ops.while_loop(
629        cond, body, [i] + initial_loop_values, name="",
630        parallel_iterations=1, back_prop=False, swap_memory=False,
631        return_same_structure=True)
632    del self._outer_control_flow_context
633
634    ctx.run_op = control_flow_ops.group(loop_result)
635
636    # Convert the last_step_outputs from a list to the original dict structure
637    # of last_step_outputs.
638    last_step_tensor_outputs = loop_result[1:]
639    last_step_tensor_outputs_dict = nest.pack_sequence_as(
640        ctx.last_step_outputs, last_step_tensor_outputs)
641
642    for name, reduce_op in ctx._last_step_outputs_reduce_ops.items():  # pylint: disable=protected-access
643      output = last_step_tensor_outputs_dict[name]
644      # For outputs that have already been reduced, wrap them in a Mirrored
645      # container, else in a PerReplica container.
646      if reduce_op is None:
647        last_step_tensor_outputs_dict[name] = distribute_utils.regroup(output)
648      else:
649        assert len(output) == 1
650        last_step_tensor_outputs_dict[name] = output[0]
651
652    ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access
653    return ctx
654
655  def _broadcast_to(self, tensor, destinations):
656    # This is both a fast path for Python constants, and a way to delay
657    # converting Python values to a tensor until we know what type it
658    # should be converted to. Otherwise we have trouble with:
659    #   global_step.assign_add(1)
660    # since the `1` gets broadcast as an int32 but global_step is int64.
661    if isinstance(tensor, (float, int)):
662      return tensor
663    # TODO(josh11b): In eager mode, use one thread per device, or async mode.
664    if not destinations:
665      # TODO(josh11b): Use current logical device instead of 0 here.
666      destinations = self._devices
667    return self._get_cross_device_ops(tensor).broadcast(tensor, destinations)
668
669  def _call_for_each_replica(self, fn, args, kwargs):
670    return mirrored_run.call_for_each_replica(
671        self._container_strategy(), fn, args, kwargs)
672
673  def _configure(self,
674                 session_config=None,
675                 cluster_spec=None,
676                 task_type=None,
677                 task_id=None):
678    del task_type, task_id
679
680    if session_config:
681      session_config.CopyFrom(self._update_config_proto(session_config))
682
683    if cluster_spec:
684      # TODO(yuefengz): remove the following code once cluster_resolver is
685      # added.
686      num_gpus_per_worker = _infer_num_gpus_per_worker(self._devices)
687      multi_worker_devices = _cluster_spec_to_device_list(
688          cluster_spec, num_gpus_per_worker)
689      self._initialize_multi_worker(multi_worker_devices)
690
691  def _update_config_proto(self, config_proto):
692    updated_config = copy.deepcopy(config_proto)
693    updated_config.isolate_session_state = True
694    return updated_config
695
696  def _get_cross_device_ops(self, value):
697    if not self._use_merge_call():
698      return self._collective_ops
699
700    if self._collective_ops_in_use:
701      if isinstance(value, values.DistributedValues):
702        value_int32 = True in {
703            dtypes.as_dtype(v.dtype) == dtypes.int32 for v in value.values
704        }
705      else:
706        value_int32 = dtypes.as_dtype(value.dtype) == dtypes.int32
707      if value_int32:
708        return cross_device_ops_lib.ReductionToOneDevice()
709      else:
710        return self._collective_ops
711
712    return self._cross_device_ops or self._inferred_cross_device_ops
713
714  def _gather_to_implementation(self, value, destinations, axis, options):
715    if not isinstance(value, values.DistributedValues):
716      # ReductionToOneDevice._gather accepts DistributedValues only.
717      return value
718    return self._get_cross_device_ops(value)._gather(  # pylint: disable=protected-access
719        value,
720        destinations=destinations,
721        axis=axis,
722        options=self._communication_options.merge(options))
723
724  def _reduce_to(self, reduce_op, value, destinations, options):
725    if (distribute_utils.is_mirrored(value) and
726        reduce_op == reduce_util.ReduceOp.MEAN):
727      return value
728    assert not distribute_utils.is_mirrored(value)
729    def get_values(value):
730      if not isinstance(value, values.DistributedValues):
731        # This function handles reducing values that are not PerReplica or
732        # Mirrored values. For example, the same value could be present on all
733        # replicas in which case `value` would be a single value or value could
734        # be 0.
735        return cross_device_ops_lib.reduce_non_distributed_value(
736            reduce_op, value, destinations, self._num_replicas_in_sync)
737      if self._use_merge_call() and self._collective_ops_in_use and ((
738          not cross_device_ops_lib._devices_match(value, destinations) or  # pylint: disable=protected-access
739          any("cpu" in d.lower()
740              for d in cross_device_ops_lib.get_devices_from(destinations)))):
741        return cross_device_ops_lib.ReductionToOneDevice().reduce(
742            reduce_op, value, destinations)
743      return self._get_cross_device_ops(value).reduce(
744          reduce_op,
745          value,
746          destinations=destinations,
747          options=self._communication_options.merge(options))
748
749    return nest.map_structure(get_values, value)
750
751  def _batch_reduce_to(self, reduce_op, value_destination_pairs, options):
752    cross_device_ops = None
753    for value, _ in value_destination_pairs:
754      if cross_device_ops is None:
755        cross_device_ops = self._get_cross_device_ops(value)
756      elif cross_device_ops is not self._get_cross_device_ops(value):
757        raise ValueError("inputs to batch_reduce_to must be either all on the "
758                         "the host or all on the compute devices")
759    return cross_device_ops.batch_reduce(
760        reduce_op,
761        value_destination_pairs,
762        options=self._communication_options.merge(options))
763
764  def _update(self, var, fn, args, kwargs, group):
765    # TODO(josh11b): In eager mode, use one thread per device.
766    assert isinstance(var, values.DistributedVariable)
767    updates = []
768    for i, v in enumerate(var.values):
769      name = "update_%d" % i
770      with ops.device(v.device), \
771           distribute_lib.UpdateContext(i), \
772           ops.name_scope(name):
773        # If args and kwargs are not mirrored, the value is returned as is.
774        updates.append(
775            fn(v, *distribute_utils.select_replica(i, args),
776               **distribute_utils.select_replica(i, kwargs)))
777    return distribute_utils.update_regroup(self, updates, group)
778
779  def _replica_ctx_all_reduce(self, reduce_op, value, options=None):
780    """Implements `StrategyExtendedV2._replica_ctx_all_reduce`."""
781    # This implementation avoids using `merge_call` and just launches collective
782    # ops in one replica.
783    if options is None:
784      options = collective_util.Options()
785
786    if context.executing_eagerly() or (
787        not tf2.enabled()) or self._use_merge_call():
788      # In eager mode, falls back to the default implementation that uses
789      # `merge_call`. Replica functions are running sequentially in eager mode,
790      # and due to the blocking nature of collective ops, execution will hang if
791      # collective ops are to be launched sequentially.
792      return super()._replica_ctx_all_reduce(reduce_op, value, options)
793
794    replica_context = distribution_strategy_context.get_replica_context()
795    assert replica_context, (
796        "`StrategyExtended._replica_ctx_all_reduce` must be called in a "
797        "replica context")
798    return self._get_cross_device_ops(value)._all_reduce(  # pylint: disable=protected-access
799        reduce_op,
800        value,
801        replica_context._replica_id,  # pylint: disable=protected-access
802        options)
803
804  def _replica_ctx_update(self, var, fn, args, kwargs, group):
805    if self._use_merge_call():
806      return super()._replica_ctx_update(var, fn, args, kwargs, group)
807
808    replica_context = distribution_strategy_context.get_replica_context()
809    assert replica_context
810    replica_id = values_util.get_current_replica_id_as_int()
811    name = "update_%d" % replica_id
812
813    if isinstance(var, values.DistributedVariable):
814      var = var._get_replica(replica_id)  # pylint: disable=protected-access
815
816    with ops.device(var.device), ops.name_scope(name):
817      result = fn(var, *args, **kwargs)
818    return result
819
820  def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
821    assert isinstance(colocate_with, tuple)
822    # TODO(josh11b): In eager mode, use one thread per device.
823    updates = []
824    for i, d in enumerate(colocate_with):
825      name = "update_%d" % i
826      with ops.device(d), distribute_lib.UpdateContext(i), ops.name_scope(name):
827        updates.append(
828            fn(*distribute_utils.select_replica(i, args),
829               **distribute_utils.select_replica(i, kwargs)))
830    return distribute_utils.update_regroup(self, updates, group)
831
832  def read_var(self, replica_local_var):
833    """Read the aggregate value of a replica-local variable."""
834    # pylint: disable=protected-access
835    if distribute_utils.is_sync_on_read(replica_local_var):
836      return replica_local_var._get_cross_replica()
837    assert distribute_utils.is_mirrored(replica_local_var)
838    return array_ops.identity(replica_local_var._get())
839    # pylint: enable=protected-access
840
841  def value_container(self, val):
842    return distribute_utils.value_container(val)
843
844  @property
845  def _num_replicas_in_sync(self):
846    return len(self._devices)
847
848  @property
849  def worker_devices(self):
850    return self._devices
851
852  @property
853  def worker_devices_by_replica(self):
854    return [[d] for d in self._devices]
855
856  @property
857  def parameter_devices(self):
858    return self.worker_devices
859
860  @property
861  def experimental_between_graph(self):
862    return False
863
864  @property
865  def experimental_should_init(self):
866    return True
867
868  @property
869  def should_checkpoint(self):
870    return True
871
872  @property
873  def should_save_summary(self):
874    return True
875
876  def non_slot_devices(self, var_list):
877    del var_list
878    # TODO(josh11b): Should this be the last logical device instead?
879    return self._devices
880
881  # TODO(priyag): Delete this once all strategies use global batch size.
882  @property
883  def _global_batch_size(self):
884    """`make_dataset_iterator` and `make_numpy_iterator` use global batch size.
885
886    `make_input_fn_iterator` assumes per-replica batching.
887
888    Returns:
889      Boolean.
890    """
891    return True
892
893  def _in_multi_worker_mode(self):
894    """Whether this strategy indicates working in multi-worker settings."""
895    return False
896
897  def _get_local_replica_id(self, replica_id_in_sync_group):
898    return replica_id_in_sync_group
899
900  def _get_replica_id_in_sync_group(self, replica_id):
901    return replica_id
902