• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=line-too-long
16"""Library for running a computation across multiple devices.
17
18The intent of this library is that you can write an algorithm in a stylized way
19and it will be usable with a variety of different `tf.distribute.Strategy`
20implementations. Each descendant will implement a different strategy for
21distributing the algorithm across multiple devices/machines.  Furthermore, these
22changes can be hidden inside the specific layers and other library classes that
23need special treatment to run in a distributed setting, so that most users'
24model definition code can run unchanged. The `tf.distribute.Strategy` API works
25the same way with eager and graph execution.
26
27*Guides*
28
29* [TensorFlow v2.x](https://www.tensorflow.org/guide/distributed_training)
30* [TensorFlow v1.x](https://github.com/tensorflow/docs/blob/master/site/en/r1/guide/distribute_strategy.ipynb)
31
32*Tutorials*
33
34* [Distributed Training Tutorials](https://www.tensorflow.org/tutorials/distribute/)
35
36  The tutorials cover how to use `tf.distribute.Strategy` to do distributed
37  training with native Keras APIs, custom training loops,
38  and Estimator APIs. They also cover how to save/load model when using
39  `tf.distribute.Strategy`.
40
41*Glossary*
42
43* _Data parallelism_ is where we run multiple copies of the model
44  on different slices of the input data. This is in contrast to
45  _model parallelism_ where we divide up a single copy of a model
46  across multiple devices.
47  Note: we only support data parallelism for now, but
48  hope to add support for model parallelism in the future.
49* A _device_ is a CPU or accelerator (e.g. GPUs, TPUs) on some machine that
50  TensorFlow can run operations on (see e.g. `tf.device`). You may have multiple
51  devices on a single machine, or be connected to devices on multiple
52  machines. Devices used to run computations are called _worker devices_.
53  Devices used to store variables are _parameter devices_. For some strategies,
54  such as `tf.distribute.MirroredStrategy`, the worker and parameter devices
55  will be the same (see mirrored variables below). For others they will be
56  different. For example, `tf.distribute.experimental.CentralStorageStrategy`
57  puts the variables on a single device (which may be a worker device or may be
58  the CPU), and `tf.distribute.experimental.ParameterServerStrategy` puts the
59  variables on separate machines called _parameter servers_ (see below).
60* A _replica_ is one copy of the model, running on one slice of the
61  input data. Right now each replica is executed on its own
62  worker device, but once we add support for model parallelism
63  a replica may span multiple worker devices.
64* A _host_ is the CPU device on a machine with worker devices, typically
65  used for running input pipelines.
66* A _worker_ is defined to be the physical machine(s) containing the physical
67  devices (e.g. GPUs, TPUs) on which the replicated computation is executed. A
68  worker may contain one or more replicas, but contains at least one
69  replica. Typically one worker will correspond to one machine, but in the case
70  of very large models with model parallelism, one worker may span multiple
71  machines. We typically run one input pipeline per worker, feeding all the
72  replicas on that worker.
73* _Synchronous_, or more commonly _sync_, training is where the updates from
74  each replica are aggregated together before updating the model variables. This
75  is in contrast to _asynchronous_, or _async_ training, where each replica
76  updates the model variables independently. You may also have replicas
77  partitioned into groups which are in sync within each group but async between
78  groups.
79* _Parameter servers_: These are machines that hold a single copy of
80  parameters/variables, used by some strategies (right now just
81  `tf.distribute.experimental.ParameterServerStrategy`). All replicas that want
82  to operate on a variable retrieve it at the beginning of a step and send an
83  update to be applied at the end of the step. These can in principle support
84  either sync or async training, but right now we only have support for async
85  training with parameter servers. Compare to
86  `tf.distribute.experimental.CentralStorageStrategy`, which puts all variables
87  on a single device on the same machine (and does sync training), and
88  `tf.distribute.MirroredStrategy`, which mirrors variables to multiple devices
89  (see below).
90
91* _Replica context_ vs. _Cross-replica context_ vs _Update context_
92
93  A _replica context_ applies
94  when you execute the computation function that was called with `strategy.run`.
95  Conceptually, you're in replica context when executing the computation
96  function that is being replicated.
97
98  An _update context_ is entered in a `tf.distribute.StrategyExtended.update`
99  call.
100
101  An _cross-replica context_ is entered when you enter a `strategy.scope`. This
102  is useful for calling `tf.distribute.Strategy` methods which operate across
103  the replicas (like `reduce_to()`). By default you start in a _replica context_
104  (the "default single _replica context_") and then some methods can switch you
105  back and forth.
106
107* _Distributed value_: Distributed value is represented by the base class
108  `tf.distribute.DistributedValues`. `tf.distribute.DistributedValues` is useful
109  to represent values on multiple devices, and it contains a map from replica id
110  to values. Two representative kinds of `tf.distribute.DistributedValues` are
111  "PerReplica" and "Mirrored" values.
112
113  "PerReplica" values exist on the worker
114  devices, with a different value for each replica. They are produced by
115  iterating through a distributed dataset returned by
116  `tf.distribute.Strategy.experimental_distribute_dataset` and
117  `tf.distribute.Strategy.distribute_datasets_from_function`. They
118  are also the typical result returned by
119  `tf.distribute.Strategy.run`.
120
121  "Mirrored" values are like "PerReplica" values, except we know that the value
122  on all replicas are the same. We can safely read a "Mirrored" value in a
123  cross-replica context by using the value on any replica.
124
125* _Unwrapping_ and _merging_: Consider calling a function `fn` on multiple
126  replicas, like `strategy.run(fn, args=[w])` with an
127  argument `w` that is a `tf.distribute.DistributedValues`. This means `w` will
128  have a map taking replica id `0` to `w0`, replica id `1` to `w1`, etc.
129  `strategy.run()` unwraps `w` before calling `fn`, so it calls `fn(w0)` on
130  device `d0`, `fn(w1)` on device `d1`, etc.  It then merges the return
131  values from `fn()`, which leads to one common object if the returned values
132  are the same object from every replica, or a `DistributedValues` object
133  otherwise.
134
135* _Reductions_ and _all-reduce_: A _reduction_ is a method of aggregating
136  multiple values into one value, like "sum" or "mean". If a strategy is doing
137  sync training, we will perform a reduction on the gradients to a parameter
138  from all replicas before applying the update. _All-reduce_ is an algorithm for
139  performing a reduction on values from multiple devices and making the result
140  available on all of those devices.
141
142* _Mirrored variables_: These are variables that are created on multiple
143  devices, where we keep the variables in sync by applying the same
144  updates to every copy. Mirrored variables are created with
145  `tf.Variable(...synchronization=tf.VariableSynchronization.ON_WRITE...)`.
146  Normally they are only used in synchronous training.
147
148* _SyncOnRead variables_
149
150  _SyncOnRead variables_ are created by
151  `tf.Variable(...synchronization=tf.VariableSynchronization.ON_READ...)`, and
152  they are created on multiple devices. In replica context, each
153  component variable on the local replica can perform reads and writes without
154  synchronization with each other. When the
155  _SyncOnRead variable_ is read in cross-replica context, the values from
156  component variables are aggregated and returned.
157
158  _SyncOnRead variables_ bring a lot of custom configuration difficulty to the
159  underlying logic, so we do not encourage users to instantiate and use
160  _SyncOnRead variable_ on their own. We have mainly used _SyncOnRead
161  variables_ for use cases such as batch norm and metrics. For performance
162  reasons, we often don't need to keep these statistics in sync every step and
163  they can be accumulated on each replica independently. The only time we want
164  to sync them is reporting or checkpointing, which typically happens in
165  cross-replica context. _SyncOnRead variables_ are also often used by advanced
166  users who want to control when variable values are aggregated. For example,
167  users sometimes want to maintain gradients independently on each replica for a
168  couple of steps without aggregation.
169
170* _Distribute-aware layers_
171
172  Layers are generally called in a replica context, except when defining a
173  Keras functional model. `tf.distribute.in_cross_replica_context` will let you
174  determine which case you are in. If in a replica context,
175  the `tf.distribute.get_replica_context` function will return the default
176  replica context outside a strategy scope, `None` within a strategy scope, and
177  a `tf.distribute.ReplicaContext` object inside a strategy scope and within a
178  `tf.distribute.Strategy.run` function. The `ReplicaContext` object has an
179  `all_reduce` method for aggregating across all replicas.
180
181
182Note that we provide a default version of `tf.distribute.Strategy` that is
183used when no other strategy is in scope, that provides the same API with
184reasonable default behavior.
185"""
186# pylint: enable=line-too-long
187
188from __future__ import absolute_import
189from __future__ import division
190from __future__ import print_function
191
192import collections
193import copy
194import enum  # pylint: disable=g-bad-import-order
195import functools
196import threading
197import weakref
198
199import six
200
201from tensorflow.python.autograph.core import ag_ctx as autograph_ctx
202from tensorflow.python.autograph.impl import api as autograph
203from tensorflow.python.data.ops import dataset_ops
204from tensorflow.python.distribute import collective_util
205from tensorflow.python.distribute import device_util
206from tensorflow.python.distribute import distribution_strategy_context
207from tensorflow.python.distribute import numpy_dataset
208from tensorflow.python.distribute import reduce_util
209from tensorflow.python.distribute import values
210from tensorflow.python.eager import context as eager_context
211from tensorflow.python.eager import def_function
212from tensorflow.python.eager import monitoring
213from tensorflow.python.framework import constant_op
214from tensorflow.python.framework import dtypes
215from tensorflow.python.framework import ops
216from tensorflow.python.framework import tensor_shape
217from tensorflow.python.framework import tensor_util
218from tensorflow.python.ops import array_ops
219from tensorflow.python.ops import control_flow_ops
220from tensorflow.python.ops import custom_gradient
221from tensorflow.python.ops import math_ops
222from tensorflow.python.ops import resource_variable_ops
223from tensorflow.python.ops import summary_ops_v2
224from tensorflow.python.ops import variable_scope
225from tensorflow.python.ops.losses import losses_impl
226from tensorflow.python.platform import tf_logging
227from tensorflow.python.training.tracking import base as trackable
228from tensorflow.python.util import deprecation
229from tensorflow.python.util import nest
230from tensorflow.python.util import tf_contextlib
231from tensorflow.python.util.deprecation import deprecated
232from tensorflow.python.util.tf_export import tf_export
233from tensorflow.tools.docs import doc_controls
234
235
236# ------------------------------------------------------------------------------
237# Context tracking whether in a strategy.update() or .update_non_slot() call.
238
239
240_update_replica_id = threading.local()
241
242
243def get_update_replica_id():
244  """Get the current device if in a `tf.distribute.Strategy.update()` call."""
245  try:
246    return _update_replica_id.current
247  except AttributeError:
248    return None
249
250
251class UpdateContext(object):
252  """Context manager when you are in `update()` or `update_non_slot()`."""
253
254  __slots__ = ["_replica_id", "_old_replica_id"]
255
256  def __init__(self, replica_id):
257    self._replica_id = replica_id
258    self._old_replica_id = None
259
260  def __enter__(self):
261    self._old_replica_id = get_update_replica_id()
262    _update_replica_id.current = self._replica_id
263
264  def __exit__(self, exception_type, exception_value, traceback):
265    del exception_type, exception_value, traceback
266    _update_replica_id.current = self._old_replica_id
267
268
269# ------------------------------------------------------------------------------
270# Public utility functions.
271
272
273@tf_export(v1=["distribute.get_loss_reduction"])
274def get_loss_reduction():
275  """`tf.distribute.ReduceOp` corresponding to the last loss reduction.
276
277  This is used to decide whether loss should be scaled in optimizer (used only
278  for estimator + v1 optimizer use case).
279
280  Returns:
281    `tf.distribute.ReduceOp` corresponding to the last loss reduction for
282    estimator and v1 optimizer use case. `tf.distribute.ReduceOp.SUM` otherwise.
283  """
284  if not distribution_strategy_context.get_strategy()._scale_loss_for_estimator:  # pylint: disable=protected-access
285    # If we are not in Estimator context then return 'SUM'. We do not need to
286    # scale loss in the optimizer.
287    return reduce_util.ReduceOp.SUM
288  last_reduction = ops.get_default_graph()._last_loss_reduction  # pylint: disable=protected-access
289  if (last_reduction == losses_impl.Reduction.SUM or
290      last_reduction == "sum"):  # Check for tf.keras.losses.Reduction.SUM
291    return reduce_util.ReduceOp.SUM
292  return reduce_util.ReduceOp.MEAN
293
294
295# ------------------------------------------------------------------------------
296# Internal API for validating the current thread mode
297
298
299def _require_cross_replica_or_default_context_extended(extended,
300                                                       error_message=None):
301  """Verify in cross-replica context."""
302  context = _get_per_thread_mode()
303  cross_replica = context.cross_replica_context
304  if cross_replica is not None and cross_replica.extended is extended:
305    return
306  if context is _get_default_replica_mode():
307    return
308  strategy = extended._container_strategy()  # pylint: disable=protected-access
309  # We have an error to report, figure out the right message.
310  if context.strategy is not strategy:
311    _wrong_strategy_scope(strategy, context)
312  assert cross_replica is None
313  if not error_message:
314    error_message = ("Method requires being in cross-replica context, use "
315                     "get_replica_context().merge_call()")
316  raise RuntimeError(error_message)
317
318
319def _wrong_strategy_scope(strategy, context):
320  # Figure out the right error message.
321  if not distribution_strategy_context.has_strategy():
322    raise RuntimeError(
323        'Need to be inside "with strategy.scope()" for %s' %
324        (strategy,))
325  else:
326    raise RuntimeError(
327        "Mixing different tf.distribute.Strategy objects: %s is not %s" %
328        (context.strategy, strategy))
329
330
331def require_replica_context(replica_ctx):
332  """Verify in `replica_ctx` replica context."""
333  context = _get_per_thread_mode()
334  if context.replica_context is replica_ctx: return
335  # We have an error to report, figure out the right message.
336  if context.replica_context is None:
337    raise RuntimeError("Need to be inside `call_for_each_replica()`")
338  if context.strategy is replica_ctx.strategy:
339    # Two different ReplicaContexts with the same tf.distribute.Strategy.
340    raise RuntimeError("Mismatching ReplicaContext.")
341  raise RuntimeError(
342      "Mismatching tf.distribute.Strategy objects: %s is not %s." %
343      (context.strategy, replica_ctx.strategy))
344
345
346def _require_strategy_scope_strategy(strategy):
347  """Verify in a `strategy.scope()` in this thread."""
348  context = _get_per_thread_mode()
349  if context.strategy is strategy: return
350  _wrong_strategy_scope(strategy, context)
351
352
353def _require_strategy_scope_extended(extended):
354  """Verify in a `distribution_strategy.scope()` in this thread."""
355  context = _get_per_thread_mode()
356  if context.strategy.extended is extended: return
357  # Report error.
358  strategy = extended._container_strategy()  # pylint: disable=protected-access
359  _wrong_strategy_scope(strategy, context)
360
361
362# ------------------------------------------------------------------------------
363# Internal context managers used to implement the DistributionStrategy
364# base class
365
366
367class _CurrentDistributionContext(object):
368  """Context manager setting the current `tf.distribute.Strategy`.
369
370  Also: overrides the variable creator and optionally the current device.
371  """
372
373  def __init__(self,
374               strategy,
375               var_creator_scope,
376               var_scope=None,
377               resource_creator_scope=None,
378               default_device=None):
379    self._context = distribution_strategy_context._CrossReplicaThreadMode(  # pylint: disable=protected-access
380        strategy)
381    self._var_creator_scope = var_creator_scope
382    self._var_scope = var_scope
383    self._resource_creator_scope = resource_creator_scope
384    if default_device:
385      self._device_scope = ops.device(default_device)
386    else:
387      self._device_scope = None
388    self._same_scope_again_count = 0
389
390  def __enter__(self):
391    # Allow this scope to be entered if this strategy is already in scope.
392    if distribution_strategy_context.has_strategy():
393      _require_cross_replica_or_default_context_extended(
394          self._context.strategy.extended)
395      self._same_scope_again_count += 1
396    else:
397      _push_per_thread_mode(self._context)
398      if self._var_scope:
399        self._var_scope.__enter__()
400      self._var_creator_scope.__enter__()
401      if self._resource_creator_scope:
402        nest.map_structure(lambda scope: scope.__enter__(),
403                           self._resource_creator_scope)
404      if self._device_scope:
405        self._device_scope.__enter__()
406    return self._context.strategy
407
408  def __exit__(self, exception_type, exception_value, traceback):
409    if self._same_scope_again_count > 0:
410      self._same_scope_again_count -= 1
411      return
412    if self._device_scope:
413      try:
414        self._device_scope.__exit__(exception_type, exception_value, traceback)
415      except RuntimeError as e:
416        six.raise_from(
417            RuntimeError("Device scope nesting error: move call to "
418                         "tf.distribute.set_strategy() out of `with` scope."),
419            e)
420
421    try:
422      self._var_creator_scope.__exit__(
423          exception_type, exception_value, traceback)
424    except RuntimeError as e:
425      six.raise_from(
426          RuntimeError("Variable creator scope nesting error: move call to "
427                       "tf.distribute.set_strategy() out of `with` scope."),
428          e)
429
430    if self._resource_creator_scope:
431      try:
432        self._resource_creator_scope.__exit__(exception_type, exception_value,
433                                              traceback)
434      except RuntimeError as e:
435        six.raise_from(
436            RuntimeError("Resource creator scope nesting error: move call "
437                         "to tf.distribute.set_strategy() out of `with` "
438                         "scope."), e)
439
440    if self._var_scope:
441      try:
442        self._var_scope.__exit__(exception_type, exception_value, traceback)
443      except RuntimeError as e:
444        six.raise_from(
445            RuntimeError("Variable scope nesting error: move call to "
446                         "tf.distribute.set_strategy() out of `with` scope."),
447            e)
448    _pop_per_thread_mode()
449
450
451# TODO(yuefengz): add more replication modes.
452@tf_export("distribute.InputReplicationMode")
453class InputReplicationMode(enum.Enum):
454  """Replication mode for input function.
455
456  * `PER_WORKER`: The input function will be called on each worker
457    independently, creating as many input pipelines as number of workers.
458    Replicas will dequeue from the local Dataset on their worker.
459    `tf.distribute.Strategy` doesn't manage any state sharing between such
460    separate input pipelines.
461  * `PER_REPLICA`: The input function will be called on each replica separately.
462    `tf.distribute.Strategy` doesn't manage any state sharing between such
463    separate input pipelines.
464  """
465  PER_WORKER = "PER_WORKER"
466  PER_REPLICA = "PER_REPLICA"
467
468
469@tf_export("distribute.InputContext")
470class InputContext(object):
471  """A class wrapping information needed by an input function.
472
473  This is a context class that is passed to the user's input function and
474  contains information about the compute replicas and input pipelines. The
475  number of compute replicas (in sync training) helps compute the local batch
476  size from the desired global batch size for each replica. The input pipeline
477  information can be used to return a different subset of the input in each
478  replica (for e.g. shard the input pipeline, use a different input
479  source etc).
480  """
481
482  __slots__ = [
483      "_num_input_pipelines", "_input_pipeline_id", "_num_replicas_in_sync"
484  ]
485
486  def __init__(self,
487               num_input_pipelines=1,
488               input_pipeline_id=0,
489               num_replicas_in_sync=1):
490    """Initializes an InputContext object.
491
492    Args:
493      num_input_pipelines: the number of input pipelines in a cluster.
494      input_pipeline_id: the current input pipeline id, should be an int in
495        [0,`num_input_pipelines`).
496      num_replicas_in_sync: the number of replicas that are in sync.
497    """
498    self._num_input_pipelines = num_input_pipelines
499    self._input_pipeline_id = input_pipeline_id
500    self._num_replicas_in_sync = num_replicas_in_sync
501
502  @property
503  def num_replicas_in_sync(self):
504    """Returns the number of compute replicas in sync."""
505    return self._num_replicas_in_sync
506
507  @property
508  def input_pipeline_id(self):
509    """Returns the input pipeline ID."""
510    return self._input_pipeline_id
511
512  @property
513  def num_input_pipelines(self):
514    """Returns the number of input pipelines."""
515    return self._num_input_pipelines
516
517  def get_per_replica_batch_size(self, global_batch_size):
518    """Returns the per-replica batch size.
519
520    Args:
521      global_batch_size: the global batch size which should be divisible by
522        `num_replicas_in_sync`.
523
524    Returns:
525      the per-replica batch size.
526
527    Raises:
528      ValueError: if `global_batch_size` not divisible by
529        `num_replicas_in_sync`.
530    """
531    if global_batch_size % self._num_replicas_in_sync != 0:
532      raise ValueError("The `global_batch_size` %r is not divisible by "
533                       "`num_replicas_in_sync` %r " %
534                       (global_batch_size, self._num_replicas_in_sync))
535    return global_batch_size // self._num_replicas_in_sync
536
537  def __str__(self):
538    return "tf.distribute.InputContext(input pipeline id {}, total: {})".format(
539        self.input_pipeline_id, self.num_input_pipelines)
540
541
542@tf_export("distribute.experimental.ValueContext", v1=[])
543class ValueContext(object):
544  """A class wrapping information needed by a distribute function.
545
546  This is a context class that is passed to the `value_fn` in
547  `strategy.experimental_distribute_values_from_function` and contains
548  information about the compute replicas. The `num_replicas_in_sync` and
549  `replica_id` can be used to customize the value on each replica.
550
551  Example usage:
552
553  1. Directly constructed.
554
555  >>> def value_fn(context):
556  ...   return context.replica_id_in_sync_group/context.num_replicas_in_sync
557  >>> context = tf.distribute.experimental.ValueContext(
558  ...   replica_id_in_sync_group=2, num_replicas_in_sync=4)
559  >>> per_replica_value = value_fn(context)
560  >>> per_replica_value
561  0.5
562
563  2. Passed in by `experimental_distribute_values_from_function`.
564
565  >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
566  >>> def value_fn(value_context):
567  ...   return value_context.num_replicas_in_sync
568  >>> distributed_values = (
569  ...      strategy.experimental_distribute_values_from_function(
570  ...        value_fn))
571  >>> local_result = strategy.experimental_local_results(distributed_values)
572  >>> local_result
573  (2, 2)
574
575  """
576
577  __slots__ = ["_replica_id_in_sync_group", "_num_replicas_in_sync"]
578
579  def __init__(self,
580               replica_id_in_sync_group=0,
581               num_replicas_in_sync=1):
582    """Initializes an ValueContext object.
583
584    Args:
585      replica_id_in_sync_group: the current replica_id, should be an int in
586        [0,`num_replicas_in_sync`).
587      num_replicas_in_sync: the number of replicas that are in sync.
588    """
589    self._replica_id_in_sync_group = replica_id_in_sync_group
590    self._num_replicas_in_sync = num_replicas_in_sync
591
592  @property
593  def num_replicas_in_sync(self):
594    """Returns the number of compute replicas in sync."""
595    return self._num_replicas_in_sync
596
597  @property
598  def replica_id_in_sync_group(self):
599    """Returns the replica ID."""
600    return self._replica_id_in_sync_group
601
602  def __str__(self):
603    return (("tf.distribute.ValueContext(replica id {}, "
604             " total replicas in sync: ""{})")
605            .format(self.replica_id_in_sync_group, self.num_replicas_in_sync))
606
607
608@tf_export("distribute.RunOptions")
609class RunOptions(
610    collections.namedtuple("RunOptions", [
611        "experimental_enable_dynamic_batch_size",
612        "experimental_bucketizing_dynamic_shape",
613        "experimental_xla_options",
614    ])):
615  """Run options for `strategy.run`.
616
617  This can be used to hold some strategy specific configs.
618
619  Attributes:
620    experimental_enable_dynamic_batch_size: Boolean. Only applies to
621      TPUStrategy. Default to True. If True, TPUStrategy will enable dynamic
622      padder to support dynamic batch size for the inputs. Otherwise only static
623      shape inputs are allowed.
624    experimental_bucketizing_dynamic_shape: Boolean. Only applies to
625      TPUStrategy. Default to False. If True, TPUStrategy will automatic
626      bucketize inputs passed into `run` if the input shape is
627      dynamic. This is a performance optimization to reduce XLA recompilation,
628      which should not have impact on correctness.
629    experimental_xla_options: A `tf.tpu.XLAOptions` instance. Only applies to
630      TPUStrategy. Controls the XLA compiling options on TPUs. Default to None.
631  """
632
633  def __new__(cls,
634              experimental_enable_dynamic_batch_size=True,
635              experimental_bucketizing_dynamic_shape=False,
636              experimental_xla_options=None):
637    return super(RunOptions,
638                 cls).__new__(cls, experimental_enable_dynamic_batch_size,
639                              experimental_bucketizing_dynamic_shape,
640                              experimental_xla_options)
641
642
643@tf_export("distribute.InputOptions", v1=[])
644class InputOptions(
645    collections.namedtuple("InputOptions", [
646        "experimental_fetch_to_device",
647        "experimental_replication_mode",
648        "experimental_place_dataset_on_device",
649        "experimental_per_replica_buffer_size",
650    ])):
651  """Run options for `experimental_distribute_dataset(s_from_function)`.
652
653  This can be used to hold some strategy specific configs.
654
655  ```python
656  # Setup TPUStrategy
657  resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
658  tf.config.experimental_connect_to_cluster(resolver)
659  tf.tpu.experimental.initialize_tpu_system(resolver)
660  strategy = tf.distribute.TPUStrategy(resolver)
661
662  dataset = tf.data.Dataset.range(16)
663  distributed_dataset_on_host = (
664      strategy.experimental_distribute_dataset(
665          dataset,
666          tf.distribute.InputOptions(
667              experimental_replication_mode=
668              experimental_replication_mode.PER_WORKER,
669              experimental_place_dataset_on_device=False,
670              experimental_per_replica_buffer_size=1)))
671  ```
672
673  Attributes:
674    experimental_fetch_to_device: Boolean. If True, dataset
675      elements will be prefetched to accelerator device memory. When False,
676      dataset elements are prefetched to host device memory. Must be False when
677      using TPUEmbedding API. experimental_fetch_to_device can only be used
678      with experimental_replication_mode=PER_WORKER. Default behavior is same as
679      setting it to True.
680    experimental_replication_mode: Replication mode for the input function.
681      Currently, the InputReplicationMode.PER_REPLICA is only supported with
682      tf.distribute.MirroredStrategy.
683      experimental_distribute_datasets_from_function.
684      The default value is InputReplicationMode.PER_WORKER.
685    experimental_place_dataset_on_device: Boolean. Default to False. When True,
686      dataset will be placed on the device, otherwise it will remain on the
687      host. experimental_place_dataset_on_device=True can only be used with
688      experimental_replication_mode=PER_REPLICA
689    experimental_per_replica_buffer_size: Integer. Default to 1. Indicates the
690      prefetch buffer size in the replica device memory. Users can set it
691      to 0 to completely disable prefetching behavior, or a number greater than
692      1 to enable larger buffer size. Note that this option is still
693      valid with `experimental_fetch_to_device=False`.
694  """
695
696  def __new__(cls,
697              experimental_fetch_to_device=None,
698              experimental_replication_mode=InputReplicationMode.PER_WORKER,
699              experimental_place_dataset_on_device=False,
700              experimental_per_replica_buffer_size=1):
701    if experimental_fetch_to_device is None:
702      experimental_fetch_to_device = True
703
704    return super(InputOptions,
705                 cls).__new__(cls, experimental_fetch_to_device,
706                              experimental_replication_mode,
707                              experimental_place_dataset_on_device,
708                              experimental_per_replica_buffer_size)
709
710# ------------------------------------------------------------------------------
711# Base classes for all distribution strategies.
712
713
714# Base class for v1 Strategy and v2 Strategy classes. For API's specific to
715# v1/v2 Strategy, add to implementing classes of StrategyBase.
716# pylint: disable=line-too-long
717class StrategyBase(object):
718  """A state & compute distribution policy on a list of devices.
719
720  See [the guide](https://www.tensorflow.org/guide/distributed_training)
721  for overview and examples. See `tf.distribute.StrategyExtended` and
722  [`tf.distribute`](https://www.tensorflow.org/api_docs/python/tf/distribute)
723  for a glossary of concepts mentioned on this page such as "per-replica",
724  _replica_, and _reduce_.
725
726  In short:
727
728  * To use it with Keras `compile`/`fit`,
729    [please
730    read](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_keras).
731  * You may pass descendant of `tf.distribute.Strategy` to
732    `tf.estimator.RunConfig` to specify how a `tf.estimator.Estimator`
733    should distribute its computation. See
734    [guide](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_estimator_limited_support).
735  * Otherwise, use `tf.distribute.Strategy.scope` to specify that a
736    strategy should be used when building an executing your model.
737    (This puts you in the "cross-replica context" for this strategy, which
738    means the strategy is put in control of things like variable placement.)
739  * If you are writing a custom training loop, you will need to call a few more
740    methods,
741    [see the
742    guide](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_custom_training_loops):
743
744      * Start by creating a `tf.data.Dataset` normally.
745      * Use `tf.distribute.Strategy.experimental_distribute_dataset` to convert
746        a `tf.data.Dataset` to something that produces "per-replica" values.
747        If you want to manually specify how the dataset should be partitioned
748        across replicas, use
749        `tf.distribute.Strategy.distribute_datasets_from_function`
750        instead.
751      * Use `tf.distribute.Strategy.run` to run a function
752        once per replica, taking values that may be "per-replica" (e.g.
753        from a `tf.distribute.DistributedDataset` object) and returning
754        "per-replica" values.
755        This function is executed in "replica context", which means each
756        operation is performed separately on each replica.
757      * Finally use a method (such as `tf.distribute.Strategy.reduce`) to
758        convert the resulting "per-replica" values into ordinary `Tensor`s.
759
760  A custom training loop can be as simple as:
761
762  ```
763  with my_strategy.scope():
764    @tf.function
765    def distribute_train_epoch(dataset):
766      def replica_fn(input):
767        # process input and return result
768        return result
769
770      total_result = 0
771      for x in dataset:
772        per_replica_result = my_strategy.run(replica_fn, args=(x,))
773        total_result += my_strategy.reduce(tf.distribute.ReduceOp.SUM,
774                                           per_replica_result, axis=None)
775      return total_result
776
777    dist_dataset = my_strategy.experimental_distribute_dataset(dataset)
778    for _ in range(EPOCHS):
779      train_result = distribute_train_epoch(dist_dataset)
780  ```
781
782  This takes an ordinary `dataset` and `replica_fn` and runs it
783  distributed using a particular `tf.distribute.Strategy` named
784  `my_strategy` above. Any variables created in `replica_fn` are created
785  using `my_strategy`'s policy, and library functions called by
786  `replica_fn` can use the `get_replica_context()` API to implement
787  distributed-specific behavior.
788
789  You can use the `reduce` API to aggregate results across replicas and use
790  this as a return value from one iteration over a
791  `tf.distribute.DistributedDataset`. Or
792  you can use `tf.keras.metrics` (such as loss, accuracy, etc.) to
793  accumulate metrics across steps in a given epoch.
794
795  See the
796  [custom training loop
797  tutorial](https://www.tensorflow.org/tutorials/distribute/custom_training)
798  for a more detailed example.
799
800  Note: `tf.distribute.Strategy` currently does not support TensorFlow's
801  partitioned variables (where a single variable is split across multiple
802  devices) at this time.
803  """
804  # pylint: enable=line-too-long
805
806  # TODO(josh11b): Partitioned computations, state; sharding
807  # TODO(josh11b): Model parallelism: "replicas" with multiple devices; shuffling
808
809  def __init__(self, extended):
810    self._extended = extended
811
812    # Flag that is used to indicate whether distribution strategy is used with
813    # Estimator. This is required for backward compatibility of loss scaling
814    # when using v1 optimizer with estimator.
815    self._scale_loss_for_estimator = False
816
817    if not hasattr(extended, "_retrace_functions_for_each_device"):
818      # pylint: disable=protected-access
819      # `extended._retrace_functions_for_each_device` dictates
820      # whether the same function will be retraced when it is called on
821      # different devices.
822      try:
823        extended._retrace_functions_for_each_device = (
824            len(extended.worker_devices) > 1)
825        distribution_strategy_replica_gauge.get_cell("num_replicas").set(
826            self.num_replicas_in_sync)
827      except:  # pylint: disable=bare-except
828        # Default for the case where extended.worker_devices can't return
829        # a sensible value.
830        extended._retrace_functions_for_each_device = True
831
832    # Below are the dicts of axis(int) -> `tf.function`.
833    self._mean_reduce_helper_fns = {}
834    self._reduce_sum_fns = {}
835
836    # Whether this strategy is designed to work with `ClusterCoordinator`.
837    self._should_use_with_coordinator = False
838
839  @property
840  def extended(self):
841    """`tf.distribute.StrategyExtended` with additional methods."""
842    return self._extended
843
844  @tf_contextlib.contextmanager
845  def _scale_loss_for_estimator_enabled(self):
846    """Scope which sets a flag used for scaling losses in optimizer.
847
848    Yields:
849      `_scale_loss_for_estimator_enabled` is a context manager with a
850      side effect, but doesn't return a value.
851    """
852    self._scale_loss_for_estimator = True
853    try:
854      yield
855    finally:
856      self._scale_loss_for_estimator = False
857
858  # pylint: disable=line-too-long
859  def scope(self):
860    """Context manager to make the strategy current and distribute variables.
861
862    This method returns a context manager, and is used as follows:
863
864    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
865    >>> # Variable created inside scope:
866    >>> with strategy.scope():
867    ...   mirrored_variable = tf.Variable(1.)
868    >>> mirrored_variable
869    MirroredVariable:{
870      0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>,
871      1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=1.0>
872    }
873    >>> # Variable created outside scope:
874    >>> regular_variable = tf.Variable(1.)
875    >>> regular_variable
876    <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>
877
878    _What happens when Strategy.scope is entered?_
879
880    * `strategy` is installed in the global context as the "current" strategy.
881      Inside this scope, `tf.distribute.get_strategy()` will now return this
882      strategy. Outside this scope, it returns the default no-op strategy.
883    * Entering the scope also enters the "cross-replica context". See
884      `tf.distribute.StrategyExtended` for an explanation on cross-replica and
885      replica contexts.
886    * Variable creation inside `scope` is intercepted by the strategy. Each
887      strategy defines how it wants to affect the variable creation. Sync
888      strategies like `MirroredStrategy`, `TPUStrategy` and
889      `MultiWorkerMiroredStrategy` create variables replicated on each replica,
890      whereas `ParameterServerStrategy` creates variables on the parameter
891      servers. This is done using a custom `tf.variable_creator_scope`.
892    * In some strategies, a default device scope may also be entered: in
893      `MultiWorkerMiroredStrategy`, a default device scope of "/CPU:0" is
894      entered on each worker.
895
896    Note: Entering a scope does not automatically distribute a computation, except
897      in the case of high level training framework like keras `model.fit`. If
898      you're not using `model.fit`, you
899      need to use `strategy.run` API to explicitly distribute that computation.
900      See an example in the [custom training loop tutorial](https://www.tensorflow.org/tutorials/distribute/custom_training).
901
902
903    _What should be in scope and what should be outside?_
904
905    There are a number of requirements on what needs to happen inside the scope.
906    However, in places where we have information about which strategy is in use,
907    we often enter the scope for the user, so they don't have to do it
908    explicitly (i.e. calling those either inside or outside the scope is OK).
909
910    * Anything that creates variables that should be distributed variables
911      must be called in a `strategy.scope`. This can be accomplished either by
912      directly calling the variable creating function within the scope context,
913      or by relying on another API like `strategy.run` or `keras.Model.fit` to
914      automatically enter it for you. Any variable that is created outside scope
915      will not be distributed and may have performance implications. Some common
916      objects that create variables in TF are Models, Optimizers, Metrics. Such
917      objects should always be initialized in the scope, and any functions
918      that may lazily create variables (e.g., `Model.__call__()`, tracing a
919      `tf.function`, etc.) should similarly be called within scope. Another
920      source of variable creation can be a checkpoint restore - when variables
921      are created lazily. Note that any variable created inside a strategy
922      captures the strategy information. So reading and writing to these
923      variables outside the `strategy.scope` can also work seamlessly, without
924      the user having to enter the scope.
925    * Some strategy APIs (such as `strategy.run` and `strategy.reduce`) which
926      require to be in a strategy's scope, enter the scope automatically, which
927      means when using those APIs you don't need to explicitly enter the scope
928      yourself.
929    * When a `tf.keras.Model` is created inside a `strategy.scope`, the Model
930      object captures the scope information. When high level training framework
931      methods such as `model.compile`, `model.fit`, etc. are then called, the
932      captured scope will be automatically entered, and the associated strategy
933      will be used to distribute the training etc. See a detailed example in
934      [distributed keras tutorial](https://www.tensorflow.org/tutorials/distribute/keras).
935      WARNING: Simply calling `model(..)` does not automatically enter the
936      captured scope -- only high level training framework APIs support this
937      behavior: `model.compile`, `model.fit`, `model.evaluate`, `model.predict`
938      and `model.save` can all be called inside or outside the scope.
939    * The following can be either inside or outside the scope:
940        * Creating the input datasets
941        * Defining `tf.function`s that represent your training step
942        * Saving APIs such as `tf.saved_model.save`. Loading creates variables,
943          so that should go inside the scope if you want to train the model in a
944          distributed way.
945        * Checkpoint saving. As mentioned above - `checkpoint.restore` may
946          sometimes need to be inside scope if it creates variables.
947
948    Returns:
949      A context manager.
950    """
951    return self._extended._scope(self)  # pylint: disable=protected-access
952  # pylint: enable=line-too-long
953
954  @doc_controls.do_not_doc_inheritable  # DEPRECATED, moving to `extended`
955  @deprecated(None, "use extended.colocate_vars_with() instead.")
956  def colocate_vars_with(self, colocate_with_variable):
957    """DEPRECATED: use extended.colocate_vars_with() instead."""
958    return self._extended.colocate_vars_with(colocate_with_variable)
959
960  @doc_controls.do_not_generate_docs  # DEPRECATED: TF 1.x only
961  def make_dataset_iterator(self, dataset):
962    """DEPRECATED TF 1.x ONLY."""
963    return self._extended._make_dataset_iterator(dataset)  # pylint: disable=protected-access
964
965  @doc_controls.do_not_generate_docs  # DEPRECATED: TF 1.x only
966  def make_input_fn_iterator(self,
967                             input_fn,
968                             replication_mode=InputReplicationMode.PER_WORKER):
969    """DEPRECATED TF 1.x ONLY."""
970    if replication_mode != InputReplicationMode.PER_WORKER:
971      raise ValueError(
972          "Input replication mode not supported: %r" % replication_mode)
973    with self.scope():
974      return self.extended._make_input_fn_iterator(  # pylint: disable=protected-access
975          input_fn, replication_mode=replication_mode)
976
977  @doc_controls.do_not_generate_docs  # DEPRECATED: TF 1.x only
978  @deprecated(None, "use run() instead")
979  def experimental_run(self, fn, input_iterator=None):
980    """DEPRECATED TF 1.x ONLY."""
981    with self.scope():
982      args = (input_iterator.get_next(),) if input_iterator is not None else ()
983    return self.run(fn, args=args)
984
985  def experimental_distribute_dataset(self, dataset, options=None):
986    # pylint: disable=line-too-long
987    """Creates `tf.distribute.DistributedDataset` from `tf.data.Dataset`.
988
989    The returned `tf.distribute.DistributedDataset` can be iterated over
990    similar to regular datasets.
991    NOTE: The user cannot add any more transformations to a
992    `tf.distribute.DistributedDataset`. You can only create an iterator or
993    examine the `tf.TypeSpec` of the data generated by it. See API docs of
994    `tf.distribute.DistributedDataset` to learn more.
995
996    The following is an example:
997
998    >>> global_batch_size = 2
999    >>> # Passing the devices is optional.
1000    ... strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"])
1001    >>> # Create a dataset
1002    ... dataset = tf.data.Dataset.range(4).batch(global_batch_size)
1003    >>> # Distribute that dataset
1004    ... dist_dataset = strategy.experimental_distribute_dataset(dataset)
1005    >>> @tf.function
1006    ... def replica_fn(input):
1007    ...   return input*2
1008    >>> result = []
1009    >>> # Iterate over the `tf.distribute.DistributedDataset`
1010    ... for x in dist_dataset:
1011    ...   # process dataset elements
1012    ...   result.append(strategy.run(replica_fn, args=(x,)))
1013    >>> print(result)
1014    [PerReplica:{
1015      0: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([0])>,
1016      1: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([2])>
1017    }, PerReplica:{
1018      0: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([4])>,
1019      1: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([6])>
1020    }]
1021
1022
1023    Three key actions happening under the hood of this method are batching,
1024    sharding, and prefetching.
1025
1026    In the code snippet above, `dataset` is batched by `global_batch_size`, and
1027    calling `experimental_distribute_dataset` on it rebatches `dataset` to a
1028    new batch size that is equal to the global batch size divided by the number
1029    of replicas in sync. We iterate through it using a Pythonic for loop.
1030    `x` is a `tf.distribute.DistributedValues` containing data for all replicas,
1031    and each replica gets data of the new batch size.
1032    `tf.distribute.Strategy.run` will take care of feeding the right per-replica
1033    data in `x` to the right `replica_fn` executed on each replica.
1034
1035    Sharding contains autosharding across multiple workers and within every
1036    worker. First, in multi-worker distributed training (i.e. when you use
1037    `tf.distribute.experimental.MultiWorkerMirroredStrategy`
1038    or `tf.distribute.TPUStrategy`), autosharding a dataset over a set of
1039    workers means that each worker is assigned a subset of the entire dataset
1040    (if the right `tf.data.experimental.AutoShardPolicy` is set). This is to
1041    ensure that at each step, a global batch size of non-overlapping dataset
1042    elements will be processed by each worker. Autosharding has a couple of
1043    different options that can be specified using
1044    `tf.data.experimental.DistributeOptions`. Then, sharding within each worker
1045    means the method will split the data among all the worker devices (if more
1046    than one a present). This will happen regardless of multi-worker
1047    autosharding.
1048
1049    Note: for autosharding across multiple workers, the default mode is
1050    `tf.data.experimental.AutoShardPolicy.AUTO`. This mode
1051    will attempt to shard the input dataset by files if the dataset is
1052    being created out of reader datasets (e.g. `tf.data.TFRecordDataset`,
1053    `tf.data.TextLineDataset`, etc.) or otherwise shard the dataset by data,
1054    where each of the workers will read the entire dataset and only process the
1055    shard assigned to it. However, if you have less than one input file per
1056    worker, we suggest that you disable dataset autosharding across workers by
1057    setting the `tf.data.experimental.DistributeOptions.auto_shard_policy` to be
1058    `tf.data.experimental.AutoShardPolicy.OFF`.
1059
1060    By default, this method adds a prefetch transformation at the end of the
1061    user provided `tf.data.Dataset` instance. The argument to the prefetch
1062    transformation which is `buffer_size` is equal to the number of replicas in
1063    sync.
1064
1065    If the above batch splitting and dataset sharding logic is undesirable,
1066    please use
1067    `tf.distribute.Strategy.distribute_datasets_from_function`
1068    instead, which does not do any automatic batching or sharding for you.
1069
1070    Note: If you are using TPUStrategy, the order in which the data is processed
1071    by the workers when using
1072    `tf.distribute.Strategy.experimental_distribute_dataset` or
1073    `tf.distribute.Strategy.distribute_datasets_from_function` is
1074    not guaranteed. This is typically required if you are using
1075    `tf.distribute` to scale prediction. You can however insert an index for
1076    each element in the batch and order outputs accordingly. Refer to [this
1077    snippet](https://www.tensorflow.org/tutorials/distribute/input#caveats)
1078    for an example of how to order outputs.
1079
1080    Note: Stateful dataset transformations are currently not supported with
1081    `tf.distribute.experimental_distribute_dataset` or
1082    `tf.distribute.distribute_datasets_from_function`. Any stateful
1083    ops that the dataset may have are currently ignored. For example, if your
1084    dataset has a `map_fn` that uses `tf.random.uniform` to rotate an image,
1085    then you have a dataset graph that depends on state (i.e the random seed) on
1086    the local machine where the python process is being executed.
1087
1088    For a tutorial on more usage and properties of this method, refer to the
1089    [tutorial on distributed input](https://www.tensorflow.org/tutorials/distribute/input#tfdistributestrategyexperimental_distribute_dataset).
1090    If you are interested in last partial batch handling, read [this section](https://www.tensorflow.org/tutorials/distribute/input#partial_batches).
1091
1092    Args:
1093      dataset: `tf.data.Dataset` that will be sharded across all replicas using
1094        the rules stated above.
1095      options: `tf.distribute.InputOptions` used to control options on how this
1096        dataset is distributed.
1097
1098    Returns:
1099      A `tf.distribute.DistributedDataset`.
1100    """
1101    distribution_strategy_input_api_counter.get_cell(
1102        self.__class__.__name__, "distribute_dataset").increase_by(1)
1103    # pylint: enable=line-too-long
1104    return self._extended._experimental_distribute_dataset(dataset, options)  # pylint: disable=protected-access
1105
1106  def distribute_datasets_from_function(self, dataset_fn, options=None):
1107    # pylint: disable=line-too-long
1108    """Distributes `tf.data.Dataset` instances created by calls to `dataset_fn`.
1109
1110    The argument `dataset_fn` that users pass in is an input function that has a
1111    `tf.distribute.InputContext` argument and returns a `tf.data.Dataset`
1112    instance. It is expected that the returned dataset from `dataset_fn` is
1113    already batched by per-replica batch size (i.e. global batch size divided by
1114    the number of replicas in sync) and sharded.
1115    `tf.distribute.Strategy.distribute_datasets_from_function` does
1116    not batch or shard the `tf.data.Dataset` instance
1117    returned from the input function. `dataset_fn` will be called on the CPU
1118    device of each of the workers and each generates a dataset where every
1119    replica on that worker will dequeue one batch of inputs (i.e. if a worker
1120    has two replicas, two batches will be dequeued from the `Dataset` every
1121    step).
1122
1123    This method can be used for several purposes. First, it allows you to
1124    specify your own batching and sharding logic. (In contrast,
1125    `tf.distribute.experimental_distribute_dataset` does batching and sharding
1126    for you.) For example, where
1127    `experimental_distribute_dataset` is unable to shard the input files, this
1128    method might be used to manually shard the dataset (avoiding the slow
1129    fallback behavior in `experimental_distribute_dataset`). In cases where the
1130    dataset is infinite, this sharding can be done by creating dataset replicas
1131    that differ only in their random seed.
1132
1133    The `dataset_fn` should take an `tf.distribute.InputContext` instance where
1134    information about batching and input replication can be accessed.
1135
1136    You can use `element_spec` property of the
1137    `tf.distribute.DistributedDataset` returned by this API to query the
1138    `tf.TypeSpec` of the elements returned by the iterator. This can be used to
1139    set the `input_signature` property of a `tf.function`. Follow
1140    `tf.distribute.DistributedDataset.element_spec` to see an example.
1141
1142    IMPORTANT: The `tf.data.Dataset` returned by `dataset_fn` should have a
1143    per-replica batch size, unlike `experimental_distribute_dataset`, which uses
1144    the global batch size. This may be computed using
1145    `input_context.get_per_replica_batch_size`.
1146
1147    Note: If you are using TPUStrategy, the order in which the data is processed
1148    by the workers when using
1149    `tf.distribute.Strategy.experimental_distribute_dataset` or
1150    `tf.distribute.Strategy.distribute_datasets_from_function` is
1151    not guaranteed. This is typically required if you are using
1152    `tf.distribute` to scale prediction. You can however insert an index for
1153    each element in the batch and order outputs accordingly. Refer to [this
1154    snippet](https://www.tensorflow.org/tutorials/distribute/input#caveats)
1155    for an example of how to order outputs.
1156
1157    Note: Stateful dataset transformations are currently not supported with
1158    `tf.distribute.experimental_distribute_dataset` or
1159    `tf.distribute.distribute_datasets_from_function`. Any stateful
1160    ops that the dataset may have are currently ignored. For example, if your
1161    dataset has a `map_fn` that uses `tf.random.uniform` to rotate an image,
1162    then you have a dataset graph that depends on state (i.e the random seed) on
1163    the local machine where the python process is being executed.
1164
1165    For a tutorial on more usage and properties of this method, refer to the
1166    [tutorial on distributed input](https://www.tensorflow.org/tutorials/distribute/input#tfdistributestrategyexperimental_distribute_datasets_from_function)).
1167    If you are interested in last partial batch handling, read [this section](https://www.tensorflow.org/tutorials/distribute/input#partial_batches).
1168
1169    Args:
1170      dataset_fn: A function taking a `tf.distribute.InputContext` instance and
1171        returning a `tf.data.Dataset`.
1172      options: `tf.distribute.InputOptions` used to control options on how this
1173        dataset is distributed.
1174
1175    Returns:
1176      A `tf.distribute.DistributedDataset`.
1177    """
1178    distribution_strategy_input_api_counter.get_cell(
1179        self.__class__.__name__,
1180        "distribute_datasets_from_function").increase_by(1)
1181    # pylint: enable=line-too-long
1182    return self._extended._distribute_datasets_from_function(  # pylint: disable=protected-access
1183        dataset_fn, options)
1184
1185  # TODO(b/162776748): Remove deprecated symbol.
1186  @doc_controls.do_not_doc_inheritable
1187  @deprecation.deprecated(None, "rename to distribute_datasets_from_function")
1188  def experimental_distribute_datasets_from_function(self,
1189                                                     dataset_fn,
1190                                                     options=None):
1191    return self.distribute_datasets_from_function(dataset_fn, options)
1192
1193  def run(self, fn, args=(), kwargs=None, options=None):
1194    """Invokes `fn` on each replica, with the given arguments.
1195
1196    This method is the primary way to distribute your computation with a
1197    tf.distribute object. It invokes `fn` on each replica. If `args` or `kwargs`
1198    have `tf.distribute.DistributedValues`, such as those produced by a
1199    `tf.distribute.DistributedDataset` from
1200    `tf.distribute.Strategy.experimental_distribute_dataset` or
1201    `tf.distribute.Strategy.distribute_datasets_from_function`,
1202    when `fn` is executed on a particular replica, it will be executed with the
1203    component of `tf.distribute.DistributedValues` that correspond to that
1204    replica.
1205
1206    `fn` is invoked under a replica context. `fn` may call
1207    `tf.distribute.get_replica_context()` to access members such as
1208    `all_reduce`. Please see the module-level docstring of tf.distribute for the
1209    concept of replica context.
1210
1211    All arguments in `args` or `kwargs` can be a nested structure of tensors,
1212    e.g. a list of tensors, in which case `args` and `kwargs` will be passed to
1213    the `fn` invoked on each replica. Or `args` or `kwargs` can be
1214    `tf.distribute.DistributedValues` containing tensors or composite tensors,
1215    i.e. `tf.compat.v1.TensorInfo.CompositeTensor`, in which case each `fn` call
1216    will get the component of a `tf.distribute.DistributedValues` corresponding
1217    to its replica. Note that arbitrary Python values that are not of the types
1218    above are not supported.
1219
1220    IMPORTANT: Depending on the implementation of `tf.distribute.Strategy` and
1221    whether eager execution is enabled, `fn` may be called one or more times. If
1222    `fn` is annotated with `tf.function` or `tf.distribute.Strategy.run` is
1223    called inside a `tf.function` (eager execution is disabled inside a
1224    `tf.function` by default), `fn` is called once per replica to generate a
1225    Tensorflow graph, which will then be reused for execution with new inputs.
1226    Otherwise, if eager execution is enabled, `fn` will be called once per
1227    replica every step just like regular python code.
1228
1229    Example usage:
1230
1231    1. Constant tensor input.
1232
1233    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
1234    >>> tensor_input = tf.constant(3.0)
1235    >>> @tf.function
1236    ... def replica_fn(input):
1237    ...   return input*2.0
1238    >>> result = strategy.run(replica_fn, args=(tensor_input,))
1239    >>> result
1240    PerReplica:{
1241      0: <tf.Tensor: shape=(), dtype=float32, numpy=6.0>,
1242      1: <tf.Tensor: shape=(), dtype=float32, numpy=6.0>
1243    }
1244
1245    2. DistributedValues input.
1246
1247    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
1248    >>> @tf.function
1249    ... def run():
1250    ...   def value_fn(value_context):
1251    ...     return value_context.num_replicas_in_sync
1252    ...   distributed_values = (
1253    ...     strategy.experimental_distribute_values_from_function(
1254    ...       value_fn))
1255    ...   def replica_fn2(input):
1256    ...     return input*2
1257    ...   return strategy.run(replica_fn2, args=(distributed_values,))
1258    >>> result = run()
1259    >>> result
1260    <tf.Tensor: shape=(), dtype=int32, numpy=4>
1261
1262    3. Use `tf.distribute.ReplicaContext` to allreduce values.
1263
1264    >>> strategy = tf.distribute.MirroredStrategy(["gpu:0", "gpu:1"])
1265    >>> @tf.function
1266    ... def run():
1267    ...    def value_fn(value_context):
1268    ...      return tf.constant(value_context.replica_id_in_sync_group)
1269    ...    distributed_values = (
1270    ...        strategy.experimental_distribute_values_from_function(
1271    ...            value_fn))
1272    ...    def replica_fn(input):
1273    ...      return tf.distribute.get_replica_context().all_reduce("sum", input)
1274    ...    return strategy.run(replica_fn, args=(distributed_values,))
1275    >>> result = run()
1276    >>> result
1277    PerReplica:{
1278      0: <tf.Tensor: shape=(), dtype=int32, numpy=1>,
1279      1: <tf.Tensor: shape=(), dtype=int32, numpy=1>
1280    }
1281
1282    Args:
1283      fn: The function to run on each replica.
1284      args: Optional positional arguments to `fn`. Its element can be a tensor,
1285        a nested structure of tensors or a `tf.distribute.DistributedValues`.
1286      kwargs: Optional keyword arguments to `fn`. Its element can be a tensor,
1287        a nested structure of tensors or a `tf.distribute.DistributedValues`.
1288      options: An optional instance of `tf.distribute.RunOptions` specifying
1289        the options to run `fn`.
1290
1291    Returns:
1292      Merged return value of `fn` across replicas. The structure of the return
1293      value is the same as the return value from `fn`. Each element in the
1294      structure can either be `tf.distribute.DistributedValues`, `Tensor`
1295      objects, or `Tensor`s (for example, if running on a single replica).
1296    """
1297    del options
1298
1299    if not isinstance(args, (list, tuple)):
1300      raise ValueError(
1301          "positional args must be a list or tuple, got {}".format(type(args)))
1302
1303    with self.scope():
1304      # tf.distribute supports Eager functions, so AutoGraph should not be
1305      # applied when the caller is also in Eager mode.
1306      fn = autograph.tf_convert(
1307          fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
1308      return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
1309
1310  def reduce(self, reduce_op, value, axis):
1311    """Reduce `value` across replicas and return result on current device.
1312
1313    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
1314    >>> def step_fn():
1315    ...   i = tf.distribute.get_replica_context().replica_id_in_sync_group
1316    ...   return tf.identity(i)
1317    >>>
1318    >>> per_replica_result = strategy.run(step_fn)
1319    >>> total = strategy.reduce("SUM", per_replica_result, axis=None)
1320    >>> total
1321    <tf.Tensor: shape=(), dtype=int32, numpy=1>
1322
1323    To see how this would look with multiple replicas, consider the same
1324    example with MirroredStrategy with 2 GPUs:
1325
1326    ```python
1327    strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"])
1328    def step_fn():
1329      i = tf.distribute.get_replica_context().replica_id_in_sync_group
1330      return tf.identity(i)
1331
1332    per_replica_result = strategy.run(step_fn)
1333    # Check devices on which per replica result is:
1334    strategy.experimental_local_results(per_replica_result)[0].device
1335    # /job:localhost/replica:0/task:0/device:GPU:0
1336    strategy.experimental_local_results(per_replica_result)[1].device
1337    # /job:localhost/replica:0/task:0/device:GPU:1
1338
1339    total = strategy.reduce("SUM", per_replica_result, axis=None)
1340    # Check device on which reduced result is:
1341    total.device
1342    # /job:localhost/replica:0/task:0/device:CPU:0
1343
1344    ```
1345
1346    This API is typically used for aggregating the results returned from
1347    different replicas, for reporting etc. For example, loss computed from
1348    different replicas can be averaged using this API before printing.
1349
1350    Note: The result is copied to the "current" device - which would typically
1351    be the CPU of the worker on which the program is running. For `TPUStrategy`,
1352    it is the first TPU host. For multi client `MultiWorkerMirroredStrategy`,
1353    this is CPU of each worker.
1354
1355    There are a number of different tf.distribute APIs for reducing values
1356    across replicas:
1357    * `tf.distribute.ReplicaContext.all_reduce`: This differs from
1358    `Strategy.reduce` in that it is for replica context and does
1359    not copy the results to the host device. `all_reduce` should be typically
1360    used for reductions inside the training step such as gradients.
1361    * `tf.distribute.StrategyExtended.reduce_to` and
1362    `tf.distribute.StrategyExtended.batch_reduce_to`: These APIs are more
1363    advanced versions of `Strategy.reduce` as they allow customizing the
1364    destination of the result. They are also called in cross replica context.
1365
1366    _What should axis be?_
1367
1368    Given a per-replica value returned by `run`, say a
1369    per-example loss, the batch will be divided across all the replicas.  This
1370    function allows you to aggregate across replicas and optionally also across
1371    batch elements by specifying the axis parameter accordingly.
1372
1373    For example, if you have a global batch size of 8 and 2
1374    replicas, values for examples `[0, 1, 2, 3]` will be on replica 0 and
1375    `[4, 5, 6, 7]` will be on replica 1. With `axis=None`, `reduce` will
1376    aggregate only across replicas, returning `[0+4, 1+5, 2+6, 3+7]`.
1377    This is useful when each replica is computing a scalar or some other value
1378    that doesn't have a "batch" dimension (like a gradient or loss).
1379    ```
1380    strategy.reduce("sum", per_replica_result, axis=None)
1381    ```
1382
1383    Sometimes, you will want to aggregate across both the global batch _and_
1384    all replicas. You can get this behavior by specifying the batch
1385    dimension as the `axis`, typically `axis=0`. In this case it would return a
1386    scalar `0+1+2+3+4+5+6+7`.
1387    ```
1388    strategy.reduce("sum", per_replica_result, axis=0)
1389    ```
1390
1391    If there is a last partial batch, you will need to specify an axis so
1392    that the resulting shape is consistent across replicas. So if the last
1393    batch has size 6 and it is divided into [0, 1, 2, 3] and [4, 5], you
1394    would get a shape mismatch unless you specify `axis=0`. If you specify
1395    `tf.distribute.ReduceOp.MEAN`, using `axis=0` will use the correct
1396    denominator of 6. Contrast this with computing `reduce_mean` to get a
1397    scalar value on each replica and this function to average those means,
1398    which will weigh some values `1/8` and others `1/4`.
1399
1400    Args:
1401      reduce_op: a `tf.distribute.ReduceOp` value specifying how values should
1402        be combined. Allows using string representation of the enum such as
1403        "SUM", "MEAN".
1404      value: a `tf.distribute.DistributedValues` instance, e.g. returned by
1405        `Strategy.run`, to be combined into a single tensor. It can also be a
1406        regular tensor when used with `OneDeviceStrategy` or default strategy.
1407      axis: specifies the dimension to reduce along within each
1408        replica's tensor. Should typically be set to the batch dimension, or
1409        `None` to only reduce across replicas (e.g. if the tensor has no batch
1410        dimension).
1411
1412    Returns:
1413      A `Tensor`.
1414    """
1415    # TODO(josh11b): support `value` being a nest.
1416    _require_cross_replica_or_default_context_extended(self._extended)
1417    if isinstance(reduce_op, six.string_types):
1418      reduce_op = reduce_util.ReduceOp(reduce_op.upper())
1419    if axis is None:
1420      return self._extended._reduce(reduce_op, value)  # pylint: disable=protected-access
1421    if reduce_op == reduce_util.ReduceOp.SUM:
1422
1423      def reduce_sum(v):
1424        return math_ops.reduce_sum(v, axis=axis)
1425
1426      if eager_context.executing_eagerly():
1427        # As some strategies (e.g. TPUStrategy) doesn't support pure eager
1428        # execution, wrap the `reduce_sum_fn` with a `tf.function` so it can be
1429        # run from eager mode. Cache the tf.function by `axis` to avoid the
1430        # same function to be traced again.
1431        if axis not in self._reduce_sum_fns:
1432
1433          def reduce_sum_fn(v):
1434            return self.run(reduce_sum, args=(v,))
1435
1436          self._reduce_sum_fns[axis] = def_function.function(reduce_sum_fn)
1437        value = self._reduce_sum_fns[axis](value)
1438      else:
1439        value = self.run(reduce_sum, args=(value,))
1440
1441      return self._extended._reduce(reduce_op, value)  # pylint: disable=protected-access
1442    if reduce_op != reduce_util.ReduceOp.MEAN:
1443      raise TypeError("Expected `reduce_op` to be a `tf.distribute.ReduceOp`, "
1444                      "not: %r" % reduce_op)
1445
1446    def mean_reduce_helper(v, axes=axis):
1447      """Computes the numerator and denominator on each replica."""
1448      numer = math_ops.reduce_sum(v, axis=axes)
1449      def dimension(axis):
1450        if v.shape.rank is not None:
1451          # Note(joshl): We support axis < 0 to be consistent with the
1452          # tf.math.reduce_* operations.
1453          if axis < 0:
1454            if axis + v.shape.rank < 0:
1455              raise ValueError(
1456                  "`axis` = %r out of range for `value` with rank %d" %
1457                  (axis, v.shape.rank))
1458            axis += v.shape.rank
1459          elif axis >= v.shape.rank:
1460            raise ValueError(
1461                "`axis` = %r out of range for `value` with rank %d" %
1462                (axis, v.shape.rank))
1463          # TF v2 returns `None` for unknown dimensions and an integer for
1464          # known dimension, whereas TF v1 returns tensor_shape.Dimension(None)
1465          # or tensor_shape.Dimension(integer). `dimension_value` hides this
1466          # difference, always returning `None` or an integer.
1467          dim = tensor_shape.dimension_value(v.shape[axis])
1468          if dim is not None:
1469            # By returning a python value in the static shape case, we can
1470            # maybe get a fast path for reducing the denominator.
1471            # TODO(b/151871486): Remove array_ops.identity after we fallback to
1472            # simple reduction if inputs are all on CPU.
1473            return array_ops.identity(
1474                constant_op.constant(dim, dtype=dtypes.int64))
1475        elif axis < 0:
1476          axis = axis + array_ops.rank(v)
1477        # TODO(b/151871486): Remove array_ops.identity after we fallback to
1478        # simple reduction if inputs are all on CPU.
1479        return array_ops.identity(
1480            array_ops.shape_v2(v, out_type=dtypes.int64)[axis])
1481      if isinstance(axis, six.integer_types):
1482        denom = dimension(axis)
1483      elif isinstance(axis, (tuple, list)):
1484        denom = math_ops.reduce_prod([dimension(a) for a in axes])
1485      else:
1486        raise TypeError(
1487            "Expected `axis` to be an integer, tuple or list not: %r" % axis)
1488      # TODO(josh11b): Should we cast denom to v.dtype here instead of after the
1489      # reduce is complete?
1490      return numer, denom
1491
1492    if eager_context.executing_eagerly():
1493      # As some strategies (e.g. TPUStrategy) doesn't support pure eager
1494      # execution, wrap the `mean_reduce_helper` with a `tf.function` so it can
1495      # be run from eager mode. Cache the tf.function by `axis` to avoid the
1496      # same function to be traced again.
1497      if axis not in self._mean_reduce_helper_fns:
1498
1499        def mean_reduce_fn(v):
1500          return self.run(mean_reduce_helper, args=(v,))
1501
1502        self._mean_reduce_helper_fns[axis] = def_function.function(
1503            mean_reduce_fn)
1504      numer, denom = self._mean_reduce_helper_fns[axis](value)
1505    else:
1506      numer, denom = self.run(mean_reduce_helper, args=(value,))
1507
1508    # TODO(josh11b): Should batch reduce here instead of doing two.
1509    numer = self._extended._reduce(reduce_util.ReduceOp.SUM, numer)  # pylint: disable=protected-access
1510    denom = self._extended._reduce(reduce_util.ReduceOp.SUM, denom)  # pylint: disable=protected-access
1511    denom = math_ops.cast(denom, numer.dtype)
1512    return math_ops.truediv(numer, denom)
1513
1514  @doc_controls.do_not_doc_inheritable  # DEPRECATED
1515  @deprecated(None, "use `experimental_local_results` instead.")
1516  def unwrap(self, value):
1517    """Returns the list of all local per-replica values contained in `value`.
1518
1519    DEPRECATED: Please use `experimental_local_results` instead.
1520
1521    Note: This only returns values on the workers initiated by this client.
1522    When using a `tf.distribute.Strategy` like
1523    `tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker
1524    will be its own client, and this function will only return values
1525    computed on that worker.
1526
1527    Args:
1528      value: A value returned by `experimental_run()`,
1529        `extended.call_for_each_replica()`, or a variable created in `scope`.
1530
1531    Returns:
1532      A tuple of values contained in `value`. If `value` represents a single
1533      value, this returns `(value,).`
1534    """
1535    return self._extended._local_results(value)  # pylint: disable=protected-access
1536
1537  def experimental_local_results(self, value):
1538    """Returns the list of all local per-replica values contained in `value`.
1539
1540    Note: This only returns values on the worker initiated by this client.
1541    When using a `tf.distribute.Strategy` like
1542    `tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker
1543    will be its own client, and this function will only return values
1544    computed on that worker.
1545
1546    Args:
1547      value: A value returned by `experimental_run()`, `run(), or a variable
1548      created in `scope`.
1549
1550    Returns:
1551      A tuple of values contained in `value` where ith element corresponds to
1552      ith replica. If `value` represents a single value, this returns
1553      `(value,).`
1554    """
1555    return self._extended._local_results(value)  # pylint: disable=protected-access
1556
1557  @doc_controls.do_not_doc_inheritable  # DEPRECATED: TF v1.x only
1558  def group(self, value, name=None):
1559    """Shortcut for `tf.group(self.experimental_local_results(value))`."""
1560    return self._extended._group(value, name)  # pylint: disable=protected-access
1561
1562  @property
1563  def num_replicas_in_sync(self):
1564    """Returns number of replicas over which gradients are aggregated."""
1565    return self._extended._num_replicas_in_sync  # pylint: disable=protected-access
1566
1567  @doc_controls.do_not_doc_inheritable  # DEPRECATED: see doc string
1568  @deprecated(None, "use `update_config_proto` instead.")
1569  def configure(self,
1570                session_config=None,
1571                cluster_spec=None,
1572                task_type=None,
1573                task_id=None):
1574    # pylint: disable=g-doc-return-or-yield,g-doc-args
1575    """DEPRECATED: use `update_config_proto` instead.
1576
1577    Configures the strategy class.
1578
1579    DEPRECATED: This method's functionality has been split into the strategy
1580    constructor and `update_config_proto`. In the future, we will allow passing
1581    cluster and config_proto to the constructor to configure the strategy. And
1582    `update_config_proto` can be used to update the config_proto based on the
1583    specific strategy.
1584    """
1585    return self._extended._configure(  # pylint: disable=protected-access
1586        session_config, cluster_spec, task_type, task_id)
1587
1588  @doc_controls.do_not_generate_docs  # DEPRECATED
1589  def update_config_proto(self, config_proto):
1590    """DEPRECATED TF 1.x ONLY."""
1591    return self._extended._update_config_proto(config_proto)  # pylint: disable=protected-access
1592
1593  def __deepcopy__(self, memo):
1594    # First do a regular deepcopy of `self`.
1595    cls = self.__class__
1596    result = cls.__new__(cls)
1597    memo[id(self)] = result
1598    for k, v in self.__dict__.items():
1599      setattr(result, k, copy.deepcopy(v, memo))
1600    # One little fix-up: we want `result._extended` to reference `result`
1601    # instead of `self`.
1602    result._extended._container_strategy_weakref = weakref.ref(result)  # pylint: disable=protected-access
1603    return result
1604
1605  def __copy__(self):
1606    raise RuntimeError("Must only deepcopy DistributionStrategy.")
1607
1608  @property
1609  def cluster_resolver(self):
1610    """Returns the cluster resolver associated with this strategy.
1611
1612    In general, when using a multi-worker `tf.distribute` strategy such as
1613    `tf.distribute.experimental.MultiWorkerMirroredStrategy` or
1614    `tf.distribute.TPUStrategy()`, there is a
1615    `tf.distribute.cluster_resolver.ClusterResolver` associated with the
1616    strategy used, and such an instance is returned by this property.
1617
1618    Strategies that intend to have an associated
1619    `tf.distribute.cluster_resolver.ClusterResolver` must set the
1620    relevant attribute, or override this property; otherwise, `None` is returned
1621    by default. Those strategies should also provide information regarding what
1622    is returned by this property.
1623
1624    Single-worker strategies usually do not have a
1625    `tf.distribute.cluster_resolver.ClusterResolver`, and in those cases this
1626    property will return `None`.
1627
1628    The `tf.distribute.cluster_resolver.ClusterResolver` may be useful when the
1629    user needs to access information such as the cluster spec, task type or task
1630    id. For example,
1631
1632    ```python
1633
1634    os.environ['TF_CONFIG'] = json.dumps({
1635      'cluster': {
1636          'worker': ["localhost:12345", "localhost:23456"],
1637          'ps': ["localhost:34567"]
1638      },
1639      'task': {'type': 'worker', 'index': 0}
1640    })
1641
1642    # This implicitly uses TF_CONFIG for the cluster and current task info.
1643    strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
1644
1645    ...
1646
1647    if strategy.cluster_resolver.task_type == 'worker':
1648      # Perform something that's only applicable on workers. Since we set this
1649      # as a worker above, this block will run on this particular instance.
1650    elif strategy.cluster_resolver.task_type == 'ps':
1651      # Perform something that's only applicable on parameter servers. Since we
1652      # set this as a worker above, this block will not run on this particular
1653      # instance.
1654    ```
1655
1656    For more information, please see
1657    `tf.distribute.cluster_resolver.ClusterResolver`'s API docstring.
1658
1659    Returns:
1660      The cluster resolver associated with this strategy. Returns `None` if a
1661      cluster resolver is not applicable or available in this strategy.
1662    """
1663    if hasattr(self.extended, "_cluster_resolver"):
1664      return self.extended._cluster_resolver  # pylint: disable=protected-access
1665    return None
1666
1667
1668@tf_export("distribute.Strategy", v1=[])  # pylint: disable=g-missing-docstring
1669class Strategy(StrategyBase):
1670
1671  __doc__ = StrategyBase.__doc__
1672
1673  def experimental_distribute_values_from_function(self, value_fn):
1674    """Generates `tf.distribute.DistributedValues` from `value_fn`.
1675
1676    This function is to generate `tf.distribute.DistributedValues` to pass
1677    into `run`, `reduce`, or other methods that take
1678    distributed values when not using datasets.
1679
1680    Args:
1681      value_fn: The function to run to generate values. It is called for
1682        each replica with `tf.distribute.ValueContext` as the sole argument. It
1683        must return a Tensor or a type that can be converted to a Tensor.
1684    Returns:
1685      A `tf.distribute.DistributedValues` containing a value for each replica.
1686
1687    Example usage:
1688
1689    1. Return constant value per replica:
1690
1691    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
1692    >>> def value_fn(ctx):
1693    ...   return tf.constant(1.)
1694    >>> distributed_values = (
1695    ...      strategy.experimental_distribute_values_from_function(
1696    ...        value_fn))
1697    >>> local_result = strategy.experimental_local_results(distributed_values)
1698    >>> local_result
1699    (<tf.Tensor: shape=(), dtype=float32, numpy=1.0>,
1700     <tf.Tensor: shape=(), dtype=float32, numpy=1.0>)
1701
1702    2. Distribute values in array based on replica_id:
1703
1704    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
1705    >>> array_value = np.array([3., 2., 1.])
1706    >>> def value_fn(ctx):
1707    ...   return array_value[ctx.replica_id_in_sync_group]
1708    >>> distributed_values = (
1709    ...      strategy.experimental_distribute_values_from_function(
1710    ...        value_fn))
1711    >>> local_result = strategy.experimental_local_results(distributed_values)
1712    >>> local_result
1713    (3.0, 2.0)
1714
1715    3. Specify values using num_replicas_in_sync:
1716
1717    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
1718    >>> def value_fn(ctx):
1719    ...   return ctx.num_replicas_in_sync
1720    >>> distributed_values = (
1721    ...      strategy.experimental_distribute_values_from_function(
1722    ...        value_fn))
1723    >>> local_result = strategy.experimental_local_results(distributed_values)
1724    >>> local_result
1725    (2, 2)
1726
1727    4. Place values on devices and distribute:
1728
1729    ```
1730    strategy = tf.distribute.TPUStrategy()
1731    worker_devices = strategy.extended.worker_devices
1732    multiple_values = []
1733    for i in range(strategy.num_replicas_in_sync):
1734      with tf.device(worker_devices[i]):
1735        multiple_values.append(tf.constant(1.0))
1736
1737    def value_fn(ctx):
1738      return multiple_values[ctx.replica_id_in_sync_group]
1739
1740    distributed_values = strategy.
1741      experimental_distribute_values_from_function(
1742      value_fn)
1743    ```
1744
1745    """
1746    return self._extended._experimental_distribute_values_from_function(  # pylint: disable=protected-access
1747        value_fn)
1748
1749  def gather(self, value, axis):
1750    # pylint: disable=line-too-long, protected-access
1751    """Gather `value` across replicas along `axis` to the current device.
1752
1753    Given a `tf.distribute.DistributedValues` or `tf.Tensor`-like
1754    object `value`, this API gathers and concatenates `value` across replicas
1755    along the `axis`-th dimension. The result is copied to the "current" device,
1756    which would typically be the CPU of the worker on which the program is
1757    running. For `tf.distribute.TPUStrategy`, it is the first TPU host. For
1758    multi-client `tf.distribute.MultiWorkerMirroredStrategy`, this is the CPU of
1759    each worker.
1760
1761    This API can only be called in the cross-replica context. For a counterpart
1762    in the replica context, see `tf.distribute.ReplicaContext.all_gather`.
1763
1764    Note: For all strategies except `tf.distribute.TPUStrategy`, the input
1765    `value` on different replicas must have the same rank, and their shapes must
1766    be the same in all dimensions except the `axis`-th dimension. In other
1767    words, their shapes cannot be different in a dimension `d` where `d` does
1768    not equal to the `axis` argument. For example, given a
1769    `tf.distribute.DistributedValues` with component tensors of shape
1770    `(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can call
1771    `gather(..., axis=1, ...)` on it, but not `gather(..., axis=0, ...)` or
1772    `gather(..., axis=2, ...)`. However, for `tf.distribute.TPUStrategy.gather`,
1773    all tensors must have exactly the same rank and same shape.
1774
1775    Note: Given a `tf.distribute.DistributedValues` `value`, its component
1776    tensors must have a non-zero rank. Otherwise, consider using
1777    `tf.expand_dims` before gathering them.
1778
1779    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
1780    >>> # A DistributedValues with component tensor of shape (2, 1) on each replica
1781    ... distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(tf.constant([[1], [2]])))
1782    >>> @tf.function
1783    ... def run():
1784    ...   return strategy.gather(distributed_values, axis=0)
1785    >>> run()
1786    <tf.Tensor: shape=(4, 1), dtype=int32, numpy=
1787    array([[1],
1788           [2],
1789           [1],
1790           [2]], dtype=int32)>
1791
1792
1793    Consider the following example for more combinations:
1794
1795    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1", "GPU:2", "GPU:3"])
1796    >>> single_tensor = tf.reshape(tf.range(6), shape=(1,2,3))
1797    >>> distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(single_tensor))
1798    >>> @tf.function
1799    ... def run(axis):
1800    ...   return strategy.gather(distributed_values, axis=axis)
1801    >>> axis=0
1802    >>> run(axis)
1803    <tf.Tensor: shape=(4, 2, 3), dtype=int32, numpy=
1804    array([[[0, 1, 2],
1805            [3, 4, 5]],
1806           [[0, 1, 2],
1807            [3, 4, 5]],
1808           [[0, 1, 2],
1809            [3, 4, 5]],
1810           [[0, 1, 2],
1811            [3, 4, 5]]], dtype=int32)>
1812    >>> axis=1
1813    >>> run(axis)
1814    <tf.Tensor: shape=(1, 8, 3), dtype=int32, numpy=
1815    array([[[0, 1, 2],
1816            [3, 4, 5],
1817            [0, 1, 2],
1818            [3, 4, 5],
1819            [0, 1, 2],
1820            [3, 4, 5],
1821            [0, 1, 2],
1822            [3, 4, 5]]], dtype=int32)>
1823    >>> axis=2
1824    >>> run(axis)
1825    <tf.Tensor: shape=(1, 2, 12), dtype=int32, numpy=
1826    array([[[0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2],
1827            [3, 4, 5, 3, 4, 5, 3, 4, 5, 3, 4, 5]]], dtype=int32)>
1828
1829
1830    Args:
1831      value: a `tf.distribute.DistributedValues` instance, e.g. returned by
1832        `Strategy.run`, to be combined into a single tensor. It can also be a
1833        regular tensor when used with `tf.distribute.OneDeviceStrategy` or the
1834        default strategy. The tensors that constitute the DistributedValues
1835        can only be dense tensors with non-zero rank, NOT a `tf.IndexedSlices`.
1836      axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the
1837        range [0, rank(value)).
1838
1839    Returns:
1840       A `Tensor` that's the concatenation of `value` across replicas along
1841       `axis` dimension.
1842    """
1843    # pylint: enable=line-too-long
1844    error_message = ("tf.distribute.Strategy.gather method requires "
1845                     "cross-replica context, use "
1846                     "get_replica_context().all_gather() instead")
1847    _require_cross_replica_or_default_context_extended(self._extended,
1848                                                       error_message)
1849    dst = device_util.current(
1850    ) or self._extended._default_device or "/device:CPU:0"
1851    if isinstance(value, ops.IndexedSlices):
1852      raise NotImplementedError("gather does not support IndexedSlices")
1853    return self._extended._local_results(
1854        self._extended._gather_to(value, dst, axis))[0]
1855
1856
1857# TF v1.x version has additional deprecated APIs
1858@tf_export(v1=["distribute.Strategy"])
1859class StrategyV1(StrategyBase):
1860  """A list of devices with a state & compute distribution policy.
1861
1862  See [the guide](https://www.tensorflow.org/guide/distribute_strategy)
1863  for overview and examples.
1864
1865  Note: Not all `tf.distribute.Strategy` implementations currently support
1866  TensorFlow's partitioned variables (where a single variable is split across
1867  multiple devices) at this time.
1868  """
1869
1870  def make_dataset_iterator(self, dataset):
1871    """Makes an iterator for input provided via `dataset`.
1872
1873    DEPRECATED: This method is not available in TF 2.x.
1874
1875    Data from the given dataset will be distributed evenly across all the
1876    compute replicas. We will assume that the input dataset is batched by the
1877    global batch size. With this assumption, we will make a best effort to
1878    divide each batch across all the replicas (one or more workers).
1879    If this effort fails, an error will be thrown, and the user should instead
1880    use `make_input_fn_iterator` which provides more control to the user, and
1881    does not try to divide a batch across replicas.
1882
1883    The user could also use `make_input_fn_iterator` if they want to
1884    customize which input is fed to which replica/worker etc.
1885
1886    Args:
1887      dataset: `tf.data.Dataset` that will be distributed evenly across all
1888        replicas.
1889
1890    Returns:
1891      An `tf.distribute.InputIterator` which returns inputs for each step of the
1892      computation.  User should call `initialize` on the returned iterator.
1893    """
1894    return self._extended._make_dataset_iterator(dataset)  # pylint: disable=protected-access
1895
1896  def make_input_fn_iterator(self,  # pylint: disable=useless-super-delegation
1897                             input_fn,
1898                             replication_mode=InputReplicationMode.PER_WORKER):
1899    """Returns an iterator split across replicas created from an input function.
1900
1901    DEPRECATED: This method is not available in TF 2.x.
1902
1903    The `input_fn` should take an `tf.distribute.InputContext` object where
1904    information about batching and input sharding can be accessed:
1905
1906    ```
1907    def input_fn(input_context):
1908      batch_size = input_context.get_per_replica_batch_size(global_batch_size)
1909      d = tf.data.Dataset.from_tensors([[1.]]).repeat().batch(batch_size)
1910      return d.shard(input_context.num_input_pipelines,
1911                     input_context.input_pipeline_id)
1912    with strategy.scope():
1913      iterator = strategy.make_input_fn_iterator(input_fn)
1914      replica_results = strategy.experimental_run(replica_fn, iterator)
1915    ```
1916
1917    The `tf.data.Dataset` returned by `input_fn` should have a per-replica
1918    batch size, which may be computed using
1919    `input_context.get_per_replica_batch_size`.
1920
1921    Args:
1922      input_fn: A function taking a `tf.distribute.InputContext` object and
1923        returning a `tf.data.Dataset`.
1924      replication_mode: an enum value of `tf.distribute.InputReplicationMode`.
1925        Only `PER_WORKER` is supported currently, which means there will be
1926        a single call to `input_fn` per worker. Replicas will dequeue from the
1927        local `tf.data.Dataset` on their worker.
1928
1929    Returns:
1930      An iterator object that should first be `.initialize()`-ed. It may then
1931      either be passed to `strategy.experimental_run()` or you can
1932      `iterator.get_next()` to get the next value to pass to
1933      `strategy.extended.call_for_each_replica()`.
1934    """
1935    return super(StrategyV1, self).make_input_fn_iterator(
1936        input_fn, replication_mode)
1937
1938  def experimental_make_numpy_dataset(self, numpy_input, session=None):
1939    """Makes a tf.data.Dataset for input provided via a numpy array.
1940
1941    This avoids adding `numpy_input` as a large constant in the graph,
1942    and copies the data to the machine or machines that will be processing
1943    the input.
1944
1945    Note that you will likely need to use
1946    tf.distribute.Strategy.experimental_distribute_dataset
1947    with the returned dataset to further distribute it with the strategy.
1948
1949    Example:
1950    ```
1951    numpy_input = np.ones([10], dtype=np.float32)
1952    dataset = strategy.experimental_make_numpy_dataset(numpy_input)
1953    dist_dataset = strategy.experimental_distribute_dataset(dataset)
1954    ```
1955
1956    Args:
1957      numpy_input: A nest of NumPy input arrays that will be converted into a
1958      dataset. Note that lists of Numpy arrays are stacked, as that is normal
1959      `tf.data.Dataset` behavior.
1960      session: (TensorFlow v1.x graph execution only) A session used for
1961        initialization.
1962
1963    Returns:
1964      A `tf.data.Dataset` representing `numpy_input`.
1965    """
1966    return self.extended.experimental_make_numpy_dataset(
1967        numpy_input, session=session)
1968
1969  @deprecated(
1970      None,
1971      "This method is not available in TF 2.x. Please switch to using `run` instead."
1972  )
1973  def experimental_run(self, fn, input_iterator=None):  # pylint: disable=useless-super-delegation
1974    """Runs ops in `fn` on each replica, with inputs from `input_iterator`.
1975
1976    DEPRECATED: This method is not available in TF 2.x. Please switch
1977    to using `run` instead.
1978
1979    When eager execution is enabled, executes ops specified by `fn` on each
1980    replica. Otherwise, builds a graph to execute the ops on each replica.
1981
1982    Each replica will take a single, different input from the inputs provided by
1983    one `get_next` call on the input iterator.
1984
1985    `fn` may call `tf.distribute.get_replica_context()` to access members such
1986    as `replica_id_in_sync_group`.
1987
1988    IMPORTANT: Depending on the `tf.distribute.Strategy` implementation being
1989    used, and whether eager execution is enabled, `fn` may be called one or more
1990    times (once for each replica).
1991
1992    Args:
1993      fn: The function to run. The inputs to the function must match the outputs
1994        of `input_iterator.get_next()`. The output must be a `tf.nest` of
1995        `Tensor`s.
1996      input_iterator: (Optional) input iterator from which the inputs are taken.
1997
1998    Returns:
1999      Merged return value of `fn` across replicas. The structure of the return
2000      value is the same as the return value from `fn`. Each element in the
2001      structure can either be `PerReplica` (if the values are unsynchronized),
2002      `Mirrored` (if the values are kept in sync), or `Tensor` (if running on a
2003      single replica).
2004    """
2005    return super(StrategyV1, self).experimental_run(
2006        fn, input_iterator)
2007
2008  def reduce(self, reduce_op, value, axis=None):
2009    return super(StrategyV1, self).reduce(reduce_op, value, axis)
2010
2011  reduce.__doc__ = StrategyBase.reduce.__doc__
2012
2013  def update_config_proto(self, config_proto):
2014    """Returns a copy of `config_proto` modified for use with this strategy.
2015
2016    DEPRECATED: This method is not available in TF 2.x.
2017
2018    The updated config has something needed to run a strategy, e.g.
2019    configuration to run collective ops, or device filters to improve
2020    distributed training performance.
2021
2022    Args:
2023      config_proto: a `tf.ConfigProto` object.
2024
2025    Returns:
2026      The updated copy of the `config_proto`.
2027    """
2028    return self._extended._update_config_proto(config_proto)  # pylint: disable=protected-access
2029
2030
2031# NOTE(josh11b): For any strategy that needs to support tf.compat.v1,
2032# instead descend from StrategyExtendedV1.
2033@tf_export("distribute.StrategyExtended", v1=[])
2034class StrategyExtendedV2(object):
2035  """Additional APIs for algorithms that need to be distribution-aware.
2036
2037  Note: For most usage of `tf.distribute.Strategy`, there should be no need to
2038  call these methods, since TensorFlow libraries (such as optimizers) already
2039  call these methods when needed on your behalf.
2040
2041
2042  Some common use cases of functions on this page:
2043
2044  * _Locality_
2045
2046  `tf.distribute.DistributedValues` can have the same _locality_ as a
2047  _distributed variable_, which leads to a mirrored value residing on the same
2048  devices as the variable (as opposed to the compute devices). Such values may
2049  be passed to a call to `tf.distribute.StrategyExtended.update` to update the
2050  value of a variable. You may use
2051  `tf.distribute.StrategyExtended.colocate_vars_with` to give a variable the
2052  same locality as another variable. You may convert a "PerReplica" value to a
2053  variable's locality by using `tf.distribute.StrategyExtended.reduce_to` or
2054  `tf.distribute.StrategyExtended.batch_reduce_to`.
2055
2056  * _How to update a distributed variable_
2057
2058  A distributed variable is variables created on multiple devices. As discussed
2059  in the [glossary](https://www.tensorflow.org/api_docs/python/tf/distribute),
2060  mirrored variable and SyncOnRead variable are two examples. The standard
2061  pattern for updating distributed variables is to:
2062
2063  1. In your function passed to `tf.distribute.Strategy.run`,
2064     compute a list of (update, variable) pairs. For example, the update might
2065     be a gradient of the loss with respect to the variable.
2066  2. Switch to cross-replica mode by calling
2067     `tf.distribute.get_replica_context().merge_call()` with the updates and
2068     variables as arguments.
2069  3. Call
2070     `tf.distribute.StrategyExtended.reduce_to(VariableAggregation.SUM, t, v)`
2071     (for one variable) or `tf.distribute.StrategyExtended.batch_reduce_to`
2072     (for a list of variables) to sum the updates.
2073  4. Call `tf.distribute.StrategyExtended.update(v)` for each variable to update
2074     its value.
2075
2076  Steps 2 through 4 are done automatically by class
2077  `tf.keras.optimizers.Optimizer` if you call its
2078  `tf.keras.optimizers.Optimizer.apply_gradients` method in a replica context.
2079
2080  In fact, a higher-level solution to update a distributed variable is by
2081  calling `assign` on the variable as you would do to a regular `tf.Variable`.
2082  You can call the method in both _replica context_ and _cross-replica context_.
2083  For a _mirrored variable_, calling `assign` in _replica context_ requires you
2084  to specify the `aggregation` type in the variable constructor. In that case,
2085  the context switching and sync described in steps 2 through 4 are handled for
2086  you. If you call `assign` on _mirrored variable_ in _cross-replica context_,
2087  you can only assign a single value or assign values from another mirrored
2088  variable or a mirrored `tf.distribute.DistributedValues`. For a _SyncOnRead
2089  variable_, in _replica context_, you can simply call `assign` on it and no
2090  aggregation happens under the hood. In _cross-replica context_, you can only
2091  assign a single value to a SyncOnRead variable. One example case is restoring
2092  from a checkpoint: if the `aggregation` type of the variable is
2093  `tf.VariableAggregation.SUM`, it is assumed that replica values were added
2094  before checkpointing, so at the time of restoring, the value is divided by
2095  the number of replicas and then assigned to each replica; if the `aggregation`
2096  type is `tf.VariableAggregation.MEAN`, the value is assigned to each replica
2097  directly.
2098
2099  """
2100
2101  def __init__(self, container_strategy):
2102    self._container_strategy_weakref = weakref.ref(container_strategy)
2103    self._default_device = None
2104    # This property is used to determine if we should set drop_remainder=True
2105    # when creating Datasets from numpy array inputs.
2106    self._require_static_shapes = False
2107
2108  def _resource_creator_scope(self):
2109    """Returns one or more ops.resource_creator_scope for some Strategy."""
2110    return None
2111
2112  def _container_strategy(self):
2113    """Get the containing `tf.distribute.Strategy`.
2114
2115    This should not generally be needed except when creating a new
2116    `ReplicaContext` and to validate that the caller is in the correct
2117    `scope()`.
2118
2119    Returns:
2120      The `tf.distribute.Strategy` such that `strategy.extended` is `self`.
2121    """
2122    container_strategy = self._container_strategy_weakref()
2123    assert container_strategy is not None
2124    return container_strategy
2125
2126  def _scope(self, strategy):
2127    """Implementation of tf.distribute.Strategy.scope()."""
2128
2129    def creator_with_resource_vars(next_creator, **kwargs):
2130      """Variable creator to use in `_CurrentDistributionContext`."""
2131      _require_strategy_scope_extended(self)
2132      kwargs["use_resource"] = True
2133      kwargs["distribute_strategy"] = strategy
2134
2135      # Unwrap `initial_value` if it is a `CheckpointInitialValue` to avoid
2136      # dereferencing a `Tensor` that is without a `name`. We still need to
2137      # propagate the metadata it's holding.
2138      if isinstance(kwargs["initial_value"], trackable.CheckpointInitialValue):
2139        checkpoint_restore_uid = kwargs[
2140            "initial_value"].checkpoint_position.restore_uid
2141        kwargs["initial_value"] = kwargs["initial_value"].wrapped_value
2142      elif isinstance(kwargs["initial_value"],
2143                      trackable.CheckpointInitialValueCallable):
2144        checkpoint_restore_uid = kwargs[
2145            "initial_value"].checkpoint_position.restore_uid
2146      elif (isinstance(kwargs["initial_value"], functools.partial) and
2147            isinstance(kwargs["initial_value"].func,
2148                       trackable.CheckpointInitialValueCallable)):
2149        # Some libraries (e.g, Keras) create partial function out of initializer
2150        # to bind shape/dtype, for example:
2151        #  initial_val = functools.partial(initializer, shape, dtype=dtype)
2152        # Therefore to get the restore_uid we need to examine the "func" of
2153        # the partial function.
2154        checkpoint_restore_uid = kwargs[
2155            "initial_value"].func.checkpoint_position.restore_uid
2156      else:
2157        checkpoint_restore_uid = None
2158
2159      created = self._create_variable(next_creator, **kwargs)
2160
2161      if checkpoint_restore_uid is not None:
2162        # pylint: disable=protected-access
2163        # Let the checkpointing infrastructure know that the variable was
2164        # already restored so it doesn't waste memory loading the value again.
2165        # In this case of CheckpointInitialValueCallable this may already be
2166        # done by the final variable creator, but it doesn't hurt to do it
2167        # again.
2168        created._maybe_initialize_trackable()
2169        created._update_uid = checkpoint_restore_uid
2170        # pylint: enable=protected-access
2171      return created
2172
2173    def distributed_getter(getter, *args, **kwargs):
2174      if not self._allow_variable_partition():
2175        if kwargs.pop("partitioner", None) is not None:
2176          tf_logging.log_first_n(
2177              tf_logging.WARN, "Partitioned variables are disabled when using "
2178              "current tf.distribute.Strategy.", 1)
2179      return getter(*args, **kwargs)
2180
2181    return _CurrentDistributionContext(
2182        strategy,
2183        variable_scope.variable_creator_scope(creator_with_resource_vars),
2184        variable_scope.variable_scope(
2185            variable_scope.get_variable_scope(),
2186            custom_getter=distributed_getter),
2187        strategy.extended._resource_creator_scope(),  # pylint: disable=protected-access
2188        self._default_device)
2189
2190  def _allow_variable_partition(self):
2191    return False
2192
2193  def _create_variable(self, next_creator, **kwargs):
2194    # Note: should support "colocate_with" argument.
2195    raise NotImplementedError("must be implemented in descendants")
2196
2197  def variable_created_in_scope(self, v):
2198    """Tests whether `v` was created while this strategy scope was active.
2199
2200    Variables created inside the strategy scope are "owned" by it:
2201
2202    >>> strategy = tf.distribute.MirroredStrategy()
2203    >>> with strategy.scope():
2204    ...   v = tf.Variable(1.)
2205    >>> strategy.extended.variable_created_in_scope(v)
2206    True
2207
2208    Variables created outside the strategy are not owned by it:
2209
2210    >>> strategy = tf.distribute.MirroredStrategy()
2211    >>> v = tf.Variable(1.)
2212    >>> strategy.extended.variable_created_in_scope(v)
2213    False
2214
2215    Args:
2216      v: A `tf.Variable` instance.
2217
2218    Returns:
2219      True if `v` was created inside the scope, False if not.
2220    """
2221    return v._distribute_strategy == self._container_strategy_weakref()  # pylint: disable=protected-access
2222
2223  def colocate_vars_with(self, colocate_with_variable):
2224    """Scope that controls which devices variables will be created on.
2225
2226    No operations should be added to the graph inside this scope, it
2227    should only be used when creating variables (some implementations
2228    work by changing variable creation, others work by using a
2229    tf.compat.v1.colocate_with() scope).
2230
2231    This may only be used inside `self.scope()`.
2232
2233    Example usage:
2234
2235    ```
2236    with strategy.scope():
2237      var1 = tf.Variable(...)
2238      with strategy.extended.colocate_vars_with(var1):
2239        # var2 and var3 will be created on the same device(s) as var1
2240        var2 = tf.Variable(...)
2241        var3 = tf.Variable(...)
2242
2243      def fn(v1, v2, v3):
2244        # operates on v1 from var1, v2 from var2, and v3 from var3
2245
2246      # `fn` runs on every device `var1` is on, `var2` and `var3` will be there
2247      # too.
2248      strategy.extended.update(var1, fn, args=(var2, var3))
2249    ```
2250
2251    Args:
2252      colocate_with_variable: A variable created in this strategy's `scope()`.
2253        Variables created while in the returned context manager will be on the
2254        same set of devices as `colocate_with_variable`.
2255
2256    Returns:
2257      A context manager.
2258    """
2259
2260    def create_colocated_variable(next_creator, **kwargs):
2261      _require_strategy_scope_extended(self)
2262      kwargs["use_resource"] = True
2263      kwargs["colocate_with"] = colocate_with_variable
2264      return next_creator(**kwargs)
2265
2266    _require_strategy_scope_extended(self)
2267    self._validate_colocate_with_variable(colocate_with_variable)
2268    return variable_scope.variable_creator_scope(create_colocated_variable)
2269
2270  def _validate_colocate_with_variable(self, colocate_with_variable):
2271    """Validate `colocate_with_variable` argument to `colocate_vars_with`."""
2272    pass
2273
2274  def _make_dataset_iterator(self, dataset):
2275    raise NotImplementedError("must be implemented in descendants")
2276
2277  def _make_input_fn_iterator(self, input_fn, replication_mode):
2278    raise NotImplementedError("must be implemented in descendants")
2279
2280  def _experimental_distribute_dataset(self, dataset, options):
2281    raise NotImplementedError("must be implemented in descendants")
2282
2283  def _distribute_datasets_from_function(self, dataset_fn, options):
2284    raise NotImplementedError("must be implemented in descendants")
2285
2286  def _experimental_distribute_values_from_function(self, value_fn):
2287    raise NotImplementedError("must be implemented in descendants")
2288
2289  def _reduce(self, reduce_op, value):
2290    # Default implementation until we have an implementation for each strategy.
2291    dst = device_util.current() or self._default_device or "/device:CPU:0"
2292    return self._local_results(self.reduce_to(reduce_op, value, dst))[0]
2293
2294  def reduce_to(self, reduce_op, value, destinations, options=None):
2295    """Combine (via e.g. sum or mean) values across replicas.
2296
2297    `reduce_to` aggregates `tf.distribute.DistributedValues` and distributed
2298    variables. It supports both dense values and `tf.IndexedSlices`.
2299
2300    This API currently can only be called in cross-replica context. Other
2301    variants to reduce values across replicas are:
2302    * `tf.distribute.StrategyExtended.batch_reduce_to`: the batch version of
2303      this API.
2304    * `tf.distribute.ReplicaContext.all_reduce`: the counterpart of this API
2305      in replica context. It supports both batched and non-batched all-reduce.
2306    * `tf.distribute.Strategy.reduce`: a more convenient method to reduce
2307      to the host in cross-replica context.
2308
2309    `destinations` specifies where to reduce the value to, e.g. "GPU:0". You can
2310    also pass in a `Tensor`, and the destinations will be the device of that
2311    tensor. For all-reduce, pass the same to `value` and `destinations`.
2312
2313    It can be used in `tf.distribute.ReplicaContext.merge_call` to write code
2314    that works for all `tf.distribute.Strategy`.
2315
2316    >>> @tf.function
2317    ... def step_fn(var):
2318    ...
2319    ...   def merge_fn(strategy, value, var):
2320    ...     # All-reduce the value. Note that `value` here is a
2321    ...     # `tf.distribute.DistributedValues`.
2322    ...     reduced = strategy.extended.reduce_to(tf.distribute.ReduceOp.SUM,
2323    ...         value, destinations=var)
2324    ...     strategy.extended.update(var, lambda var, value: var.assign(value),
2325    ...         args=(reduced,))
2326    ...
2327    ...   value = tf.identity(1.)
2328    ...   tf.distribute.get_replica_context().merge_call(merge_fn,
2329    ...     args=(value, var))
2330    >>>
2331    >>> def run(strategy):
2332    ...   with strategy.scope():
2333    ...     v = tf.Variable(0.)
2334    ...     strategy.run(step_fn, args=(v,))
2335    ...     return v
2336    >>>
2337    >>> run(tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]))
2338    MirroredVariable:{
2339      0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>,
2340      1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=2.0>
2341    }
2342    >>> run(tf.distribute.experimental.CentralStorageStrategy(
2343    ...     compute_devices=["GPU:0", "GPU:1"], parameter_device="CPU:0"))
2344    <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>
2345    >>> run(tf.distribute.OneDeviceStrategy("GPU:0"))
2346    <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>
2347
2348    Args:
2349      reduce_op: a `tf.distribute.ReduceOp` value specifying how values should
2350        be combined. Allows using string representation of the enum such as
2351        "SUM", "MEAN".
2352      value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` like object.
2353      destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
2354        `tf.Tensor` alike object, or a device string. It specifies the devices
2355        to reduce to. To perform an all-reduce, pass the same to `value` and
2356        `destinations`. Note that if it's a `tf.Variable`, the value is reduced
2357        to the devices of that variable, and this method doesn't update the
2358        variable.
2359      options: a `tf.distribute.experimental.CommunicationOptions`. Options to
2360        perform collective operations. This overrides the default options if the
2361        `tf.distribute.Strategy` takes one in the constructor. See
2362        `tf.distribute.experimental.CommunicationOptions` for details of the
2363        options.
2364
2365    Returns:
2366      A tensor or value reduced to `destinations`.
2367    """
2368    if options is None:
2369      options = collective_util.Options()
2370    _require_cross_replica_or_default_context_extended(self)
2371    assert not isinstance(destinations, (list, tuple))
2372    assert not isinstance(reduce_op, variable_scope.VariableAggregation)
2373    if isinstance(reduce_op, six.string_types):
2374      reduce_op = reduce_util.ReduceOp(reduce_op.upper())
2375    assert (reduce_op == reduce_util.ReduceOp.SUM or
2376            reduce_op == reduce_util.ReduceOp.MEAN)
2377    return self._reduce_to(reduce_op, value, destinations, options)
2378
2379  def _reduce_to(self, reduce_op, value, destinations, options):
2380    raise NotImplementedError("must be implemented in descendants")
2381
2382  def batch_reduce_to(self, reduce_op, value_destination_pairs, options=None):
2383    """Combine multiple `reduce_to` calls into one for faster execution.
2384
2385    Similar to `reduce_to`, but accepts a list of (value, destinations) pairs.
2386    It's more efficient than reduce each value separately.
2387
2388    This API currently can only be called in cross-replica context. Other
2389    variants to reduce values across replicas are:
2390    * `tf.distribute.StrategyExtended.reduce_to`: the non-batch version of
2391      this API.
2392    * `tf.distribute.ReplicaContext.all_reduce`: the counterpart of this API
2393      in replica context. It supports both batched and non-batched all-reduce.
2394    * `tf.distribute.Strategy.reduce`: a more convenient method to reduce
2395      to the host in cross-replica context.
2396
2397    See `reduce_to` for more information.
2398
2399    >>> @tf.function
2400    ... def step_fn(var):
2401    ...
2402    ...   def merge_fn(strategy, value, var):
2403    ...     # All-reduce the value. Note that `value` here is a
2404    ...     # `tf.distribute.DistributedValues`.
2405    ...     reduced = strategy.extended.batch_reduce_to(
2406    ...         tf.distribute.ReduceOp.SUM, [(value, var)])[0]
2407    ...     strategy.extended.update(var, lambda var, value: var.assign(value),
2408    ...         args=(reduced,))
2409    ...
2410    ...   value = tf.identity(1.)
2411    ...   tf.distribute.get_replica_context().merge_call(merge_fn,
2412    ...     args=(value, var))
2413    >>>
2414    >>> def run(strategy):
2415    ...   with strategy.scope():
2416    ...     v = tf.Variable(0.)
2417    ...     strategy.run(step_fn, args=(v,))
2418    ...     return v
2419    >>>
2420    >>> run(tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]))
2421    MirroredVariable:{
2422      0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>,
2423      1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=2.0>
2424    }
2425    >>> run(tf.distribute.experimental.CentralStorageStrategy(
2426    ...     compute_devices=["GPU:0", "GPU:1"], parameter_device="CPU:0"))
2427    <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>
2428    >>> run(tf.distribute.OneDeviceStrategy("GPU:0"))
2429    <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>
2430
2431    Args:
2432      reduce_op: a `tf.distribute.ReduceOp` value specifying how values should
2433        be combined. Allows using string representation of the enum such as
2434        "SUM", "MEAN".
2435      value_destination_pairs: a sequence of (value, destinations) pairs. See
2436        `tf.distribute.Strategy.reduce_to` for descriptions.
2437      options: a `tf.distribute.experimental.CommunicationOptions`. Options to
2438        perform collective operations. This overrides the default options if the
2439        `tf.distribute.Strategy` takes one in the constructor. See
2440        `tf.distribute.experimental.CommunicationOptions` for details of the
2441        options.
2442
2443    Returns:
2444      A list of reduced values, one per pair in `value_destination_pairs`.
2445    """
2446    if options is None:
2447      options = collective_util.Options()
2448    _require_cross_replica_or_default_context_extended(self)
2449    assert not isinstance(reduce_op, variable_scope.VariableAggregation)
2450    if isinstance(reduce_op, six.string_types):
2451      reduce_op = reduce_util.ReduceOp(reduce_op.upper())
2452    return self._batch_reduce_to(reduce_op, value_destination_pairs, options)
2453
2454  def _batch_reduce_to(self, reduce_op, value_destination_pairs, options):
2455    return [
2456        self.reduce_to(reduce_op, t, destinations=v, options=options)
2457        for t, v in value_destination_pairs
2458    ]
2459
2460  def _replica_ctx_all_reduce(self, reduce_op, value, options=None):
2461    """All-reduce `value` across all replicas so that all get the final result.
2462
2463    If `value` is a nested structure of tensors, all-reduces of these tensors
2464    will be batched when possible. `options` can be set to hint the batching
2465    behavior.
2466
2467    This API must be called in a replica context.
2468
2469    Args:
2470      reduce_op: A `tf.distribute.ReduceOp` value specifying how values should
2471        be combined.
2472      value: Value to be reduced. A tensor or a nested structure of tensors.
2473      options: A `tf.distribute.experimental.CommunicationOptions`. Options to
2474        perform collective operations. This overrides the default options if the
2475        `tf.distribute.Strategy` takes one in the constructor.
2476
2477    Returns:
2478      A tensor or a nested strucutre of tensors with the reduced values. The
2479      structure is the same as `value`.
2480    """
2481    if options is None:
2482      options = collective_util.Options()
2483    replica_context = distribution_strategy_context.get_replica_context()
2484    assert replica_context, (
2485        "`StrategyExtended._replica_ctx_all_reduce` must be called in"
2486        " a replica context")
2487
2488    def merge_fn(_, flat_value):
2489      return self.batch_reduce_to(reduce_op, [(v, v) for v in flat_value],
2490                                  options)
2491
2492    reduced = replica_context.merge_call(merge_fn, args=(nest.flatten(value),))
2493    return nest.pack_sequence_as(value, reduced)
2494
2495  def _replica_ctx_update(self, var, fn, args=(), kwargs=None, group=True):
2496    """Run `fn` with `args` and `kwargs` to update `var`."""
2497    # This method is called by ReplicaContext.update. Strategies who'd like to
2498    # remove merge_call in this path should override this method.
2499    replica_context = distribution_strategy_context.get_replica_context()
2500    if not replica_context:
2501      raise ValueError("`StrategyExtended._replica_ctx_update` must be called "
2502                       "in a replica context.")
2503
2504    def merge_fn(_, *merged_args, **merged_kwargs):
2505      return self.update(var, fn, merged_args, merged_kwargs, group=group)
2506
2507    return replica_context.merge_call(merge_fn, args=args, kwargs=kwargs)
2508
2509  def _gather_to(self, value, destinations, axis, options=None):
2510    """Gather `value` across replicas along axis-th dimension to `destinations`.
2511
2512    `gather_to` gathers `tf.distribute.DistributedValues` or `tf.Tensor`-like
2513    object, along `axis`-th dimension. It supports only dense tensors but NOT
2514    sparse tensor. This API can only be called in cross-replica context.
2515
2516    Args:
2517      value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` like object.
2518      destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
2519        `tf.Tensor` alike object, or a device string. It specifies the devices
2520        to reduce to. To perform an all-gather, pass the same to `value` and
2521        `destinations`. Note that if it's a `tf.Variable`, the value is reduced
2522        to the devices of that variable, and this method doesn't update the
2523        variable.
2524      axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the
2525        range [0, rank(value)).
2526      options: a `tf.distribute.experimental.CommunicationOptions`. Options to
2527        perform collective operations. This overrides the default options if the
2528        `tf.distribute.Strategy` takes one in the constructor. See
2529        `tf.distribute.experimental.CommunicationOptions` for details of the
2530        options.
2531
2532    Returns:
2533      A tensor or value gathered to `destinations`.
2534    """
2535    _require_cross_replica_or_default_context_extended(self)
2536    assert not isinstance(destinations, (list, tuple))
2537    if options is None:
2538      options = collective_util.Options()
2539    return self._gather_to_implementation(value, destinations, axis, options)
2540
2541  def _gather_to_implementation(self, value, destinations, axis, options):
2542    raise NotImplementedError("_gather_to must be implemented in descendants")
2543
2544  def _batch_gather_to(self, value_destination_pairs, axis, options=None):
2545    _require_cross_replica_or_default_context_extended(self)
2546    if options is None:
2547      options = collective_util.Options()
2548    return [
2549        self._gather_to(t, destinations=v, axis=axis, options=options)
2550        for t, v in value_destination_pairs
2551    ]
2552
2553  def update(self, var, fn, args=(), kwargs=None, group=True):
2554    """Run `fn` to update `var` using inputs mirrored to the same devices.
2555
2556    `tf.distribute.StrategyExtended.update` takes a distributed variable `var`
2557    to be updated, an update function `fn`, and `args` and `kwargs` for `fn`. It
2558    applies `fn` to each component variable of `var` and passes corresponding
2559    values from `args` and `kwargs`. Neither `args` nor `kwargs` may contain
2560    per-replica values. If they contain mirrored values, they will be unwrapped
2561    before calling `fn`. For example, `fn` can be `assign_add` and `args` can be
2562    a mirrored DistributedValues where each component contains the value to be
2563    added to this mirrored variable `var`. Calling `update` will call
2564    `assign_add` on each component variable of `var` with the corresponding
2565    tensor value on that device.
2566
2567    Example usage:
2568
2569    ```python
2570    strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) # With 2
2571    devices
2572    with strategy.scope():
2573      v = tf.Variable(5.0, aggregation=tf.VariableAggregation.SUM)
2574    def update_fn(v):
2575      return v.assign(1.0)
2576    result = strategy.extended.update(v, update_fn)
2577    # result is
2578    # Mirrored:{
2579    #  0: tf.Tensor(1.0, shape=(), dtype=float32),
2580    #  1: tf.Tensor(1.0, shape=(), dtype=float32)
2581    # }
2582    ```
2583
2584    If `var` is mirrored across multiple devices, then this method implements
2585    logic as following:
2586
2587    ```python
2588    results = {}
2589    for device, v in var:
2590      with tf.device(device):
2591        # args and kwargs will be unwrapped if they are mirrored.
2592        results[device] = fn(v, *args, **kwargs)
2593    return merged(results)
2594    ```
2595
2596    Otherwise, this method returns `fn(var, *args, **kwargs)` colocated with
2597    `var`.
2598
2599    Args:
2600      var: Variable, possibly mirrored to multiple devices, to operate on.
2601      fn: Function to call. Should take the variable as the first argument.
2602      args: Tuple or list. Additional positional arguments to pass to `fn()`.
2603      kwargs: Dict with keyword arguments to pass to `fn()`.
2604      group: Boolean. Defaults to True. If False, the return value will be
2605        unwrapped.
2606
2607    Returns:
2608      By default, the merged return value of `fn` across all replicas.  The
2609      merged result has dependencies to make sure that if it is evaluated at
2610      all, the side effects (updates) will happen on every replica. If instead
2611      "group=False" is specified, this function will return a nest of lists
2612      where each list has an element per replica, and the caller is responsible
2613      for ensuring all elements are executed.
2614    """
2615    # TODO(b/178944108): Update the documentation to relfect the fact that
2616    # `update` can be called in a replica context.
2617    if kwargs is None:
2618      kwargs = {}
2619    replica_context = distribution_strategy_context.get_replica_context()
2620    # pylint: disable=protected-access
2621    if (replica_context is None or replica_context is
2622        distribution_strategy_context._get_default_replica_context()):
2623      fn = autograph.tf_convert(
2624          fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
2625      with self._container_strategy().scope():
2626        return self._update(var, fn, args, kwargs, group)
2627    else:
2628      return self._replica_ctx_update(
2629          var, fn, args=args, kwargs=kwargs, group=group)
2630
2631  def _update(self, var, fn, args, kwargs, group):
2632    raise NotImplementedError("must be implemented in descendants")
2633
2634  def _local_results(self, val):
2635    """Returns local results per replica as a tuple."""
2636    if isinstance(val, values.DistributedValues):
2637      return val._values  # pylint: disable=protected-access
2638
2639    if nest.is_nested(val):
2640      replica_values = []
2641
2642      def get_values(x, index):
2643        if isinstance(x, values.DistributedValues):
2644          return x._values[index]  # pylint: disable=protected-access
2645        return x
2646
2647      for i in range(len(self.worker_devices)):
2648        replica_values.append(
2649            nest.map_structure(
2650                lambda x: get_values(x, i),  # pylint: disable=cell-var-from-loop
2651                val))
2652      return tuple(replica_values)
2653    return (val,)
2654
2655  def value_container(self, value):
2656    """Returns the container that this per-replica `value` belongs to.
2657
2658    Args:
2659      value: A value returned by `run()` or a variable created in `scope()`.
2660
2661    Returns:
2662      A container that `value` belongs to.
2663      If value does not belong to any container (including the case of
2664      container having been destroyed), returns the value itself.
2665      `value in experimental_local_results(value_container(value))` will
2666      always be true.
2667    """
2668    raise NotImplementedError("must be implemented in descendants")
2669
2670  def _group(self, value, name=None):
2671    """Implementation of `group`."""
2672    value = nest.flatten(self._local_results(value))
2673
2674    if len(value) != 1 or name is not None:
2675      return control_flow_ops.group(value, name=name)
2676    # Special handling for the common case of one op.
2677    v, = value
2678    if hasattr(v, "op"):
2679      v = v.op
2680    return v
2681
2682  @property
2683  def experimental_require_static_shapes(self):
2684    """Returns `True` if static shape is required; `False` otherwise."""
2685    return self._require_static_shapes
2686
2687  @property
2688  def _num_replicas_in_sync(self):
2689    """Returns number of replicas over which gradients are aggregated."""
2690    raise NotImplementedError("must be implemented in descendants")
2691
2692  @property
2693  def worker_devices(self):
2694    """Returns the tuple of all devices used to for compute replica execution.
2695    """
2696    # TODO(josh11b): More docstring
2697    raise NotImplementedError("must be implemented in descendants")
2698
2699  @property
2700  def parameter_devices(self):
2701    """Returns the tuple of all devices used to place variables."""
2702    # TODO(josh11b): More docstring
2703    raise NotImplementedError("must be implemented in descendants")
2704
2705  def _configure(self,
2706                 session_config=None,
2707                 cluster_spec=None,
2708                 task_type=None,
2709                 task_id=None):
2710    """Configures the strategy class."""
2711    del session_config, cluster_spec, task_type, task_id
2712
2713  def _update_config_proto(self, config_proto):
2714    return copy.deepcopy(config_proto)
2715
2716  def _in_multi_worker_mode(self):
2717    """Whether this strategy indicates working in multi-worker settings.
2718
2719    Multi-worker training refers to the setup where the training is
2720    distributed across multiple workers, as opposed to the case where
2721    only a local process performs the training. This function is
2722    used by higher-level APIs such as Keras' `model.fit()` to infer
2723    for example whether or not a distribute coordinator should be run,
2724    and thus TensorFlow servers should be started for communication
2725    with other servers in the cluster, or whether or not saving/restoring
2726    checkpoints is relevant for preemption fault tolerance.
2727
2728    Subclasses should override this to provide whether the strategy is
2729    currently in multi-worker setup.
2730
2731    Experimental. Signature and implementation are subject to change.
2732    """
2733    raise NotImplementedError("must be implemented in descendants")
2734
2735
2736@tf_export(v1=["distribute.StrategyExtended"])  # pylint: disable=missing-docstring
2737class StrategyExtendedV1(StrategyExtendedV2):
2738
2739  __doc__ = StrategyExtendedV2.__doc__
2740
2741  def experimental_make_numpy_dataset(self, numpy_input, session=None):
2742    """Makes a dataset for input provided via a numpy array.
2743
2744    This avoids adding `numpy_input` as a large constant in the graph,
2745    and copies the data to the machine or machines that will be processing
2746    the input.
2747
2748    Args:
2749      numpy_input: A nest of NumPy input arrays that will be distributed evenly
2750        across all replicas. Note that lists of Numpy arrays are stacked, as
2751        that is normal `tf.data.Dataset` behavior.
2752      session: (TensorFlow v1.x graph execution only) A session used for
2753        initialization.
2754
2755    Returns:
2756      A `tf.data.Dataset` representing `numpy_input`.
2757    """
2758    _require_cross_replica_or_default_context_extended(self)
2759    return self._experimental_make_numpy_dataset(numpy_input, session=session)
2760
2761  def _experimental_make_numpy_dataset(self, numpy_input, session):
2762    raise NotImplementedError("must be implemented in descendants")
2763
2764  def broadcast_to(self, tensor, destinations):
2765    """Mirror a tensor on one device to all worker devices.
2766
2767    Args:
2768      tensor: A Tensor value to broadcast.
2769      destinations: A mirrored variable or device string specifying the
2770        destination devices to copy `tensor` to.
2771
2772    Returns:
2773      A value mirrored to `destinations` devices.
2774    """
2775    assert destinations is not None  # from old strategy.broadcast()
2776    # TODO(josh11b): More docstring
2777    _require_cross_replica_or_default_context_extended(self)
2778    assert not isinstance(destinations, (list, tuple))
2779    return self._broadcast_to(tensor, destinations)
2780
2781  def _broadcast_to(self, tensor, destinations):
2782    raise NotImplementedError("must be implemented in descendants")
2783
2784  @deprecated(None, "please use `run` instead.")
2785  def experimental_run_steps_on_iterator(self,
2786                                         fn,
2787                                         iterator,
2788                                         iterations=1,
2789                                         initial_loop_values=None):
2790    """DEPRECATED: please use `run` instead.
2791
2792    Run `fn` with input from `iterator` for `iterations` times.
2793
2794    This method can be used to run a step function for training a number of
2795    times using input from a dataset.
2796
2797    Args:
2798      fn: function to run using this distribution strategy. The function must
2799        have the following signature: `def fn(context, inputs)`. `context` is an
2800          instance of `MultiStepContext` that will be passed when `fn` is run.
2801          `context` can be used to specify the outputs to be returned from `fn`
2802          by calling `context.set_last_step_output`. It can also be used to
2803          capture non tensor outputs by `context.set_non_tensor_output`. See
2804          `MultiStepContext` documentation for more information. `inputs` will
2805          have same type/structure as `iterator.get_next()`. Typically, `fn`
2806          will use `call_for_each_replica` method of the strategy to distribute
2807          the computation over multiple replicas.
2808      iterator: Iterator of a dataset that represents the input for `fn`. The
2809        caller is responsible for initializing the iterator as needed.
2810      iterations: (Optional) Number of iterations that `fn` should be run.
2811        Defaults to 1.
2812      initial_loop_values: (Optional) Initial values to be passed into the
2813        loop that runs `fn`. Defaults to `None`. # TODO(priyag): Remove
2814          initial_loop_values argument when we have a mechanism to infer the
2815          outputs of `fn`.
2816
2817    Returns:
2818      Returns the `MultiStepContext` object which has the following properties,
2819      among other things:
2820        - run_op: An op that runs `fn` `iterations` times.
2821        - last_step_outputs: A dictionary containing tensors set using
2822        `context.set_last_step_output`. Evaluating this returns the value of
2823        the tensors after the last iteration.
2824        - non_tensor_outputs: A dictionary containing anything that was set by
2825          `fn` by calling `context.set_non_tensor_output`.
2826    """
2827    _require_cross_replica_or_default_context_extended(self)
2828    with self._container_strategy().scope():
2829      return self._experimental_run_steps_on_iterator(fn, iterator, iterations,
2830                                                      initial_loop_values)
2831
2832  def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
2833                                          initial_loop_values):
2834    raise NotImplementedError("must be implemented in descendants")
2835
2836  def call_for_each_replica(self, fn, args=(), kwargs=None):
2837    """Run `fn` once per replica.
2838
2839    `fn` may call `tf.get_replica_context()` to access methods such as
2840    `replica_id_in_sync_group` and `merge_call()`.
2841
2842    `merge_call()` is used to communicate between the replicas and
2843    re-enter the cross-replica context. All replicas pause their execution
2844    having encountered a `merge_call()` call. After that the
2845    `merge_fn`-function is executed. Its results are then unwrapped and
2846    given back to each replica call. After that execution resumes until
2847    `fn` is complete or encounters another `merge_call()`.  Example:
2848
2849    ```python
2850    # Called once in "cross-replica" context.
2851    def merge_fn(distribution, three_plus_replica_id):
2852      # sum the values across replicas
2853      return sum(distribution.experimental_local_results(three_plus_replica_id))
2854
2855    # Called once per replica in `distribution`, in a "replica" context.
2856    def fn(three):
2857      replica_ctx = tf.get_replica_context()
2858      v = three + replica_ctx.replica_id_in_sync_group
2859      # Computes the sum of the `v` values across all replicas.
2860      s = replica_ctx.merge_call(merge_fn, args=(v,))
2861      return s + v
2862
2863    with distribution.scope():
2864      # in "cross-replica" context
2865      ...
2866      merged_results = distribution.run(fn, args=[3])
2867      # merged_results has the values from every replica execution of `fn`.
2868      # This statement prints a list:
2869      print(distribution.experimental_local_results(merged_results))
2870    ```
2871
2872    Args:
2873      fn: function to run (will be run once per replica).
2874      args: Tuple or list with positional arguments for `fn`.
2875      kwargs: Dict with keyword arguments for `fn`.
2876
2877    Returns:
2878      Merged return value of `fn` across all replicas.
2879    """
2880    _require_cross_replica_or_default_context_extended(self)
2881    if kwargs is None:
2882      kwargs = {}
2883    with self._container_strategy().scope():
2884      return self._call_for_each_replica(fn, args, kwargs)
2885
2886  def _call_for_each_replica(self, fn, args, kwargs):
2887    raise NotImplementedError("must be implemented in descendants")
2888
2889  def read_var(self, v):
2890    """Reads the value of a variable.
2891
2892    Returns the aggregate value of a replica-local variable, or the
2893    (read-only) value of any other variable.
2894
2895    Args:
2896      v: A variable allocated within the scope of this `tf.distribute.Strategy`.
2897
2898    Returns:
2899      A tensor representing the value of `v`, aggregated across replicas if
2900      necessary.
2901    """
2902    raise NotImplementedError("must be implemented in descendants")
2903
2904  def update_non_slot(
2905      self, colocate_with, fn, args=(), kwargs=None, group=True):
2906    """Runs `fn(*args, **kwargs)` on `colocate_with` devices.
2907
2908    Used to update non-slot variables.
2909
2910    DEPRECATED: TF 1.x ONLY.
2911
2912    Args:
2913      colocate_with: Devices returned by `non_slot_devices()`.
2914      fn: Function to execute.
2915      args: Tuple or list. Positional arguments to pass to `fn()`.
2916      kwargs: Dict with keyword arguments to pass to `fn()`.
2917      group: Boolean. Defaults to True. If False, the return value will be
2918        unwrapped.
2919
2920    Returns:
2921      Return value of `fn`, possibly merged across devices.
2922    """
2923    _require_cross_replica_or_default_context_extended(self)
2924    if kwargs is None:
2925      kwargs = {}
2926    fn = autograph.tf_convert(
2927        fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
2928    with self._container_strategy().scope():
2929      return self._update_non_slot(colocate_with, fn, args, kwargs, group)
2930
2931  def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
2932    raise NotImplementedError("must be implemented in descendants")
2933
2934  def non_slot_devices(self, var_list):
2935    """Device(s) for non-slot variables.
2936
2937    DEPRECATED: TF 1.x ONLY.
2938
2939    This method returns non-slot devices where non-slot variables are placed.
2940    Users can create non-slot variables on these devices by using a block:
2941
2942    ```python
2943    with tf.distribute.StrategyExtended.colocate_vars_with(tf.distribute.StrategyExtended.non_slot_devices(...)):
2944      ...
2945    ```
2946
2947    Args:
2948      var_list: The list of variables being optimized, needed with the
2949        default `tf.distribute.Strategy`.
2950    Returns:
2951      A sequence of devices for non-slot variables.
2952    """
2953    raise NotImplementedError("must be implemented in descendants")
2954
2955  def _use_merge_call(self):
2956    """Whether to use merge-calls inside the distributed strategy."""
2957    return True
2958
2959  @property
2960  def experimental_between_graph(self):
2961    """Whether the strategy uses between-graph replication or not.
2962
2963      This is expected to return a constant value that will not be changed
2964      throughout its life cycle.
2965    """
2966    raise NotImplementedError("must be implemented in descendants")
2967
2968  @property
2969  def experimental_should_init(self):
2970    """Whether initialization is needed."""
2971    raise NotImplementedError("must be implemented in descendants")
2972
2973  @property
2974  def should_checkpoint(self):
2975    """Whether checkpointing is needed."""
2976    raise NotImplementedError("must be implemented in descendants")
2977
2978  @property
2979  def should_save_summary(self):
2980    """Whether saving summaries is needed."""
2981    raise NotImplementedError("must be implemented in descendants")
2982
2983
2984# A note about the difference between the context managers
2985# `ReplicaContext` (defined here) and `_CurrentDistributionContext`
2986# (defined above) used by `tf.distribute.Strategy.scope()`:
2987#
2988# * a ReplicaContext is only present during a `run()`
2989#   call (except during a `merge_run` call) and in such a scope it
2990#   will be returned by calls to `get_replica_context()`.  Implementers of new
2991#   Strategy descendants will frequently also need to
2992#   define a descendant of ReplicaContext, and are responsible for
2993#   entering and exiting this context.
2994#
2995# * Strategy.scope() sets up a variable_creator scope that
2996#   changes variable creation calls (e.g. to make mirrored
2997#   variables). This is intended as an outer scope that users enter once
2998#   around their model creation and graph definition. There is no
2999#   anticipated need to define descendants of _CurrentDistributionContext.
3000#   It sets the current Strategy for purposes of
3001#   `get_strategy()` and `has_strategy()`
3002#   and switches the thread mode to a "cross-replica context".
3003class ReplicaContextBase(object):
3004  """A class with a collection of APIs that can be called in a replica context.
3005
3006  You can use `tf.distribute.get_replica_context` to get an instance of
3007  `ReplicaContext`, which can only be called inside the function passed to
3008  `tf.distribute.Strategy.run`.
3009
3010  >>> strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1'])
3011  >>> def func():
3012  ...   replica_context = tf.distribute.get_replica_context()
3013  ...   return replica_context.replica_id_in_sync_group
3014  >>> strategy.run(func)
3015  PerReplica:{
3016    0: <tf.Tensor: shape=(), dtype=int32, numpy=0>,
3017    1: <tf.Tensor: shape=(), dtype=int32, numpy=1>
3018  }
3019  """
3020
3021  def __init__(self, strategy, replica_id_in_sync_group):
3022    """Creates a ReplicaContext.
3023
3024    Args:
3025      strategy: A `tf.distribute.Strategy`.
3026      replica_id_in_sync_group: An integer, a `Tensor` or None. Prefer an
3027        integer whenever possible to avoid issues with nested `tf.function`. It
3028        accepts a `Tensor` only to be compatible with `tpu.replicate`.
3029    """
3030    self._strategy = strategy
3031    self._thread_context = distribution_strategy_context._InReplicaThreadMode(  # pylint: disable=protected-access
3032        self)
3033    if not (replica_id_in_sync_group is None or
3034            tensor_util.is_tf_type(replica_id_in_sync_group) or
3035            isinstance(replica_id_in_sync_group, int)):
3036      raise ValueError(
3037          "replica_id_in_sync_group can only be an integer, a Tensor or None.")
3038    self._replica_id_in_sync_group = replica_id_in_sync_group
3039    # We need this check because TPUContext extends from ReplicaContext and
3040    # does not pass a strategy object since it is used by TPUEstimator.
3041    if strategy:
3042      self._local_replica_id = strategy.extended._get_local_replica_id(
3043          replica_id_in_sync_group)
3044    self._summary_recording_distribution_strategy = None
3045
3046  @doc_controls.do_not_generate_docs
3047  def __enter__(self):
3048    _push_per_thread_mode(self._thread_context)
3049
3050    def replica_id_is_zero():
3051      return math_ops.equal(self.replica_id_in_sync_group,
3052                            constant_op.constant(0))
3053
3054    summary_state = summary_ops_v2._summary_state  # pylint: disable=protected-access
3055    self._summary_recording_distribution_strategy = (
3056        summary_state.is_recording_distribution_strategy)
3057    summary_state.is_recording_distribution_strategy = replica_id_is_zero
3058
3059  @doc_controls.do_not_generate_docs
3060  def __exit__(self, exception_type, exception_value, traceback):
3061    summary_state = summary_ops_v2._summary_state  # pylint: disable=protected-access
3062    summary_state.is_recording_distribution_strategy = (
3063        self._summary_recording_distribution_strategy)
3064    _pop_per_thread_mode()
3065
3066  def merge_call(self, merge_fn, args=(), kwargs=None):
3067    """Merge args across replicas and run `merge_fn` in a cross-replica context.
3068
3069    This allows communication and coordination when there are multiple calls
3070    to the step_fn triggered by a call to `strategy.run(step_fn, ...)`.
3071
3072    See `tf.distribute.Strategy.run` for an explanation.
3073
3074    If not inside a distributed scope, this is equivalent to:
3075
3076    ```
3077    strategy = tf.distribute.get_strategy()
3078    with cross-replica-context(strategy):
3079      return merge_fn(strategy, *args, **kwargs)
3080    ```
3081
3082    Args:
3083      merge_fn: Function that joins arguments from threads that are given as
3084        PerReplica. It accepts `tf.distribute.Strategy` object as
3085        the first argument.
3086      args: List or tuple with positional per-thread arguments for `merge_fn`.
3087      kwargs: Dict with keyword per-thread arguments for `merge_fn`.
3088
3089    Returns:
3090      The return value of `merge_fn`, except for `PerReplica` values which are
3091      unpacked.
3092    """
3093    require_replica_context(self)
3094    if kwargs is None:
3095      kwargs = {}
3096
3097    merge_fn = autograph.tf_convert(
3098        merge_fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
3099    return self._merge_call(merge_fn, args, kwargs)
3100
3101  def _merge_call(self, merge_fn, args, kwargs):
3102    """Default implementation for single replica."""
3103    _push_per_thread_mode(  # thread-local, so not needed with multiple threads
3104        distribution_strategy_context._CrossReplicaThreadMode(self._strategy))  # pylint: disable=protected-access
3105    try:
3106      return merge_fn(self._strategy, *args, **kwargs)
3107    finally:
3108      _pop_per_thread_mode()
3109
3110  @property
3111  def num_replicas_in_sync(self):
3112    """Returns number of replicas that are kept in sync."""
3113    return self._strategy.num_replicas_in_sync
3114
3115  @property
3116  def replica_id_in_sync_group(self):
3117    """Returns the id of the replica.
3118
3119    This identifies the replica among all replicas that are kept in sync. The
3120    value of the replica id can range from 0 to
3121    `tf.distribute.ReplicaContext.num_replicas_in_sync` - 1.
3122
3123    NOTE: This is not guaranteed to be the same ID as the XLA replica ID use
3124    for low-level operations such as collective_permute.
3125
3126    Returns:
3127      a `Tensor`.
3128    """
3129    # It's important to prefer making the Tensor at call time whenever possible.
3130    # Keeping Tensors in global states doesn't work well with nested
3131    # tf.function, since it's possible that the tensor is generated in one func
3132    # graph, and gets captured by another, which will result in a subtle "An op
3133    # outside of the function building code is being passed a Graph tensor"
3134    # error. Making the tensor at call time to ensure it is the same graph where
3135    # it's used. However to be compatible with tpu.replicate(),
3136    # self._replica_id_in_sync_group can also be a Tensor.
3137    if tensor_util.is_tf_type(self._replica_id_in_sync_group):
3138      return self._replica_id_in_sync_group
3139    return constant_op.constant(
3140        self._replica_id_in_sync_group,
3141        dtypes.int32,
3142        name="replica_id_in_sync_group")
3143
3144  @property
3145  def _replica_id(self):
3146    """This is the local replica id in a given sync group."""
3147    return self._local_replica_id
3148
3149  @property
3150  def strategy(self):
3151    """The current `tf.distribute.Strategy` object."""
3152    return self._strategy
3153
3154  @property
3155  @deprecation.deprecated(None, "Please avoid relying on devices property.")
3156  def devices(self):
3157    """Returns the devices this replica is to be executed on, as a tuple of strings.
3158
3159    NOTE: For `tf.distribute.MirroredStrategy` and
3160    `tf.distribute.experimental.MultiWorkerMirroredStrategy`, this returns a
3161    nested
3162    list of device strings, e.g, [["GPU:0"]].
3163    """
3164    require_replica_context(self)
3165    return (device_util.current(),)
3166
3167  def all_reduce(self, reduce_op, value, options=None):
3168    """All-reduces `value` across all replicas.
3169
3170    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
3171    >>> def step_fn():
3172    ...   ctx = tf.distribute.get_replica_context()
3173    ...   value = tf.identity(1.)
3174    ...   return ctx.all_reduce(tf.distribute.ReduceOp.SUM, value)
3175    >>> strategy.experimental_local_results(strategy.run(step_fn))
3176    (<tf.Tensor: shape=(), dtype=float32, numpy=2.0>,
3177     <tf.Tensor: shape=(), dtype=float32, numpy=2.0>)
3178
3179    It supports batched operations. You can pass a list of values and it
3180    attempts to batch them when possible. You can also specify `options`
3181    to indicate the desired batching behavior, e.g. batch the values into
3182    multiple packs so that they can better overlap with computations.
3183
3184    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
3185    >>> def step_fn():
3186    ...   ctx = tf.distribute.get_replica_context()
3187    ...   value1 = tf.identity(1.)
3188    ...   value2 = tf.identity(2.)
3189    ...   return ctx.all_reduce(tf.distribute.ReduceOp.SUM, [value1, value2])
3190    >>> strategy.experimental_local_results(strategy.run(step_fn))
3191    ([<tf.Tensor: shape=(), dtype=float32, numpy=2.0>,
3192    <tf.Tensor: shape=(), dtype=float32, numpy=4.0>],
3193    [<tf.Tensor: shape=(), dtype=float32, numpy=2.0>,
3194    <tf.Tensor: shape=(), dtype=float32, numpy=4.0>])
3195
3196    Note that all replicas need to participate in the all-reduce, otherwise this
3197    operation hangs. Note that if there're multiple all-reduces, they need to
3198    execute in the same order on all replicas. Dispatching all-reduce based on
3199    conditions is usually error-prone.
3200
3201    Known limitation: if `value` contains `tf.IndexedSlices`, attempting to
3202    compute gradient w.r.t `value` would result in an error.
3203
3204    This API currently can only be called in the replica context. Other
3205    variants to reduce values across replicas are:
3206    * `tf.distribute.StrategyExtended.reduce_to`: the reduce and all-reduce API
3207      in the cross-replica context.
3208    * `tf.distribute.StrategyExtended.batch_reduce_to`: the batched reduce and
3209      all-reduce API in the cross-replica context.
3210    * `tf.distribute.Strategy.reduce`: a more convenient method to reduce
3211      to the host in cross-replica context.
3212
3213    Args:
3214      reduce_op: a `tf.distribute.ReduceOp` value specifying how values should
3215        be combined. Allows using string representation of the enum such as
3216        "SUM", "MEAN".
3217      value: a potentially nested structure of `tf.Tensor` or `tf.IndexedSlices` which
3218        `tf.nest.flatten` accepts. The structure and the shapes of `value` need to be
3219        same on all replicas.
3220      options: a `tf.distribute.experimental.CommunicationOptions`. Options to
3221        perform collective operations. This overrides the default options if the
3222        `tf.distribute.Strategy` takes one in the constructor. See
3223        `tf.distribute.experimental.CommunicationOptions` for details of the
3224        options.
3225
3226    Returns:
3227       A nested structure of `tf.Tensor` with the reduced values. The structure
3228       is the same as `value`.
3229    """
3230    flattened_value = nest.flatten(value)
3231    has_indexed_slices = False
3232
3233    for v in flattened_value:
3234      if isinstance(v, ops.IndexedSlices):
3235        has_indexed_slices = True
3236
3237    if isinstance(reduce_op, six.string_types):
3238      reduce_op = reduce_util.ReduceOp(reduce_op.upper())
3239    if options is None:
3240      options = collective_util.Options()
3241
3242    def batch_all_reduce(strategy, *value_flat):
3243      return strategy.extended.batch_reduce_to(
3244          reduce_op, [(v, _batch_reduce_destination(v)) for v in value_flat],
3245          options)
3246
3247    # Due to the use of `capture_call_time_value` in collective ops, we have
3248    # to maintain two branches: one w/ merge_call and one w/o. Details can be
3249    # found in b/184009754.
3250    if self._strategy.extended._use_merge_call():  # pylint: disable=protected-access
3251      # TODO(cjfj): Work out why `batch_reduce` doesn't return the correct grad.
3252      if has_indexed_slices:
3253        return nest.pack_sequence_as(
3254            value,
3255            self.merge_call(batch_all_reduce, args=flattened_value))
3256
3257      @custom_gradient.custom_gradient
3258      def grad_wrapper(*xs):
3259        ys = self.merge_call(batch_all_reduce, args=xs)
3260        # The gradient of an all-sum is itself an all-sum (all-mean, likewise).
3261        return ys, lambda *dy_s: self.all_reduce(reduce_op, dy_s)
3262      return nest.pack_sequence_as(value, grad_wrapper(*flattened_value))
3263    else:
3264      if has_indexed_slices:
3265        return nest.pack_sequence_as(
3266            value,
3267            self._strategy.extended._replica_ctx_all_reduce(  # pylint: disable=protected-access
3268                reduce_op, flattened_value, options))
3269
3270      @custom_gradient.custom_gradient
3271      def grad_wrapper(*xs):
3272        ys = self._strategy.extended._replica_ctx_all_reduce(  # pylint: disable=protected-access
3273            reduce_op, xs, options)
3274        # The gradient of an all-sum is itself an all-sum (all-mean, likewise).
3275        return ys, lambda *dy_s: self.all_reduce(reduce_op, dy_s)
3276
3277      return nest.pack_sequence_as(value, grad_wrapper(*flattened_value))
3278
3279  # TODO(josh11b): Implement `start_all_reduce(method, t)` for efficient
3280  # all-reduce. It would return a function returning the result of reducing `t`
3281  # across all replicas. The caller would wait to call this function until they
3282  # needed the reduce result, allowing an efficient implementation:
3283  # * With eager execution, the reduction could be performed asynchronously
3284  #   in the background, not blocking until the result was needed.
3285  # * When constructing a graph, it could batch up all reduction requests up
3286  #   to that point that the first result is needed. Most likely this can be
3287  #   implemented in terms of `merge_call()` and `batch_reduce_to()`.
3288
3289
3290@tf_export("distribute.ReplicaContext", v1=[])
3291class ReplicaContext(ReplicaContextBase):
3292
3293  __doc__ = ReplicaContextBase.__doc__
3294
3295  def all_gather(self, value, axis, options=None):
3296    """All-gathers `value` across all replicas along `axis`.
3297
3298    Note: An `all_gather` method can only be called in replica context. For
3299    a cross-replica context counterpart, see `tf.distribute.Strategy.gather`.
3300    All replicas need to participate in the all-gather, otherwise this
3301    operation hangs. So if `all_gather` is called in any replica, it must be
3302    called in all replicas.
3303
3304    Note: If there are multiple `all_gather` calls, they need to be executed in
3305    the same order on all replicas. Dispatching `all_gather` based on conditions
3306    is usually error-prone.
3307
3308    For all strategies except `tf.distribute.TPUStrategy`, the input
3309    `value` on different replicas must have the same rank, and their shapes must
3310    be the same in all dimensions except the `axis`-th dimension. In other
3311    words, their shapes cannot be different in a dimension `d` where `d` does
3312    not equal to the `axis` argument. For example, given a
3313    `tf.distribute.DistributedValues` with component tensors of shape
3314    `(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can call
3315    `all_gather(..., axis=1, ...)` on it, but not `all_gather(..., axis=0, ...)`
3316    or `all_gather(..., axis=2, ...)`. However, with
3317    `tf.distribute.TPUStrategy`, all tensors must have exactly the same rank and
3318    same shape.
3319
3320    Note: The input `value` must have a non-zero rank. Otherwise, consider using
3321    `tf.expand_dims` before gathering them.
3322
3323    You can pass in a single tensor to all-gather:
3324
3325    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
3326    >>> @tf.function
3327    ... def gather_value():
3328    ...   ctx = tf.distribute.get_replica_context()
3329    ...   local_value = tf.constant([1, 2, 3])
3330    ...   return ctx.all_gather(local_value, axis=0)
3331    >>> result = strategy.run(gather_value)
3332    >>> result
3333    PerReplica:{
3334      0: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>,
3335      1: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>
3336    }
3337    >>> strategy.experimental_local_results(result)
3338    (<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3],
3339    dtype=int32)>,
3340    <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3],
3341    dtype=int32)>)
3342
3343
3344    You can also pass in a nested structure of tensors to all-gather, say, a
3345    list:
3346
3347    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
3348    >>> @tf.function
3349    ... def gather_nest():
3350    ...   ctx = tf.distribute.get_replica_context()
3351    ...   value_1 = tf.constant([1, 2, 3])
3352    ...   value_2 = tf.constant([[1, 2], [3, 4]])
3353    ...   # all_gather a nest of `tf.distribute.DistributedValues`
3354    ...   return ctx.all_gather([value_1, value_2], axis=0)
3355    >>> result = strategy.run(gather_nest)
3356    >>> result
3357    [PerReplica:{
3358      0: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>,
3359      1: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>
3360    }, PerReplica:{
3361      0: <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
3362    array([[1, 2],
3363           [3, 4],
3364           [1, 2],
3365           [3, 4]], dtype=int32)>,
3366      1: <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
3367    array([[1, 2],
3368           [3, 4],
3369           [1, 2],
3370           [3, 4]], dtype=int32)>
3371    }]
3372    >>> strategy.experimental_local_results(result)
3373    ([<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>,
3374    <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
3375    array([[1, 2],
3376           [3, 4],
3377           [1, 2],
3378           [3, 4]], dtype=int32)>],
3379           [<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>,
3380           <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
3381    array([[1, 2],
3382           [3, 4],
3383           [1, 2],
3384           [3, 4]], dtype=int32)>])
3385
3386
3387    What if you are all-gathering tensors with different shapes on different
3388    replicas? Consider the following example with two replicas, where you have
3389    `value` as a nested structure consisting of two items to all-gather, `a` and
3390    `b`.
3391
3392      On Replica 0, `value` is `{'a': [0], 'b': [[0, 1]]}`.
3393
3394      On Replica 1, `value` is `{'a': [1], 'b': [[2, 3], [4, 5]]}`.
3395
3396      Result for `all_gather` with `axis`=0 (on each of the replicas) is:
3397
3398      ```{'a': [1, 2], 'b': [[0, 1], [2, 3], [4, 5]]}```
3399
3400    Args:
3401      value: a nested structure of `tf.Tensor` which `tf.nest.flatten` accepts,
3402        or a `tf.distribute.DistributedValues` instance. The structure of the
3403        `tf.Tensor` need to be same on all replicas. The underlying tensor
3404        constructs can only be dense tensors with non-zero rank, NOT
3405        `tf.IndexedSlices`.
3406      axis: 0-D int32 Tensor. Dimension along which to gather.
3407      options: a `tf.distribute.experimental.CommunicationOptions`. Options to
3408        perform collective operations. This overrides the default options if the
3409        `tf.distribute.Strategy` takes one in the constructor. See
3410        `tf.distribute.experimental.CommunicationOptions` for details of the
3411        options.
3412
3413    Returns:
3414       A nested structure of `tf.Tensor` with the gathered values. The structure
3415       is the same as `value`.
3416    """
3417    for v in nest.flatten(value):
3418      if isinstance(v, ops.IndexedSlices):
3419        raise NotImplementedError("all_gather does not support IndexedSlices")
3420
3421    if options is None:
3422      options = collective_util.Options()
3423
3424    def batch_all_gather(strategy, *value_flat):
3425      return strategy.extended._batch_gather_to(  # pylint: disable=protected-access
3426          [(v, _batch_reduce_destination(v)) for v in value_flat], axis,
3427          options)
3428
3429    @custom_gradient.custom_gradient
3430    def grad_wrapper(*xs):
3431      ys = self.merge_call(batch_all_gather, args=xs)
3432
3433      def grad(*dy_s):
3434        grads = self.all_reduce(reduce_util.ReduceOp.SUM, dy_s)
3435        new_grads = []
3436        for i, grad in enumerate(grads):
3437          input_shape = array_ops.shape(xs[i])
3438          axis_dim = array_ops.reshape(input_shape[axis], [1])
3439          with ops.control_dependencies([array_ops.identity(grads)]):
3440            d = self.all_gather(axis_dim, axis=0)
3441            begin_dim = math_ops.reduce_sum(d[:self.replica_id_in_sync_group])
3442            end_dim = begin_dim + array_ops.shape(xs[i])[axis]
3443            new_grad = array_ops.gather(
3444                grad, axis=axis, indices=math_ops.range(begin_dim, end_dim))
3445            new_grads.append(new_grad)
3446        return new_grads
3447
3448      return ys, grad
3449
3450    return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value)))
3451
3452  def _update(self, var, fn, args=(), kwargs=None, group=True):
3453    """Run `fn` to update `var` with `args` and `kwargs` in replica context.
3454
3455    `tf.distribute.ReplicaContext.update` takes a (distributed) variable `var`
3456    to be updated, an update function `fn`, and `args` and `kwargs` for `fn`.
3457    `fn` applies to each component variable of `var` with corresponding input
3458    values from `args` and `kwargs`.
3459
3460    Example usage:
3461
3462    >>> strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) # 2 replicas
3463    >>> with strategy.scope():
3464    ...   distributed_variable = tf.Variable(5.0)
3465    >>> distributed_variable
3466    MirroredVariable:{
3467      0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>,
3468      1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=5.0>
3469    }
3470    >>> def replica_fn(v):
3471    ...   value = tf.identity(1.0)
3472    ...   replica_context = tf.distribute.get_replica_context()
3473    ...   update_fn = lambda var, value: var.assign(value)
3474    ...   replica_context._update(v, update_fn, args=(value,))
3475    >>> strategy.run(replica_fn, args=(distributed_variable,))
3476    >>> distributed_variable
3477    MirroredVariable:{
3478      0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>,
3479      1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=1.0>
3480    }
3481
3482    This API must be called in a replica context.
3483
3484    Note that if `var` is a MirroredVariable (i.e., the type of variable created
3485    under the scope of a synchronous strategy, and is synchronized on-write, see
3486    `tf.VariableSynchronization` for more information) and `args`/`kwargs`
3487    contains different values for different replicas, `var` will be dangerously
3488    out of synchronization. Thus we recommend using `variable.assign(value)` as
3489    long as you can, which under the hood aggregates the updates and guarantees
3490    the synchronization. The case where you actually want this API instead of
3491    `variable.assign(value)` is that before assigning `value` to the `variable`,
3492    you'd like to conduct some pre-`assign` computation colocated with the
3493    variable devices (i.e. where variables reside, for MirroredStrategy they are
3494    the same as the compute device, for ParameterServerStrategy they refer to
3495    parameter servers). E.g.,
3496
3497    ```python
3498    strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) # 2 replicas
3499    with strategy.scope():
3500      v = tf.Variable(5.0, aggregation=tf.VariableAggregation.SUM)
3501    def replica_fn(inputs):
3502      value = computation(inputs)
3503      replica_context = tf.distribute.get_replica_context()
3504      reduced_value = replica_context.all_reduce(value)
3505
3506      def update_fn(var, value):
3507        # this computation will colocate with `var`'s device
3508        updated_value = post_reduce_pre_update_computation(value)
3509        var.assign(value)
3510
3511      replica_context._update(v, update_fn, args=(reduced_value,))
3512
3513    strategy.run(replica_fn, args=(inputs,))
3514    ```
3515
3516    This code snippet is consistent across all strategies. If you directly
3517    compute and use `assign` in the replica context instead of wrapping it with
3518    `update`, for strategies with fewer variable devices than compute devices
3519    (e.g., parameter server strategy, usually), the
3520    `post_reduce_pre_update_computation` will happen
3521    N==number_of_compute_devices times which is less performant.
3522
3523
3524    Args:
3525      var: Variable, possibly distributed to multiple devices, to operate on.
3526      fn: Function to call. Should take the variable as the first argument.
3527      args: Tuple or list. Additional positional arguments to pass to `fn()`.
3528      kwargs: Dict with keyword arguments to pass to `fn()`.
3529      group: Boolean. Defaults to True. Most strategies enter a merge_call to
3530      conduct update in cross-replica context, and group=True guarantees updates
3531      on all replicas is executed.
3532
3533    Returns:
3534      The return value of `fn` for the local replica.
3535    """
3536    if kwargs is None:
3537      kwargs = {}
3538    return self._strategy.extended._replica_ctx_update(var, fn, args=args, kwargs=kwargs, group=group)  # pylint: disable=protected-access
3539
3540
3541@tf_export(v1=["distribute.ReplicaContext"])
3542class ReplicaContextV1(ReplicaContextBase):
3543  __doc__ = ReplicaContextBase.__doc__
3544
3545
3546def _batch_reduce_destination(x):
3547  """Returns the destinations for batch all-reduce."""
3548  if isinstance(x, ops.Tensor):
3549    # If this is a one device strategy.
3550    return x.device
3551  else:
3552    return x
3553
3554
3555# ------------------------------------------------------------------------------
3556
3557
3558_creating_default_strategy_singleton = False
3559
3560
3561class _DefaultDistributionStrategyV1(StrategyV1):
3562  """Default `tf.distribute.Strategy` if none is explicitly selected."""
3563
3564  def __init__(self):
3565    if not _creating_default_strategy_singleton:
3566      raise RuntimeError("Should only create a single instance of "
3567                         "_DefaultDistributionStrategy")
3568    super(_DefaultDistributionStrategyV1,
3569          self).__init__(_DefaultDistributionExtended(self))
3570
3571  def __deepcopy__(self, memo):
3572    del memo
3573    raise RuntimeError("Should only create a single instance of "
3574                       "_DefaultDistributionStrategy")
3575
3576
3577class _DefaultDistributionStrategy(Strategy):
3578  """Default `tf.distribute.Strategy` if none is explicitly selected."""
3579
3580  def __init__(self):
3581    if not _creating_default_strategy_singleton:
3582      raise RuntimeError("Should only create a single instance of "
3583                         "_DefaultDistributionStrategy")
3584    super(_DefaultDistributionStrategy, self).__init__(
3585        _DefaultDistributionExtended(self))
3586
3587  def __deepcopy__(self, memo):
3588    del memo
3589    raise RuntimeError("Should only create a single instance of "
3590                       "_DefaultDistributionStrategy")
3591
3592
3593class _DefaultDistributionContext(object):
3594  """Context manager setting the default `tf.distribute.Strategy`."""
3595
3596  __slots__ = ["_var_creator_scope", "_strategy", "_nested_count"]
3597
3598  def __init__(self, strategy):
3599
3600    def creator(next_creator, **kwargs):
3601      _require_strategy_scope_strategy(strategy)
3602      return next_creator(**kwargs)
3603
3604    self._var_creator_scope = variable_scope.variable_creator_scope(creator)
3605    self._strategy = strategy
3606    self._nested_count = 0
3607
3608  def __enter__(self):
3609    # Allow this scope to be entered if this strategy is already in scope.
3610    if distribution_strategy_context.has_strategy():
3611      raise RuntimeError("Must not nest tf.distribute.Strategy scopes.")
3612    if self._nested_count == 0:
3613      self._var_creator_scope.__enter__()
3614    self._nested_count += 1
3615    return self._strategy
3616
3617  def __exit__(self, exception_type, exception_value, traceback):
3618    self._nested_count -= 1
3619    if self._nested_count == 0:
3620      try:
3621        self._var_creator_scope.__exit__(
3622            exception_type, exception_value, traceback)
3623      except RuntimeError as e:
3624        six.raise_from(
3625            RuntimeError("Variable creator scope nesting error: move call to "
3626                         "tf.distribute.set_strategy() out of `with` scope."),
3627            e)
3628
3629
3630class _DefaultDistributionExtended(StrategyExtendedV1):
3631  """Implementation of _DefaultDistributionStrategy."""
3632
3633  def __init__(self, container_strategy):
3634    super(_DefaultDistributionExtended, self).__init__(container_strategy)
3635    self._retrace_functions_for_each_device = False
3636
3637  def _scope(self, strategy):
3638    """Context manager setting a variable creator and `self` as current."""
3639    return _DefaultDistributionContext(strategy)
3640
3641  def colocate_vars_with(self, colocate_with_variable):
3642    """Does not require `self.scope`."""
3643    _require_strategy_scope_extended(self)
3644    return ops.colocate_with(colocate_with_variable)
3645
3646  def variable_created_in_scope(self, v):
3647    return v._distribute_strategy is None  # pylint: disable=protected-access
3648
3649  def _experimental_distribute_dataset(self, dataset, options):
3650    return dataset
3651
3652  def _distribute_datasets_from_function(self, dataset_fn, options):
3653    return dataset_fn(InputContext())
3654
3655  def _experimental_distribute_values_from_function(self, value_fn):
3656    return value_fn(ValueContext())
3657
3658  def _make_dataset_iterator(self, dataset):
3659    return _DefaultDistributionExtended.DefaultInputIterator(dataset)
3660
3661  def _make_input_fn_iterator(self,
3662                              input_fn,
3663                              replication_mode=InputReplicationMode.PER_WORKER):
3664    dataset = input_fn(InputContext())
3665    return _DefaultDistributionExtended.DefaultInputIterator(dataset)
3666
3667  def _experimental_make_numpy_dataset(self, numpy_input, session):
3668    numpy_flat = nest.flatten(numpy_input)
3669    vars_flat = tuple(
3670        variable_scope.variable(array_ops.zeros(i.shape, i.dtype),
3671                                trainable=False, use_resource=True)
3672        for i in numpy_flat
3673    )
3674    for v, i in zip(vars_flat, numpy_flat):
3675      numpy_dataset.init_var_from_numpy(v, i, session)
3676    vars_nested = nest.pack_sequence_as(numpy_input, vars_flat)
3677    return dataset_ops.Dataset.from_tensor_slices(vars_nested)
3678
3679  def _broadcast_to(self, tensor, destinations):
3680    if destinations is None:
3681      return tensor
3682    else:
3683      raise NotImplementedError("TODO")
3684
3685  def _call_for_each_replica(self, fn, args, kwargs):
3686    with ReplicaContext(self._container_strategy(), replica_id_in_sync_group=0):
3687      return fn(*args, **kwargs)
3688
3689  def _reduce_to(self, reduce_op, value, destinations, options):
3690    # TODO(josh11b): Use destinations?
3691    del reduce_op, destinations, options
3692    return value
3693
3694  def _gather_to_implementation(self, value, destinations, axis, options):
3695    del destinations, axis, options
3696    return value
3697
3698  def _update(self, var, fn, args, kwargs, group):
3699    # The implementations of _update() and _update_non_slot() are identical
3700    # except _update() passes `var` as the first argument to `fn()`.
3701    return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group)
3702
3703  def _update_non_slot(self, colocate_with, fn, args, kwargs, should_group):
3704    # TODO(josh11b): Figure out what we should be passing to UpdateContext()
3705    # once that value is used for something.
3706    with UpdateContext(colocate_with):
3707      result = fn(*args, **kwargs)
3708      if should_group:
3709        return result
3710      else:
3711        return nest.map_structure(self._local_results, result)
3712
3713  def read_var(self, replica_local_var):
3714    return array_ops.identity(replica_local_var)
3715
3716  def _local_results(self, distributed_value):
3717    return (distributed_value,)
3718
3719  def value_container(self, value):
3720    return value
3721
3722  @property
3723  def _num_replicas_in_sync(self):
3724    return 1
3725
3726  @property
3727  def worker_devices(self):
3728    raise RuntimeError("worker_devices() method unsupported by default "
3729                       "tf.distribute.Strategy.")
3730
3731  @property
3732  def parameter_devices(self):
3733    raise RuntimeError("parameter_devices() method unsupported by default "
3734                       "tf.distribute.Strategy.")
3735
3736  def non_slot_devices(self, var_list):
3737    return min(var_list, key=lambda x: x.name)
3738
3739  def _in_multi_worker_mode(self):
3740    """Whether this strategy indicates working in multi-worker settings."""
3741    # Default strategy doesn't indicate multi-worker training.
3742    return False
3743
3744  @property
3745  def should_checkpoint(self):
3746    return True
3747
3748  @property
3749  def should_save_summary(self):
3750    return True
3751
3752  def _get_local_replica_id(self, replica_id_in_sync_group):
3753    return replica_id_in_sync_group
3754
3755  def _get_replica_id_in_sync_group(self, replica_id):
3756    return replica_id
3757
3758  # TODO(priyag): This should inherit from `InputIterator`, once dependency
3759  # issues have been resolved.
3760  class DefaultInputIterator(object):
3761    """Default implementation of `InputIterator` for default strategy."""
3762
3763    def __init__(self, dataset):
3764      self._dataset = dataset
3765      if eager_context.executing_eagerly():
3766        self._iterator = dataset_ops.make_one_shot_iterator(dataset)
3767      else:
3768        self._iterator = dataset_ops.make_initializable_iterator(dataset)
3769
3770    def get_next(self):
3771      return self._iterator.get_next()
3772
3773    def get_next_as_optional(self):
3774      return self._iterator.get_next_as_optional()
3775
3776    @deprecated(None, "Use the iterator's `initializer` property instead.")
3777    def initialize(self):
3778      """Initialize underlying iterators.
3779
3780      Returns:
3781        A list of any initializer ops that should be run.
3782      """
3783      if eager_context.executing_eagerly():
3784        self._iterator = self._dataset.make_one_shot_iterator()
3785        return []
3786      else:
3787        return [self._iterator.initializer]
3788
3789    @property
3790    def initializer(self):
3791      """Returns a list of ops that initialize the iterator."""
3792      return self.initialize()
3793
3794  # TODO(priyag): Delete this once all strategies use global batch size.
3795  @property
3796  def _global_batch_size(self):
3797    """Global and per-replica batching are equivalent for this strategy."""
3798    return True
3799
3800
3801class _DefaultReplicaContext(ReplicaContext):
3802  """ReplicaContext for _DefaultDistributionStrategy."""
3803
3804  @property
3805  def replica_id_in_sync_group(self):
3806    # Return 0 instead of a constant tensor to avoid creating a new node for
3807    # users who don't use distribution strategy.
3808    return 0
3809
3810
3811# ------------------------------------------------------------------------------
3812# We haven't yet implemented deserialization for DistributedVariables.
3813# So here we catch any attempts to deserialize variables
3814# when using distribution strategies.
3815# pylint: disable=protected-access
3816_original_from_proto = resource_variable_ops._from_proto_fn
3817
3818
3819def _from_proto_fn(v, import_scope=None):
3820  if distribution_strategy_context.has_strategy():
3821    raise NotImplementedError(
3822        "Deserialization of variables is not yet supported when using a "
3823        "tf.distribute.Strategy.")
3824  else:
3825    return _original_from_proto(v, import_scope=import_scope)
3826
3827resource_variable_ops._from_proto_fn = _from_proto_fn
3828# pylint: enable=protected-access
3829
3830
3831#-------------------------------------------------------------------------------
3832# Shorthand for some methods from distribution_strategy_context.
3833_push_per_thread_mode = distribution_strategy_context._push_per_thread_mode  # pylint: disable=protected-access
3834_get_per_thread_mode = distribution_strategy_context._get_per_thread_mode  # pylint: disable=protected-access
3835_pop_per_thread_mode = distribution_strategy_context._pop_per_thread_mode  # pylint: disable=protected-access
3836_get_default_replica_mode = (
3837    distribution_strategy_context._get_default_replica_mode)  # pylint: disable=protected-access
3838
3839
3840# ------------------------------------------------------------------------------
3841# Metrics to track which distribution strategy is being called
3842distribution_strategy_gauge = monitoring.StringGauge(
3843    "/tensorflow/api/distribution_strategy",
3844    "Gauge to track the type of distribution strategy used.", "TFVersion")
3845distribution_strategy_replica_gauge = monitoring.IntGauge(
3846    "/tensorflow/api/distribution_strategy/replica",
3847    "Gauge to track the number of replica each distribution strategy used.",
3848    "CountType")
3849distribution_strategy_input_api_counter = monitoring.Counter(
3850    "/tensorflow/api/distribution_strategy/input_api",
3851    "Counter to track the usage of the input APIs", "strategy", "api")
3852