• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Class MirroredStrategy implementing tf.distribute.Strategy."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import copy
22
23from tensorflow.python.distribute import collective_util
24from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
25from tensorflow.python.distribute import cross_device_utils
26from tensorflow.python.distribute import device_util
27from tensorflow.python.distribute import distribute_lib
28from tensorflow.python.distribute import distribute_utils
29from tensorflow.python.distribute import input_lib
30from tensorflow.python.distribute import mirrored_run
31from tensorflow.python.distribute import multi_worker_util
32from tensorflow.python.distribute import numpy_dataset
33from tensorflow.python.distribute import reduce_util
34from tensorflow.python.distribute import values
35from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
36from tensorflow.python.eager import context
37from tensorflow.python.eager import tape
38from tensorflow.python.framework import config
39from tensorflow.python.framework import constant_op
40from tensorflow.python.framework import device as tf_device
41from tensorflow.python.framework import dtypes
42from tensorflow.python.framework import ops
43from tensorflow.python.ops import array_ops
44from tensorflow.python.ops import control_flow_ops
45from tensorflow.python.ops import variables as variables_lib
46from tensorflow.python.platform import tf_logging as logging
47from tensorflow.python.util import nest
48from tensorflow.python.util.tf_export import tf_export
49
50# TODO(josh11b): Replace asserts in this file with if ...: raise ...
51
52
53def _is_device_list_single_worker(devices):
54  """Checks whether the devices list is for single or multi-worker.
55
56  Args:
57    devices: a list of device strings or tf.config.LogicalDevice objects, for
58      either local or for remote devices.
59
60  Returns:
61    a boolean indicating whether these device strings are for local or for
62    remote.
63
64  Raises:
65    ValueError: if device strings are not consistent.
66  """
67  specs = []
68  for d in devices:
69    name = d.name if isinstance(d, context.LogicalDevice) else d
70    specs.append(tf_device.DeviceSpec.from_string(name))
71  num_workers = len({(d.job, d.task, d.replica) for d in specs})
72  all_local = all(d.job in (None, "localhost") for d in specs)
73  any_local = any(d.job in (None, "localhost") for d in specs)
74
75  if any_local and not all_local:
76    raise ValueError("Local device string cannot have job specified other "
77                     "than 'localhost'")
78
79  if num_workers == 1 and not all_local:
80    if any(d.task is None for d in specs):
81      raise ValueError("Remote device string must have task specified.")
82
83  return num_workers == 1
84
85
86def _cluster_spec_to_device_list(cluster_spec, num_gpus_per_worker):
87  """Returns a device list given a cluster spec."""
88  cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
89  devices = []
90  for task_type in ("chief", "worker"):
91    for task_id in range(len(cluster_spec.as_dict().get(task_type, []))):
92      if num_gpus_per_worker == 0:
93        devices.append("/job:%s/task:%d/device:CPU:0" % (task_type, task_id))
94      else:
95        devices.extend([
96            "/job:%s/task:%d/device:GPU:%i" % (task_type, task_id, gpu_id)
97            for gpu_id in range(num_gpus_per_worker)
98        ])
99  return devices
100
101
102def _group_device_list(devices):
103  """Groups the devices list by task_type and task_id.
104
105  Args:
106    devices: a list of device strings for remote devices.
107
108  Returns:
109    a dict of list of device strings mapping from task_type to a list of devices
110    for the task_type in the ascending order of task_id.
111  """
112  assert not _is_device_list_single_worker(devices)
113  device_dict = {}
114
115  for d in devices:
116    d_spec = tf_device.DeviceSpec.from_string(d)
117
118    # Create an entry for the task_type.
119    if d_spec.job not in device_dict:
120      device_dict[d_spec.job] = []
121
122    # Fill the device list for task_type until it covers the task_id.
123    while len(device_dict[d_spec.job]) <= d_spec.task:
124      device_dict[d_spec.job].append([])
125
126    device_dict[d_spec.job][d_spec.task].append(d)
127
128  return device_dict
129
130
131def _is_gpu_device(device):
132  return tf_device.DeviceSpec.from_string(device).device_type == "GPU"
133
134
135def _infer_num_gpus_per_worker(devices):
136  """Infers the number of GPUs on each worker.
137
138  Currently to make multi-worker cross device ops work, we need all workers to
139  have the same number of GPUs.
140
141  Args:
142    devices: a list of device strings, can be either local devices or remote
143      devices.
144
145  Returns:
146    number of GPUs per worker.
147
148  Raises:
149    ValueError if workers have different number of GPUs or GPU indices are not
150    consecutive and starting from 0.
151  """
152  if _is_device_list_single_worker(devices):
153    return sum(1 for d in devices if _is_gpu_device(d))
154  else:
155    device_dict = _group_device_list(devices)
156    num_gpus = None
157    for _, devices_in_task in device_dict.items():
158      for device_in_task in devices_in_task:
159        if num_gpus is None:
160          num_gpus = sum(1 for d in device_in_task if _is_gpu_device(d))
161
162        # Verify other workers have the same number of GPUs.
163        elif num_gpus != sum(1 for d in device_in_task if _is_gpu_device(d)):
164          raise ValueError("All workers should have the same number of GPUs.")
165
166        for d in device_in_task:
167          d_spec = tf_device.DeviceSpec.from_string(d)
168          if (d_spec.device_type == "GPU" and
169              d_spec.device_index >= num_gpus):
170            raise ValueError("GPU `device_index` on a worker should be "
171                             "consecutive and start from 0.")
172    return num_gpus
173
174
175def all_local_devices(num_gpus=None):
176  devices = config.list_logical_devices("GPU")
177  if num_gpus is not None:
178    devices = devices[:num_gpus]
179  return devices or config.list_logical_devices("CPU")
180
181
182def all_devices():
183  devices = []
184  tfconfig = TFConfigClusterResolver()
185  if tfconfig.cluster_spec().as_dict():
186    devices = _cluster_spec_to_device_list(tfconfig.cluster_spec(),
187                                           context.num_gpus())
188  return devices if devices else all_local_devices()
189
190
191@tf_export("distribute.MirroredStrategy", v1=[])  # pylint: disable=g-classes-have-attributes
192class MirroredStrategy(distribute_lib.Strategy):
193  """Synchronous training across multiple replicas on one machine.
194
195  This strategy is typically used for training on one
196  machine with multiple GPUs. For TPUs, use
197  `tf.distribute.TPUStrategy`. To use `MirroredStrategy` with multiple workers,
198  please refer to `tf.distribute.experimental.MultiWorkerMirroredStrategy`.
199
200  For example, a variable created under a `MirroredStrategy` is a
201  `MirroredVariable`. If no devices are specified in the constructor argument of
202  the strategy then it will use all the available GPUs. If no GPUs are found, it
203  will use the available CPUs. Note that TensorFlow treats all CPUs on a
204  machine as a single device, and uses threads internally for parallelism.
205
206  >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
207  >>> with strategy.scope():
208  ...   x = tf.Variable(1.)
209  >>> x
210  MirroredVariable:{
211    0: <tf.Variable ... shape=() dtype=float32, numpy=1.0>,
212    1: <tf.Variable ... shape=() dtype=float32, numpy=1.0>
213  }
214
215  While using distribution strategies, all the variable creation should be done
216  within the strategy's scope. This will replicate the variables across all the
217  replicas and keep them in sync using an all-reduce algorithm.
218
219  Variables created inside a `MirroredStrategy` which is wrapped with a
220  `tf.function` are still `MirroredVariables`.
221
222  >>> x = []
223  >>> @tf.function  # Wrap the function with tf.function.
224  ... def create_variable():
225  ...   if not x:
226  ...     x.append(tf.Variable(1.))
227  ...   return x[0]
228  >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
229  >>> with strategy.scope():
230  ...   _ = create_variable()
231  ...   print(x[0])
232  MirroredVariable:{
233    0: <tf.Variable ... shape=() dtype=float32, numpy=1.0>,
234    1: <tf.Variable ... shape=() dtype=float32, numpy=1.0>
235  }
236
237  `experimental_distribute_dataset` can be used to distribute the dataset across
238  the replicas when writing your own training loop. If you are using `.fit` and
239  `.compile` methods available in `tf.keras`, then `tf.keras` will handle the
240  distribution for you.
241
242  For example:
243
244  ```python
245  my_strategy = tf.distribute.MirroredStrategy()
246  with my_strategy.scope():
247    @tf.function
248    def distribute_train_epoch(dataset):
249      def replica_fn(input):
250        # process input and return result
251        return result
252
253      total_result = 0
254      for x in dataset:
255        per_replica_result = my_strategy.run(replica_fn, args=(x,))
256        total_result += my_strategy.reduce(tf.distribute.ReduceOp.SUM,
257                                           per_replica_result, axis=None)
258      return total_result
259
260    dist_dataset = my_strategy.experimental_distribute_dataset(dataset)
261    for _ in range(EPOCHS):
262      train_result = distribute_train_epoch(dist_dataset)
263  ```
264
265  Args:
266    devices: a list of device strings such as `['/gpu:0', '/gpu:1']`.  If
267      `None`, all available GPUs are used. If no GPUs are found, CPU is used.
268    cross_device_ops: optional, a descedant of `CrossDeviceOps`. If this is not
269      set, `NcclAllReduce()` will be used by default.  One would customize this
270      if NCCL isn't available or if a special implementation that exploits
271      the particular hardware is available.
272  """
273
274  # Only set this in tests.
275  _collective_key_base = 0
276
277  def __init__(self, devices=None, cross_device_ops=None):
278    extended = MirroredExtended(
279        self, devices=devices, cross_device_ops=cross_device_ops)
280    super(MirroredStrategy, self).__init__(extended)
281    distribute_lib.distribution_strategy_gauge.get_cell("V2").set(
282        "MirroredStrategy")
283
284
285@tf_export(v1=["distribute.MirroredStrategy"])
286class MirroredStrategyV1(distribute_lib.StrategyV1):  # pylint: disable=g-missing-docstring
287
288  __doc__ = MirroredStrategy.__doc__
289
290  # Only set this in tests.
291  _collective_key_base = 0
292
293  def __init__(self, devices=None, cross_device_ops=None):
294    extended = MirroredExtended(
295        self, devices=devices, cross_device_ops=cross_device_ops)
296    super(MirroredStrategyV1, self).__init__(extended)
297    distribute_lib.distribution_strategy_gauge.get_cell("V1").set(
298        "MirroredStrategy")
299
300
301# TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1.
302class MirroredExtended(distribute_lib.StrategyExtendedV1):
303  """Implementation of MirroredStrategy."""
304
305  # If this is set to True, use NCCL collective ops instead of NCCL cross device
306  # ops.
307  _prefer_collective_ops = False
308
309  def __init__(self, container_strategy, devices=None, cross_device_ops=None):
310    super(MirroredExtended, self).__init__(container_strategy)
311    if context.executing_eagerly():
312      if devices and not _is_device_list_single_worker(devices):
313        raise RuntimeError("In-graph multi-worker training with "
314                           "`MirroredStrategy` is not supported in eager mode.")
315      else:
316        if TFConfigClusterResolver().cluster_spec().as_dict():
317          # if you are executing in eager mode, only the single machine code
318          # path is supported.
319          logging.info("Initializing local devices since in-graph multi-worker "
320                       "training with `MirroredStrategy` is not supported in "
321                       "eager mode. TF_CONFIG will be ignored when "
322                       "when initializing `MirroredStrategy`.")
323        devices = devices or all_local_devices()
324    else:
325      devices = devices or all_devices()
326
327    assert devices, ("Got an empty `devices` list and unable to recognize "
328                     "any local devices.")
329    self._cross_device_ops = cross_device_ops
330    if self._prefer_collective_ops:
331      self._communication_options = collective_util.Options(
332          implementation=collective_util.CommunicationImplementation.NCCL)
333    else:
334      self._communication_options = collective_util.Options()
335    self._collective_ops_in_use = False
336    self._collective_key_base = container_strategy._collective_key_base
337    self._initialize_strategy(devices)
338
339    # TODO(b/128995245): Enable last partial batch support in graph mode.
340    if ops.executing_eagerly_outside_functions():
341      self.experimental_enable_get_next_as_optional = True
342
343    # Flag to turn on VariablePolicy.
344    self._use_var_policy = False
345
346  def _initialize_strategy(self, devices):
347    # The _initialize_strategy method is intended to be used by distribute
348    # coordinator as well.
349    assert devices, "Must specify at least one device."
350    devices = tuple(device_util.resolve(d) for d in devices)
351    assert len(set(devices)) == len(devices), (
352        "No duplicates allowed in `devices` argument: %s" % (devices,))
353    if _is_device_list_single_worker(devices):
354      self._initialize_single_worker(devices)
355      if self._prefer_collective_ops and (
356          isinstance(self._cross_device_ops, cross_device_ops_lib.NcclAllReduce)
357          or isinstance(self._inferred_cross_device_ops,
358                        cross_device_ops_lib.NcclAllReduce)):
359        self._use_collective_ops(devices)
360        self._inferred_cross_device_ops = None
361      logging.info("Using MirroredStrategy with devices %r", devices)
362    else:
363      self._initialize_multi_worker(devices)
364
365  def _use_collective_ops(self, devices):
366    if ops.executing_eagerly_outside_functions():
367      try:
368        context.context().configure_collective_ops(
369            scoped_allocator_enabled_ops=("CollectiveReduce",))
370      except RuntimeError:
371        logging.warning("Collective ops is not configured at program startup."
372                        " Some performance features may not be enabled.")
373
374    self._collective_keys = cross_device_utils.CollectiveKeys(
375        group_key_start=1 + self._collective_key_base)  # pylint: disable=protected-access
376    self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
377        devices=self._devices,
378        group_size=len(self._devices),
379        collective_keys=self._collective_keys)
380    self._collective_ops_in_use = True
381
382  def _initialize_single_worker(self, devices):
383    """Initializes the object for single-worker training."""
384    self._devices = tuple(device_util.canonicalize(d) for d in devices)
385    self._input_workers_devices = (
386        (device_util.canonicalize("/device:CPU:0", devices[0]), devices),)
387
388    self._inferred_cross_device_ops = None if self._cross_device_ops else (
389        cross_device_ops_lib.select_cross_device_ops(devices))
390    self._host_input_device = numpy_dataset.SingleDevice(
391        self._input_workers_devices[0][0])
392    self._is_multi_worker_training = False
393    device_spec = tf_device.DeviceSpec.from_string(
394        self._input_workers_devices[0][0])
395    # Ensures when we enter strategy.scope() we use the correct default device
396    if device_spec.job is not None and device_spec.job != "localhost":
397      self._default_device = "/job:%s/replica:%d/task:%d" % (
398          device_spec.job, device_spec.replica, device_spec.task)
399
400  def _initialize_multi_worker(self, devices):
401    """Initializes the object for multi-worker training."""
402    device_dict = _group_device_list(devices)
403    workers = []
404    worker_devices = []
405    for job in ("chief", "worker"):
406      for task in range(len(device_dict.get(job, []))):
407        worker = "/job:%s/task:%d" % (job, task)
408        workers.append(worker)
409        worker_devices.append((worker, device_dict[job][task]))
410
411    # Setting `_default_device` will add a device scope in the
412    # distribution.scope. We set the default device to the first worker. When
413    # users specify device under distribution.scope by
414    #   with tf.device("/cpu:0"):
415    #     ...
416    # their ops will end up on the cpu device of its first worker, e.g.
417    # "/job:worker/task:0/device:CPU:0". Note this is not used in replica mode.
418    self._default_device = workers[0]
419    self._host_input_device = numpy_dataset.SingleDevice(workers[0])
420
421    self._devices = tuple(devices)
422    self._input_workers_devices = worker_devices
423    self._is_multi_worker_training = True
424
425    if len(workers) > 1:
426      # Grandfather usage in the legacy tests if they're configured properly.
427      if (not isinstance(self._cross_device_ops,
428                         cross_device_ops_lib.ReductionToOneDevice) or
429          self._cross_device_ops._num_between_graph_workers > 1):  # pylint: disable=protected-access
430        raise ValueError(
431            "In-graph multi-worker training with `MirroredStrategy` is not "
432            "supported.")
433      self._inferred_cross_device_ops = self._cross_device_ops
434    else:
435      # TODO(yuefengz): make `select_cross_device_ops` work with device strings
436      # containing job names.
437      self._inferred_cross_device_ops = cross_device_ops_lib.NcclAllReduce()
438
439    logging.info("Using MirroredStrategy with remote devices %r", devices)
440
441  def _input_workers_with_options(self, options=None):
442    if not options:
443      return input_lib.InputWorkers(self._input_workers_devices)
444    if (options.experimental_replication_mode ==
445        distribute_lib.InputReplicationMode.PER_REPLICA):
446      if options.experimental_place_dataset_on_device:
447        self._input_workers_devices = (
448            tuple(
449                (device_util.canonicalize(d, d), (d,)) for d in self._devices))
450      else:
451        self._input_workers_devices = (
452            tuple((device_util.canonicalize("/device:CPU:0", d), (d,))
453                  for d in self._devices))
454      return input_lib.InputWorkers(self._input_workers_devices)
455    else:
456      if not options.experimental_prefetch_to_device:
457        return input_lib.InputWorkers([
458            (host_device, (host_device,) * len(compute_devices))
459            for host_device, compute_devices in self._input_workers_devices
460        ])
461      else:
462        return input_lib.InputWorkers(self._input_workers_devices)
463
464  @property
465  def _input_workers(self):
466    return self._input_workers_with_options()
467
468  def _get_variable_creator_initial_value(self,
469                                          replica_id,
470                                          device,
471                                          primary_var,
472                                          **kwargs):
473    """Return the initial value for variables on a replica."""
474    if replica_id == 0:
475      return kwargs["initial_value"]
476    else:
477      assert primary_var is not None
478      assert device is not None
479      assert kwargs is not None
480
481      def initial_value_fn():
482        if context.executing_eagerly() or ops.inside_function():
483          init_value = primary_var.value()
484          return array_ops.identity(init_value)
485        else:
486          with ops.device(device):
487            init_value = primary_var.initial_value
488            return array_ops.identity(init_value)
489
490      return initial_value_fn
491
492  def _create_variable(self, next_creator, **kwargs):
493    """Create a mirrored variable. See `DistributionStrategy.scope`."""
494    colocate_with = kwargs.pop("colocate_with", None)
495    if colocate_with is None:
496      devices = self._devices
497    elif isinstance(colocate_with, numpy_dataset.SingleDevice):
498      with ops.device(colocate_with.device):
499        return next_creator(**kwargs)
500    else:
501      devices = colocate_with._devices  # pylint: disable=protected-access
502
503    def _real_mirrored_creator(**kwargs):  # pylint: disable=g-missing-docstring
504      value_list = []
505      for i, d in enumerate(devices):
506        with ops.device(d):
507          kwargs["initial_value"] = self._get_variable_creator_initial_value(
508              replica_id=i,
509              device=d,
510              primary_var=value_list[0] if value_list else None,
511              **kwargs)
512          if i > 0:
513            # Give replicas meaningful distinct names:
514            var0name = value_list[0].name.split(":")[0]
515            # We append a / to variable names created on replicas with id > 0 to
516            # ensure that we ignore the name scope and instead use the given
517            # name as the absolute name of the variable.
518            kwargs["name"] = "%s/replica_%d/" % (var0name, i)
519          with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
520            # Don't record operations (e.g. other variable reads) during
521            # variable creation.
522            with tape.stop_recording():
523              v = next_creator(**kwargs)
524          assert not isinstance(v, values.DistributedVariable)
525          value_list.append(v)
526      return value_list
527
528    return distribute_utils.create_mirrored_variable(
529        self._container_strategy(), _real_mirrored_creator,
530        distribute_utils.VARIABLE_CLASS_MAPPING,
531        distribute_utils.VARIABLE_POLICY_MAPPING, **kwargs)
532
533  def _validate_colocate_with_variable(self, colocate_with_variable):
534    distribute_utils.validate_colocate_distributed_variable(
535        colocate_with_variable, self)
536
537  def _make_dataset_iterator(self, dataset):
538    return input_lib.DatasetIterator(
539        dataset,
540        self._input_workers,
541        self._container_strategy(),
542        num_replicas_in_sync=self._num_replicas_in_sync)
543
544  def _make_input_fn_iterator(
545      self,
546      input_fn,
547      replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
548    input_contexts = []
549    num_workers = self._input_workers.num_workers
550    for i in range(num_workers):
551      input_contexts.append(distribute_lib.InputContext(
552          num_input_pipelines=num_workers,
553          input_pipeline_id=i,
554          num_replicas_in_sync=self._num_replicas_in_sync))
555    return input_lib.InputFunctionIterator(input_fn, self._input_workers,
556                                           input_contexts,
557                                           self._container_strategy())
558
559  def _experimental_distribute_dataset(self, dataset, options):
560    if (options and options.experimental_replication_mode ==
561        distribute_lib.InputReplicationMode.PER_REPLICA):
562      raise NotImplementedError(
563          "InputReplicationMode.PER_REPLICA "
564          "is only supported in "
565          "`experimental_distribute_datasets_from_function`."
566      )
567    return input_lib.get_distributed_dataset(
568        dataset,
569        self._input_workers_with_options(options),
570        self._container_strategy(),
571        num_replicas_in_sync=self._num_replicas_in_sync)
572
573  def _experimental_make_numpy_dataset(self, numpy_input, session):
574    return numpy_dataset.one_host_numpy_dataset(
575        numpy_input, self._host_input_device, session)
576
577  def _distribute_datasets_from_function(self, dataset_fn, options):
578    input_workers = self._input_workers_with_options(options)
579    input_contexts = []
580    num_workers = input_workers.num_workers
581    for i in range(num_workers):
582      input_contexts.append(distribute_lib.InputContext(
583          num_input_pipelines=num_workers,
584          input_pipeline_id=i,
585          num_replicas_in_sync=self._num_replicas_in_sync))
586
587    return input_lib.get_distributed_datasets_from_function(
588        dataset_fn, input_workers, input_contexts, self._container_strategy(),
589        options)
590
591  def _experimental_distribute_values_from_function(self, value_fn):
592    per_replica_values = []
593    for replica_id in range(self._num_replicas_in_sync):
594      per_replica_values.append(value_fn(
595          distribute_lib.ValueContext(replica_id,
596                                      self._num_replicas_in_sync)))
597    return distribute_utils.regroup(per_replica_values, always_wrap=True)
598
599  # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
600  def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
601                                          initial_loop_values=None):
602    if initial_loop_values is None:
603      initial_loop_values = {}
604    initial_loop_values = nest.flatten(initial_loop_values)
605
606    ctx = input_lib.MultiStepContext()
607    def body(i, *args):
608      """A wrapper around `fn` to create the while loop body."""
609      del args
610      fn_result = fn(ctx, iterator.get_next())
611      for (name, output) in ctx.last_step_outputs.items():
612        # Convert all outputs to tensors, potentially from `DistributedValues`.
613        ctx.last_step_outputs[name] = self._local_results(output)
614      flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
615      with ops.control_dependencies([fn_result]):
616        return [i + 1] + flat_last_step_outputs
617
618    # We capture the control_flow_context at this point, before we run `fn`
619    # inside a while_loop. This is useful in cases where we might need to exit
620    # these contexts and get back to the outer context to do some things, for
621    # e.g. create an op which should be evaluated only once at the end of the
622    # loop on the host. One such usage is in creating metrics' value op.
623    self._outer_control_flow_context = (
624        ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access
625
626    cond = lambda i, *args: i < iterations
627    i = constant_op.constant(0)
628    loop_result = control_flow_ops.while_loop(
629        cond, body, [i] + initial_loop_values, name="",
630        parallel_iterations=1, back_prop=False, swap_memory=False,
631        return_same_structure=True)
632    del self._outer_control_flow_context
633
634    ctx.run_op = control_flow_ops.group(loop_result)
635
636    # Convert the last_step_outputs from a list to the original dict structure
637    # of last_step_outputs.
638    last_step_tensor_outputs = loop_result[1:]
639    last_step_tensor_outputs_dict = nest.pack_sequence_as(
640        ctx.last_step_outputs, last_step_tensor_outputs)
641
642    for name, reduce_op in ctx._last_step_outputs_reduce_ops.items():  # pylint: disable=protected-access
643      output = last_step_tensor_outputs_dict[name]
644      # For outputs that have already been reduced, wrap them in a Mirrored
645      # container, else in a PerReplica container.
646      if reduce_op is None:
647        last_step_tensor_outputs_dict[name] = distribute_utils.regroup(output)
648      else:
649        assert len(output) == 1
650        last_step_tensor_outputs_dict[name] = output[0]
651
652    ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access
653    return ctx
654
655  def _broadcast_to(self, tensor, destinations):
656    # This is both a fast path for Python constants, and a way to delay
657    # converting Python values to a tensor until we know what type it
658    # should be converted to. Otherwise we have trouble with:
659    #   global_step.assign_add(1)
660    # since the `1` gets broadcast as an int32 but global_step is int64.
661    if isinstance(tensor, (float, int)):
662      return tensor
663    # TODO(josh11b): In eager mode, use one thread per device, or async mode.
664    if not destinations:
665      # TODO(josh11b): Use current logical device instead of 0 here.
666      destinations = self._devices
667    return self._get_cross_device_ops(tensor).broadcast(tensor, destinations)
668
669  def _call_for_each_replica(self, fn, args, kwargs):
670    return mirrored_run.call_for_each_replica(
671        self._container_strategy(), fn, args, kwargs)
672
673  def _configure(self,
674                 session_config=None,
675                 cluster_spec=None,
676                 task_type=None,
677                 task_id=None):
678    del task_type, task_id
679
680    if session_config:
681      session_config.CopyFrom(self._update_config_proto(session_config))
682
683    if cluster_spec:
684      # TODO(yuefengz): remove the following code once cluster_resolver is
685      # added.
686      num_gpus_per_worker = _infer_num_gpus_per_worker(self._devices)
687      multi_worker_devices = _cluster_spec_to_device_list(
688          cluster_spec, num_gpus_per_worker)
689      self._initialize_multi_worker(multi_worker_devices)
690
691  def _update_config_proto(self, config_proto):
692    updated_config = copy.deepcopy(config_proto)
693    updated_config.isolate_session_state = True
694    return updated_config
695
696  def _get_cross_device_ops(self, value):
697    if self._collective_ops_in_use:
698      if isinstance(value, values.DistributedValues):
699        value_int32 = True in {
700            dtypes.as_dtype(v.dtype) == dtypes.int32 for v in value.values
701        }
702      else:
703        value_int32 = dtypes.as_dtype(value.dtype) == dtypes.int32
704      if value_int32:
705        return cross_device_ops_lib.ReductionToOneDevice()
706
707    return self._cross_device_ops or self._inferred_cross_device_ops
708
709  def _gather_to_implementation(self, value, destinations, axis, options):
710    if not isinstance(value, values.DistributedValues):
711      # ReductionToOneDevice._gather accepts DistributedValues only.
712      return value
713    return self._get_cross_device_ops(value)._gather(  # pylint: disable=protected-access
714        value,
715        destinations=destinations,
716        axis=axis,
717        options=self._communication_options.merge(options))
718
719  def _reduce_to(self, reduce_op, value, destinations, options):
720    if (distribute_utils.is_mirrored(value) and
721        reduce_op == reduce_util.ReduceOp.MEAN):
722      return value
723    assert not distribute_utils.is_mirrored(value)
724    if not isinstance(value, values.DistributedValues):
725      # This function handles reducing values that are not PerReplica or
726      # Mirrored values. For example, the same value could be present on all
727      # replicas in which case `value` would be a single value or value could
728      # be 0.
729      return cross_device_ops_lib.reduce_non_distributed_value(
730          reduce_op, value, destinations, self._num_replicas_in_sync)
731    if self._collective_ops_in_use and (
732        (not cross_device_ops_lib._devices_match(value, destinations) or  # pylint: disable=protected-access
733         any("cpu" in d.lower()
734             for d in cross_device_ops_lib.get_devices_from(destinations)))):
735      return cross_device_ops_lib.ReductionToOneDevice().reduce(
736          reduce_op, value, destinations)
737    return self._get_cross_device_ops(value).reduce(
738        reduce_op,
739        value,
740        destinations=destinations,
741        options=self._communication_options.merge(options))
742
743  def _batch_reduce_to(self, reduce_op, value_destination_pairs, options):
744    cross_device_ops = None
745    for value, _ in value_destination_pairs:
746      if cross_device_ops is None:
747        cross_device_ops = self._get_cross_device_ops(value)
748      elif cross_device_ops is not self._get_cross_device_ops(value):
749        raise ValueError("inputs to batch_reduce_to must be either all on the "
750                         "the host or all on the compute devices")
751    return cross_device_ops.batch_reduce(
752        reduce_op,
753        value_destination_pairs,
754        options=self._communication_options.merge(options))
755
756  def _update(self, var, fn, args, kwargs, group):
757    # TODO(josh11b): In eager mode, use one thread per device.
758    assert isinstance(var, values.DistributedVariable)
759    if (var.synchronization != variables_lib.VariableSynchronization.ON_READ and
760        var.aggregation != variables_lib.VariableAggregation.NONE):
761      distribute_utils.assert_mirrored(args)
762      distribute_utils.assert_mirrored(kwargs)
763    updates = []
764    for i, v in enumerate(var.values):
765      name = "update_%d" % i
766      with ops.device(v.device), \
767           distribute_lib.UpdateContext(i), \
768           ops.name_scope(name):
769        # If args and kwargs are not mirrored, the value is returned as is.
770        updates.append(
771            fn(v, *distribute_utils.select_replica(i, args),
772               **distribute_utils.select_replica(i, kwargs)))
773    return distribute_utils.update_regroup(self, updates, group)
774
775  def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
776    assert isinstance(colocate_with, tuple)
777    # TODO(josh11b): In eager mode, use one thread per device.
778    updates = []
779    for i, d in enumerate(colocate_with):
780      name = "update_%d" % i
781      with ops.device(d), distribute_lib.UpdateContext(i), ops.name_scope(name):
782        updates.append(
783            fn(*distribute_utils.select_replica_mirrored(i, args),
784               **distribute_utils.select_replica_mirrored(i, kwargs)))
785    return distribute_utils.update_regroup(self, updates, group)
786
787  def read_var(self, replica_local_var):
788    """Read the aggregate value of a replica-local variable."""
789    # pylint: disable=protected-access
790    if distribute_utils.is_sync_on_read(replica_local_var):
791      return replica_local_var._get_cross_replica()
792    assert distribute_utils.is_mirrored(replica_local_var)
793    return array_ops.identity(replica_local_var._get())
794    # pylint: enable=protected-access
795
796  def _local_results(self, val):
797    if isinstance(val, values.DistributedValues):
798      return val._values  # pylint: disable=protected-access
799    return (val,)
800
801  def value_container(self, val):
802    return distribute_utils.value_container(val)
803
804  @property
805  def _num_replicas_in_sync(self):
806    return len(self._devices)
807
808  @property
809  def worker_devices(self):
810    return self._devices
811
812  @property
813  def worker_devices_by_replica(self):
814    return [[d] for d in self._devices]
815
816  @property
817  def parameter_devices(self):
818    return self.worker_devices
819
820  @property
821  def experimental_between_graph(self):
822    return False
823
824  @property
825  def experimental_should_init(self):
826    return True
827
828  @property
829  def should_checkpoint(self):
830    return True
831
832  @property
833  def should_save_summary(self):
834    return True
835
836  def non_slot_devices(self, var_list):
837    del var_list
838    # TODO(josh11b): Should this be the last logical device instead?
839    return self._devices
840
841  # TODO(priyag): Delete this once all strategies use global batch size.
842  @property
843  def _global_batch_size(self):
844    """`make_dataset_iterator` and `make_numpy_iterator` use global batch size.
845
846    `make_input_fn_iterator` assumes per-replica batching.
847
848    Returns:
849      Boolean.
850    """
851    return True
852
853  def _in_multi_worker_mode(self):
854    """Whether this strategy indicates working in multi-worker settings."""
855    return False
856
857  def _get_local_replica_id(self, replica_id_in_sync_group):
858    return replica_id_in_sync_group
859
860  def _get_replica_id_in_sync_group(self, replica_id):
861    return replica_id
862