• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=line-too-long
16"""Library for running a computation across multiple devices.
17
18The intent of this library is that you can write an algorithm in a stylized way
19and it will be usable with a variety of different `tf.distribute.Strategy`
20implementations. Each descendant will implement a different strategy for
21distributing the algorithm across multiple devices/machines.  Furthermore, these
22changes can be hidden inside the specific layers and other library classes that
23need special treatment to run in a distributed setting, so that most users'
24model definition code can run unchanged. The `tf.distribute.Strategy` API works
25the same way with eager and graph execution.
26
27*Guides*
28
29* [TensorFlow v2.x](https://www.tensorflow.org/guide/distributed_training)
30* [TensorFlow v1.x](https://github.com/tensorflow/docs/blob/master/site/en/r1/guide/distribute_strategy.ipynb)
31
32*Tutorials*
33
34* [Distributed Training Tutorials](https://www.tensorflow.org/tutorials/distribute/)
35
36  The tutorials cover how to use `tf.distribute.Strategy` to do distributed
37  training with native Keras APIs, custom training loops,
38  and Estimator APIs. They also cover how to save/load model when using
39  `tf.distribute.Strategy`.
40
41*Glossary*
42
43* _Data parallelism_ is where we run multiple copies of the model
44  on different slices of the input data. This is in contrast to
45  _model parallelism_ where we divide up a single copy of a model
46  across multiple devices.
47  Note: we only support data parallelism for now, but
48  hope to add support for model parallelism in the future.
49* A _device_ is a CPU or accelerator (e.g. GPUs, TPUs) on some machine that
50  TensorFlow can run operations on (see e.g. `tf.device`). You may have multiple
51  devices on a single machine, or be connected to devices on multiple
52  machines. Devices used to run computations are called _worker devices_.
53  Devices used to store variables are _parameter devices_. For some strategies,
54  such as `tf.distribute.MirroredStrategy`, the worker and parameter devices
55  will be the same (see mirrored variables below). For others they will be
56  different. For example, `tf.distribute.experimental.CentralStorageStrategy`
57  puts the variables on a single device (which may be a worker device or may be
58  the CPU), and `tf.distribute.experimental.ParameterServerStrategy` puts the
59  variables on separate machines called _parameter servers_ (see below).
60* A _replica_ is one copy of the model, running on one slice of the
61  input data. Right now each replica is executed on its own
62  worker device, but once we add support for model parallelism
63  a replica may span multiple worker devices.
64* A _host_ is the CPU device on a machine with worker devices, typically
65  used for running input pipelines.
66* A _worker_ is defined to be the physical machine(s) containing the physical
67  devices (e.g. GPUs, TPUs) on which the replicated computation is executed. A
68  worker may contain one or more replicas, but contains at least one
69  replica. Typically one worker will correspond to one machine, but in the case
70  of very large models with model parallelism, one worker may span multiple
71  machines. We typically run one input pipeline per worker, feeding all the
72  replicas on that worker.
73* _Synchronous_, or more commonly _sync_, training is where the updates from
74  each replica are aggregated together before updating the model variables. This
75  is in contrast to _asynchronous_, or _async_ training, where each replica
76  updates the model variables independently. You may also have replicas
77  partitioned into groups which are in sync within each group but async between
78  groups.
79* _Parameter servers_: These are machines that hold a single copy of
80  parameters/variables, used by some strategies (right now just
81  `tf.distribute.experimental.ParameterServerStrategy`). All replicas that want
82  to operate on a variable retrieve it at the beginning of a step and send an
83  update to be applied at the end of the step. These can in principle support
84  either sync or async training, but right now we only have support for async
85  training with parameter servers. Compare to
86  `tf.distribute.experimental.CentralStorageStrategy`, which puts all variables
87  on a single device on the same machine (and does sync training), and
88  `tf.distribute.MirroredStrategy`, which mirrors variables to multiple devices
89  (see below).
90
91* _Replica context_ vs. _Cross-replica context_ vs _Update context_
92
93  A _replica context_ applies
94  when you execute the computation function that was called with `strategy.run`.
95  Conceptually, you're in replica context when executing the computation
96  function that is being replicated.
97
98  An _update context_ is entered in a `tf.distribute.StrategyExtended.update`
99  call.
100
101  An _cross-replica context_ is entered when you enter a `strategy.scope`. This
102  is useful for calling `tf.distribute.Strategy` methods which operate across
103  the replicas (like `reduce_to()`). By default you start in a _replica context_
104  (the "default single _replica context_") and then some methods can switch you
105  back and forth.
106
107* _Distributed value_: Distributed value is represented by the base class
108  `tf.distribute.DistributedValues`. `tf.distribute.DistributedValues` is useful
109  to represent values on multiple devices, and it contains a map from replica id
110  to values. Two representative types of `tf.distribute.DistributedValues`
111  are `tf.types.experimental.PerReplica` and `tf.types.experimental.Mirrored`
112  values.
113
114  `PerReplica` values exist on the worker devices, with a different value for
115  each replica. They are produced by iterating through a distributed dataset
116  returned by `tf.distribute.Strategy.experimental_distribute_dataset` and
117  `tf.distribute.Strategy.distribute_datasets_from_function`. They are also the
118  typical result returned by `tf.distribute.Strategy.run`.
119
120  `Mirrored` values are like `PerReplica` values, except we know that the value
121  on all replicas are the same. `Mirrored` values are kept synchronized by the
122  distribution strategy in use, while `PerReplica` values are left
123  unsynchronized. `Mirrored` values typically represent model weights. We can
124  safely read a `Mirrored` value in a cross-replica context by using the value
125  on any replica, while PerReplica values can only be read within a replica
126  context.
127
128* _Unwrapping_ and _merging_: Consider calling a function `fn` on multiple
129  replicas, like `strategy.run(fn, args=[w])` with an
130  argument `w` that is a `tf.distribute.DistributedValues`. This means `w` will
131  have a map taking replica id `0` to `w0`, replica id `1` to `w1`, etc.
132  `strategy.run()` unwraps `w` before calling `fn`, so it calls `fn(w0)` on
133  device `d0`, `fn(w1)` on device `d1`, etc.  It then merges the return
134  values from `fn()`, which leads to one common object if the returned values
135  are the same object from every replica, or a `DistributedValues` object
136  otherwise.
137
138* _Reductions_ and _all-reduce_: A _reduction_ is a method of aggregating
139  multiple values into one value, like "sum" or "mean". If a strategy is doing
140  sync training, we will perform a reduction on the gradients to a parameter
141  from all replicas before applying the update. _All-reduce_ is an algorithm for
142  performing a reduction on values from multiple devices and making the result
143  available on all of those devices.
144
145* _Mirrored variables_: These are variables that are created on multiple
146  devices, where we keep the variables in sync by applying the same
147  updates to every copy. Mirrored variables are created with
148  `tf.Variable(...synchronization=tf.VariableSynchronization.ON_WRITE...)`.
149  Normally they are only used in synchronous training.
150
151* _SyncOnRead variables_
152
153  _SyncOnRead variables_ are created by
154  `tf.Variable(...synchronization=tf.VariableSynchronization.ON_READ...)`, and
155  they are created on multiple devices. In replica context, each
156  component variable on the local replica can perform reads and writes without
157  synchronization with each other. When the
158  _SyncOnRead variable_ is read in cross-replica context, the values from
159  component variables are aggregated and returned.
160
161  _SyncOnRead variables_ bring a lot of custom configuration difficulty to the
162  underlying logic, so we do not encourage users to instantiate and use
163  _SyncOnRead variable_ on their own. We have mainly used _SyncOnRead
164  variables_ for use cases such as batch norm and metrics. For performance
165  reasons, we often don't need to keep these statistics in sync every step and
166  they can be accumulated on each replica independently. The only time we want
167  to sync them is reporting or checkpointing, which typically happens in
168  cross-replica context. _SyncOnRead variables_ are also often used by advanced
169  users who want to control when variable values are aggregated. For example,
170  users sometimes want to maintain gradients independently on each replica for a
171  couple of steps without aggregation.
172
173* _Distribute-aware layers_
174
175  Layers are generally called in a replica context, except when defining a
176  Keras functional model. `tf.distribute.in_cross_replica_context` will let you
177  determine which case you are in. If in a replica context,
178  the `tf.distribute.get_replica_context` function will return the default
179  replica context outside a strategy scope, `None` within a strategy scope, and
180  a `tf.distribute.ReplicaContext` object inside a strategy scope and within a
181  `tf.distribute.Strategy.run` function. The `ReplicaContext` object has an
182  `all_reduce` method for aggregating across all replicas.
183
184
185Note that we provide a default version of `tf.distribute.Strategy` that is
186used when no other strategy is in scope, that provides the same API with
187reasonable default behavior.
188"""
189# pylint: enable=line-too-long
190
191import collections
192import copy
193import enum  # pylint: disable=g-bad-import-order
194import functools
195import threading
196import weakref
197
198import six
199
200from tensorflow.python.autograph.core import ag_ctx as autograph_ctx
201from tensorflow.python.autograph.impl import api as autograph
202from tensorflow.python.data.ops import dataset_ops
203from tensorflow.python.distribute import collective_util
204from tensorflow.python.distribute import device_util
205from tensorflow.python.distribute import distribution_strategy_context
206from tensorflow.python.distribute import numpy_dataset
207from tensorflow.python.distribute import reduce_util
208from tensorflow.python.distribute import values
209from tensorflow.python.eager import context as eager_context
210from tensorflow.python.eager import def_function
211from tensorflow.python.eager import monitoring
212from tensorflow.python.framework import constant_op
213from tensorflow.python.framework import dtypes
214from tensorflow.python.framework import indexed_slices
215from tensorflow.python.framework import ops
216from tensorflow.python.framework import tensor_shape
217from tensorflow.python.framework import tensor_util
218from tensorflow.python.ops import array_ops
219from tensorflow.python.ops import control_flow_ops
220from tensorflow.python.ops import custom_gradient
221from tensorflow.python.ops import math_ops
222from tensorflow.python.ops import resource_variable_ops
223from tensorflow.python.ops import summary_ops_v2
224from tensorflow.python.ops import variable_scope
225from tensorflow.python.ops.losses import losses_impl
226from tensorflow.python.platform import tf_logging
227from tensorflow.python.trackable import base as trackable
228from tensorflow.python.util import deprecation
229from tensorflow.python.util import nest
230from tensorflow.python.util import tf_contextlib
231from tensorflow.python.util.deprecation import deprecated
232from tensorflow.python.util.tf_export import tf_export
233from tensorflow.tools.docs import doc_controls
234
235# ------------------------------------------------------------------------------
236# Context tracking whether in a strategy.update() or .update_non_slot() call.
237
238
239_update_replica_id = threading.local()
240
241
242def get_update_replica_id():
243  """Get the current device if in a `tf.distribute.Strategy.update()` call."""
244  try:
245    return _update_replica_id.current
246  except AttributeError:
247    return None
248
249
250class UpdateContext(object):
251  """Context manager when you are in `update()` or `update_non_slot()`."""
252
253  __slots__ = ["_replica_id", "_old_replica_id"]
254
255  def __init__(self, replica_id):
256    self._replica_id = replica_id
257    self._old_replica_id = None
258
259  def __enter__(self):
260    self._old_replica_id = get_update_replica_id()
261    _update_replica_id.current = self._replica_id
262
263  def __exit__(self, exception_type, exception_value, traceback):
264    del exception_type, exception_value, traceback
265    _update_replica_id.current = self._old_replica_id
266
267
268# ------------------------------------------------------------------------------
269# Public utility functions.
270
271
272@tf_export(v1=["distribute.get_loss_reduction"])
273def get_loss_reduction():
274  """`tf.distribute.ReduceOp` corresponding to the last loss reduction.
275
276  This is used to decide whether loss should be scaled in optimizer (used only
277  for estimator + v1 optimizer use case).
278
279  Returns:
280    `tf.distribute.ReduceOp` corresponding to the last loss reduction for
281    estimator and v1 optimizer use case. `tf.distribute.ReduceOp.SUM` otherwise.
282  """
283  if not distribution_strategy_context.get_strategy()._scale_loss_for_estimator:  # pylint: disable=protected-access
284    # If we are not in Estimator context then return 'SUM'. We do not need to
285    # scale loss in the optimizer.
286    return reduce_util.ReduceOp.SUM
287  last_reduction = ops.get_default_graph()._last_loss_reduction  # pylint: disable=protected-access
288  if (last_reduction == losses_impl.Reduction.SUM or
289      last_reduction == "sum"):  # Check for tf.keras.losses.Reduction.SUM
290    return reduce_util.ReduceOp.SUM
291  return reduce_util.ReduceOp.MEAN
292
293
294# ------------------------------------------------------------------------------
295# Internal API for validating the current thread mode
296
297
298def _require_cross_replica_or_default_context_extended(extended,
299                                                       error_message=None):
300  """Verify in cross-replica context."""
301  context = _get_per_thread_mode()
302  cross_replica = context.cross_replica_context
303  if cross_replica is not None and cross_replica.extended is extended:
304    return
305  if context is _get_default_replica_mode():
306    return
307  strategy = extended._container_strategy()  # pylint: disable=protected-access
308  # We have an error to report, figure out the right message.
309  if context.strategy is not strategy:
310    _wrong_strategy_scope(strategy, context)
311  assert cross_replica is None
312  if not error_message:
313    error_message = ("Method requires being in cross-replica context, use "
314                     "get_replica_context().merge_call()")
315  raise RuntimeError(error_message)
316
317
318def _wrong_strategy_scope(strategy, context):
319  # Figure out the right error message.
320  if not distribution_strategy_context.has_strategy():
321    raise RuntimeError(
322        'Need to be inside "with strategy.scope()" for %s' %
323        (strategy,))
324  else:
325    raise RuntimeError(
326        "Mixing different tf.distribute.Strategy objects: %s is not %s" %
327        (context.strategy, strategy))
328
329
330def require_replica_context(replica_ctx):
331  """Verify in `replica_ctx` replica context."""
332  context = _get_per_thread_mode()
333  if context.replica_context is replica_ctx: return
334  # We have an error to report, figure out the right message.
335  if context.replica_context is None:
336    raise RuntimeError("Need to be inside `call_for_each_replica()`")
337  if context.strategy is replica_ctx.strategy:
338    # Two different ReplicaContexts with the same tf.distribute.Strategy.
339    raise RuntimeError("Mismatching ReplicaContext.")
340  raise RuntimeError(
341      "Mismatching tf.distribute.Strategy objects: %s is not %s." %
342      (context.strategy, replica_ctx.strategy))
343
344
345def _require_strategy_scope_strategy(strategy):
346  """Verify in a `strategy.scope()` in this thread."""
347  context = _get_per_thread_mode()
348  if context.strategy is strategy: return
349  _wrong_strategy_scope(strategy, context)
350
351
352def _require_strategy_scope_extended(extended):
353  """Verify in a `distribution_strategy.scope()` in this thread."""
354  context = _get_per_thread_mode()
355  if context.strategy.extended is extended: return
356  # Report error.
357  strategy = extended._container_strategy()  # pylint: disable=protected-access
358  _wrong_strategy_scope(strategy, context)
359
360
361# ------------------------------------------------------------------------------
362# Internal context managers used to implement the DistributionStrategy
363# base class
364
365
366class _CurrentDistributionContext(object):
367  """Context manager setting the current `tf.distribute.Strategy`.
368
369  Also: overrides the variable creator and optionally the current device.
370  """
371
372  def __init__(self,
373               strategy,
374               var_creator_scope,
375               var_scope=None,
376               resource_creator_scope=None,
377               default_device=None):
378    self._context = distribution_strategy_context._CrossReplicaThreadMode(  # pylint: disable=protected-access
379        strategy)
380    self._var_creator_scope = var_creator_scope
381    self._var_scope = var_scope
382    self._resource_creator_scope = resource_creator_scope
383    if default_device:
384      self._device_scope = ops.device(default_device)
385    else:
386      self._device_scope = None
387    self._same_scope_again_count = 0
388
389  def __enter__(self):
390    # Allow this scope to be entered if this strategy is already in scope.
391    if distribution_strategy_context.has_strategy():
392      _require_cross_replica_or_default_context_extended(
393          self._context.strategy.extended)
394      self._same_scope_again_count += 1
395    else:
396      _push_per_thread_mode(self._context)
397      if self._var_scope:
398        self._var_scope.__enter__()
399      self._var_creator_scope.__enter__()
400      if self._resource_creator_scope:
401        nest.map_structure(lambda scope: scope.__enter__(),
402                           self._resource_creator_scope)
403      if self._device_scope:
404        self._device_scope.__enter__()
405    return self._context.strategy
406
407  def __exit__(self, exception_type, exception_value, traceback):
408    if self._same_scope_again_count > 0:
409      self._same_scope_again_count -= 1
410      return
411    if self._device_scope:
412      try:
413        self._device_scope.__exit__(exception_type, exception_value, traceback)
414      except RuntimeError as e:
415        six.raise_from(
416            RuntimeError("Device scope nesting error: move call to "
417                         "tf.distribute.set_strategy() out of `with` scope."),
418            e)
419
420    try:
421      self._var_creator_scope.__exit__(
422          exception_type, exception_value, traceback)
423    except RuntimeError as e:
424      six.raise_from(
425          RuntimeError("Variable creator scope nesting error: move call to "
426                       "tf.distribute.set_strategy() out of `with` scope."),
427          e)
428
429    if self._resource_creator_scope:
430      try:
431        if isinstance(self._resource_creator_scope, list):
432          reversed_resource_creator_scope = self._resource_creator_scope[::-1]
433          nest.map_structure(
434              lambda scope: scope.__exit__(exception_type, exception_value,  # pylint:disable=g-long-lambda
435                                           traceback),
436              reversed_resource_creator_scope)
437
438        else:
439          self._resource_creator_scope.__exit__(exception_type, exception_value,
440                                                traceback)
441      except RuntimeError as e:
442        six.raise_from(
443            RuntimeError("Resource creator scope nesting error: move call "
444                         "to tf.distribute.set_strategy() out of `with` "
445                         "scope."), e)
446
447    if self._var_scope:
448      try:
449        self._var_scope.__exit__(exception_type, exception_value, traceback)
450      except RuntimeError as e:
451        six.raise_from(
452            RuntimeError("Variable scope nesting error: move call to "
453                         "tf.distribute.set_strategy() out of `with` scope."),
454            e)
455    _pop_per_thread_mode()
456
457
458# TODO(yuefengz): add more replication modes.
459@tf_export("distribute.InputReplicationMode")
460class InputReplicationMode(enum.Enum):
461  """Replication mode for input function.
462
463  * `PER_WORKER`: The input function will be called on each worker
464    independently, creating as many input pipelines as number of workers.
465    Replicas will dequeue from the local Dataset on their worker.
466    `tf.distribute.Strategy` doesn't manage any state sharing between such
467    separate input pipelines.
468  * `PER_REPLICA`: The input function will be called on each replica separately.
469    `tf.distribute.Strategy` doesn't manage any state sharing between such
470    separate input pipelines.
471  """
472  PER_WORKER = "PER_WORKER"
473  PER_REPLICA = "PER_REPLICA"
474
475
476@tf_export("distribute.InputContext")
477class InputContext(object):
478  """A class wrapping information needed by an input function.
479
480  This is a context class that is passed to the user's input function and
481  contains information about the compute replicas and input pipelines. The
482  number of compute replicas (in sync training) helps compute the local batch
483  size from the desired global batch size for each replica. The input pipeline
484  information can be used to return a different subset of the input in each
485  replica (for e.g. shard the input pipeline, use a different input
486  source etc).
487  """
488
489  __slots__ = [
490      "_num_input_pipelines", "_input_pipeline_id", "_num_replicas_in_sync"
491  ]
492
493  def __init__(self,
494               num_input_pipelines=1,
495               input_pipeline_id=0,
496               num_replicas_in_sync=1):
497    """Initializes an InputContext object.
498
499    Args:
500      num_input_pipelines: the number of input pipelines in a cluster.
501      input_pipeline_id: the current input pipeline id, should be an int in
502        [0,`num_input_pipelines`).
503      num_replicas_in_sync: the number of replicas that are in sync.
504    """
505    self._num_input_pipelines = num_input_pipelines
506    self._input_pipeline_id = input_pipeline_id
507    self._num_replicas_in_sync = num_replicas_in_sync
508
509  @property
510  def num_replicas_in_sync(self):
511    """Returns the number of compute replicas in sync."""
512    return self._num_replicas_in_sync
513
514  @property
515  def input_pipeline_id(self):
516    """Returns the input pipeline ID."""
517    return self._input_pipeline_id
518
519  @property
520  def num_input_pipelines(self):
521    """Returns the number of input pipelines."""
522    return self._num_input_pipelines
523
524  def get_per_replica_batch_size(self, global_batch_size):
525    """Returns the per-replica batch size.
526
527    Args:
528      global_batch_size: the global batch size which should be divisible by
529        `num_replicas_in_sync`.
530
531    Returns:
532      the per-replica batch size.
533
534    Raises:
535      ValueError: if `global_batch_size` not divisible by
536        `num_replicas_in_sync`.
537    """
538    if global_batch_size % self._num_replicas_in_sync != 0:
539      raise ValueError("The `global_batch_size` %r is not divisible by "
540                       "`num_replicas_in_sync` %r " %
541                       (global_batch_size, self._num_replicas_in_sync))
542    return global_batch_size // self._num_replicas_in_sync
543
544  def __str__(self):
545    return "tf.distribute.InputContext(input pipeline id {}, total: {})".format(
546        self.input_pipeline_id, self.num_input_pipelines)
547
548
549@tf_export("distribute.experimental.ValueContext", v1=[])
550class ValueContext(object):
551  """A class wrapping information needed by a distribute function.
552
553  This is a context class that is passed to the `value_fn` in
554  `strategy.experimental_distribute_values_from_function` and contains
555  information about the compute replicas. The `num_replicas_in_sync` and
556  `replica_id` can be used to customize the value on each replica.
557
558  Example usage:
559
560  1. Directly constructed.
561
562  >>> def value_fn(context):
563  ...   return context.replica_id_in_sync_group/context.num_replicas_in_sync
564  >>> context = tf.distribute.experimental.ValueContext(
565  ...   replica_id_in_sync_group=2, num_replicas_in_sync=4)
566  >>> per_replica_value = value_fn(context)
567  >>> per_replica_value
568  0.5
569
570  2. Passed in by `experimental_distribute_values_from_function`.
571
572  >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
573  >>> def value_fn(value_context):
574  ...   return value_context.num_replicas_in_sync
575  >>> distributed_values = (
576  ...      strategy.experimental_distribute_values_from_function(
577  ...        value_fn))
578  >>> local_result = strategy.experimental_local_results(distributed_values)
579  >>> local_result
580  (2, 2)
581
582  """
583
584  __slots__ = ["_replica_id_in_sync_group", "_num_replicas_in_sync"]
585
586  def __init__(self,
587               replica_id_in_sync_group=0,
588               num_replicas_in_sync=1):
589    """Initializes an ValueContext object.
590
591    Args:
592      replica_id_in_sync_group: the current replica_id, should be an int in
593        [0,`num_replicas_in_sync`).
594      num_replicas_in_sync: the number of replicas that are in sync.
595    """
596    self._replica_id_in_sync_group = replica_id_in_sync_group
597    self._num_replicas_in_sync = num_replicas_in_sync
598
599  @property
600  def num_replicas_in_sync(self):
601    """Returns the number of compute replicas in sync."""
602    return self._num_replicas_in_sync
603
604  @property
605  def replica_id_in_sync_group(self):
606    """Returns the replica ID."""
607    return self._replica_id_in_sync_group
608
609  def __str__(self):
610    return (("tf.distribute.ValueContext(replica id {}, "
611             " total replicas in sync: ""{})")
612            .format(self.replica_id_in_sync_group, self.num_replicas_in_sync))
613
614
615@tf_export("distribute.RunOptions")
616class RunOptions(
617    collections.namedtuple("RunOptions", [
618        "experimental_enable_dynamic_batch_size",
619        "experimental_bucketizing_dynamic_shape",
620        "experimental_xla_options",
621    ])):
622  """Run options for `strategy.run`.
623
624  This can be used to hold some strategy specific configs.
625
626  Attributes:
627    experimental_enable_dynamic_batch_size: Boolean. Only applies to
628      TPUStrategy. Default to True. If True, TPUStrategy will enable dynamic
629      padder to support dynamic batch size for the inputs. Otherwise only static
630      shape inputs are allowed.
631    experimental_bucketizing_dynamic_shape: Boolean. Only applies to
632      TPUStrategy. Default to False. If True, TPUStrategy will automatic
633      bucketize inputs passed into `run` if the input shape is
634      dynamic. This is a performance optimization to reduce XLA recompilation,
635      which should not have impact on correctness.
636    experimental_xla_options: A `tf.tpu.XLAOptions` instance. Only applies to
637      TPUStrategy. Controls the XLA compiling options on TPUs. Default to None.
638  """
639
640  def __new__(cls,
641              experimental_enable_dynamic_batch_size=True,
642              experimental_bucketizing_dynamic_shape=False,
643              experimental_xla_options=None):
644    return super(RunOptions,
645                 cls).__new__(cls, experimental_enable_dynamic_batch_size,
646                              experimental_bucketizing_dynamic_shape,
647                              experimental_xla_options)
648
649
650@tf_export("distribute.InputOptions", v1=[])
651class InputOptions(
652    collections.namedtuple("InputOptions", [
653        "experimental_fetch_to_device",
654        "experimental_replication_mode",
655        "experimental_place_dataset_on_device",
656        "experimental_per_replica_buffer_size",
657    ])):
658  """Run options for `experimental_distribute_dataset(s_from_function)`.
659
660  This can be used to hold some strategy specific configs.
661
662  ```python
663  # Setup TPUStrategy
664  resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
665  tf.config.experimental_connect_to_cluster(resolver)
666  tf.tpu.experimental.initialize_tpu_system(resolver)
667  strategy = tf.distribute.TPUStrategy(resolver)
668
669  dataset = tf.data.Dataset.range(16)
670  distributed_dataset_on_host = (
671      strategy.experimental_distribute_dataset(
672          dataset,
673          tf.distribute.InputOptions(
674              experimental_replication_mode=
675              experimental_replication_mode.PER_WORKER,
676              experimental_place_dataset_on_device=False,
677              experimental_per_replica_buffer_size=1)))
678  ```
679
680  Attributes:
681    experimental_fetch_to_device: Boolean. If True, dataset
682      elements will be prefetched to accelerator device memory. When False,
683      dataset elements are prefetched to host device memory. Must be False when
684      using TPUEmbedding API. experimental_fetch_to_device can only be used
685      with experimental_replication_mode=PER_WORKER. Default behavior is same as
686      setting it to True.
687    experimental_replication_mode: Replication mode for the input function.
688      Currently, the InputReplicationMode.PER_REPLICA is only supported with
689      tf.distribute.MirroredStrategy.
690      experimental_distribute_datasets_from_function.
691      The default value is InputReplicationMode.PER_WORKER.
692    experimental_place_dataset_on_device: Boolean. Default to False. When True,
693      dataset will be placed on the device, otherwise it will remain on the
694      host. experimental_place_dataset_on_device=True can only be used with
695      experimental_replication_mode=PER_REPLICA
696    experimental_per_replica_buffer_size: Integer. Default to 1. Indicates the
697      prefetch buffer size in the replica device memory. Users can set it
698      to 0 to completely disable prefetching behavior, or a number greater than
699      1 to enable larger buffer size. Note that this option is still
700      valid with `experimental_fetch_to_device=False`.
701  """
702
703  def __new__(cls,
704              experimental_fetch_to_device=None,
705              experimental_replication_mode=InputReplicationMode.PER_WORKER,
706              experimental_place_dataset_on_device=False,
707              experimental_per_replica_buffer_size=1):
708    if experimental_fetch_to_device is None:
709      experimental_fetch_to_device = True
710
711    return super(InputOptions,
712                 cls).__new__(cls, experimental_fetch_to_device,
713                              experimental_replication_mode,
714                              experimental_place_dataset_on_device,
715                              experimental_per_replica_buffer_size)
716
717# ------------------------------------------------------------------------------
718# Base classes for all distribution strategies.
719
720
721# Base class for v1 Strategy and v2 Strategy classes. For API's specific to
722# v1/v2 Strategy, add to implementing classes of StrategyBase.
723# pylint: disable=line-too-long
724class StrategyBase(object):
725  """A state & compute distribution policy on a list of devices.
726
727  See [the guide](https://www.tensorflow.org/guide/distributed_training)
728  for overview and examples. See `tf.distribute.StrategyExtended` and
729  [`tf.distribute`](https://www.tensorflow.org/api_docs/python/tf/distribute)
730  for a glossary of concepts mentioned on this page such as "per-replica",
731  _replica_, and _reduce_.
732
733  In short:
734
735  * To use it with Keras `compile`/`fit`,
736    [please
737    read](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_keras).
738  * You may pass descendant of `tf.distribute.Strategy` to
739    `tf.estimator.RunConfig` to specify how a `tf.estimator.Estimator`
740    should distribute its computation. See
741    [guide](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_estimator_limited_support).
742  * Otherwise, use `tf.distribute.Strategy.scope` to specify that a
743    strategy should be used when building an executing your model.
744    (This puts you in the "cross-replica context" for this strategy, which
745    means the strategy is put in control of things like variable placement.)
746  * If you are writing a custom training loop, you will need to call a few more
747    methods,
748    [see the
749    guide](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_custom_training_loops):
750
751      * Start by creating a `tf.data.Dataset` normally.
752      * Use `tf.distribute.Strategy.experimental_distribute_dataset` to convert
753        a `tf.data.Dataset` to something that produces "per-replica" values.
754        If you want to manually specify how the dataset should be partitioned
755        across replicas, use
756        `tf.distribute.Strategy.distribute_datasets_from_function`
757        instead.
758      * Use `tf.distribute.Strategy.run` to run a function
759        once per replica, taking values that may be "per-replica" (e.g.
760        from a `tf.distribute.DistributedDataset` object) and returning
761        "per-replica" values.
762        This function is executed in "replica context", which means each
763        operation is performed separately on each replica.
764      * Finally use a method (such as `tf.distribute.Strategy.reduce`) to
765        convert the resulting "per-replica" values into ordinary `Tensor`s.
766
767  A custom training loop can be as simple as:
768
769  ```
770  with my_strategy.scope():
771    @tf.function
772    def distribute_train_epoch(dataset):
773      def replica_fn(input):
774        # process input and return result
775        return result
776
777      total_result = 0
778      for x in dataset:
779        per_replica_result = my_strategy.run(replica_fn, args=(x,))
780        total_result += my_strategy.reduce(tf.distribute.ReduceOp.SUM,
781                                           per_replica_result, axis=None)
782      return total_result
783
784    dist_dataset = my_strategy.experimental_distribute_dataset(dataset)
785    for _ in range(EPOCHS):
786      train_result = distribute_train_epoch(dist_dataset)
787  ```
788
789  This takes an ordinary `dataset` and `replica_fn` and runs it
790  distributed using a particular `tf.distribute.Strategy` named
791  `my_strategy` above. Any variables created in `replica_fn` are created
792  using `my_strategy`'s policy, and library functions called by
793  `replica_fn` can use the `get_replica_context()` API to implement
794  distributed-specific behavior.
795
796  You can use the `reduce` API to aggregate results across replicas and use
797  this as a return value from one iteration over a
798  `tf.distribute.DistributedDataset`. Or
799  you can use `tf.keras.metrics` (such as loss, accuracy, etc.) to
800  accumulate metrics across steps in a given epoch.
801
802  See the
803  [custom training loop
804  tutorial](https://www.tensorflow.org/tutorials/distribute/custom_training)
805  for a more detailed example.
806
807  Note: `tf.distribute.Strategy` currently does not support TensorFlow's
808  partitioned variables (where a single variable is split across multiple
809  devices) at this time.
810  """
811  # pylint: enable=line-too-long
812
813  # TODO(josh11b): Partitioned computations, state; sharding
814  # TODO(josh11b): Model parallelism: "replicas" with multiple devices; shuffling
815
816  def __init__(self, extended):
817    self._extended = extended
818
819    # Flag that is used to indicate whether distribution strategy is used with
820    # Estimator. This is required for backward compatibility of loss scaling
821    # when using v1 optimizer with estimator.
822    self._scale_loss_for_estimator = False
823
824    if not hasattr(extended, "_retrace_functions_for_each_device"):
825      # pylint: disable=protected-access
826      # `extended._retrace_functions_for_each_device` dictates
827      # whether the same function will be retraced when it is called on
828      # different devices.
829      try:
830        extended._retrace_functions_for_each_device = (
831            len(extended.worker_devices) > 1)
832        distribution_strategy_replica_gauge.get_cell("num_replicas").set(
833            self.num_replicas_in_sync)
834      except:  # pylint: disable=bare-except
835        # Default for the case where extended.worker_devices can't return
836        # a sensible value.
837        extended._retrace_functions_for_each_device = True
838
839    # Below are the dicts of axis(int) -> `tf.function`.
840    self._mean_reduce_helper_fns = {}
841    self._reduce_sum_fns = {}
842
843    # Whether this strategy is designed to work with `ClusterCoordinator`.
844    self._should_use_with_coordinator = False
845
846  @property
847  def extended(self):
848    """`tf.distribute.StrategyExtended` with additional methods."""
849    return self._extended
850
851  @tf_contextlib.contextmanager
852  def _scale_loss_for_estimator_enabled(self):
853    """Scope which sets a flag used for scaling losses in optimizer.
854
855    Yields:
856      `_scale_loss_for_estimator_enabled` is a context manager with a
857      side effect, but doesn't return a value.
858    """
859    self._scale_loss_for_estimator = True
860    try:
861      yield
862    finally:
863      self._scale_loss_for_estimator = False
864
865  # pylint: disable=line-too-long
866  def scope(self):
867    """Context manager to make the strategy current and distribute variables.
868
869    This method returns a context manager, and is used as follows:
870
871    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
872    >>> # Variable created inside scope:
873    >>> with strategy.scope():
874    ...   mirrored_variable = tf.Variable(1.)
875    >>> mirrored_variable
876    MirroredVariable:{
877      0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>,
878      1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=1.0>
879    }
880    >>> # Variable created outside scope:
881    >>> regular_variable = tf.Variable(1.)
882    >>> regular_variable
883    <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>
884
885    _What happens when Strategy.scope is entered?_
886
887    * `strategy` is installed in the global context as the "current" strategy.
888      Inside this scope, `tf.distribute.get_strategy()` will now return this
889      strategy. Outside this scope, it returns the default no-op strategy.
890    * Entering the scope also enters the "cross-replica context". See
891      `tf.distribute.StrategyExtended` for an explanation on cross-replica and
892      replica contexts.
893    * Variable creation inside `scope` is intercepted by the strategy. Each
894      strategy defines how it wants to affect the variable creation. Sync
895      strategies like `MirroredStrategy`, `TPUStrategy` and
896      `MultiWorkerMiroredStrategy` create variables replicated on each replica,
897      whereas `ParameterServerStrategy` creates variables on the parameter
898      servers. This is done using a custom `tf.variable_creator_scope`.
899    * In some strategies, a default device scope may also be entered: in
900      `MultiWorkerMiroredStrategy`, a default device scope of "/CPU:0" is
901      entered on each worker.
902
903    Note: Entering a scope does not automatically distribute a computation, except
904      in the case of high level training framework like keras `model.fit`. If
905      you're not using `model.fit`, you
906      need to use `strategy.run` API to explicitly distribute that computation.
907      See an example in the [custom training loop tutorial](https://www.tensorflow.org/tutorials/distribute/custom_training).
908
909
910    _What should be in scope and what should be outside?_
911
912    There are a number of requirements on what needs to happen inside the scope.
913    However, in places where we have information about which strategy is in use,
914    we often enter the scope for the user, so they don't have to do it
915    explicitly (i.e. calling those either inside or outside the scope is OK).
916
917    * Anything that creates variables that should be distributed variables
918      must be called in a `strategy.scope`. This can be accomplished either by
919      directly calling the variable creating function within the scope context,
920      or by relying on another API like `strategy.run` or `keras.Model.fit` to
921      automatically enter it for you. Any variable that is created outside scope
922      will not be distributed and may have performance implications. Some common
923      objects that create variables in TF are Models, Optimizers, Metrics. Such
924      objects should always be initialized in the scope, and any functions
925      that may lazily create variables (e.g., `Model.__call__()`, tracing a
926      `tf.function`, etc.) should similarly be called within scope. Another
927      source of variable creation can be a checkpoint restore - when variables
928      are created lazily. Note that any variable created inside a strategy
929      captures the strategy information. So reading and writing to these
930      variables outside the `strategy.scope` can also work seamlessly, without
931      the user having to enter the scope.
932    * Some strategy APIs (such as `strategy.run` and `strategy.reduce`) which
933      require to be in a strategy's scope, enter the scope automatically, which
934      means when using those APIs you don't need to explicitly enter the scope
935      yourself.
936    * When a `tf.keras.Model` is created inside a `strategy.scope`, the Model
937      object captures the scope information. When high level training framework
938      methods such as `model.compile`, `model.fit`, etc. are then called, the
939      captured scope will be automatically entered, and the associated strategy
940      will be used to distribute the training etc. See a detailed example in
941      [distributed keras tutorial](https://www.tensorflow.org/tutorials/distribute/keras).
942      WARNING: Simply calling `model(..)` does not automatically enter the
943      captured scope -- only high level training framework APIs support this
944      behavior: `model.compile`, `model.fit`, `model.evaluate`, `model.predict`
945      and `model.save` can all be called inside or outside the scope.
946    * The following can be either inside or outside the scope:
947        * Creating the input datasets
948        * Defining `tf.function`s that represent your training step
949        * Saving APIs such as `tf.saved_model.save`. Loading creates variables,
950          so that should go inside the scope if you want to train the model in a
951          distributed way.
952        * Checkpoint saving. As mentioned above - `checkpoint.restore` may
953          sometimes need to be inside scope if it creates variables.
954
955    Returns:
956      A context manager.
957    """
958    return self._extended._scope(self)  # pylint: disable=protected-access
959  # pylint: enable=line-too-long
960
961  @doc_controls.do_not_doc_inheritable  # DEPRECATED, moving to `extended`
962  @deprecated(None, "use extended.colocate_vars_with() instead.")
963  def colocate_vars_with(self, colocate_with_variable):
964    """DEPRECATED: use extended.colocate_vars_with() instead."""
965    return self._extended.colocate_vars_with(colocate_with_variable)
966
967  @doc_controls.do_not_generate_docs  # DEPRECATED: TF 1.x only
968  def make_dataset_iterator(self, dataset):
969    """DEPRECATED TF 1.x ONLY."""
970    return self._extended._make_dataset_iterator(dataset)  # pylint: disable=protected-access
971
972  @doc_controls.do_not_generate_docs  # DEPRECATED: TF 1.x only
973  def make_input_fn_iterator(self,
974                             input_fn,
975                             replication_mode=InputReplicationMode.PER_WORKER):
976    """DEPRECATED TF 1.x ONLY."""
977    if replication_mode != InputReplicationMode.PER_WORKER:
978      raise ValueError(
979          "Input replication mode not supported: %r" % replication_mode)
980    with self.scope():
981      return self.extended._make_input_fn_iterator(  # pylint: disable=protected-access
982          input_fn, replication_mode=replication_mode)
983
984  @doc_controls.do_not_generate_docs  # DEPRECATED: TF 1.x only
985  @deprecated(None, "use run() instead")
986  def experimental_run(self, fn, input_iterator=None):
987    """DEPRECATED TF 1.x ONLY."""
988    with self.scope():
989      args = (input_iterator.get_next(),) if input_iterator is not None else ()
990    return self.run(fn, args=args)
991
992  def experimental_distribute_dataset(self, dataset, options=None):
993    # pylint: disable=line-too-long
994    """Creates `tf.distribute.DistributedDataset` from `tf.data.Dataset`.
995
996    The returned `tf.distribute.DistributedDataset` can be iterated over
997    similar to regular datasets.
998    NOTE: The user cannot add any more transformations to a
999    `tf.distribute.DistributedDataset`. You can only create an iterator or
1000    examine the `tf.TypeSpec` of the data generated by it. See API docs of
1001    `tf.distribute.DistributedDataset` to learn more.
1002
1003    The following is an example:
1004
1005    >>> global_batch_size = 2
1006    >>> # Passing the devices is optional.
1007    ... strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"])
1008    >>> # Create a dataset
1009    ... dataset = tf.data.Dataset.range(4).batch(global_batch_size)
1010    >>> # Distribute that dataset
1011    ... dist_dataset = strategy.experimental_distribute_dataset(dataset)
1012    >>> @tf.function
1013    ... def replica_fn(input):
1014    ...   return input*2
1015    >>> result = []
1016    >>> # Iterate over the `tf.distribute.DistributedDataset`
1017    ... for x in dist_dataset:
1018    ...   # process dataset elements
1019    ...   result.append(strategy.run(replica_fn, args=(x,)))
1020    >>> print(result)
1021    [PerReplica:{
1022      0: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([0])>,
1023      1: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([2])>
1024    }, PerReplica:{
1025      0: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([4])>,
1026      1: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([6])>
1027    }]
1028
1029
1030    Three key actions happening under the hood of this method are batching,
1031    sharding, and prefetching.
1032
1033    In the code snippet above, `dataset` is batched by `global_batch_size`, and
1034    calling `experimental_distribute_dataset` on it rebatches `dataset` to a
1035    new batch size that is equal to the global batch size divided by the number
1036    of replicas in sync. We iterate through it using a Pythonic for loop.
1037    `x` is a `tf.distribute.DistributedValues` containing data for all replicas,
1038    and each replica gets data of the new batch size.
1039    `tf.distribute.Strategy.run` will take care of feeding the right per-replica
1040    data in `x` to the right `replica_fn` executed on each replica.
1041
1042    Sharding contains autosharding across multiple workers and within every
1043    worker. First, in multi-worker distributed training (i.e. when you use
1044    `tf.distribute.experimental.MultiWorkerMirroredStrategy`
1045    or `tf.distribute.TPUStrategy`), autosharding a dataset over a set of
1046    workers means that each worker is assigned a subset of the entire dataset
1047    (if the right `tf.data.experimental.AutoShardPolicy` is set). This is to
1048    ensure that at each step, a global batch size of non-overlapping dataset
1049    elements will be processed by each worker. Autosharding has a couple of
1050    different options that can be specified using
1051    `tf.data.experimental.DistributeOptions`. Then, sharding within each worker
1052    means the method will split the data among all the worker devices (if more
1053    than one a present). This will happen regardless of multi-worker
1054    autosharding.
1055
1056    Note: for autosharding across multiple workers, the default mode is
1057    `tf.data.experimental.AutoShardPolicy.AUTO`. This mode
1058    will attempt to shard the input dataset by files if the dataset is
1059    being created out of reader datasets (e.g. `tf.data.TFRecordDataset`,
1060    `tf.data.TextLineDataset`, etc.) or otherwise shard the dataset by data,
1061    where each of the workers will read the entire dataset and only process the
1062    shard assigned to it. However, if you have less than one input file per
1063    worker, we suggest that you disable dataset autosharding across workers by
1064    setting the `tf.data.experimental.DistributeOptions.auto_shard_policy` to be
1065    `tf.data.experimental.AutoShardPolicy.OFF`.
1066
1067    By default, this method adds a prefetch transformation at the end of the
1068    user provided `tf.data.Dataset` instance. The argument to the prefetch
1069    transformation which is `buffer_size` is equal to the number of replicas in
1070    sync.
1071
1072    If the above batch splitting and dataset sharding logic is undesirable,
1073    please use
1074    `tf.distribute.Strategy.distribute_datasets_from_function`
1075    instead, which does not do any automatic batching or sharding for you.
1076
1077    Note: If you are using TPUStrategy, the order in which the data is processed
1078    by the workers when using
1079    `tf.distribute.Strategy.experimental_distribute_dataset` or
1080    `tf.distribute.Strategy.distribute_datasets_from_function` is
1081    not guaranteed. This is typically required if you are using
1082    `tf.distribute` to scale prediction. You can however insert an index for
1083    each element in the batch and order outputs accordingly. Refer to [this
1084    snippet](https://www.tensorflow.org/tutorials/distribute/input#caveats)
1085    for an example of how to order outputs.
1086
1087    Note: Stateful dataset transformations are currently not supported with
1088    `tf.distribute.experimental_distribute_dataset` or
1089    `tf.distribute.distribute_datasets_from_function`. Any stateful
1090    ops that the dataset may have are currently ignored. For example, if your
1091    dataset has a `map_fn` that uses `tf.random.uniform` to rotate an image,
1092    then you have a dataset graph that depends on state (i.e the random seed) on
1093    the local machine where the python process is being executed.
1094
1095    For a tutorial on more usage and properties of this method, refer to the
1096    [tutorial on distributed input](https://www.tensorflow.org/tutorials/distribute/input#tfdistributestrategyexperimental_distribute_dataset).
1097    If you are interested in last partial batch handling, read [this section](https://www.tensorflow.org/tutorials/distribute/input#partial_batches).
1098
1099    Args:
1100      dataset: `tf.data.Dataset` that will be sharded across all replicas using
1101        the rules stated above.
1102      options: `tf.distribute.InputOptions` used to control options on how this
1103        dataset is distributed.
1104
1105    Returns:
1106      A `tf.distribute.DistributedDataset`.
1107    """
1108    distribution_strategy_input_api_counter.get_cell(
1109        self.__class__.__name__, "distribute_dataset").increase_by(1)
1110    # pylint: enable=line-too-long
1111    return self._extended._experimental_distribute_dataset(dataset, options)  # pylint: disable=protected-access
1112
1113  def distribute_datasets_from_function(self, dataset_fn, options=None):
1114    # pylint: disable=line-too-long
1115    """Distributes `tf.data.Dataset` instances created by calls to `dataset_fn`.
1116
1117    The argument `dataset_fn` that users pass in is an input function that has a
1118    `tf.distribute.InputContext` argument and returns a `tf.data.Dataset`
1119    instance. It is expected that the returned dataset from `dataset_fn` is
1120    already batched by per-replica batch size (i.e. global batch size divided by
1121    the number of replicas in sync) and sharded.
1122    `tf.distribute.Strategy.distribute_datasets_from_function` does
1123    not batch or shard the `tf.data.Dataset` instance
1124    returned from the input function. `dataset_fn` will be called on the CPU
1125    device of each of the workers and each generates a dataset where every
1126    replica on that worker will dequeue one batch of inputs (i.e. if a worker
1127    has two replicas, two batches will be dequeued from the `Dataset` every
1128    step).
1129
1130    This method can be used for several purposes. First, it allows you to
1131    specify your own batching and sharding logic. (In contrast,
1132    `tf.distribute.experimental_distribute_dataset` does batching and sharding
1133    for you.) For example, where
1134    `experimental_distribute_dataset` is unable to shard the input files, this
1135    method might be used to manually shard the dataset (avoiding the slow
1136    fallback behavior in `experimental_distribute_dataset`). In cases where the
1137    dataset is infinite, this sharding can be done by creating dataset replicas
1138    that differ only in their random seed.
1139
1140    The `dataset_fn` should take an `tf.distribute.InputContext` instance where
1141    information about batching and input replication can be accessed.
1142
1143    You can use `element_spec` property of the
1144    `tf.distribute.DistributedDataset` returned by this API to query the
1145    `tf.TypeSpec` of the elements returned by the iterator. This can be used to
1146    set the `input_signature` property of a `tf.function`. Follow
1147    `tf.distribute.DistributedDataset.element_spec` to see an example.
1148
1149    IMPORTANT: The `tf.data.Dataset` returned by `dataset_fn` should have a
1150    per-replica batch size, unlike `experimental_distribute_dataset`, which uses
1151    the global batch size. This may be computed using
1152    `input_context.get_per_replica_batch_size`.
1153
1154    Note: If you are using TPUStrategy, the order in which the data is processed
1155    by the workers when using
1156    `tf.distribute.Strategy.experimental_distribute_dataset` or
1157    `tf.distribute.Strategy.distribute_datasets_from_function` is
1158    not guaranteed. This is typically required if you are using
1159    `tf.distribute` to scale prediction. You can however insert an index for
1160    each element in the batch and order outputs accordingly. Refer to [this
1161    snippet](https://www.tensorflow.org/tutorials/distribute/input#caveats)
1162    for an example of how to order outputs.
1163
1164    Note: Stateful dataset transformations are currently not supported with
1165    `tf.distribute.experimental_distribute_dataset` or
1166    `tf.distribute.distribute_datasets_from_function`. Any stateful
1167    ops that the dataset may have are currently ignored. For example, if your
1168    dataset has a `map_fn` that uses `tf.random.uniform` to rotate an image,
1169    then you have a dataset graph that depends on state (i.e the random seed) on
1170    the local machine where the python process is being executed.
1171
1172    For a tutorial on more usage and properties of this method, refer to the
1173    [tutorial on distributed input](https://www.tensorflow.org/tutorials/distribute/input#tfdistributestrategyexperimental_distribute_datasets_from_function)).
1174    If you are interested in last partial batch handling, read [this section](https://www.tensorflow.org/tutorials/distribute/input#partial_batches).
1175
1176    Args:
1177      dataset_fn: A function taking a `tf.distribute.InputContext` instance and
1178        returning a `tf.data.Dataset`.
1179      options: `tf.distribute.InputOptions` used to control options on how this
1180        dataset is distributed.
1181
1182    Returns:
1183      A `tf.distribute.DistributedDataset`.
1184    """
1185    distribution_strategy_input_api_counter.get_cell(
1186        self.__class__.__name__,
1187        "distribute_datasets_from_function").increase_by(1)
1188    # pylint: enable=line-too-long
1189    return self._extended._distribute_datasets_from_function(  # pylint: disable=protected-access
1190        dataset_fn, options)
1191
1192  # TODO(b/162776748): Remove deprecated symbol.
1193  @doc_controls.do_not_doc_inheritable
1194  @deprecation.deprecated(None, "rename to distribute_datasets_from_function")
1195  def experimental_distribute_datasets_from_function(self,
1196                                                     dataset_fn,
1197                                                     options=None):
1198    return self.distribute_datasets_from_function(dataset_fn, options)
1199
1200  def run(self, fn, args=(), kwargs=None, options=None):
1201    """Invokes `fn` on each replica, with the given arguments.
1202
1203    This method is the primary way to distribute your computation with a
1204    tf.distribute object. It invokes `fn` on each replica. If `args` or `kwargs`
1205    have `tf.distribute.DistributedValues`, such as those produced by a
1206    `tf.distribute.DistributedDataset` from
1207    `tf.distribute.Strategy.experimental_distribute_dataset` or
1208    `tf.distribute.Strategy.distribute_datasets_from_function`,
1209    when `fn` is executed on a particular replica, it will be executed with the
1210    component of `tf.distribute.DistributedValues` that correspond to that
1211    replica.
1212
1213    `fn` is invoked under a replica context. `fn` may call
1214    `tf.distribute.get_replica_context()` to access members such as
1215    `all_reduce`. Please see the module-level docstring of tf.distribute for the
1216    concept of replica context.
1217
1218    All arguments in `args` or `kwargs` can be a nested structure of tensors,
1219    e.g. a list of tensors, in which case `args` and `kwargs` will be passed to
1220    the `fn` invoked on each replica. Or `args` or `kwargs` can be
1221    `tf.distribute.DistributedValues` containing tensors or composite tensors,
1222    i.e. `tf.compat.v1.TensorInfo.CompositeTensor`, in which case each `fn` call
1223    will get the component of a `tf.distribute.DistributedValues` corresponding
1224    to its replica. Note that arbitrary Python values that are not of the types
1225    above are not supported.
1226
1227    IMPORTANT: Depending on the implementation of `tf.distribute.Strategy` and
1228    whether eager execution is enabled, `fn` may be called one or more times. If
1229    `fn` is annotated with `tf.function` or `tf.distribute.Strategy.run` is
1230    called inside a `tf.function` (eager execution is disabled inside a
1231    `tf.function` by default), `fn` is called once per replica to generate a
1232    Tensorflow graph, which will then be reused for execution with new inputs.
1233    Otherwise, if eager execution is enabled, `fn` will be called once per
1234    replica every step just like regular python code.
1235
1236    Example usage:
1237
1238    1. Constant tensor input.
1239
1240    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
1241    >>> tensor_input = tf.constant(3.0)
1242    >>> @tf.function
1243    ... def replica_fn(input):
1244    ...   return input*2.0
1245    >>> result = strategy.run(replica_fn, args=(tensor_input,))
1246    >>> result
1247    PerReplica:{
1248      0: <tf.Tensor: shape=(), dtype=float32, numpy=6.0>,
1249      1: <tf.Tensor: shape=(), dtype=float32, numpy=6.0>
1250    }
1251
1252    2. DistributedValues input.
1253
1254    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
1255    >>> @tf.function
1256    ... def run():
1257    ...   def value_fn(value_context):
1258    ...     return value_context.num_replicas_in_sync
1259    ...   distributed_values = (
1260    ...     strategy.experimental_distribute_values_from_function(
1261    ...       value_fn))
1262    ...   def replica_fn2(input):
1263    ...     return input*2
1264    ...   return strategy.run(replica_fn2, args=(distributed_values,))
1265    >>> result = run()
1266    >>> result
1267    <tf.Tensor: shape=(), dtype=int32, numpy=4>
1268
1269    3. Use `tf.distribute.ReplicaContext` to allreduce values.
1270
1271    >>> strategy = tf.distribute.MirroredStrategy(["gpu:0", "gpu:1"])
1272    >>> @tf.function
1273    ... def run():
1274    ...    def value_fn(value_context):
1275    ...      return tf.constant(value_context.replica_id_in_sync_group)
1276    ...    distributed_values = (
1277    ...        strategy.experimental_distribute_values_from_function(
1278    ...            value_fn))
1279    ...    def replica_fn(input):
1280    ...      return tf.distribute.get_replica_context().all_reduce("sum", input)
1281    ...    return strategy.run(replica_fn, args=(distributed_values,))
1282    >>> result = run()
1283    >>> result
1284    PerReplica:{
1285      0: <tf.Tensor: shape=(), dtype=int32, numpy=1>,
1286      1: <tf.Tensor: shape=(), dtype=int32, numpy=1>
1287    }
1288
1289    Args:
1290      fn: The function to run on each replica.
1291      args: Optional positional arguments to `fn`. Its element can be a tensor,
1292        a nested structure of tensors or a `tf.distribute.DistributedValues`.
1293      kwargs: Optional keyword arguments to `fn`. Its element can be a tensor,
1294        a nested structure of tensors or a `tf.distribute.DistributedValues`.
1295      options: An optional instance of `tf.distribute.RunOptions` specifying
1296        the options to run `fn`.
1297
1298    Returns:
1299      Merged return value of `fn` across replicas. The structure of the return
1300      value is the same as the return value from `fn`. Each element in the
1301      structure can either be `tf.distribute.DistributedValues`, `Tensor`
1302      objects, or `Tensor`s (for example, if running on a single replica).
1303    """
1304    del options
1305
1306    if not isinstance(args, (list, tuple)):
1307      raise ValueError(
1308          "positional args must be a list or tuple, got {}".format(type(args)))
1309
1310    with self.scope():
1311      # tf.distribute supports Eager functions, so AutoGraph should not be
1312      # applied when the caller is also in Eager mode.
1313      fn = autograph.tf_convert(
1314          fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
1315      return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
1316
1317  def reduce(self, reduce_op, value, axis):
1318    """Reduce `value` across replicas and return result on current device.
1319
1320    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
1321    >>> def step_fn():
1322    ...   i = tf.distribute.get_replica_context().replica_id_in_sync_group
1323    ...   return tf.identity(i)
1324    >>>
1325    >>> per_replica_result = strategy.run(step_fn)
1326    >>> total = strategy.reduce("SUM", per_replica_result, axis=None)
1327    >>> total
1328    <tf.Tensor: shape=(), dtype=int32, numpy=1>
1329
1330    To see how this would look with multiple replicas, consider the same
1331    example with MirroredStrategy with 2 GPUs:
1332
1333    ```python
1334    strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"])
1335    def step_fn():
1336      i = tf.distribute.get_replica_context().replica_id_in_sync_group
1337      return tf.identity(i)
1338
1339    per_replica_result = strategy.run(step_fn)
1340    # Check devices on which per replica result is:
1341    strategy.experimental_local_results(per_replica_result)[0].device
1342    # /job:localhost/replica:0/task:0/device:GPU:0
1343    strategy.experimental_local_results(per_replica_result)[1].device
1344    # /job:localhost/replica:0/task:0/device:GPU:1
1345
1346    total = strategy.reduce("SUM", per_replica_result, axis=None)
1347    # Check device on which reduced result is:
1348    total.device
1349    # /job:localhost/replica:0/task:0/device:CPU:0
1350
1351    ```
1352
1353    This API is typically used for aggregating the results returned from
1354    different replicas, for reporting etc. For example, loss computed from
1355    different replicas can be averaged using this API before printing.
1356
1357    Note: The result is copied to the "current" device - which would typically
1358    be the CPU of the worker on which the program is running. For `TPUStrategy`,
1359    it is the first TPU host. For multi client `MultiWorkerMirroredStrategy`,
1360    this is CPU of each worker.
1361
1362    There are a number of different tf.distribute APIs for reducing values
1363    across replicas:
1364    * `tf.distribute.ReplicaContext.all_reduce`: This differs from
1365    `Strategy.reduce` in that it is for replica context and does
1366    not copy the results to the host device. `all_reduce` should be typically
1367    used for reductions inside the training step such as gradients.
1368    * `tf.distribute.StrategyExtended.reduce_to` and
1369    `tf.distribute.StrategyExtended.batch_reduce_to`: These APIs are more
1370    advanced versions of `Strategy.reduce` as they allow customizing the
1371    destination of the result. They are also called in cross replica context.
1372
1373    _What should axis be?_
1374
1375    Given a per-replica value returned by `run`, say a
1376    per-example loss, the batch will be divided across all the replicas.  This
1377    function allows you to aggregate across replicas and optionally also across
1378    batch elements by specifying the axis parameter accordingly.
1379
1380    For example, if you have a global batch size of 8 and 2
1381    replicas, values for examples `[0, 1, 2, 3]` will be on replica 0 and
1382    `[4, 5, 6, 7]` will be on replica 1. With `axis=None`, `reduce` will
1383    aggregate only across replicas, returning `[0+4, 1+5, 2+6, 3+7]`.
1384    This is useful when each replica is computing a scalar or some other value
1385    that doesn't have a "batch" dimension (like a gradient or loss).
1386    ```
1387    strategy.reduce("sum", per_replica_result, axis=None)
1388    ```
1389
1390    Sometimes, you will want to aggregate across both the global batch _and_
1391    all replicas. You can get this behavior by specifying the batch
1392    dimension as the `axis`, typically `axis=0`. In this case it would return a
1393    scalar `0+1+2+3+4+5+6+7`.
1394    ```
1395    strategy.reduce("sum", per_replica_result, axis=0)
1396    ```
1397
1398    If there is a last partial batch, you will need to specify an axis so
1399    that the resulting shape is consistent across replicas. So if the last
1400    batch has size 6 and it is divided into [0, 1, 2, 3] and [4, 5], you
1401    would get a shape mismatch unless you specify `axis=0`. If you specify
1402    `tf.distribute.ReduceOp.MEAN`, using `axis=0` will use the correct
1403    denominator of 6. Contrast this with computing `reduce_mean` to get a
1404    scalar value on each replica and this function to average those means,
1405    which will weigh some values `1/8` and others `1/4`.
1406
1407    Args:
1408      reduce_op: a `tf.distribute.ReduceOp` value specifying how values should
1409        be combined. Allows using string representation of the enum such as
1410        "SUM", "MEAN".
1411      value: a `tf.distribute.DistributedValues` instance, e.g. returned by
1412        `Strategy.run`, to be combined into a single tensor. It can also be a
1413        regular tensor when used with `OneDeviceStrategy` or default strategy.
1414      axis: specifies the dimension to reduce along within each
1415        replica's tensor. Should typically be set to the batch dimension, or
1416        `None` to only reduce across replicas (e.g. if the tensor has no batch
1417        dimension).
1418
1419    Returns:
1420      A `Tensor`.
1421    """
1422    # TODO(josh11b): support `value` being a nest.
1423    _require_cross_replica_or_default_context_extended(self._extended)
1424    if isinstance(reduce_op, six.string_types):
1425      reduce_op = reduce_util.ReduceOp(reduce_op.upper())
1426    if axis is None:
1427      return self._extended._reduce(reduce_op, value)  # pylint: disable=protected-access
1428    if reduce_op == reduce_util.ReduceOp.SUM:
1429
1430      def reduce_sum(v):
1431        return math_ops.reduce_sum(v, axis=axis)
1432
1433      if eager_context.executing_eagerly():
1434        # As some strategies (e.g. TPUStrategy) doesn't support pure eager
1435        # execution, wrap the `reduce_sum_fn` with a `tf.function` so it can be
1436        # run from eager mode. Cache the tf.function by `axis` to avoid the
1437        # same function to be traced again.
1438        if axis not in self._reduce_sum_fns:
1439
1440          def reduce_sum_fn(v):
1441            return self.run(reduce_sum, args=(v,))
1442
1443          self._reduce_sum_fns[axis] = def_function.function(reduce_sum_fn)
1444        value = self._reduce_sum_fns[axis](value)
1445      else:
1446        value = self.run(reduce_sum, args=(value,))
1447
1448      return self._extended._reduce(reduce_op, value)  # pylint: disable=protected-access
1449    if reduce_op != reduce_util.ReduceOp.MEAN:
1450      raise TypeError("Expected `reduce_op` to be a `tf.distribute.ReduceOp`, "
1451                      "not: %r" % reduce_op)
1452
1453    def mean_reduce_helper(v, axes=axis):
1454      """Computes the numerator and denominator on each replica."""
1455      numer = math_ops.reduce_sum(v, axis=axes)
1456      def dimension(axis):
1457        if v.shape.rank is not None:
1458          # Note(joshl): We support axis < 0 to be consistent with the
1459          # tf.math.reduce_* operations.
1460          if axis < 0:
1461            if axis + v.shape.rank < 0:
1462              raise ValueError(
1463                  "`axis` = %r out of range for `value` with rank %d" %
1464                  (axis, v.shape.rank))
1465            axis += v.shape.rank
1466          elif axis >= v.shape.rank:
1467            raise ValueError(
1468                "`axis` = %r out of range for `value` with rank %d" %
1469                (axis, v.shape.rank))
1470          # TF v2 returns `None` for unknown dimensions and an integer for
1471          # known dimension, whereas TF v1 returns tensor_shape.Dimension(None)
1472          # or tensor_shape.Dimension(integer). `dimension_value` hides this
1473          # difference, always returning `None` or an integer.
1474          dim = tensor_shape.dimension_value(v.shape[axis])
1475          if dim is not None:
1476            # By returning a python value in the static shape case, we can
1477            # maybe get a fast path for reducing the denominator.
1478            # TODO(b/151871486): Remove array_ops.identity after we fallback to
1479            # simple reduction if inputs are all on CPU.
1480            return array_ops.identity(
1481                constant_op.constant(dim, dtype=dtypes.int64))
1482        elif axis < 0:
1483          axis = axis + array_ops.rank(v)
1484        # TODO(b/151871486): Remove array_ops.identity after we fallback to
1485        # simple reduction if inputs are all on CPU.
1486        return array_ops.identity(
1487            array_ops.shape_v2(v, out_type=dtypes.int64)[axis])
1488      if isinstance(axis, six.integer_types):
1489        denom = dimension(axis)
1490      elif isinstance(axis, (tuple, list)):
1491        denom = math_ops.reduce_prod([dimension(a) for a in axes])
1492      else:
1493        raise TypeError(
1494            "Expected `axis` to be an integer, tuple or list not: %r" % axis)
1495      # TODO(josh11b): Should we cast denom to v.dtype here instead of after the
1496      # reduce is complete?
1497      return numer, denom
1498
1499    if eager_context.executing_eagerly():
1500      # As some strategies (e.g. TPUStrategy) doesn't support pure eager
1501      # execution, wrap the `mean_reduce_helper` with a `tf.function` so it can
1502      # be run from eager mode. Cache the tf.function by `axis` to avoid the
1503      # same function to be traced again.
1504      if axis not in self._mean_reduce_helper_fns:
1505
1506        def mean_reduce_fn(v):
1507          return self.run(mean_reduce_helper, args=(v,))
1508
1509        self._mean_reduce_helper_fns[axis] = def_function.function(
1510            mean_reduce_fn)
1511      numer, denom = self._mean_reduce_helper_fns[axis](value)
1512    else:
1513      numer, denom = self.run(mean_reduce_helper, args=(value,))
1514
1515    # TODO(josh11b): Should batch reduce here instead of doing two.
1516    numer = self._extended._reduce(reduce_util.ReduceOp.SUM, numer)  # pylint: disable=protected-access
1517    denom = self._extended._reduce(reduce_util.ReduceOp.SUM, denom)  # pylint: disable=protected-access
1518    denom = math_ops.cast(denom, numer.dtype)
1519    return math_ops.truediv(numer, denom)
1520
1521  @doc_controls.do_not_doc_inheritable  # DEPRECATED
1522  @deprecated(None, "use `experimental_local_results` instead.")
1523  def unwrap(self, value):
1524    """Returns the list of all local per-replica values contained in `value`.
1525
1526    DEPRECATED: Please use `experimental_local_results` instead.
1527
1528    Note: This only returns values on the workers initiated by this client.
1529    When using a `tf.distribute.Strategy` like
1530    `tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker
1531    will be its own client, and this function will only return values
1532    computed on that worker.
1533
1534    Args:
1535      value: A value returned by `experimental_run()`,
1536        `extended.call_for_each_replica()`, or a variable created in `scope`.
1537
1538    Returns:
1539      A tuple of values contained in `value`. If `value` represents a single
1540      value, this returns `(value,).`
1541    """
1542    return self._extended._local_results(value)  # pylint: disable=protected-access
1543
1544  def experimental_local_results(self, value):
1545    """Returns the list of all local per-replica values contained in `value`.
1546
1547    Note: This only returns values on the worker initiated by this client.
1548    When using a `tf.distribute.Strategy` like
1549    `tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker
1550    will be its own client, and this function will only return values
1551    computed on that worker.
1552
1553    Args:
1554      value: A value returned by `experimental_run()`, `run(), or a variable
1555      created in `scope`.
1556
1557    Returns:
1558      A tuple of values contained in `value` where ith element corresponds to
1559      ith replica. If `value` represents a single value, this returns
1560      `(value,).`
1561    """
1562    return self._extended._local_results(value)  # pylint: disable=protected-access
1563
1564  @doc_controls.do_not_doc_inheritable  # DEPRECATED: TF v1.x only
1565  def group(self, value, name=None):
1566    """Shortcut for `tf.group(self.experimental_local_results(value))`."""
1567    return self._extended._group(value, name)  # pylint: disable=protected-access
1568
1569  @property
1570  def num_replicas_in_sync(self):
1571    """Returns number of replicas over which gradients are aggregated."""
1572    return self._extended._num_replicas_in_sync  # pylint: disable=protected-access
1573
1574  @doc_controls.do_not_doc_inheritable  # DEPRECATED: see doc string
1575  @deprecated(None, "use `update_config_proto` instead.")
1576  def configure(self,
1577                session_config=None,
1578                cluster_spec=None,
1579                task_type=None,
1580                task_id=None):
1581    # pylint: disable=g-doc-return-or-yield,g-doc-args
1582    """DEPRECATED: use `update_config_proto` instead.
1583
1584    Configures the strategy class.
1585
1586    DEPRECATED: This method's functionality has been split into the strategy
1587    constructor and `update_config_proto`. In the future, we will allow passing
1588    cluster and config_proto to the constructor to configure the strategy. And
1589    `update_config_proto` can be used to update the config_proto based on the
1590    specific strategy.
1591    """
1592    return self._extended._configure(  # pylint: disable=protected-access
1593        session_config, cluster_spec, task_type, task_id)
1594
1595  @doc_controls.do_not_generate_docs  # DEPRECATED
1596  def update_config_proto(self, config_proto):
1597    """DEPRECATED TF 1.x ONLY."""
1598    return self._extended._update_config_proto(config_proto)  # pylint: disable=protected-access
1599
1600  def __deepcopy__(self, memo):
1601    # First do a regular deepcopy of `self`.
1602    cls = self.__class__
1603    result = cls.__new__(cls)
1604    memo[id(self)] = result
1605    for k, v in self.__dict__.items():
1606      setattr(result, k, copy.deepcopy(v, memo))
1607    # One little fix-up: we want `result._extended` to reference `result`
1608    # instead of `self`.
1609    result._extended._container_strategy_weakref = weakref.ref(result)  # pylint: disable=protected-access
1610    return result
1611
1612  def __copy__(self):
1613    raise RuntimeError("Must only deepcopy DistributionStrategy.")
1614
1615  @property
1616  def cluster_resolver(self):
1617    """Returns the cluster resolver associated with this strategy.
1618
1619    In general, when using a multi-worker `tf.distribute` strategy such as
1620    `tf.distribute.experimental.MultiWorkerMirroredStrategy` or
1621    `tf.distribute.TPUStrategy()`, there is a
1622    `tf.distribute.cluster_resolver.ClusterResolver` associated with the
1623    strategy used, and such an instance is returned by this property.
1624
1625    Strategies that intend to have an associated
1626    `tf.distribute.cluster_resolver.ClusterResolver` must set the
1627    relevant attribute, or override this property; otherwise, `None` is returned
1628    by default. Those strategies should also provide information regarding what
1629    is returned by this property.
1630
1631    Single-worker strategies usually do not have a
1632    `tf.distribute.cluster_resolver.ClusterResolver`, and in those cases this
1633    property will return `None`.
1634
1635    The `tf.distribute.cluster_resolver.ClusterResolver` may be useful when the
1636    user needs to access information such as the cluster spec, task type or task
1637    id. For example,
1638
1639    ```python
1640
1641    os.environ['TF_CONFIG'] = json.dumps({
1642      'cluster': {
1643          'worker': ["localhost:12345", "localhost:23456"],
1644          'ps': ["localhost:34567"]
1645      },
1646      'task': {'type': 'worker', 'index': 0}
1647    })
1648
1649    # This implicitly uses TF_CONFIG for the cluster and current task info.
1650    strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
1651
1652    ...
1653
1654    if strategy.cluster_resolver.task_type == 'worker':
1655      # Perform something that's only applicable on workers. Since we set this
1656      # as a worker above, this block will run on this particular instance.
1657    elif strategy.cluster_resolver.task_type == 'ps':
1658      # Perform something that's only applicable on parameter servers. Since we
1659      # set this as a worker above, this block will not run on this particular
1660      # instance.
1661    ```
1662
1663    For more information, please see
1664    `tf.distribute.cluster_resolver.ClusterResolver`'s API docstring.
1665
1666    Returns:
1667      The cluster resolver associated with this strategy. Returns `None` if a
1668      cluster resolver is not applicable or available in this strategy.
1669    """
1670    if hasattr(self.extended, "_cluster_resolver"):
1671      return self.extended._cluster_resolver  # pylint: disable=protected-access
1672    return None
1673
1674
1675@tf_export("distribute.Strategy", v1=[])  # pylint: disable=g-missing-docstring
1676class Strategy(StrategyBase):
1677
1678  __doc__ = StrategyBase.__doc__
1679
1680  def experimental_distribute_values_from_function(self, value_fn):
1681    """Generates `tf.distribute.DistributedValues` from `value_fn`.
1682
1683    This function is to generate `tf.distribute.DistributedValues` to pass
1684    into `run`, `reduce`, or other methods that take
1685    distributed values when not using datasets.
1686
1687    Args:
1688      value_fn: The function to run to generate values. It is called for
1689        each replica with `tf.distribute.ValueContext` as the sole argument. It
1690        must return a Tensor or a type that can be converted to a Tensor.
1691    Returns:
1692      A `tf.distribute.DistributedValues` containing a value for each replica.
1693
1694    Example usage:
1695
1696    1. Return constant value per replica:
1697
1698    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
1699    >>> def value_fn(ctx):
1700    ...   return tf.constant(1.)
1701    >>> distributed_values = (
1702    ...      strategy.experimental_distribute_values_from_function(
1703    ...        value_fn))
1704    >>> local_result = strategy.experimental_local_results(distributed_values)
1705    >>> local_result
1706    (<tf.Tensor: shape=(), dtype=float32, numpy=1.0>,
1707     <tf.Tensor: shape=(), dtype=float32, numpy=1.0>)
1708
1709    2. Distribute values in array based on replica_id:
1710
1711    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
1712    >>> array_value = np.array([3., 2., 1.])
1713    >>> def value_fn(ctx):
1714    ...   return array_value[ctx.replica_id_in_sync_group]
1715    >>> distributed_values = (
1716    ...      strategy.experimental_distribute_values_from_function(
1717    ...        value_fn))
1718    >>> local_result = strategy.experimental_local_results(distributed_values)
1719    >>> local_result
1720    (3.0, 2.0)
1721
1722    3. Specify values using num_replicas_in_sync:
1723
1724    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
1725    >>> def value_fn(ctx):
1726    ...   return ctx.num_replicas_in_sync
1727    >>> distributed_values = (
1728    ...      strategy.experimental_distribute_values_from_function(
1729    ...        value_fn))
1730    >>> local_result = strategy.experimental_local_results(distributed_values)
1731    >>> local_result
1732    (2, 2)
1733
1734    4. Place values on devices and distribute:
1735
1736    ```
1737    strategy = tf.distribute.TPUStrategy()
1738    worker_devices = strategy.extended.worker_devices
1739    multiple_values = []
1740    for i in range(strategy.num_replicas_in_sync):
1741      with tf.device(worker_devices[i]):
1742        multiple_values.append(tf.constant(1.0))
1743
1744    def value_fn(ctx):
1745      return multiple_values[ctx.replica_id_in_sync_group]
1746
1747    distributed_values = strategy.
1748      experimental_distribute_values_from_function(
1749      value_fn)
1750    ```
1751
1752    """
1753    return self._extended._experimental_distribute_values_from_function(  # pylint: disable=protected-access
1754        value_fn)
1755
1756  def gather(self, value, axis):
1757    # pylint: disable=line-too-long, protected-access
1758    """Gather `value` across replicas along `axis` to the current device.
1759
1760    Given a `tf.distribute.DistributedValues` or `tf.Tensor`-like
1761    object `value`, this API gathers and concatenates `value` across replicas
1762    along the `axis`-th dimension. The result is copied to the "current" device,
1763    which would typically be the CPU of the worker on which the program is
1764    running. For `tf.distribute.TPUStrategy`, it is the first TPU host. For
1765    multi-client `tf.distribute.MultiWorkerMirroredStrategy`, this is the CPU of
1766    each worker.
1767
1768    This API can only be called in the cross-replica context. For a counterpart
1769    in the replica context, see `tf.distribute.ReplicaContext.all_gather`.
1770
1771    Note: For all strategies except `tf.distribute.TPUStrategy`, the input
1772    `value` on different replicas must have the same rank, and their shapes must
1773    be the same in all dimensions except the `axis`-th dimension. In other
1774    words, their shapes cannot be different in a dimension `d` where `d` does
1775    not equal to the `axis` argument. For example, given a
1776    `tf.distribute.DistributedValues` with component tensors of shape
1777    `(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can call
1778    `gather(..., axis=1, ...)` on it, but not `gather(..., axis=0, ...)` or
1779    `gather(..., axis=2, ...)`. However, for `tf.distribute.TPUStrategy.gather`,
1780    all tensors must have exactly the same rank and same shape.
1781
1782    Note: Given a `tf.distribute.DistributedValues` `value`, its component
1783    tensors must have a non-zero rank. Otherwise, consider using
1784    `tf.expand_dims` before gathering them.
1785
1786    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
1787    >>> # A DistributedValues with component tensor of shape (2, 1) on each replica
1788    ... distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(tf.constant([[1], [2]])))
1789    >>> @tf.function
1790    ... def run():
1791    ...   return strategy.gather(distributed_values, axis=0)
1792    >>> run()
1793    <tf.Tensor: shape=(4, 1), dtype=int32, numpy=
1794    array([[1],
1795           [2],
1796           [1],
1797           [2]], dtype=int32)>
1798
1799
1800    Consider the following example for more combinations:
1801
1802    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1", "GPU:2", "GPU:3"])
1803    >>> single_tensor = tf.reshape(tf.range(6), shape=(1,2,3))
1804    >>> distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(single_tensor))
1805    >>> @tf.function
1806    ... def run(axis):
1807    ...   return strategy.gather(distributed_values, axis=axis)
1808    >>> axis=0
1809    >>> run(axis)
1810    <tf.Tensor: shape=(4, 2, 3), dtype=int32, numpy=
1811    array([[[0, 1, 2],
1812            [3, 4, 5]],
1813           [[0, 1, 2],
1814            [3, 4, 5]],
1815           [[0, 1, 2],
1816            [3, 4, 5]],
1817           [[0, 1, 2],
1818            [3, 4, 5]]], dtype=int32)>
1819    >>> axis=1
1820    >>> run(axis)
1821    <tf.Tensor: shape=(1, 8, 3), dtype=int32, numpy=
1822    array([[[0, 1, 2],
1823            [3, 4, 5],
1824            [0, 1, 2],
1825            [3, 4, 5],
1826            [0, 1, 2],
1827            [3, 4, 5],
1828            [0, 1, 2],
1829            [3, 4, 5]]], dtype=int32)>
1830    >>> axis=2
1831    >>> run(axis)
1832    <tf.Tensor: shape=(1, 2, 12), dtype=int32, numpy=
1833    array([[[0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2],
1834            [3, 4, 5, 3, 4, 5, 3, 4, 5, 3, 4, 5]]], dtype=int32)>
1835
1836
1837    Args:
1838      value: a `tf.distribute.DistributedValues` instance, e.g. returned by
1839        `Strategy.run`, to be combined into a single tensor. It can also be a
1840        regular tensor when used with `tf.distribute.OneDeviceStrategy` or the
1841        default strategy. The tensors that constitute the DistributedValues
1842        can only be dense tensors with non-zero rank, NOT a `tf.IndexedSlices`.
1843      axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the
1844        range [0, rank(value)).
1845
1846    Returns:
1847       A `Tensor` that's the concatenation of `value` across replicas along
1848       `axis` dimension.
1849    """
1850    # pylint: enable=line-too-long
1851    error_message = ("tf.distribute.Strategy.gather method requires "
1852                     "cross-replica context, use "
1853                     "get_replica_context().all_gather() instead")
1854    _require_cross_replica_or_default_context_extended(self._extended,
1855                                                       error_message)
1856    dst = device_util.current(
1857    ) or self._extended._default_device or "/device:CPU:0"
1858    if isinstance(value, indexed_slices.IndexedSlices):
1859      raise NotImplementedError("gather does not support IndexedSlices")
1860    return self._extended._local_results(
1861        self._extended._gather_to(value, dst, axis))[0]
1862
1863
1864# TF v1.x version has additional deprecated APIs
1865@tf_export(v1=["distribute.Strategy"])
1866class StrategyV1(StrategyBase):
1867  """A list of devices with a state & compute distribution policy.
1868
1869  See [the guide](https://www.tensorflow.org/guide/distribute_strategy)
1870  for overview and examples.
1871
1872  Note: Not all `tf.distribute.Strategy` implementations currently support
1873  TensorFlow's partitioned variables (where a single variable is split across
1874  multiple devices) at this time.
1875  """
1876
1877  def make_dataset_iterator(self, dataset):
1878    """Makes an iterator for input provided via `dataset`.
1879
1880    DEPRECATED: This method is not available in TF 2.x.
1881
1882    Data from the given dataset will be distributed evenly across all the
1883    compute replicas. We will assume that the input dataset is batched by the
1884    global batch size. With this assumption, we will make a best effort to
1885    divide each batch across all the replicas (one or more workers).
1886    If this effort fails, an error will be thrown, and the user should instead
1887    use `make_input_fn_iterator` which provides more control to the user, and
1888    does not try to divide a batch across replicas.
1889
1890    The user could also use `make_input_fn_iterator` if they want to
1891    customize which input is fed to which replica/worker etc.
1892
1893    Args:
1894      dataset: `tf.data.Dataset` that will be distributed evenly across all
1895        replicas.
1896
1897    Returns:
1898      An `tf.distribute.InputIterator` which returns inputs for each step of the
1899      computation.  User should call `initialize` on the returned iterator.
1900    """
1901    return self._extended._make_dataset_iterator(dataset)  # pylint: disable=protected-access
1902
1903  def make_input_fn_iterator(self,  # pylint: disable=useless-super-delegation
1904                             input_fn,
1905                             replication_mode=InputReplicationMode.PER_WORKER):
1906    """Returns an iterator split across replicas created from an input function.
1907
1908    DEPRECATED: This method is not available in TF 2.x.
1909
1910    The `input_fn` should take an `tf.distribute.InputContext` object where
1911    information about batching and input sharding can be accessed:
1912
1913    ```
1914    def input_fn(input_context):
1915      batch_size = input_context.get_per_replica_batch_size(global_batch_size)
1916      d = tf.data.Dataset.from_tensors([[1.]]).repeat().batch(batch_size)
1917      return d.shard(input_context.num_input_pipelines,
1918                     input_context.input_pipeline_id)
1919    with strategy.scope():
1920      iterator = strategy.make_input_fn_iterator(input_fn)
1921      replica_results = strategy.experimental_run(replica_fn, iterator)
1922    ```
1923
1924    The `tf.data.Dataset` returned by `input_fn` should have a per-replica
1925    batch size, which may be computed using
1926    `input_context.get_per_replica_batch_size`.
1927
1928    Args:
1929      input_fn: A function taking a `tf.distribute.InputContext` object and
1930        returning a `tf.data.Dataset`.
1931      replication_mode: an enum value of `tf.distribute.InputReplicationMode`.
1932        Only `PER_WORKER` is supported currently, which means there will be
1933        a single call to `input_fn` per worker. Replicas will dequeue from the
1934        local `tf.data.Dataset` on their worker.
1935
1936    Returns:
1937      An iterator object that should first be `.initialize()`-ed. It may then
1938      either be passed to `strategy.experimental_run()` or you can
1939      `iterator.get_next()` to get the next value to pass to
1940      `strategy.extended.call_for_each_replica()`.
1941    """
1942    return super(StrategyV1, self).make_input_fn_iterator(
1943        input_fn, replication_mode)
1944
1945  def experimental_make_numpy_dataset(self, numpy_input, session=None):
1946    """Makes a tf.data.Dataset for input provided via a numpy array.
1947
1948    This avoids adding `numpy_input` as a large constant in the graph,
1949    and copies the data to the machine or machines that will be processing
1950    the input.
1951
1952    Note that you will likely need to use
1953    tf.distribute.Strategy.experimental_distribute_dataset
1954    with the returned dataset to further distribute it with the strategy.
1955
1956    Example:
1957    ```
1958    numpy_input = np.ones([10], dtype=np.float32)
1959    dataset = strategy.experimental_make_numpy_dataset(numpy_input)
1960    dist_dataset = strategy.experimental_distribute_dataset(dataset)
1961    ```
1962
1963    Args:
1964      numpy_input: A nest of NumPy input arrays that will be converted into a
1965      dataset. Note that lists of Numpy arrays are stacked, as that is normal
1966      `tf.data.Dataset` behavior.
1967      session: (TensorFlow v1.x graph execution only) A session used for
1968        initialization.
1969
1970    Returns:
1971      A `tf.data.Dataset` representing `numpy_input`.
1972    """
1973    return self.extended.experimental_make_numpy_dataset(
1974        numpy_input, session=session)
1975
1976  @deprecated(
1977      None,
1978      "This method is not available in TF 2.x. Please switch to using `run` instead."
1979  )
1980  def experimental_run(self, fn, input_iterator=None):  # pylint: disable=useless-super-delegation
1981    """Runs ops in `fn` on each replica, with inputs from `input_iterator`.
1982
1983    DEPRECATED: This method is not available in TF 2.x. Please switch
1984    to using `run` instead.
1985
1986    When eager execution is enabled, executes ops specified by `fn` on each
1987    replica. Otherwise, builds a graph to execute the ops on each replica.
1988
1989    Each replica will take a single, different input from the inputs provided by
1990    one `get_next` call on the input iterator.
1991
1992    `fn` may call `tf.distribute.get_replica_context()` to access members such
1993    as `replica_id_in_sync_group`.
1994
1995    IMPORTANT: Depending on the `tf.distribute.Strategy` implementation being
1996    used, and whether eager execution is enabled, `fn` may be called one or more
1997    times (once for each replica).
1998
1999    Args:
2000      fn: The function to run. The inputs to the function must match the outputs
2001        of `input_iterator.get_next()`. The output must be a `tf.nest` of
2002        `Tensor`s.
2003      input_iterator: (Optional) input iterator from which the inputs are taken.
2004
2005    Returns:
2006      Merged return value of `fn` across replicas. The structure of the return
2007      value is the same as the return value from `fn`. Each element in the
2008      structure can either be `PerReplica` (if the values are unsynchronized),
2009      `Mirrored` (if the values are kept in sync), or `Tensor` (if running on a
2010      single replica).
2011    """
2012    return super(StrategyV1, self).experimental_run(
2013        fn, input_iterator)
2014
2015  def reduce(self, reduce_op, value, axis=None):
2016    return super(StrategyV1, self).reduce(reduce_op, value, axis)
2017
2018  reduce.__doc__ = StrategyBase.reduce.__doc__
2019
2020  def update_config_proto(self, config_proto):
2021    """Returns a copy of `config_proto` modified for use with this strategy.
2022
2023    DEPRECATED: This method is not available in TF 2.x.
2024
2025    The updated config has something needed to run a strategy, e.g.
2026    configuration to run collective ops, or device filters to improve
2027    distributed training performance.
2028
2029    Args:
2030      config_proto: a `tf.ConfigProto` object.
2031
2032    Returns:
2033      The updated copy of the `config_proto`.
2034    """
2035    return self._extended._update_config_proto(config_proto)  # pylint: disable=protected-access
2036
2037
2038# NOTE(josh11b): For any strategy that needs to support tf.compat.v1,
2039# instead descend from StrategyExtendedV1.
2040@tf_export("distribute.StrategyExtended", v1=[])
2041class StrategyExtendedV2(object):
2042  """Additional APIs for algorithms that need to be distribution-aware.
2043
2044  Note: For most usage of `tf.distribute.Strategy`, there should be no need to
2045  call these methods, since TensorFlow libraries (such as optimizers) already
2046  call these methods when needed on your behalf.
2047
2048
2049  Some common use cases of functions on this page:
2050
2051  * _Locality_
2052
2053  `tf.distribute.DistributedValues` can have the same _locality_ as a
2054  _distributed variable_, which leads to a mirrored value residing on the same
2055  devices as the variable (as opposed to the compute devices). Such values may
2056  be passed to a call to `tf.distribute.StrategyExtended.update` to update the
2057  value of a variable. You may use
2058  `tf.distribute.StrategyExtended.colocate_vars_with` to give a variable the
2059  same locality as another variable. You may convert a "PerReplica" value to a
2060  variable's locality by using `tf.distribute.StrategyExtended.reduce_to` or
2061  `tf.distribute.StrategyExtended.batch_reduce_to`.
2062
2063  * _How to update a distributed variable_
2064
2065  A distributed variable is variables created on multiple devices. As discussed
2066  in the [glossary](https://www.tensorflow.org/api_docs/python/tf/distribute),
2067  mirrored variable and SyncOnRead variable are two examples. The standard
2068  pattern for updating distributed variables is to:
2069
2070  1. In your function passed to `tf.distribute.Strategy.run`,
2071     compute a list of (update, variable) pairs. For example, the update might
2072     be a gradient of the loss with respect to the variable.
2073  2. Switch to cross-replica mode by calling
2074     `tf.distribute.get_replica_context().merge_call()` with the updates and
2075     variables as arguments.
2076  3. Call
2077     `tf.distribute.StrategyExtended.reduce_to(VariableAggregation.SUM, t, v)`
2078     (for one variable) or `tf.distribute.StrategyExtended.batch_reduce_to`
2079     (for a list of variables) to sum the updates.
2080  4. Call `tf.distribute.StrategyExtended.update(v)` for each variable to update
2081     its value.
2082
2083  Steps 2 through 4 are done automatically by class
2084  `tf.keras.optimizers.Optimizer` if you call its
2085  `tf.keras.optimizers.Optimizer.apply_gradients` method in a replica context.
2086
2087  In fact, a higher-level solution to update a distributed variable is by
2088  calling `assign` on the variable as you would do to a regular `tf.Variable`.
2089  You can call the method in both _replica context_ and _cross-replica context_.
2090  For a _mirrored variable_, calling `assign` in _replica context_ requires you
2091  to specify the `aggregation` type in the variable constructor. In that case,
2092  the context switching and sync described in steps 2 through 4 are handled for
2093  you. If you call `assign` on _mirrored variable_ in _cross-replica context_,
2094  you can only assign a single value or assign values from another mirrored
2095  variable or a mirrored `tf.distribute.DistributedValues`. For a _SyncOnRead
2096  variable_, in _replica context_, you can simply call `assign` on it and no
2097  aggregation happens under the hood. In _cross-replica context_, you can only
2098  assign a single value to a SyncOnRead variable. One example case is restoring
2099  from a checkpoint: if the `aggregation` type of the variable is
2100  `tf.VariableAggregation.SUM`, it is assumed that replica values were added
2101  before checkpointing, so at the time of restoring, the value is divided by
2102  the number of replicas and then assigned to each replica; if the `aggregation`
2103  type is `tf.VariableAggregation.MEAN`, the value is assigned to each replica
2104  directly.
2105
2106  """
2107
2108  def __init__(self, container_strategy):
2109    self._container_strategy_weakref = weakref.ref(container_strategy)
2110    self._default_device = None
2111    # This property is used to determine if we should set drop_remainder=True
2112    # when creating Datasets from numpy array inputs.
2113    self._require_static_shapes = False
2114
2115  def _resource_creator_scope(self):
2116    """Returns one or a list of ops.resource_creator_scope for some Strategy."""
2117    return None
2118
2119  def _container_strategy(self):
2120    """Get the containing `tf.distribute.Strategy`.
2121
2122    This should not generally be needed except when creating a new
2123    `ReplicaContext` and to validate that the caller is in the correct
2124    `scope()`.
2125
2126    Returns:
2127      The `tf.distribute.Strategy` such that `strategy.extended` is `self`.
2128    """
2129    container_strategy = self._container_strategy_weakref()
2130    assert container_strategy is not None
2131    return container_strategy
2132
2133  def _scope(self, strategy):
2134    """Implementation of tf.distribute.Strategy.scope()."""
2135
2136    def creator_with_resource_vars(next_creator, **kwargs):
2137      """Variable creator to use in `_CurrentDistributionContext`."""
2138      _require_strategy_scope_extended(self)
2139      kwargs["use_resource"] = True
2140      kwargs["distribute_strategy"] = strategy
2141
2142      # Unwrap `initial_value` if it is a `CheckpointInitialValue` to avoid
2143      # dereferencing a `Tensor` that is without a `name`. We still need to
2144      # propagate the metadata it's holding.
2145      if isinstance(kwargs["initial_value"], trackable.CheckpointInitialValue):
2146        checkpoint_restore_uid = kwargs[
2147            "initial_value"].checkpoint_position.restore_uid
2148        kwargs["initial_value"] = kwargs["initial_value"].wrapped_value
2149      elif isinstance(kwargs["initial_value"],
2150                      trackable.CheckpointInitialValueCallable):
2151        checkpoint_restore_uid = kwargs[
2152            "initial_value"].checkpoint_position.restore_uid
2153      elif (isinstance(kwargs["initial_value"], functools.partial) and
2154            isinstance(kwargs["initial_value"].func,
2155                       trackable.CheckpointInitialValueCallable)):
2156        # Some libraries (e.g, Keras) create partial function out of initializer
2157        # to bind shape/dtype, for example:
2158        #  initial_val = functools.partial(initializer, shape, dtype=dtype)
2159        # Therefore to get the restore_uid we need to examine the "func" of
2160        # the partial function.
2161        checkpoint_restore_uid = kwargs[
2162            "initial_value"].func.checkpoint_position.restore_uid
2163      else:
2164        checkpoint_restore_uid = None
2165
2166      created = self._create_variable(next_creator, **kwargs)
2167
2168      if checkpoint_restore_uid is not None:
2169        # pylint: disable=protected-access
2170        # Let the checkpointing infrastructure know that the variable was
2171        # already restored so it doesn't waste memory loading the value again.
2172        # In this case of CheckpointInitialValueCallable this may already be
2173        # done by the final variable creator, but it doesn't hurt to do it
2174        # again.
2175        created._maybe_initialize_trackable()
2176        created._update_uid = checkpoint_restore_uid
2177        # pylint: enable=protected-access
2178      return created
2179
2180    def distributed_getter(getter, *args, **kwargs):
2181      if not self._allow_variable_partition():
2182        if kwargs.pop("partitioner", None) is not None:
2183          tf_logging.log_first_n(
2184              tf_logging.WARN, "Partitioned variables are disabled when using "
2185              "current tf.distribute.Strategy.", 1)
2186      return getter(*args, **kwargs)
2187
2188    return _CurrentDistributionContext(
2189        strategy,
2190        variable_scope.variable_creator_scope(creator_with_resource_vars),
2191        variable_scope.variable_scope(
2192            variable_scope.get_variable_scope(),
2193            custom_getter=distributed_getter),
2194        strategy.extended._resource_creator_scope(),  # pylint: disable=protected-access
2195        self._default_device)
2196
2197  def _allow_variable_partition(self):
2198    return False
2199
2200  def _create_variable(self, next_creator, **kwargs):
2201    # Note: should support "colocate_with" argument.
2202    raise NotImplementedError("must be implemented in descendants")
2203
2204  def variable_created_in_scope(self, v):
2205    """Tests whether `v` was created while this strategy scope was active.
2206
2207    Variables created inside the strategy scope are "owned" by it:
2208
2209    >>> strategy = tf.distribute.MirroredStrategy()
2210    >>> with strategy.scope():
2211    ...   v = tf.Variable(1.)
2212    >>> strategy.extended.variable_created_in_scope(v)
2213    True
2214
2215    Variables created outside the strategy are not owned by it:
2216
2217    >>> strategy = tf.distribute.MirroredStrategy()
2218    >>> v = tf.Variable(1.)
2219    >>> strategy.extended.variable_created_in_scope(v)
2220    False
2221
2222    Args:
2223      v: A `tf.Variable` instance.
2224
2225    Returns:
2226      True if `v` was created inside the scope, False if not.
2227    """
2228    return v._distribute_strategy == self._container_strategy_weakref()  # pylint: disable=protected-access
2229
2230  def colocate_vars_with(self, colocate_with_variable):
2231    """Scope that controls which devices variables will be created on.
2232
2233    No operations should be added to the graph inside this scope, it
2234    should only be used when creating variables (some implementations
2235    work by changing variable creation, others work by using a
2236    tf.compat.v1.colocate_with() scope).
2237
2238    This may only be used inside `self.scope()`.
2239
2240    Example usage:
2241
2242    ```
2243    with strategy.scope():
2244      var1 = tf.Variable(...)
2245      with strategy.extended.colocate_vars_with(var1):
2246        # var2 and var3 will be created on the same device(s) as var1
2247        var2 = tf.Variable(...)
2248        var3 = tf.Variable(...)
2249
2250      def fn(v1, v2, v3):
2251        # operates on v1 from var1, v2 from var2, and v3 from var3
2252
2253      # `fn` runs on every device `var1` is on, `var2` and `var3` will be there
2254      # too.
2255      strategy.extended.update(var1, fn, args=(var2, var3))
2256    ```
2257
2258    Args:
2259      colocate_with_variable: A variable created in this strategy's `scope()`.
2260        Variables created while in the returned context manager will be on the
2261        same set of devices as `colocate_with_variable`.
2262
2263    Returns:
2264      A context manager.
2265    """
2266
2267    def create_colocated_variable(next_creator, **kwargs):
2268      _require_strategy_scope_extended(self)
2269      kwargs["use_resource"] = True
2270      kwargs["colocate_with"] = colocate_with_variable
2271      return next_creator(**kwargs)
2272
2273    _require_strategy_scope_extended(self)
2274    self._validate_colocate_with_variable(colocate_with_variable)
2275    return variable_scope.variable_creator_scope(create_colocated_variable)
2276
2277  def _validate_colocate_with_variable(self, colocate_with_variable):
2278    """Validate `colocate_with_variable` argument to `colocate_vars_with`."""
2279    pass
2280
2281  def _make_dataset_iterator(self, dataset):
2282    raise NotImplementedError("must be implemented in descendants")
2283
2284  def _make_input_fn_iterator(self, input_fn, replication_mode):
2285    raise NotImplementedError("must be implemented in descendants")
2286
2287  def _experimental_distribute_dataset(self, dataset, options):
2288    raise NotImplementedError("must be implemented in descendants")
2289
2290  def _distribute_datasets_from_function(self, dataset_fn, options):
2291    raise NotImplementedError("must be implemented in descendants")
2292
2293  def _experimental_distribute_values_from_function(self, value_fn):
2294    raise NotImplementedError("must be implemented in descendants")
2295
2296  def _reduce(self, reduce_op, value):
2297    # Default implementation until we have an implementation for each strategy.
2298    dst = device_util.current() or self._default_device or "/device:CPU:0"
2299    return self._local_results(self.reduce_to(reduce_op, value, dst))[0]
2300
2301  def reduce_to(self, reduce_op, value, destinations, options=None):
2302    """Combine (via e.g. sum or mean) values across replicas.
2303
2304    `reduce_to` aggregates `tf.distribute.DistributedValues` and distributed
2305    variables. It supports both dense values and `tf.IndexedSlices`.
2306
2307    This API currently can only be called in cross-replica context. Other
2308    variants to reduce values across replicas are:
2309    * `tf.distribute.StrategyExtended.batch_reduce_to`: the batch version of
2310      this API.
2311    * `tf.distribute.ReplicaContext.all_reduce`: the counterpart of this API
2312      in replica context. It supports both batched and non-batched all-reduce.
2313    * `tf.distribute.Strategy.reduce`: a more convenient method to reduce
2314      to the host in cross-replica context.
2315
2316    `destinations` specifies where to reduce the value to, e.g. "GPU:0". You can
2317    also pass in a `Tensor`, and the destinations will be the device of that
2318    tensor. For all-reduce, pass the same to `value` and `destinations`.
2319
2320    It can be used in `tf.distribute.ReplicaContext.merge_call` to write code
2321    that works for all `tf.distribute.Strategy`.
2322
2323    >>> @tf.function
2324    ... def step_fn(var):
2325    ...
2326    ...   def merge_fn(strategy, value, var):
2327    ...     # All-reduce the value. Note that `value` here is a
2328    ...     # `tf.distribute.DistributedValues`.
2329    ...     reduced = strategy.extended.reduce_to(tf.distribute.ReduceOp.SUM,
2330    ...         value, destinations=var)
2331    ...     strategy.extended.update(var, lambda var, value: var.assign(value),
2332    ...         args=(reduced,))
2333    ...
2334    ...   value = tf.identity(1.)
2335    ...   tf.distribute.get_replica_context().merge_call(merge_fn,
2336    ...     args=(value, var))
2337    >>>
2338    >>> def run(strategy):
2339    ...   with strategy.scope():
2340    ...     v = tf.Variable(0.)
2341    ...     strategy.run(step_fn, args=(v,))
2342    ...     return v
2343    >>>
2344    >>> run(tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]))
2345    MirroredVariable:{
2346      0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>,
2347      1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=2.0>
2348    }
2349    >>> run(tf.distribute.experimental.CentralStorageStrategy(
2350    ...     compute_devices=["GPU:0", "GPU:1"], parameter_device="CPU:0"))
2351    <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>
2352    >>> run(tf.distribute.OneDeviceStrategy("GPU:0"))
2353    <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>
2354
2355    Args:
2356      reduce_op: a `tf.distribute.ReduceOp` value specifying how values should
2357        be combined. Allows using string representation of the enum such as
2358        "SUM", "MEAN".
2359      value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` like object.
2360      destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
2361        `tf.Tensor` alike object, or a device string. It specifies the devices
2362        to reduce to. To perform an all-reduce, pass the same to `value` and
2363        `destinations`. Note that if it's a `tf.Variable`, the value is reduced
2364        to the devices of that variable, and this method doesn't update the
2365        variable.
2366      options: a `tf.distribute.experimental.CommunicationOptions`. Options to
2367        perform collective operations. This overrides the default options if the
2368        `tf.distribute.Strategy` takes one in the constructor. See
2369        `tf.distribute.experimental.CommunicationOptions` for details of the
2370        options.
2371
2372    Returns:
2373      A tensor or value reduced to `destinations`.
2374    """
2375    if options is None:
2376      options = collective_util.Options()
2377    _require_cross_replica_or_default_context_extended(self)
2378    assert not isinstance(destinations, (list, tuple))
2379    assert not isinstance(reduce_op, variable_scope.VariableAggregation)
2380    if isinstance(reduce_op, six.string_types):
2381      reduce_op = reduce_util.ReduceOp(reduce_op.upper())
2382    assert (reduce_op == reduce_util.ReduceOp.SUM or
2383            reduce_op == reduce_util.ReduceOp.MEAN)
2384    return self._reduce_to(reduce_op, value, destinations, options)
2385
2386  def _reduce_to(self, reduce_op, value, destinations, options):
2387    raise NotImplementedError("must be implemented in descendants")
2388
2389  def batch_reduce_to(self, reduce_op, value_destination_pairs, options=None):
2390    """Combine multiple `reduce_to` calls into one for faster execution.
2391
2392    Similar to `reduce_to`, but accepts a list of (value, destinations) pairs.
2393    It's more efficient than reduce each value separately.
2394
2395    This API currently can only be called in cross-replica context. Other
2396    variants to reduce values across replicas are:
2397    * `tf.distribute.StrategyExtended.reduce_to`: the non-batch version of
2398      this API.
2399    * `tf.distribute.ReplicaContext.all_reduce`: the counterpart of this API
2400      in replica context. It supports both batched and non-batched all-reduce.
2401    * `tf.distribute.Strategy.reduce`: a more convenient method to reduce
2402      to the host in cross-replica context.
2403
2404    See `reduce_to` for more information.
2405
2406    >>> @tf.function
2407    ... def step_fn(var):
2408    ...
2409    ...   def merge_fn(strategy, value, var):
2410    ...     # All-reduce the value. Note that `value` here is a
2411    ...     # `tf.distribute.DistributedValues`.
2412    ...     reduced = strategy.extended.batch_reduce_to(
2413    ...         tf.distribute.ReduceOp.SUM, [(value, var)])[0]
2414    ...     strategy.extended.update(var, lambda var, value: var.assign(value),
2415    ...         args=(reduced,))
2416    ...
2417    ...   value = tf.identity(1.)
2418    ...   tf.distribute.get_replica_context().merge_call(merge_fn,
2419    ...     args=(value, var))
2420    >>>
2421    >>> def run(strategy):
2422    ...   with strategy.scope():
2423    ...     v = tf.Variable(0.)
2424    ...     strategy.run(step_fn, args=(v,))
2425    ...     return v
2426    >>>
2427    >>> run(tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]))
2428    MirroredVariable:{
2429      0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>,
2430      1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=2.0>
2431    }
2432    >>> run(tf.distribute.experimental.CentralStorageStrategy(
2433    ...     compute_devices=["GPU:0", "GPU:1"], parameter_device="CPU:0"))
2434    <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>
2435    >>> run(tf.distribute.OneDeviceStrategy("GPU:0"))
2436    <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>
2437
2438    Args:
2439      reduce_op: a `tf.distribute.ReduceOp` value specifying how values should
2440        be combined. Allows using string representation of the enum such as
2441        "SUM", "MEAN".
2442      value_destination_pairs: a sequence of (value, destinations) pairs. See
2443        `tf.distribute.Strategy.reduce_to` for descriptions.
2444      options: a `tf.distribute.experimental.CommunicationOptions`. Options to
2445        perform collective operations. This overrides the default options if the
2446        `tf.distribute.Strategy` takes one in the constructor. See
2447        `tf.distribute.experimental.CommunicationOptions` for details of the
2448        options.
2449
2450    Returns:
2451      A list of reduced values, one per pair in `value_destination_pairs`.
2452    """
2453    if options is None:
2454      options = collective_util.Options()
2455    _require_cross_replica_or_default_context_extended(self)
2456    assert not isinstance(reduce_op, variable_scope.VariableAggregation)
2457    if isinstance(reduce_op, six.string_types):
2458      reduce_op = reduce_util.ReduceOp(reduce_op.upper())
2459    return self._batch_reduce_to(reduce_op, value_destination_pairs, options)
2460
2461  def _batch_reduce_to(self, reduce_op, value_destination_pairs, options):
2462    return [
2463        self.reduce_to(reduce_op, t, destinations=v, options=options)
2464        for t, v in value_destination_pairs
2465    ]
2466
2467  def _replica_ctx_all_reduce(self, reduce_op, value, options=None):
2468    """All-reduce `value` across all replicas so that all get the final result.
2469
2470    If `value` is a nested structure of tensors, all-reduces of these tensors
2471    will be batched when possible. `options` can be set to hint the batching
2472    behavior.
2473
2474    This API must be called in a replica context.
2475
2476    Args:
2477      reduce_op: A `tf.distribute.ReduceOp` value specifying how values should
2478        be combined.
2479      value: Value to be reduced. A tensor or a nested structure of tensors.
2480      options: A `tf.distribute.experimental.CommunicationOptions`. Options to
2481        perform collective operations. This overrides the default options if the
2482        `tf.distribute.Strategy` takes one in the constructor.
2483
2484    Returns:
2485      A tensor or a nested strucutre of tensors with the reduced values. The
2486      structure is the same as `value`.
2487    """
2488    if options is None:
2489      options = collective_util.Options()
2490    replica_context = distribution_strategy_context.get_replica_context()
2491    assert replica_context, (
2492        "`StrategyExtended._replica_ctx_all_reduce` must be called in"
2493        " a replica context")
2494
2495    def merge_fn(_, flat_value):
2496      return self.batch_reduce_to(reduce_op, [(v, v) for v in flat_value],
2497                                  options)
2498
2499    reduced = replica_context.merge_call(merge_fn, args=(nest.flatten(value),))
2500    return nest.pack_sequence_as(value, reduced)
2501
2502  def _replica_ctx_update(self, var, fn, args=(), kwargs=None, group=True):
2503    """Run `fn` with `args` and `kwargs` to update `var`."""
2504    # This method is called by ReplicaContext.update. Strategies who'd like to
2505    # remove merge_call in this path should override this method.
2506    replica_context = distribution_strategy_context.get_replica_context()
2507    if not replica_context:
2508      raise ValueError("`StrategyExtended._replica_ctx_update` must be called "
2509                       "in a replica context.")
2510
2511    def merge_fn(_, *merged_args, **merged_kwargs):
2512      return self.update(var, fn, merged_args, merged_kwargs, group=group)
2513
2514    return replica_context.merge_call(merge_fn, args=args, kwargs=kwargs)
2515
2516  def _gather_to(self, value, destinations, axis, options=None):
2517    """Gather `value` across replicas along axis-th dimension to `destinations`.
2518
2519    `gather_to` gathers `tf.distribute.DistributedValues` or `tf.Tensor`-like
2520    object, along `axis`-th dimension. It supports only dense tensors but NOT
2521    sparse tensor. This API can only be called in cross-replica context.
2522
2523    Args:
2524      value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` like object.
2525      destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
2526        `tf.Tensor` alike object, or a device string. It specifies the devices
2527        to reduce to. To perform an all-gather, pass the same to `value` and
2528        `destinations`. Note that if it's a `tf.Variable`, the value is reduced
2529        to the devices of that variable, and this method doesn't update the
2530        variable.
2531      axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the
2532        range [0, rank(value)).
2533      options: a `tf.distribute.experimental.CommunicationOptions`. Options to
2534        perform collective operations. This overrides the default options if the
2535        `tf.distribute.Strategy` takes one in the constructor. See
2536        `tf.distribute.experimental.CommunicationOptions` for details of the
2537        options.
2538
2539    Returns:
2540      A tensor or value gathered to `destinations`.
2541    """
2542    _require_cross_replica_or_default_context_extended(self)
2543    assert not isinstance(destinations, (list, tuple))
2544    if options is None:
2545      options = collective_util.Options()
2546    return self._gather_to_implementation(value, destinations, axis, options)
2547
2548  def _gather_to_implementation(self, value, destinations, axis, options):
2549    raise NotImplementedError("_gather_to must be implemented in descendants")
2550
2551  def _batch_gather_to(self, value_destination_pairs, axis, options=None):
2552    _require_cross_replica_or_default_context_extended(self)
2553    if options is None:
2554      options = collective_util.Options()
2555    return [
2556        self._gather_to(t, destinations=v, axis=axis, options=options)
2557        for t, v in value_destination_pairs
2558    ]
2559
2560  def update(self, var, fn, args=(), kwargs=None, group=True):
2561    """Run `fn` to update `var` using inputs mirrored to the same devices.
2562
2563    `tf.distribute.StrategyExtended.update` takes a distributed variable `var`
2564    to be updated, an update function `fn`, and `args` and `kwargs` for `fn`. It
2565    applies `fn` to each component variable of `var` and passes corresponding
2566    values from `args` and `kwargs`. Neither `args` nor `kwargs` may contain
2567    per-replica values. If they contain mirrored values, they will be unwrapped
2568    before calling `fn`. For example, `fn` can be `assign_add` and `args` can be
2569    a mirrored DistributedValues where each component contains the value to be
2570    added to this mirrored variable `var`. Calling `update` will call
2571    `assign_add` on each component variable of `var` with the corresponding
2572    tensor value on that device.
2573
2574    Example usage:
2575
2576    ```python
2577    strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) # With 2
2578    devices
2579    with strategy.scope():
2580      v = tf.Variable(5.0, aggregation=tf.VariableAggregation.SUM)
2581    def update_fn(v):
2582      return v.assign(1.0)
2583    result = strategy.extended.update(v, update_fn)
2584    # result is
2585    # Mirrored:{
2586    #  0: tf.Tensor(1.0, shape=(), dtype=float32),
2587    #  1: tf.Tensor(1.0, shape=(), dtype=float32)
2588    # }
2589    ```
2590
2591    If `var` is mirrored across multiple devices, then this method implements
2592    logic as following:
2593
2594    ```python
2595    results = {}
2596    for device, v in var:
2597      with tf.device(device):
2598        # args and kwargs will be unwrapped if they are mirrored.
2599        results[device] = fn(v, *args, **kwargs)
2600    return merged(results)
2601    ```
2602
2603    Otherwise, this method returns `fn(var, *args, **kwargs)` colocated with
2604    `var`.
2605
2606    Args:
2607      var: Variable, possibly mirrored to multiple devices, to operate on.
2608      fn: Function to call. Should take the variable as the first argument.
2609      args: Tuple or list. Additional positional arguments to pass to `fn()`.
2610      kwargs: Dict with keyword arguments to pass to `fn()`.
2611      group: Boolean. Defaults to True. If False, the return value will be
2612        unwrapped.
2613
2614    Returns:
2615      By default, the merged return value of `fn` across all replicas.  The
2616      merged result has dependencies to make sure that if it is evaluated at
2617      all, the side effects (updates) will happen on every replica. If instead
2618      "group=False" is specified, this function will return a nest of lists
2619      where each list has an element per replica, and the caller is responsible
2620      for ensuring all elements are executed.
2621    """
2622    # TODO(b/178944108): Update the documentation to relfect the fact that
2623    # `update` can be called in a replica context.
2624    if kwargs is None:
2625      kwargs = {}
2626    replica_context = distribution_strategy_context.get_replica_context()
2627    # pylint: disable=protected-access
2628    if (replica_context is None or replica_context is
2629        distribution_strategy_context._get_default_replica_context()):
2630      fn = autograph.tf_convert(
2631          fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
2632      with self._container_strategy().scope():
2633        return self._update(var, fn, args, kwargs, group)
2634    else:
2635      return self._replica_ctx_update(
2636          var, fn, args=args, kwargs=kwargs, group=group)
2637
2638  def _update(self, var, fn, args, kwargs, group):
2639    raise NotImplementedError("must be implemented in descendants")
2640
2641  def _local_results(self, val):
2642    """Returns local results per replica as a tuple."""
2643    if isinstance(val, values.DistributedValues):
2644      return val._values  # pylint: disable=protected-access
2645
2646    if nest.is_nested(val):
2647      replica_values = []
2648
2649      def get_values(x, index):
2650        if isinstance(x, values.DistributedValues):
2651          return x._values[index]  # pylint: disable=protected-access
2652        return x
2653
2654      for i in range(len(self.worker_devices)):
2655        replica_values.append(
2656            nest.map_structure(
2657                lambda x: get_values(x, i),  # pylint: disable=cell-var-from-loop
2658                val))
2659      return tuple(replica_values)
2660    return (val,)
2661
2662  def value_container(self, value):
2663    """Returns the container that this per-replica `value` belongs to.
2664
2665    Args:
2666      value: A value returned by `run()` or a variable created in `scope()`.
2667
2668    Returns:
2669      A container that `value` belongs to.
2670      If value does not belong to any container (including the case of
2671      container having been destroyed), returns the value itself.
2672      `value in experimental_local_results(value_container(value))` will
2673      always be true.
2674    """
2675    raise NotImplementedError("must be implemented in descendants")
2676
2677  def _group(self, value, name=None):
2678    """Implementation of `group`."""
2679    value = nest.flatten(self._local_results(value))
2680
2681    if len(value) != 1 or name is not None:
2682      return control_flow_ops.group(value, name=name)
2683    # Special handling for the common case of one op.
2684    v, = value
2685    if hasattr(v, "op"):
2686      v = v.op
2687    return v
2688
2689  @property
2690  def experimental_require_static_shapes(self):
2691    """Returns `True` if static shape is required; `False` otherwise."""
2692    return self._require_static_shapes
2693
2694  @property
2695  def _num_replicas_in_sync(self):
2696    """Returns number of replicas over which gradients are aggregated."""
2697    raise NotImplementedError("must be implemented in descendants")
2698
2699  @property
2700  def worker_devices(self):
2701    """Returns the tuple of all devices used to for compute replica execution.
2702    """
2703    # TODO(josh11b): More docstring
2704    raise NotImplementedError("must be implemented in descendants")
2705
2706  @property
2707  def parameter_devices(self):
2708    """Returns the tuple of all devices used to place variables."""
2709    # TODO(josh11b): More docstring
2710    raise NotImplementedError("must be implemented in descendants")
2711
2712  def _configure(self,
2713                 session_config=None,
2714                 cluster_spec=None,
2715                 task_type=None,
2716                 task_id=None):
2717    """Configures the strategy class."""
2718    del session_config, cluster_spec, task_type, task_id
2719
2720  def _update_config_proto(self, config_proto):
2721    return copy.deepcopy(config_proto)
2722
2723  def _in_multi_worker_mode(self):
2724    """Whether this strategy indicates working in multi-worker settings.
2725
2726    Multi-worker training refers to the setup where the training is
2727    distributed across multiple workers, as opposed to the case where
2728    only a local process performs the training. This function is
2729    used by higher-level APIs such as Keras' `model.fit()` to infer
2730    for example whether or not a distribute coordinator should be run,
2731    and thus TensorFlow servers should be started for communication
2732    with other servers in the cluster, or whether or not saving/restoring
2733    checkpoints is relevant for preemption fault tolerance.
2734
2735    Subclasses should override this to provide whether the strategy is
2736    currently in multi-worker setup.
2737
2738    Experimental. Signature and implementation are subject to change.
2739    """
2740    raise NotImplementedError("must be implemented in descendants")
2741
2742
2743@tf_export(v1=["distribute.StrategyExtended"])  # pylint: disable=missing-docstring
2744class StrategyExtendedV1(StrategyExtendedV2):
2745
2746  __doc__ = StrategyExtendedV2.__doc__
2747
2748  def experimental_make_numpy_dataset(self, numpy_input, session=None):
2749    """Makes a dataset for input provided via a numpy array.
2750
2751    This avoids adding `numpy_input` as a large constant in the graph,
2752    and copies the data to the machine or machines that will be processing
2753    the input.
2754
2755    Args:
2756      numpy_input: A nest of NumPy input arrays that will be distributed evenly
2757        across all replicas. Note that lists of Numpy arrays are stacked, as
2758        that is normal `tf.data.Dataset` behavior.
2759      session: (TensorFlow v1.x graph execution only) A session used for
2760        initialization.
2761
2762    Returns:
2763      A `tf.data.Dataset` representing `numpy_input`.
2764    """
2765    _require_cross_replica_or_default_context_extended(self)
2766    return self._experimental_make_numpy_dataset(numpy_input, session=session)
2767
2768  def _experimental_make_numpy_dataset(self, numpy_input, session):
2769    raise NotImplementedError("must be implemented in descendants")
2770
2771  def broadcast_to(self, tensor, destinations):
2772    """Mirror a tensor on one device to all worker devices.
2773
2774    Args:
2775      tensor: A Tensor value to broadcast.
2776      destinations: A mirrored variable or device string specifying the
2777        destination devices to copy `tensor` to.
2778
2779    Returns:
2780      A value mirrored to `destinations` devices.
2781    """
2782    assert destinations is not None  # from old strategy.broadcast()
2783    # TODO(josh11b): More docstring
2784    _require_cross_replica_or_default_context_extended(self)
2785    assert not isinstance(destinations, (list, tuple))
2786    return self._broadcast_to(tensor, destinations)
2787
2788  def _broadcast_to(self, tensor, destinations):
2789    raise NotImplementedError("must be implemented in descendants")
2790
2791  @deprecated(None, "please use `run` instead.")
2792  def experimental_run_steps_on_iterator(self,
2793                                         fn,
2794                                         iterator,
2795                                         iterations=1,
2796                                         initial_loop_values=None):
2797    """DEPRECATED: please use `run` instead.
2798
2799    Run `fn` with input from `iterator` for `iterations` times.
2800
2801    This method can be used to run a step function for training a number of
2802    times using input from a dataset.
2803
2804    Args:
2805      fn: function to run using this distribution strategy. The function must
2806        have the following signature: `def fn(context, inputs)`. `context` is an
2807          instance of `MultiStepContext` that will be passed when `fn` is run.
2808          `context` can be used to specify the outputs to be returned from `fn`
2809          by calling `context.set_last_step_output`. It can also be used to
2810          capture non tensor outputs by `context.set_non_tensor_output`. See
2811          `MultiStepContext` documentation for more information. `inputs` will
2812          have same type/structure as `iterator.get_next()`. Typically, `fn`
2813          will use `call_for_each_replica` method of the strategy to distribute
2814          the computation over multiple replicas.
2815      iterator: Iterator of a dataset that represents the input for `fn`. The
2816        caller is responsible for initializing the iterator as needed.
2817      iterations: (Optional) Number of iterations that `fn` should be run.
2818        Defaults to 1.
2819      initial_loop_values: (Optional) Initial values to be passed into the
2820        loop that runs `fn`. Defaults to `None`. # TODO(priyag): Remove
2821          initial_loop_values argument when we have a mechanism to infer the
2822          outputs of `fn`.
2823
2824    Returns:
2825      Returns the `MultiStepContext` object which has the following properties,
2826      among other things:
2827        - run_op: An op that runs `fn` `iterations` times.
2828        - last_step_outputs: A dictionary containing tensors set using
2829        `context.set_last_step_output`. Evaluating this returns the value of
2830        the tensors after the last iteration.
2831        - non_tensor_outputs: A dictionary containing anything that was set by
2832          `fn` by calling `context.set_non_tensor_output`.
2833    """
2834    _require_cross_replica_or_default_context_extended(self)
2835    with self._container_strategy().scope():
2836      return self._experimental_run_steps_on_iterator(fn, iterator, iterations,
2837                                                      initial_loop_values)
2838
2839  def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
2840                                          initial_loop_values):
2841    raise NotImplementedError("must be implemented in descendants")
2842
2843  def call_for_each_replica(self, fn, args=(), kwargs=None):
2844    """Run `fn` once per replica.
2845
2846    `fn` may call `tf.get_replica_context()` to access methods such as
2847    `replica_id_in_sync_group` and `merge_call()`.
2848
2849    `merge_call()` is used to communicate between the replicas and
2850    re-enter the cross-replica context. All replicas pause their execution
2851    having encountered a `merge_call()` call. After that the
2852    `merge_fn`-function is executed. Its results are then unwrapped and
2853    given back to each replica call. After that execution resumes until
2854    `fn` is complete or encounters another `merge_call()`.  Example:
2855
2856    ```python
2857    # Called once in "cross-replica" context.
2858    def merge_fn(distribution, three_plus_replica_id):
2859      # sum the values across replicas
2860      return sum(distribution.experimental_local_results(three_plus_replica_id))
2861
2862    # Called once per replica in `distribution`, in a "replica" context.
2863    def fn(three):
2864      replica_ctx = tf.get_replica_context()
2865      v = three + replica_ctx.replica_id_in_sync_group
2866      # Computes the sum of the `v` values across all replicas.
2867      s = replica_ctx.merge_call(merge_fn, args=(v,))
2868      return s + v
2869
2870    with distribution.scope():
2871      # in "cross-replica" context
2872      ...
2873      merged_results = distribution.run(fn, args=[3])
2874      # merged_results has the values from every replica execution of `fn`.
2875      # This statement prints a list:
2876      print(distribution.experimental_local_results(merged_results))
2877    ```
2878
2879    Args:
2880      fn: function to run (will be run once per replica).
2881      args: Tuple or list with positional arguments for `fn`.
2882      kwargs: Dict with keyword arguments for `fn`.
2883
2884    Returns:
2885      Merged return value of `fn` across all replicas.
2886    """
2887    _require_cross_replica_or_default_context_extended(self)
2888    if kwargs is None:
2889      kwargs = {}
2890    with self._container_strategy().scope():
2891      return self._call_for_each_replica(fn, args, kwargs)
2892
2893  def _call_for_each_replica(self, fn, args, kwargs):
2894    raise NotImplementedError("must be implemented in descendants")
2895
2896  def read_var(self, v):
2897    """Reads the value of a variable.
2898
2899    Returns the aggregate value of a replica-local variable, or the
2900    (read-only) value of any other variable.
2901
2902    Args:
2903      v: A variable allocated within the scope of this `tf.distribute.Strategy`.
2904
2905    Returns:
2906      A tensor representing the value of `v`, aggregated across replicas if
2907      necessary.
2908    """
2909    raise NotImplementedError("must be implemented in descendants")
2910
2911  def update_non_slot(
2912      self, colocate_with, fn, args=(), kwargs=None, group=True):
2913    """Runs `fn(*args, **kwargs)` on `colocate_with` devices.
2914
2915    Used to update non-slot variables.
2916
2917    DEPRECATED: TF 1.x ONLY.
2918
2919    Args:
2920      colocate_with: Devices returned by `non_slot_devices()`.
2921      fn: Function to execute.
2922      args: Tuple or list. Positional arguments to pass to `fn()`.
2923      kwargs: Dict with keyword arguments to pass to `fn()`.
2924      group: Boolean. Defaults to True. If False, the return value will be
2925        unwrapped.
2926
2927    Returns:
2928      Return value of `fn`, possibly merged across devices.
2929    """
2930    _require_cross_replica_or_default_context_extended(self)
2931    if kwargs is None:
2932      kwargs = {}
2933    fn = autograph.tf_convert(
2934        fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
2935    with self._container_strategy().scope():
2936      return self._update_non_slot(colocate_with, fn, args, kwargs, group)
2937
2938  def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
2939    raise NotImplementedError("must be implemented in descendants")
2940
2941  def non_slot_devices(self, var_list):
2942    """Device(s) for non-slot variables.
2943
2944    DEPRECATED: TF 1.x ONLY.
2945
2946    This method returns non-slot devices where non-slot variables are placed.
2947    Users can create non-slot variables on these devices by using a block:
2948
2949    ```python
2950    with tf.distribute.StrategyExtended.colocate_vars_with(tf.distribute.StrategyExtended.non_slot_devices(...)):
2951      ...
2952    ```
2953
2954    Args:
2955      var_list: The list of variables being optimized, needed with the
2956        default `tf.distribute.Strategy`.
2957    Returns:
2958      A sequence of devices for non-slot variables.
2959    """
2960    raise NotImplementedError("must be implemented in descendants")
2961
2962  def _use_merge_call(self):
2963    """Whether to use merge-calls inside the distributed strategy."""
2964    return True
2965
2966  @property
2967  def experimental_between_graph(self):
2968    """Whether the strategy uses between-graph replication or not.
2969
2970      This is expected to return a constant value that will not be changed
2971      throughout its life cycle.
2972    """
2973    raise NotImplementedError("must be implemented in descendants")
2974
2975  @property
2976  def experimental_should_init(self):
2977    """Whether initialization is needed."""
2978    raise NotImplementedError("must be implemented in descendants")
2979
2980  @property
2981  def should_checkpoint(self):
2982    """Whether checkpointing is needed."""
2983    raise NotImplementedError("must be implemented in descendants")
2984
2985  @property
2986  def should_save_summary(self):
2987    """Whether saving summaries is needed."""
2988    raise NotImplementedError("must be implemented in descendants")
2989
2990
2991# A note about the difference between the context managers
2992# `ReplicaContext` (defined here) and `_CurrentDistributionContext`
2993# (defined above) used by `tf.distribute.Strategy.scope()`:
2994#
2995# * a ReplicaContext is only present during a `run()`
2996#   call (except during a `merge_run` call) and in such a scope it
2997#   will be returned by calls to `get_replica_context()`.  Implementers of new
2998#   Strategy descendants will frequently also need to
2999#   define a descendant of ReplicaContext, and are responsible for
3000#   entering and exiting this context.
3001#
3002# * Strategy.scope() sets up a variable_creator scope that
3003#   changes variable creation calls (e.g. to make mirrored
3004#   variables). This is intended as an outer scope that users enter once
3005#   around their model creation and graph definition. There is no
3006#   anticipated need to define descendants of _CurrentDistributionContext.
3007#   It sets the current Strategy for purposes of
3008#   `get_strategy()` and `has_strategy()`
3009#   and switches the thread mode to a "cross-replica context".
3010class ReplicaContextBase(object):
3011  """A class with a collection of APIs that can be called in a replica context.
3012
3013  You can use `tf.distribute.get_replica_context` to get an instance of
3014  `ReplicaContext`, which can only be called inside the function passed to
3015  `tf.distribute.Strategy.run`.
3016
3017  >>> strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1'])
3018  >>> def func():
3019  ...   replica_context = tf.distribute.get_replica_context()
3020  ...   return replica_context.replica_id_in_sync_group
3021  >>> strategy.run(func)
3022  PerReplica:{
3023    0: <tf.Tensor: shape=(), dtype=int32, numpy=0>,
3024    1: <tf.Tensor: shape=(), dtype=int32, numpy=1>
3025  }
3026  """
3027
3028  def __init__(self, strategy, replica_id_in_sync_group):
3029    """Creates a ReplicaContext.
3030
3031    Args:
3032      strategy: A `tf.distribute.Strategy`.
3033      replica_id_in_sync_group: An integer, a `Tensor` or None. Prefer an
3034        integer whenever possible to avoid issues with nested `tf.function`. It
3035        accepts a `Tensor` only to be compatible with `tpu.replicate`.
3036    """
3037    self._strategy = strategy
3038    self._thread_context = distribution_strategy_context._InReplicaThreadMode(  # pylint: disable=protected-access
3039        self)
3040    if not (replica_id_in_sync_group is None or
3041            tensor_util.is_tf_type(replica_id_in_sync_group) or
3042            isinstance(replica_id_in_sync_group, int)):
3043      raise ValueError(
3044          "replica_id_in_sync_group can only be an integer, a Tensor or None.")
3045    self._replica_id_in_sync_group = replica_id_in_sync_group
3046    # We need this check because TPUContext extends from ReplicaContext and
3047    # does not pass a strategy object since it is used by TPUEstimator.
3048    if strategy:
3049      self._local_replica_id = strategy.extended._get_local_replica_id(
3050          replica_id_in_sync_group)
3051    self._summary_recording_distribution_strategy = None
3052
3053  @doc_controls.do_not_generate_docs
3054  def __enter__(self):
3055    _push_per_thread_mode(self._thread_context)
3056
3057    def replica_id_is_zero():
3058      return math_ops.equal(self.replica_id_in_sync_group,
3059                            constant_op.constant(0))
3060
3061    summary_state = summary_ops_v2._summary_state  # pylint: disable=protected-access
3062    self._summary_recording_distribution_strategy = (
3063        summary_state.is_recording_distribution_strategy)
3064    summary_state.is_recording_distribution_strategy = replica_id_is_zero
3065
3066  @doc_controls.do_not_generate_docs
3067  def __exit__(self, exception_type, exception_value, traceback):
3068    summary_state = summary_ops_v2._summary_state  # pylint: disable=protected-access
3069    summary_state.is_recording_distribution_strategy = (
3070        self._summary_recording_distribution_strategy)
3071    _pop_per_thread_mode()
3072
3073  def merge_call(self, merge_fn, args=(), kwargs=None):
3074    """Merge args across replicas and run `merge_fn` in a cross-replica context.
3075
3076    This allows communication and coordination when there are multiple calls
3077    to the step_fn triggered by a call to `strategy.run(step_fn, ...)`.
3078
3079    See `tf.distribute.Strategy.run` for an explanation.
3080
3081    If not inside a distributed scope, this is equivalent to:
3082
3083    ```
3084    strategy = tf.distribute.get_strategy()
3085    with cross-replica-context(strategy):
3086      return merge_fn(strategy, *args, **kwargs)
3087    ```
3088
3089    Args:
3090      merge_fn: Function that joins arguments from threads that are given as
3091        PerReplica. It accepts `tf.distribute.Strategy` object as
3092        the first argument.
3093      args: List or tuple with positional per-thread arguments for `merge_fn`.
3094      kwargs: Dict with keyword per-thread arguments for `merge_fn`.
3095
3096    Returns:
3097      The return value of `merge_fn`, except for `PerReplica` values which are
3098      unpacked.
3099    """
3100    require_replica_context(self)
3101    if kwargs is None:
3102      kwargs = {}
3103
3104    merge_fn = autograph.tf_convert(
3105        merge_fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
3106    return self._merge_call(merge_fn, args, kwargs)
3107
3108  def _merge_call(self, merge_fn, args, kwargs):
3109    """Default implementation for single replica."""
3110    _push_per_thread_mode(  # thread-local, so not needed with multiple threads
3111        distribution_strategy_context._CrossReplicaThreadMode(self._strategy))  # pylint: disable=protected-access
3112    try:
3113      return merge_fn(self._strategy, *args, **kwargs)
3114    finally:
3115      _pop_per_thread_mode()
3116
3117  @property
3118  def num_replicas_in_sync(self):
3119    """Returns number of replicas that are kept in sync."""
3120    return self._strategy.num_replicas_in_sync
3121
3122  @property
3123  def replica_id_in_sync_group(self):
3124    """Returns the id of the replica.
3125
3126    This identifies the replica among all replicas that are kept in sync. The
3127    value of the replica id can range from 0 to
3128    `tf.distribute.ReplicaContext.num_replicas_in_sync` - 1.
3129
3130    NOTE: This is not guaranteed to be the same ID as the XLA replica ID use
3131    for low-level operations such as collective_permute.
3132
3133    Returns:
3134      a `Tensor`.
3135    """
3136    # It's important to prefer making the Tensor at call time whenever possible.
3137    # Keeping Tensors in global states doesn't work well with nested
3138    # tf.function, since it's possible that the tensor is generated in one func
3139    # graph, and gets captured by another, which will result in a subtle "An op
3140    # outside of the function building code is being passed a Graph tensor"
3141    # error. Making the tensor at call time to ensure it is the same graph where
3142    # it's used. However to be compatible with tpu.replicate(),
3143    # self._replica_id_in_sync_group can also be a Tensor.
3144    if tensor_util.is_tf_type(self._replica_id_in_sync_group):
3145      return self._replica_id_in_sync_group
3146    return constant_op.constant(
3147        self._replica_id_in_sync_group,
3148        dtypes.int32,
3149        name="replica_id_in_sync_group")
3150
3151  @property
3152  def _replica_id(self):
3153    """This is the local replica id in a given sync group."""
3154    return self._local_replica_id
3155
3156  @property
3157  def strategy(self):
3158    """The current `tf.distribute.Strategy` object."""
3159    return self._strategy
3160
3161  @property
3162  @deprecation.deprecated(None, "Please avoid relying on devices property.")
3163  def devices(self):
3164    """Returns the devices this replica is to be executed on, as a tuple of strings.
3165
3166    NOTE: For `tf.distribute.MirroredStrategy` and
3167    `tf.distribute.experimental.MultiWorkerMirroredStrategy`, this returns a
3168    nested
3169    list of device strings, e.g, [["GPU:0"]].
3170    """
3171    require_replica_context(self)
3172    return (device_util.current(),)
3173
3174  def all_reduce(self, reduce_op, value, options=None):
3175    """All-reduces `value` across all replicas.
3176
3177    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
3178    >>> def step_fn():
3179    ...   ctx = tf.distribute.get_replica_context()
3180    ...   value = tf.identity(1.)
3181    ...   return ctx.all_reduce(tf.distribute.ReduceOp.SUM, value)
3182    >>> strategy.experimental_local_results(strategy.run(step_fn))
3183    (<tf.Tensor: shape=(), dtype=float32, numpy=2.0>,
3184     <tf.Tensor: shape=(), dtype=float32, numpy=2.0>)
3185
3186    It supports batched operations. You can pass a list of values and it
3187    attempts to batch them when possible. You can also specify `options`
3188    to indicate the desired batching behavior, e.g. batch the values into
3189    multiple packs so that they can better overlap with computations.
3190
3191    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
3192    >>> def step_fn():
3193    ...   ctx = tf.distribute.get_replica_context()
3194    ...   value1 = tf.identity(1.)
3195    ...   value2 = tf.identity(2.)
3196    ...   return ctx.all_reduce(tf.distribute.ReduceOp.SUM, [value1, value2])
3197    >>> strategy.experimental_local_results(strategy.run(step_fn))
3198    ([<tf.Tensor: shape=(), dtype=float32, numpy=2.0>,
3199    <tf.Tensor: shape=(), dtype=float32, numpy=4.0>],
3200    [<tf.Tensor: shape=(), dtype=float32, numpy=2.0>,
3201    <tf.Tensor: shape=(), dtype=float32, numpy=4.0>])
3202
3203    Note that all replicas need to participate in the all-reduce, otherwise this
3204    operation hangs. Note that if there're multiple all-reduces, they need to
3205    execute in the same order on all replicas. Dispatching all-reduce based on
3206    conditions is usually error-prone.
3207
3208    Known limitation: if `value` contains `tf.IndexedSlices`, attempting to
3209    compute gradient w.r.t `value` would result in an error.
3210
3211    This API currently can only be called in the replica context. Other
3212    variants to reduce values across replicas are:
3213    * `tf.distribute.StrategyExtended.reduce_to`: the reduce and all-reduce API
3214      in the cross-replica context.
3215    * `tf.distribute.StrategyExtended.batch_reduce_to`: the batched reduce and
3216      all-reduce API in the cross-replica context.
3217    * `tf.distribute.Strategy.reduce`: a more convenient method to reduce
3218      to the host in cross-replica context.
3219
3220    Args:
3221      reduce_op: a `tf.distribute.ReduceOp` value specifying how values should
3222        be combined. Allows using string representation of the enum such as
3223        "SUM", "MEAN".
3224      value: a potentially nested structure of `tf.Tensor` or `tf.IndexedSlices` which
3225        `tf.nest.flatten` accepts. The structure and the shapes of `value` need to be
3226        same on all replicas.
3227      options: a `tf.distribute.experimental.CommunicationOptions`. Options to
3228        perform collective operations. This overrides the default options if the
3229        `tf.distribute.Strategy` takes one in the constructor. See
3230        `tf.distribute.experimental.CommunicationOptions` for details of the
3231        options.
3232
3233    Returns:
3234       A nested structure of `tf.Tensor` with the reduced values. The structure
3235       is the same as `value`.
3236    """
3237    flattened_value = nest.flatten(value)
3238    has_indexed_slices = False
3239
3240    for v in flattened_value:
3241      if isinstance(v, indexed_slices.IndexedSlices):
3242        has_indexed_slices = True
3243
3244    if isinstance(reduce_op, six.string_types):
3245      reduce_op = reduce_util.ReduceOp(reduce_op.upper())
3246    if options is None:
3247      options = collective_util.Options()
3248
3249    def batch_all_reduce(strategy, *value_flat):
3250      return strategy.extended.batch_reduce_to(
3251          reduce_op, [(v, _batch_reduce_destination(v)) for v in value_flat],
3252          options)
3253
3254    # Due to the use of `capture_call_time_value` in collective ops, we have
3255    # to maintain two branches: one w/ merge_call and one w/o. Details can be
3256    # found in b/184009754.
3257    if self._strategy.extended._use_merge_call():  # pylint: disable=protected-access
3258      # TODO(cjfj): Work out why `batch_reduce` doesn't return the correct grad.
3259      if has_indexed_slices:
3260        return nest.pack_sequence_as(
3261            value,
3262            self.merge_call(batch_all_reduce, args=flattened_value))
3263
3264      @custom_gradient.custom_gradient
3265      def grad_wrapper(*xs):
3266        ys = self.merge_call(batch_all_reduce, args=xs)
3267        # The gradient of an all-sum is itself an all-sum (all-mean, likewise).
3268        return ys, lambda *dy_s: self.all_reduce(reduce_op, dy_s)
3269      return nest.pack_sequence_as(value, grad_wrapper(*flattened_value))
3270    else:
3271      if has_indexed_slices:
3272        return nest.pack_sequence_as(
3273            value,
3274            self._strategy.extended._replica_ctx_all_reduce(  # pylint: disable=protected-access
3275                reduce_op, flattened_value, options))
3276
3277      @custom_gradient.custom_gradient
3278      def grad_wrapper(*xs):
3279        ys = self._strategy.extended._replica_ctx_all_reduce(  # pylint: disable=protected-access
3280            reduce_op, xs, options)
3281        # The gradient of an all-sum is itself an all-sum (all-mean, likewise).
3282        return ys, lambda *dy_s: self.all_reduce(reduce_op, dy_s)
3283
3284      return nest.pack_sequence_as(value, grad_wrapper(*flattened_value))
3285
3286  # TODO(josh11b): Implement `start_all_reduce(method, t)` for efficient
3287  # all-reduce. It would return a function returning the result of reducing `t`
3288  # across all replicas. The caller would wait to call this function until they
3289  # needed the reduce result, allowing an efficient implementation:
3290  # * With eager execution, the reduction could be performed asynchronously
3291  #   in the background, not blocking until the result was needed.
3292  # * When constructing a graph, it could batch up all reduction requests up
3293  #   to that point that the first result is needed. Most likely this can be
3294  #   implemented in terms of `merge_call()` and `batch_reduce_to()`.
3295
3296
3297@tf_export("distribute.ReplicaContext", v1=[])
3298class ReplicaContext(ReplicaContextBase):
3299
3300  __doc__ = ReplicaContextBase.__doc__
3301
3302  def all_gather(self, value, axis, options=None):
3303    """All-gathers `value` across all replicas along `axis`.
3304
3305    Note: An `all_gather` method can only be called in replica context. For
3306    a cross-replica context counterpart, see `tf.distribute.Strategy.gather`.
3307    All replicas need to participate in the all-gather, otherwise this
3308    operation hangs. So if `all_gather` is called in any replica, it must be
3309    called in all replicas.
3310
3311    Note: If there are multiple `all_gather` calls, they need to be executed in
3312    the same order on all replicas. Dispatching `all_gather` based on conditions
3313    is usually error-prone.
3314
3315    For all strategies except `tf.distribute.TPUStrategy`, the input
3316    `value` on different replicas must have the same rank, and their shapes must
3317    be the same in all dimensions except the `axis`-th dimension. In other
3318    words, their shapes cannot be different in a dimension `d` where `d` does
3319    not equal to the `axis` argument. For example, given a
3320    `tf.distribute.DistributedValues` with component tensors of shape
3321    `(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can call
3322    `all_gather(..., axis=1, ...)` on it, but not `all_gather(..., axis=0, ...)`
3323    or `all_gather(..., axis=2, ...)`. However, with
3324    `tf.distribute.TPUStrategy`, all tensors must have exactly the same rank and
3325    same shape.
3326
3327    Note: The input `value` must have a non-zero rank. Otherwise, consider using
3328    `tf.expand_dims` before gathering them.
3329
3330    You can pass in a single tensor to all-gather:
3331
3332    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
3333    >>> @tf.function
3334    ... def gather_value():
3335    ...   ctx = tf.distribute.get_replica_context()
3336    ...   local_value = tf.constant([1, 2, 3])
3337    ...   return ctx.all_gather(local_value, axis=0)
3338    >>> result = strategy.run(gather_value)
3339    >>> result
3340    PerReplica:{
3341      0: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>,
3342      1: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>
3343    }
3344    >>> strategy.experimental_local_results(result)
3345    (<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3],
3346    dtype=int32)>,
3347    <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3],
3348    dtype=int32)>)
3349
3350
3351    You can also pass in a nested structure of tensors to all-gather, say, a
3352    list:
3353
3354    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
3355    >>> @tf.function
3356    ... def gather_nest():
3357    ...   ctx = tf.distribute.get_replica_context()
3358    ...   value_1 = tf.constant([1, 2, 3])
3359    ...   value_2 = tf.constant([[1, 2], [3, 4]])
3360    ...   # all_gather a nest of `tf.distribute.DistributedValues`
3361    ...   return ctx.all_gather([value_1, value_2], axis=0)
3362    >>> result = strategy.run(gather_nest)
3363    >>> result
3364    [PerReplica:{
3365      0: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>,
3366      1: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>
3367    }, PerReplica:{
3368      0: <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
3369    array([[1, 2],
3370           [3, 4],
3371           [1, 2],
3372           [3, 4]], dtype=int32)>,
3373      1: <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
3374    array([[1, 2],
3375           [3, 4],
3376           [1, 2],
3377           [3, 4]], dtype=int32)>
3378    }]
3379    >>> strategy.experimental_local_results(result)
3380    ([<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>,
3381    <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
3382    array([[1, 2],
3383           [3, 4],
3384           [1, 2],
3385           [3, 4]], dtype=int32)>],
3386           [<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>,
3387           <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
3388    array([[1, 2],
3389           [3, 4],
3390           [1, 2],
3391           [3, 4]], dtype=int32)>])
3392
3393
3394    What if you are all-gathering tensors with different shapes on different
3395    replicas? Consider the following example with two replicas, where you have
3396    `value` as a nested structure consisting of two items to all-gather, `a` and
3397    `b`.
3398
3399    * On Replica 0, `value` is `{'a': [0], 'b': [[0, 1]]}`.
3400    * On Replica 1, `value` is `{'a': [1], 'b': [[2, 3], [4, 5]]}`.
3401    * Result for `all_gather` with `axis=0` (on each of the replicas) is:
3402
3403      ```
3404      {'a': [1, 2], 'b': [[0, 1], [2, 3], [4, 5]]}
3405      ```
3406
3407    Args:
3408      value: a nested structure of `tf.Tensor` which `tf.nest.flatten` accepts,
3409        or a `tf.distribute.DistributedValues` instance. The structure of the
3410        `tf.Tensor` need to be same on all replicas. The underlying tensor
3411        constructs can only be dense tensors with non-zero rank, NOT
3412        `tf.IndexedSlices`.
3413      axis: 0-D int32 Tensor. Dimension along which to gather.
3414      options: a `tf.distribute.experimental.CommunicationOptions`. Options to
3415        perform collective operations. This overrides the default options if the
3416        `tf.distribute.Strategy` takes one in the constructor. See
3417        `tf.distribute.experimental.CommunicationOptions` for details of the
3418        options.
3419
3420    Returns:
3421       A nested structure of `tf.Tensor` with the gathered values. The structure
3422       is the same as `value`.
3423    """
3424    for v in nest.flatten(value):
3425      if isinstance(v, indexed_slices.IndexedSlices):
3426        raise NotImplementedError("all_gather does not support IndexedSlices")
3427
3428    if options is None:
3429      options = collective_util.Options()
3430
3431    def batch_all_gather(strategy, *value_flat):
3432      return strategy.extended._batch_gather_to(  # pylint: disable=protected-access
3433          [(v, _batch_reduce_destination(v)) for v in value_flat], axis,
3434          options)
3435
3436    @custom_gradient.custom_gradient
3437    def grad_wrapper(*xs):
3438      ys = self.merge_call(batch_all_gather, args=xs)
3439
3440      def grad(*dy_s):
3441        grads = self.all_reduce(reduce_util.ReduceOp.SUM, dy_s)
3442        new_grads = []
3443        for i, grad in enumerate(grads):
3444          input_shape = array_ops.shape(xs[i])
3445          axis_dim = array_ops.reshape(input_shape[axis], [1])
3446          with ops.control_dependencies([array_ops.identity(grads)]):
3447            d = self.all_gather(axis_dim, axis=0)
3448            begin_dim = math_ops.reduce_sum(d[:self.replica_id_in_sync_group])
3449            end_dim = begin_dim + array_ops.shape(xs[i])[axis]
3450            new_grad = array_ops.gather(
3451                grad, axis=axis, indices=math_ops.range(begin_dim, end_dim))
3452            new_grads.append(new_grad)
3453        return new_grads
3454
3455      return ys, grad
3456
3457    return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value)))
3458
3459  def _update(self, var, fn, args=(), kwargs=None, group=True):
3460    """Run `fn` to update `var` with `args` and `kwargs` in replica context.
3461
3462    `tf.distribute.ReplicaContext.update` takes a (distributed) variable `var`
3463    to be updated, an update function `fn`, and `args` and `kwargs` for `fn`.
3464    `fn` applies to each component variable of `var` with corresponding input
3465    values from `args` and `kwargs`.
3466
3467    Example usage:
3468
3469    >>> strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) # 2 replicas
3470    >>> with strategy.scope():
3471    ...   distributed_variable = tf.Variable(5.0)
3472    >>> distributed_variable
3473    MirroredVariable:{
3474      0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>,
3475      1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=5.0>
3476    }
3477    >>> def replica_fn(v):
3478    ...   value = tf.identity(1.0)
3479    ...   replica_context = tf.distribute.get_replica_context()
3480    ...   update_fn = lambda var, value: var.assign(value)
3481    ...   replica_context._update(v, update_fn, args=(value,))
3482    >>> strategy.run(replica_fn, args=(distributed_variable,))
3483    >>> distributed_variable
3484    MirroredVariable:{
3485      0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>,
3486      1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=1.0>
3487    }
3488
3489    This API must be called in a replica context.
3490
3491    Note that if `var` is a MirroredVariable (i.e., the type of variable created
3492    under the scope of a synchronous strategy, and is synchronized on-write, see
3493    `tf.VariableSynchronization` for more information) and `args`/`kwargs`
3494    contains different values for different replicas, `var` will be dangerously
3495    out of synchronization. Thus we recommend using `variable.assign(value)` as
3496    long as you can, which under the hood aggregates the updates and guarantees
3497    the synchronization. The case where you actually want this API instead of
3498    `variable.assign(value)` is that before assigning `value` to the `variable`,
3499    you'd like to conduct some pre-`assign` computation colocated with the
3500    variable devices (i.e. where variables reside, for MirroredStrategy they are
3501    the same as the compute device, for ParameterServerStrategy they refer to
3502    parameter servers). E.g.,
3503
3504    ```python
3505    strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) # 2 replicas
3506    with strategy.scope():
3507      v = tf.Variable(5.0, aggregation=tf.VariableAggregation.SUM)
3508    def replica_fn(inputs):
3509      value = computation(inputs)
3510      replica_context = tf.distribute.get_replica_context()
3511      reduced_value = replica_context.all_reduce(value)
3512
3513      def update_fn(var, value):
3514        # this computation will colocate with `var`'s device
3515        updated_value = post_reduce_pre_update_computation(value)
3516        var.assign(value)
3517
3518      replica_context._update(v, update_fn, args=(reduced_value,))
3519
3520    strategy.run(replica_fn, args=(inputs,))
3521    ```
3522
3523    This code snippet is consistent across all strategies. If you directly
3524    compute and use `assign` in the replica context instead of wrapping it with
3525    `update`, for strategies with fewer variable devices than compute devices
3526    (e.g., parameter server strategy, usually), the
3527    `post_reduce_pre_update_computation` will happen
3528    N==number_of_compute_devices times which is less performant.
3529
3530
3531    Args:
3532      var: Variable, possibly distributed to multiple devices, to operate on.
3533      fn: Function to call. Should take the variable as the first argument.
3534      args: Tuple or list. Additional positional arguments to pass to `fn()`.
3535      kwargs: Dict with keyword arguments to pass to `fn()`.
3536      group: Boolean. Defaults to True. Most strategies enter a merge_call to
3537      conduct update in cross-replica context, and group=True guarantees updates
3538      on all replicas is executed.
3539
3540    Returns:
3541      The return value of `fn` for the local replica.
3542    """
3543    if kwargs is None:
3544      kwargs = {}
3545    return self._strategy.extended._replica_ctx_update(var, fn, args=args, kwargs=kwargs, group=group)  # pylint: disable=protected-access
3546
3547
3548@tf_export(v1=["distribute.ReplicaContext"])
3549class ReplicaContextV1(ReplicaContextBase):
3550  __doc__ = ReplicaContextBase.__doc__
3551
3552
3553def _batch_reduce_destination(x):
3554  """Returns the destinations for batch all-reduce."""
3555  if isinstance(x, ops.Tensor):
3556    # If this is a one device strategy.
3557    return x.device
3558  else:
3559    return x
3560# ------------------------------------------------------------------------------
3561
3562
3563_creating_default_strategy_singleton = False
3564
3565
3566class _DefaultDistributionStrategyV1(StrategyV1):
3567  """Default `tf.distribute.Strategy` if none is explicitly selected."""
3568
3569  def __init__(self):
3570    if not _creating_default_strategy_singleton:
3571      raise RuntimeError("Should only create a single instance of "
3572                         "_DefaultDistributionStrategy")
3573    super(_DefaultDistributionStrategyV1,
3574          self).__init__(_DefaultDistributionExtended(self))
3575
3576  def __deepcopy__(self, memo):
3577    del memo
3578    raise RuntimeError("Should only create a single instance of "
3579                       "_DefaultDistributionStrategy")
3580
3581
3582class _DefaultDistributionStrategy(Strategy):
3583  """Default `tf.distribute.Strategy` if none is explicitly selected."""
3584
3585  def __init__(self):
3586    if not _creating_default_strategy_singleton:
3587      raise RuntimeError("Should only create a single instance of "
3588                         "_DefaultDistributionStrategy")
3589    super(_DefaultDistributionStrategy, self).__init__(
3590        _DefaultDistributionExtended(self))
3591
3592  def __deepcopy__(self, memo):
3593    del memo
3594    raise RuntimeError("Should only create a single instance of "
3595                       "_DefaultDistributionStrategy")
3596
3597
3598class _DefaultDistributionContext(object):
3599  """Context manager setting the default `tf.distribute.Strategy`."""
3600
3601  __slots__ = ["_var_creator_scope", "_strategy", "_nested_count"]
3602
3603  def __init__(self, strategy):
3604
3605    def creator(next_creator, **kwargs):
3606      _require_strategy_scope_strategy(strategy)
3607      return next_creator(**kwargs)
3608
3609    self._var_creator_scope = variable_scope.variable_creator_scope(creator)
3610    self._strategy = strategy
3611    self._nested_count = 0
3612
3613  def __enter__(self):
3614    # Allow this scope to be entered if this strategy is already in scope.
3615    if distribution_strategy_context.has_strategy():
3616      raise RuntimeError("Must not nest tf.distribute.Strategy scopes.")
3617    if self._nested_count == 0:
3618      self._var_creator_scope.__enter__()
3619    self._nested_count += 1
3620    return self._strategy
3621
3622  def __exit__(self, exception_type, exception_value, traceback):
3623    self._nested_count -= 1
3624    if self._nested_count == 0:
3625      try:
3626        self._var_creator_scope.__exit__(
3627            exception_type, exception_value, traceback)
3628      except RuntimeError as e:
3629        six.raise_from(
3630            RuntimeError("Variable creator scope nesting error: move call to "
3631                         "tf.distribute.set_strategy() out of `with` scope."),
3632            e)
3633
3634
3635class _DefaultDistributionExtended(StrategyExtendedV1):
3636  """Implementation of _DefaultDistributionStrategy."""
3637
3638  def __init__(self, container_strategy):
3639    super(_DefaultDistributionExtended, self).__init__(container_strategy)
3640    self._retrace_functions_for_each_device = False
3641
3642  def _scope(self, strategy):
3643    """Context manager setting a variable creator and `self` as current."""
3644    return _DefaultDistributionContext(strategy)
3645
3646  def colocate_vars_with(self, colocate_with_variable):
3647    """Does not require `self.scope`."""
3648    _require_strategy_scope_extended(self)
3649    return ops.colocate_with(colocate_with_variable)
3650
3651  def variable_created_in_scope(self, v):
3652    return v._distribute_strategy is None  # pylint: disable=protected-access
3653
3654  def _experimental_distribute_dataset(self, dataset, options):
3655    return dataset
3656
3657  def _distribute_datasets_from_function(self, dataset_fn, options):
3658    return dataset_fn(InputContext())
3659
3660  def _experimental_distribute_values_from_function(self, value_fn):
3661    return value_fn(ValueContext())
3662
3663  def _make_dataset_iterator(self, dataset):
3664    return _DefaultDistributionExtended.DefaultInputIterator(dataset)
3665
3666  def _make_input_fn_iterator(self,
3667                              input_fn,
3668                              replication_mode=InputReplicationMode.PER_WORKER):
3669    dataset = input_fn(InputContext())
3670    return _DefaultDistributionExtended.DefaultInputIterator(dataset)
3671
3672  def _experimental_make_numpy_dataset(self, numpy_input, session):
3673    numpy_flat = nest.flatten(numpy_input)
3674    vars_flat = tuple(
3675        variable_scope.variable(array_ops.zeros(i.shape, i.dtype),
3676                                trainable=False, use_resource=True)
3677        for i in numpy_flat
3678    )
3679    for v, i in zip(vars_flat, numpy_flat):
3680      numpy_dataset.init_var_from_numpy(v, i, session)
3681    vars_nested = nest.pack_sequence_as(numpy_input, vars_flat)
3682    return dataset_ops.Dataset.from_tensor_slices(vars_nested)
3683
3684  def _broadcast_to(self, tensor, destinations):
3685    if destinations is None:
3686      return tensor
3687    else:
3688      raise NotImplementedError("TODO")
3689
3690  def _call_for_each_replica(self, fn, args, kwargs):
3691    with ReplicaContext(self._container_strategy(), replica_id_in_sync_group=0):
3692      return fn(*args, **kwargs)
3693
3694  def _reduce_to(self, reduce_op, value, destinations, options):
3695    # TODO(josh11b): Use destinations?
3696    del reduce_op, destinations, options
3697    return value
3698
3699  def _gather_to_implementation(self, value, destinations, axis, options):
3700    del destinations, axis, options
3701    return value
3702
3703  def _update(self, var, fn, args, kwargs, group):
3704    # The implementations of _update() and _update_non_slot() are identical
3705    # except _update() passes `var` as the first argument to `fn()`.
3706    return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group)
3707
3708  def _update_non_slot(self, colocate_with, fn, args, kwargs, should_group):
3709    # TODO(josh11b): Figure out what we should be passing to UpdateContext()
3710    # once that value is used for something.
3711    with UpdateContext(colocate_with):
3712      result = fn(*args, **kwargs)
3713      if should_group:
3714        return result
3715      else:
3716        return nest.map_structure(self._local_results, result)
3717
3718  def read_var(self, replica_local_var):
3719    return array_ops.identity(replica_local_var)
3720
3721  def _local_results(self, distributed_value):
3722    return (distributed_value,)
3723
3724  def value_container(self, value):
3725    return value
3726
3727  @property
3728  def _num_replicas_in_sync(self):
3729    return 1
3730
3731  @property
3732  def worker_devices(self):
3733    raise RuntimeError("worker_devices() method unsupported by default "
3734                       "tf.distribute.Strategy.")
3735
3736  @property
3737  def parameter_devices(self):
3738    raise RuntimeError("parameter_devices() method unsupported by default "
3739                       "tf.distribute.Strategy.")
3740
3741  def non_slot_devices(self, var_list):
3742    return min(var_list, key=lambda x: x.name)
3743
3744  def _in_multi_worker_mode(self):
3745    """Whether this strategy indicates working in multi-worker settings."""
3746    # Default strategy doesn't indicate multi-worker training.
3747    return False
3748
3749  @property
3750  def should_checkpoint(self):
3751    return True
3752
3753  @property
3754  def should_save_summary(self):
3755    return True
3756
3757  def _get_local_replica_id(self, replica_id_in_sync_group):
3758    return replica_id_in_sync_group
3759
3760  def _get_replica_id_in_sync_group(self, replica_id):
3761    return replica_id
3762
3763  # TODO(priyag): This should inherit from `InputIterator`, once dependency
3764  # issues have been resolved.
3765  class DefaultInputIterator(object):
3766    """Default implementation of `InputIterator` for default strategy."""
3767
3768    def __init__(self, dataset):
3769      self._dataset = dataset
3770      if eager_context.executing_eagerly():
3771        self._iterator = dataset_ops.make_one_shot_iterator(dataset)
3772      else:
3773        self._iterator = dataset_ops.make_initializable_iterator(dataset)
3774
3775    def get_next(self):
3776      return self._iterator.get_next()
3777
3778    def get_next_as_optional(self):
3779      return self._iterator.get_next_as_optional()
3780
3781    @deprecated(None, "Use the iterator's `initializer` property instead.")
3782    def initialize(self):
3783      """Initialize underlying iterators.
3784
3785      Returns:
3786        A list of any initializer ops that should be run.
3787      """
3788      if eager_context.executing_eagerly():
3789        self._iterator = self._dataset.make_one_shot_iterator()
3790        return []
3791      else:
3792        return [self._iterator.initializer]
3793
3794    @property
3795    def initializer(self):
3796      """Returns a list of ops that initialize the iterator."""
3797      return self.initialize()
3798
3799  # TODO(priyag): Delete this once all strategies use global batch size.
3800  @property
3801  def _global_batch_size(self):
3802    """Global and per-replica batching are equivalent for this strategy."""
3803    return True
3804
3805
3806class _DefaultReplicaContext(ReplicaContext):
3807  """ReplicaContext for _DefaultDistributionStrategy."""
3808
3809  @property
3810  def replica_id_in_sync_group(self):
3811    # Return 0 instead of a constant tensor to avoid creating a new node for
3812    # users who don't use distribution strategy.
3813    return 0
3814
3815
3816# ------------------------------------------------------------------------------
3817# We haven't yet implemented deserialization for DistributedVariables.
3818# So here we catch any attempts to deserialize variables
3819# when using distribution strategies.
3820# pylint: disable=protected-access
3821_original_from_proto = resource_variable_ops._from_proto_fn
3822
3823
3824def _from_proto_fn(v, import_scope=None):
3825  if distribution_strategy_context.has_strategy():
3826    raise NotImplementedError(
3827        "Deserialization of variables is not yet supported when using a "
3828        "tf.distribute.Strategy.")
3829  else:
3830    return _original_from_proto(v, import_scope=import_scope)
3831
3832resource_variable_ops._from_proto_fn = _from_proto_fn
3833# pylint: enable=protected-access
3834
3835
3836#-------------------------------------------------------------------------------
3837# Shorthand for some methods from distribution_strategy_context.
3838_push_per_thread_mode = distribution_strategy_context._push_per_thread_mode  # pylint: disable=protected-access
3839_get_per_thread_mode = distribution_strategy_context._get_per_thread_mode  # pylint: disable=protected-access
3840_pop_per_thread_mode = distribution_strategy_context._pop_per_thread_mode  # pylint: disable=protected-access
3841_get_default_replica_mode = (
3842    distribution_strategy_context._get_default_replica_mode)  # pylint: disable=protected-access
3843
3844
3845# ------------------------------------------------------------------------------
3846# Metrics to track which distribution strategy is being called
3847distribution_strategy_gauge = monitoring.StringGauge(
3848    "/tensorflow/api/distribution_strategy",
3849    "Gauge to track the type of distribution strategy used.", "TFVersion")
3850distribution_strategy_replica_gauge = monitoring.IntGauge(
3851    "/tensorflow/api/distribution_strategy/replica",
3852    "Gauge to track the number of replica each distribution strategy used.",
3853    "CountType")
3854distribution_strategy_input_api_counter = monitoring.Counter(
3855    "/tensorflow/api/distribution_strategy/input_api",
3856    "Counter to track the usage of the input APIs", "strategy", "api")
3857