• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Library for running a computation across multiple devices.
16
17See the guide for overview and examples:
18[TensorFlow v2.x](https://www.tensorflow.org/guide/distributed_training),
19[TensorFlow v1.x](https://github.com/tensorflow/docs/blob/master/site/en/r1/guide/distribute_strategy.ipynb).  # pylint: disable=line-too-long
20
21The intent of this library is that you can write an algorithm in a stylized way
22and it will be usable with a variety of different `tf.distribute.Strategy`
23implementations. Each descendant will implement a different strategy for
24distributing the algorithm across multiple devices/machines.  Furthermore, these
25changes can be hidden inside the specific layers and other library classes that
26need special treatment to run in a distributed setting, so that most users'
27model definition code can run unchanged. The `tf.distribute.Strategy` API works
28the same way with eager and graph execution.
29
30*Glossary*
31
32* _Data parallelism_ is where we run multiple copies of the model
33  on different slices of the input data. This is in contrast to
34  _model parallelism_ where we divide up a single copy of a model
35  across multiple devices.
36  Note: we only support data parallelism for now, but
37  hope to add support for model parallelism in the future.
38* A _device_ is a CPU or accelerator (e.g. GPUs, TPUs) on some machine that
39  TensorFlow can run operations on (see e.g. `tf.device`). You may have multiple
40  devices on a single machine, or be connected to devices on multiple
41  machines. Devices used to run computations are called _worker devices_.
42  Devices used to store variables are _parameter devices_. For some strategies,
43  such as `tf.distribute.MirroredStrategy`, the worker and parameter devices
44  will be the same (see mirrored variables below). For others they will be
45  different.  For example, `tf.distribute.experimental.CentralStorageStrategy`
46  puts the variables on a single device (which may be a worker device or may be
47  the CPU), and `tf.distribute.experimental.ParameterServerStrategy` puts the
48  variables on separate machines called parameter servers (see below).
49* A _replica_ is one copy of the model, running on one slice of the
50  input data. Right now each replica is executed on its own
51  worker device, but once we add support for model parallelism
52  a replica may span multiple worker devices.
53* A _host_ is the CPU device on a machine with worker devices, typically
54  used for running input pipelines.
55* A _worker_ is defined to be the physical machine(s) containing the physical
56  devices (e.g. GPUs, TPUs) on which the replicated computation is executed. A
57  worker may contain one or more replicas, but contains at least one
58  replica. Typically one worker will correspond to one machine, but in the case
59  of very large models with model parallelism, one worker may span multiple
60  machines. We typically run one input pipeline per worker, feeding all the
61  replicas on that worker.
62* _Synchronous_, or more commonly _sync_, training is where the updates from
63  each replica are aggregated together before updating the model variables. This
64  is in contrast to _asynchronous_, or _async_ training, where each replica
65  updates the model variables independently. You may also have replicas
66  partitioned into groups which are in sync within each group but async between
67  groups.
68* _Parameter servers_: These are machines that hold a single copy of
69  parameters/variables, used by some strategies (right now just
70  `tf.distribute.experimental.ParameterServerStrategy`). All replicas that want
71  to operate on a variable retrieve it at the beginning of a step and send an
72  update to be applied at the end of the step. These can in priniciple support
73  either sync or async training, but right now we only have support for async
74  training with parameter servers. Compare to
75  `tf.distribute.experimental.CentralStorageStrategy`, which puts all variables
76  on a single device on the same machine (and does sync training), and
77  `tf.distribute.MirroredStrategy`, which mirrors variables to multiple devices
78  (see below).
79* _Mirrored variables_: These are variables that are copied to multiple
80  devices, where we keep the copies in sync by applying the same
81  updates to every copy. Normally would only be used with sync training.
82* Reductions and all-reduce: A _reduction_ is some method of aggregating
83  multiple values into one value, like "sum" or "mean". If a strategy is doing
84  sync training, we will perform a reduction on the gradients to a parameter
85  from all replicas before applying the update. _All-reduce_ is an algorithm for
86  performing a reduction on values from multiple devices and making the result
87  available on all of those devices.
88
89Note that we provide a default version of `tf.distribute.Strategy` that is
90used when no other strategy is in scope, that provides the same API with
91reasonable default behavior.
92"""
93
94from __future__ import absolute_import
95from __future__ import division
96from __future__ import print_function
97
98import copy
99import enum  # pylint: disable=g-bad-import-order
100import threading
101import weakref
102
103import six
104
105from tensorflow.python.autograph.core import ag_ctx as autograph_ctx
106from tensorflow.python.autograph.impl import api as autograph
107from tensorflow.python.data.ops import dataset_ops
108from tensorflow.python.distribute import device_util
109from tensorflow.python.distribute import distribution_strategy_context
110from tensorflow.python.distribute import numpy_dataset
111from tensorflow.python.distribute import reduce_util
112from tensorflow.python.eager import context as eager_context
113from tensorflow.python.eager import monitoring
114from tensorflow.python.framework import constant_op
115from tensorflow.python.framework import dtypes
116from tensorflow.python.framework import ops
117from tensorflow.python.framework import tensor_shape
118from tensorflow.python.ops import array_ops
119from tensorflow.python.ops import control_flow_ops
120from tensorflow.python.ops import custom_gradient
121from tensorflow.python.ops import math_ops
122from tensorflow.python.ops import resource_variable_ops
123from tensorflow.python.ops import summary_ops_v2
124from tensorflow.python.ops import variable_scope
125from tensorflow.python.ops.losses import loss_reduction
126from tensorflow.python.ops.losses import losses_impl
127from tensorflow.python.platform import tf_logging
128from tensorflow.python.training.tracking import base as trackable
129from tensorflow.python.util import nest
130from tensorflow.python.util import tf_contextlib
131from tensorflow.python.util.deprecation import deprecated
132from tensorflow.python.util.tf_export import tf_export
133from tensorflow.tools.docs import doc_controls
134
135
136# ------------------------------------------------------------------------------
137# Context tracking whether in a strategy.update() or .update_non_slot() call.
138
139
140_update_replica_id = threading.local()
141
142
143def get_update_replica_id():
144  """Get the current device if in a `tf.distribute.Strategy.update()` call."""
145  try:
146    return _update_replica_id.current
147  except AttributeError:
148    return None
149
150
151class UpdateContext(object):
152  """Context manager when you are in `update()` or `update_non_slot()`."""
153
154  def __init__(self, replica_id):
155    self._replica_id = replica_id
156    self._old_replica_id = None
157
158  def __enter__(self):
159    self._old_replica_id = get_update_replica_id()
160    _update_replica_id.current = self._replica_id
161
162  def __exit__(self, exception_type, exception_value, traceback):
163    del exception_type, exception_value, traceback
164    _update_replica_id.current = self._old_replica_id
165
166
167# ------------------------------------------------------------------------------
168# Public utility functions.
169
170
171@tf_export(v1=["distribute.get_loss_reduction"])
172def get_loss_reduction():
173  """`tf.distribute.ReduceOp` corresponding to the last loss reduction.
174
175  This is used to decide whether loss should be scaled in optimizer (used only
176  for estimator + v1 optimizer use case).
177
178  Returns:
179    `tf.distribute.ReduceOp` corresponding to the last loss reduction for
180    estimator and v1 optimizer use case. `tf.distribute.ReduceOp.SUM` otherwise.
181  """
182  if not distribution_strategy_context.get_strategy()._scale_loss_for_estimator:  # pylint: disable=protected-access
183    # If we are not in Estimator context then return 'SUM'. We do not need to
184    # scale loss in the optimizer.
185    return reduce_util.ReduceOp.SUM
186  last_reduction = ops.get_default_graph()._last_loss_reduction  # pylint: disable=protected-access
187  if (last_reduction == losses_impl.Reduction.SUM or
188      last_reduction == loss_reduction.ReductionV2.SUM):
189    return reduce_util.ReduceOp.SUM
190  return reduce_util.ReduceOp.MEAN
191
192
193# ------------------------------------------------------------------------------
194# Internal API for validating the current thread mode
195
196
197def _require_cross_replica_or_default_context_extended(extended):
198  """Verify in cross-replica context."""
199  context = _get_per_thread_mode()
200  cross_replica = context.cross_replica_context
201  if cross_replica is not None and cross_replica.extended is extended:
202    return
203  if context is _get_default_replica_mode():
204    return
205  strategy = extended._container_strategy()  # pylint: disable=protected-access
206  # We have an error to report, figure out the right message.
207  if context.strategy is not strategy:
208    _wrong_strategy_scope(strategy, context)
209  assert cross_replica is None
210  raise RuntimeError("Method requires being in cross-replica context, use "
211                     "get_replica_context().merge_call()")
212
213
214def _wrong_strategy_scope(strategy, context):
215  # Figure out the right error message.
216  if not distribution_strategy_context.has_strategy():
217    raise RuntimeError(
218        'Need to be inside "with strategy.scope()" for %s' %
219        (strategy,))
220  else:
221    raise RuntimeError(
222        "Mixing different tf.distribute.Strategy objects: %s is not %s" %
223        (context.strategy, strategy))
224
225
226def require_replica_context(replica_ctx):
227  """Verify in `replica_ctx` replica context."""
228  context = _get_per_thread_mode()
229  if context.replica_context is replica_ctx: return
230  # We have an error to report, figure out the right message.
231  if context.replica_context is None:
232    raise RuntimeError("Need to be inside `call_for_each_replica()`")
233  if context.strategy is replica_ctx.strategy:
234    # Two different ReplicaContexts with the same tf.distribute.Strategy.
235    raise RuntimeError("Mismatching ReplicaContext.")
236  raise RuntimeError(
237      "Mismatching tf.distribute.Strategy objects: %s is not %s." %
238      (context.strategy, replica_ctx.strategy))
239
240
241def _require_strategy_scope_strategy(strategy):
242  """Verify in a `strategy.scope()` in this thread."""
243  context = _get_per_thread_mode()
244  if context.strategy is strategy: return
245  _wrong_strategy_scope(strategy, context)
246
247
248def _require_strategy_scope_extended(extended):
249  """Verify in a `distribution_strategy.scope()` in this thread."""
250  context = _get_per_thread_mode()
251  if context.strategy.extended is extended: return
252  # Report error.
253  strategy = extended._container_strategy()  # pylint: disable=protected-access
254  _wrong_strategy_scope(strategy, context)
255
256
257# ------------------------------------------------------------------------------
258# Internal context managers used to implement the DistributionStrategy
259# base class
260
261
262class _CurrentDistributionContext(object):
263  """Context manager setting the current `tf.distribute.Strategy`.
264
265  Also: overrides the variable creator and optionally the current device.
266  """
267
268  def __init__(self,
269               strategy,
270               var_creator_scope,
271               var_scope=None,
272               default_device=None):
273    self._context = distribution_strategy_context._CrossReplicaThreadMode(  # pylint: disable=protected-access
274        strategy)
275    self._var_creator_scope = var_creator_scope
276    self._var_scope = var_scope
277    if default_device:
278      self._device_scope = ops.device(default_device)
279    else:
280      self._device_scope = None
281    self._same_scope_again_count = 0
282
283  def __enter__(self):
284    # Allow this scope to be entered if this strategy is already in scope.
285    if distribution_strategy_context.has_strategy():
286      _require_cross_replica_or_default_context_extended(
287          self._context.strategy.extended)
288      self._same_scope_again_count += 1
289    else:
290      _push_per_thread_mode(self._context)
291      if self._var_scope:
292        self._var_scope.__enter__()
293      self._var_creator_scope.__enter__()
294      if self._device_scope:
295        self._device_scope.__enter__()
296    return self._context.strategy
297
298  def __exit__(self, exception_type, exception_value, traceback):
299    if self._same_scope_again_count > 0:
300      self._same_scope_again_count -= 1
301      return
302    if self._device_scope:
303      try:
304        self._device_scope.__exit__(exception_type, exception_value, traceback)
305      except RuntimeError as e:
306        six.raise_from(
307            RuntimeError("Device scope nesting error: move call to "
308                         "tf.distribute.set_strategy() out of `with` scope."),
309            e)
310
311    try:
312      self._var_creator_scope.__exit__(
313          exception_type, exception_value, traceback)
314    except RuntimeError as e:
315      six.raise_from(
316          RuntimeError("Variable creator scope nesting error: move call to "
317                       "tf.distribute.set_strategy() out of `with` scope."),
318          e)
319
320    if self._var_scope:
321      try:
322        self._var_scope.__exit__(exception_type, exception_value, traceback)
323      except RuntimeError as e:
324        six.raise_from(
325            RuntimeError("Variable scope nesting error: move call to "
326                         "tf.distribute.set_strategy() out of `with` scope."),
327            e)
328    _pop_per_thread_mode()
329
330
331# TODO(yuefengz): add more replication modes.
332@tf_export("distribute.InputReplicationMode")
333class InputReplicationMode(enum.Enum):
334  """Replication mode for input function.
335
336  * `PER_WORKER`: The input function will be called on each worker
337    independently, creating as many input pipelines as number of workers.
338    Replicas will dequeue from the local Dataset on their worker.
339    `tf.distribute.Strategy` doesn't manage any state sharing between such
340    separate input pipelines.
341  """
342  PER_WORKER = "PER_WORKER"
343
344
345@tf_export("distribute.InputContext")
346class InputContext(object):
347  """A class wrapping information needed by an input function.
348
349  This is a context class that is passed to the user's input function and
350  contains information about the compute replicas and input pipelines. The
351  number of compute replicas (in sync training) helps compute the local batch
352  size from the desired global batch size for each replica. The input pipeline
353  information can be used to return a different subset of the input in each
354  replica (for e.g. shard the input pipeline, use a different input
355  source etc).
356  """
357
358  def __init__(self,
359               num_input_pipelines=1,
360               input_pipeline_id=0,
361               num_replicas_in_sync=1):
362    """Initializes an InputContext object.
363
364    Args:
365      num_input_pipelines: the number of input pipelines in a cluster.
366      input_pipeline_id: the current input pipeline id, should be an int in
367        [0,`num_input_pipelines`).
368      num_replicas_in_sync: the number of replicas that are in sync.
369    """
370    self._num_input_pipelines = num_input_pipelines
371    self._input_pipeline_id = input_pipeline_id
372    self._num_replicas_in_sync = num_replicas_in_sync
373
374  @property
375  def num_replicas_in_sync(self):
376    """Returns the number of compute replicas in sync."""
377    return self._num_replicas_in_sync
378
379  @property
380  def input_pipeline_id(self):
381    """Returns the input pipeline ID."""
382    return self._input_pipeline_id
383
384  @property
385  def num_input_pipelines(self):
386    """Returns the number of input pipelines."""
387    return self._num_input_pipelines
388
389  def get_per_replica_batch_size(self, global_batch_size):
390    """Returns the per-replica batch size.
391
392    Args:
393      global_batch_size: the global batch size which should be divisible by
394        `num_replicas_in_sync`.
395
396    Returns:
397      the per-replica batch size.
398
399    Raises:
400      ValueError: if `global_batch_size` not divisible by
401        `num_replicas_in_sync`.
402    """
403    if global_batch_size % self._num_replicas_in_sync != 0:
404      raise ValueError("The `global_batch_size` %r is not divisible by "
405                       "`num_replicas_in_sync` %r " %
406                       (global_batch_size, self._num_replicas_in_sync))
407    return global_batch_size // self._num_replicas_in_sync
408
409  def __str__(self):
410    return "tf.distribute.InputContext(input pipeline id {}, total: {})".format(
411        self.input_pipeline_id, self.num_input_pipelines)
412
413
414# ------------------------------------------------------------------------------
415# Base classes for all distribution strategies.
416
417
418# pylint: disable=line-too-long
419@tf_export("distribute.Strategy", v1=[])
420class Strategy(object):
421  """A state & compute distribution policy on a list of devices.
422
423  See [the guide](https://www.tensorflow.org/guide/distributed_training)
424  for overview and examples.
425
426  In short:
427
428  * To use it with Keras `compile`/`fit`,
429    [please
430    read](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_keras).
431  * You may pass descendant of `tf.distribute.Strategy` to
432    `tf.estimator.RunConfig` to specify how a `tf.estimator.Estimator`
433    should distribute its computation. See
434    [guide](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_estimator_limited_support).
435  * Otherwise, use `tf.distribute.Strategy.scope` to specify that a
436    strategy should be used when building an executing your model.
437    (This puts you in the "cross-replica context" for this strategy, which
438    means the strategy is put in control of things like variable placement.)
439  * If you are writing a custom training loop, you will need to call a few more
440    methods,
441    [see the
442    guide](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_custom_training_loops):
443
444      * Start by either creating a `tf.data.Dataset` normally or using
445        `tf.distribute.experimental_make_numpy_dataset` to make a dataset out of
446        a `numpy` array.
447      * Use `tf.distribute.Strategy.experimental_distribute_dataset` to convert
448        a `tf.data.Dataset` to something that produces "per-replica" values.
449        If you want to manually specify how the dataset should be partitioned
450        across replicas, use
451        `tf.distribute.Strategy.experimental_distribute_datasets_from_function`
452        instead.
453      * Use `tf.distribute.Strategy.experimental_run_v2` to run a function
454        once per replica, taking values that may be "per-replica" (e.g.
455        from a distributed dataset) and returning "per-replica" values.
456        This function is executed in "replica context", which means each
457        operation is performed separately on each replica.
458      * Finally use a method (such as `tf.distribute.Strategy.reduce`) to
459        convert the resulting "per-replica" values into ordinary `Tensor`s.
460
461  A custom training loop can be as simple as:
462
463  ```
464  with my_strategy.scope():
465    @tf.function
466    def distribute_train_epoch(dataset):
467      def replica_fn(input):
468        # process input and return result
469        return result
470
471      total_result = 0
472      for x in dataset:
473        per_replica_result = my_strategy.experimental_run_v2(replica_fn,
474                                                             args=(x,))
475        total_result += my_strategy.reduce(tf.distribute.ReduceOp.SUM,
476                                           per_replica_result, axis=None)
477      return total_result
478
479    dist_dataset = my_strategy.experimental_distribute_dataset(dataset)
480    for _ in range(EPOCHS):
481      train_result = distribute_train_epoch(dist_dataset)
482  ```
483
484  This takes an ordinary `dataset` and `replica_fn` and runs it
485  distributed using a particular `tf.distribute.Strategy` named
486  `my_strategy` above. Any variables created in `replica_fn` are created
487  using `my_strategy`'s policy, and library functions called by
488  `replica_fn` can use the `get_replica_context()` API to implement
489  distributed-specific behavior.
490
491  You can use the `reduce` API to aggregate results across replicas and use
492  this as a return value from one iteration over the distributed dataset. Or
493  you can use `tf.keras.metrics` (such as loss, accuracy, etc.) to
494  accumulate metrics across steps in a given epoch.
495
496  See the
497  [custom training loop
498  tutorial](https://www.tensorflow.org/tutorials/distribute/custom_training)
499  for a more detailed example.
500
501  Note: `tf.distribute.Strategy` currently does not support TensorFlow's
502  partitioned variables (where a single variable is split across multiple
503  devices) at this time.
504  """
505  # pylint: enable=line-too-long
506
507  # TODO(josh11b): Partitioned computations, state; sharding
508  # TODO(josh11b): Model parallelism: "replicas" with multiple devices; shuffling
509
510  def __init__(self, extended):
511    self._extended = extended
512
513    # Flag that is used to indicate whether distribution strategy is used with
514    # Estimator. This is required for backward compatibility of loss scaling
515    # when using v1 optimizer with estimator.
516    self._scale_loss_for_estimator = False
517
518    if not hasattr(extended, "_retrace_functions_for_each_device"):
519      # pylint: disable=protected-access
520      try:
521        extended._retrace_functions_for_each_device = (
522            len(extended.worker_devices) > 1)
523        distribution_strategy_replica_gauge.get_cell("num_replicas").set(
524            self.num_replicas_in_sync)
525      except:  # pylint: disable=bare-except
526        # Default for the case where extended.worker_devices can't return
527        # a sensible value.
528        extended._retrace_functions_for_each_device = True
529
530  @property
531  def extended(self):
532    """`tf.distribute.StrategyExtended` with additional methods."""
533    return self._extended
534
535  @tf_contextlib.contextmanager
536  def _scale_loss_for_estimator_enabled(self):
537    """Scope which sets a flag used for scaling losses in optimizer.
538
539    Yields:
540      `_scale_loss_for_estimator_enabled` is a context manager with a
541      side effect, but doesn't return a value.
542    """
543    self._scale_loss_for_estimator = True
544    try:
545      yield
546    finally:
547      self._scale_loss_for_estimator = False
548
549  def scope(self):
550    """Returns a context manager selecting this Strategy as current.
551
552    Inside a `with strategy.scope():` code block, this thread
553    will use a variable creator set by `strategy`, and will
554    enter its "cross-replica context".
555
556    Returns:
557      A context manager.
558    """
559    return self._extended._scope(self)  # pylint: disable=protected-access
560
561  @doc_controls.do_not_doc_inheritable  # DEPRECATED, moving to `extended`
562  def colocate_vars_with(self, colocate_with_variable):
563    """DEPRECATED: use extended.colocate_vars_with() instead."""
564    return self._extended.colocate_vars_with(colocate_with_variable)
565
566  @doc_controls.do_not_generate_docs  # DEPRECATED: TF 1.x only
567  def make_dataset_iterator(self, dataset):
568    """DEPRECATED TF 1.x ONLY."""
569    return self._extended._make_dataset_iterator(dataset)  # pylint: disable=protected-access
570
571  @doc_controls.do_not_generate_docs  # DEPRECATED: TF 1.x only
572  def make_input_fn_iterator(self,
573                             input_fn,
574                             replication_mode=InputReplicationMode.PER_WORKER):
575    """DEPRECATED TF 1.x ONLY."""
576    if replication_mode != InputReplicationMode.PER_WORKER:
577      raise ValueError(
578          "Input replication mode not supported: %r" % replication_mode)
579    with self.scope():
580      return self.extended._make_input_fn_iterator(  # pylint: disable=protected-access
581          input_fn, replication_mode=replication_mode)
582
583  def experimental_make_numpy_dataset(self, numpy_input):
584    """Makes a `tf.data.Dataset` for input provided via a numpy array.
585
586    This avoids adding `numpy_input` as a large constant in the graph,
587    and copies the data to the machine or machines that will be processing
588    the input.
589
590    Note that you will likely need to use `experimental_distribute_dataset`
591    with the returned dataset to further distribute it with the strategy.
592
593    Example:
594    ```
595    numpy_input = np.ones([10], dtype=np.float32)
596    dataset = strategy.experimental_make_numpy_dataset(numpy_input)
597    dist_dataset = strategy.experimental_distribute_dataset(dataset)
598    ```
599
600    Args:
601      numpy_input: A nest of NumPy input arrays that will be converted into a
602      dataset. Note that lists of Numpy arrays are stacked, as that is normal
603      `tf.data.Dataset` behavior.
604
605    Returns:
606      A `tf.data.Dataset` representing `numpy_input`.
607    """
608    return self.extended.experimental_make_numpy_dataset(
609        numpy_input, session=None)
610
611  @doc_controls.do_not_generate_docs  # DEPRECATED: TF 1.x only
612  def experimental_run(self, fn, input_iterator=None):
613    """DEPRECATED TF 1.x ONLY."""
614    with self.scope():
615      args = (input_iterator.get_next(),) if input_iterator is not None else ()
616    return self.experimental_run_v2(fn, args=args)
617
618  def experimental_distribute_dataset(self, dataset):
619    """Distributes a tf.data.Dataset instance provided via `dataset`.
620
621    The returned distributed dataset can be iterated over similar to how
622    regular datasets can.
623    NOTE: Currently, the user cannot add any more transformations to a
624    distributed dataset.
625
626    The following is an example:
627
628    ```python
629    strategy = tf.distribute.MirroredStrategy()
630
631    # Create a dataset
632    dataset = dataset_ops.Dataset.TFRecordDataset([
633      "/a/1.tfr", "/a/2.tfr", "/a/3.tfr", "/a/4.tfr"])
634
635    # Distribute that dataset
636    dist_dataset = strategy.experimental_distribute_dataset(dataset)
637
638    # Iterate over the distributed dataset
639    for x in dist_dataset:
640      # process dataset elements
641      strategy.experimental_run_v2(train_step, args=(x,))
642    ```
643
644    We will assume that the input dataset is batched by the
645    global batch size. With this assumption, we will make a best effort to
646    divide each batch across all the replicas (one or more workers).
647
648    In a multi-worker setting, we will first attempt to distribute the dataset
649    by attempting to detect whether the dataset is being created out of
650    ReaderDatasets (e.g. TFRecordDataset, TextLineDataset, etc.) and if so,
651    attempting to shard the input files. Note that there has to be at least one
652    input file per worker. If you have less than one input file per worker, we
653    suggest that you should disable distributing your dataset using the method
654    below.
655
656    If that attempt is unsuccessful (e.g. the dataset is created from a
657    Dataset.range), we will shard the dataset evenly at the end by appending a
658    `.shard` operation to the end of the processing pipeline. This will cause
659    the entire preprocessing pipeline for all the data to be run on every
660    worker, and each worker will do redundant work. We will print a warning
661    if this method of sharding is selected.
662
663    You can disable dataset sharding across workers using the
664    `auto_shard_policy` option in `tf.data.experimental.DistributeOptions`.
665
666    Within each worker, we will also split the data among all the worker
667    devices (if more than one a present), and this will happen even if
668    multi-worker sharding is disabled using the method above.
669
670    If the above batch splitting and dataset sharding logic is undesirable,
671    please use `experimental_distribute_datasets_from_function` instead, which
672    does not do any automatic splitting or sharding.
673
674    You can also use the `element_spec` property of the distributed dataset
675    returned by this API to query the `tf.TypeSpec` of the elements returned
676    by the iterator. This can be used to set the `input_signature` property
677    of a `tf.function`.
678
679    ```python
680    strategy = tf.distribute.MirroredStrategy()
681
682    # Create a dataset
683    dataset = dataset_ops.Dataset.TFRecordDataset([
684      "/a/1.tfr", "/a/2.tfr", "/a/3.tfr", "/a/4.tfr"])
685
686    # Distribute that dataset
687    dist_dataset = strategy.experimental_distribute_dataset(dataset)
688
689    @tf.function(input_signature=[dist_dataset.element_spec])
690    def train_step(inputs):
691      # train model with inputs
692      return
693
694    # Iterate over the distributed dataset
695    for x in dist_dataset:
696      # process dataset elements
697      strategy.experimental_run_v2(train_step, args=(x,))
698    ```
699
700    Args:
701      dataset: `tf.data.Dataset` that will be sharded across all replicas using
702        the rules stated above.
703
704    Returns:
705      A "distributed `Dataset`", which acts like a `tf.data.Dataset` except
706      it produces "per-replica" values.
707    """
708    return self._extended._experimental_distribute_dataset(dataset)  # pylint: disable=protected-access
709
710  def experimental_distribute_datasets_from_function(self, dataset_fn):
711    """Distributes `tf.data.Dataset` instances created by calls to `dataset_fn`.
712
713    `dataset_fn` will be called once for each worker in the strategy. Each
714    replica on that worker will dequeue one batch of inputs from the local
715    `Dataset` (i.e. if a worker has two replicas, two batches will be dequeued
716    from the `Dataset` every step).
717
718    This method can be used for several purposes. For example, where
719    `experimental_distribute_dataset` is unable to shard the input files, this
720    method might be used to manually shard the dataset (avoiding the slow
721    fallback behavior in `experimental_distribute_dataset`). In cases where the
722    dataset is infinite, this sharding can be done by creating dataset replicas
723    that differ only in their random seed.
724    `experimental_distribute_dataset` may also sometimes fail to split the
725    batch across replicas on a worker. In that case, this method can be used
726    where that limitation does not exist.
727
728    The `dataset_fn` should take an `tf.distribute.InputContext` instance where
729    information about batching and input replication can be accessed:
730
731    ```
732    def dataset_fn(input_context):
733      batch_size = input_context.get_per_replica_batch_size(global_batch_size)
734      d = tf.data.Dataset.from_tensors([[1.]]).repeat().batch(batch_size)
735      return d.shard(
736          input_context.num_input_pipelines, input_context.input_pipeline_id)
737
738    inputs = strategy.experimental_distribute_datasets_from_function(dataset_fn)
739
740    for batch in inputs:
741      replica_results = strategy.experimental_run_v2(replica_fn, args=(batch,))
742    ```
743
744    IMPORTANT: The `tf.data.Dataset` returned by `dataset_fn` should have a
745    per-replica batch size, unlike `experimental_distribute_dataset`, which uses
746    the global batch size.  This may be computed using
747    `input_context.get_per_replica_batch_size`.
748
749    To query the `tf.TypeSpec` of the elements in the distributed dataset
750    returned by this API, you need to use the `element_spec` property of the
751    distributed iterator. This `tf.TypeSpec` can be used to set the
752    `input_signature` property of a `tf.function`.
753
754    ```python
755    # If you want to specify `input_signature` for a `tf.function` you must
756    # first create the iterator.
757    iterator = iter(inputs)
758
759    @tf.function(input_signature=[iterator.element_spec])
760    def replica_fn_with_signature(inputs):
761      # train the model with inputs
762      return
763
764    for _ in range(steps):
765      strategy.experimental_run_v2(replica_fn_with_signature,
766          args=(next(iterator),))
767    ```
768
769    Args:
770      dataset_fn: A function taking a `tf.distribute.InputContext` instance and
771        returning a `tf.data.Dataset`.
772
773    Returns:
774      A "distributed `Dataset`", which acts like a `tf.data.Dataset` except
775      it produces "per-replica" values.
776    """
777    return self._extended._experimental_distribute_datasets_from_function(  # pylint: disable=protected-access
778        dataset_fn)
779
780  def experimental_run_v2(self, fn, args=(), kwargs=None):
781    """Run `fn` on each replica, with the given arguments.
782
783    Executes ops specified by `fn` on each replica. If `args` or `kwargs` have
784    "per-replica" values, such as those produced by a "distributed `Dataset`",
785    when `fn` is executed on a particular replica, it will be executed with the
786    component of those "per-replica" values that correspond to that replica.
787
788    `fn` may call `tf.distribute.get_replica_context()` to access members such
789    as `all_reduce`.
790
791    All arguments in `args` or `kwargs` should either be nest of tensors or
792    per-replica objects containing tensors or composite tensors.
793
794    IMPORTANT: Depending on the implementation of `tf.distribute.Strategy` and
795    whether eager execution is enabled, `fn` may be called one or more times (
796    once for each replica).
797
798    Args:
799      fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
800      args: (Optional) Positional arguments to `fn`.
801      kwargs: (Optional) Keyword arguments to `fn`.
802
803    Returns:
804      Merged return value of `fn` across replicas. The structure of the return
805      value is the same as the return value from `fn`. Each element in the
806      structure can either be "per-replica" `Tensor` objects or `Tensor`s
807      (for example, if running on a single replica).
808    """
809    if not isinstance(args, (list, tuple)):
810      raise ValueError(
811          "positional args must be a list or tuple, got {}".format(type(args)))
812
813    with self.scope():
814      # tf.distribute supports Eager functions, so AutoGraph should not be
815      # applied when when the caller is also in Eager mode.
816      fn = autograph.tf_convert(
817          fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
818      return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
819
820  def reduce(self, reduce_op, value, axis):
821    """Reduce `value` across replicas.
822
823    Given a per-replica value returned by `experimental_run_v2`, say a
824    per-example loss, the batch will be divided across all the replicas.  This
825    function allows you to aggregate across replicas and optionally also across
826    batch elements.  For example, if you have a global batch size of 8 and 2
827    replicas, values for examples `[0, 1, 2, 3]` will be on replica 0 and
828    `[4, 5, 6, 7]` will be on replica 1. By default, `reduce` will just
829    aggregate across replicas, returning `[0+4, 1+5, 2+6, 3+7]`. This is useful
830    when each replica is computing a scalar or some other value that doesn't
831    have a "batch" dimension (like a gradient). More often you will want to
832    aggregate across the global batch, which you can get by specifying the batch
833    dimension as the `axis`, typically `axis=0`. In this case it would return a
834    scalar `0+1+2+3+4+5+6+7`.
835
836    If there is a last partial batch, you will need to specify an axis so
837    that the resulting shape is consistent across replicas. So if the last
838    batch has size 6 and it is divided into [0, 1, 2, 3] and [4, 5], you
839    would get a shape mismatch unless you specify `axis=0`. If you specify
840    `tf.distribute.ReduceOp.MEAN`, using `axis=0` will use the correct
841    denominator of 6. Contrast this with computing `reduce_mean` to get a
842    scalar value on each replica and this function to average those means,
843    which will weigh some values `1/8` and others `1/4`.
844
845    Args:
846      reduce_op: A `tf.distribute.ReduceOp` value specifying how values should
847        be combined.
848      value: A "per replica" value, e.g. returned by `experimental_run_v2` to
849        be combined into a single tensor.
850      axis: Specifies the dimension to reduce along within each
851        replica's tensor. Should typically be set to the batch dimension, or
852        `None` to only reduce across replicas (e.g. if the tensor has no batch
853        dimension).
854
855    Returns:
856      A `Tensor`.
857    """
858    # TODO(josh11b): support `value` being a nest.
859    _require_cross_replica_or_default_context_extended(self._extended)
860    if isinstance(reduce_op, six.string_types):
861      reduce_op = reduce_util.ReduceOp(reduce_op.upper())
862    if axis is None:
863      return self._extended._reduce(reduce_op, value)  # pylint: disable=protected-access
864    if reduce_op == reduce_util.ReduceOp.SUM:
865      value = self.experimental_run_v2(
866          lambda v: math_ops.reduce_sum(v, axis=axis), args=(value,))
867      return self._extended._reduce(reduce_op, value)  # pylint: disable=protected-access
868    if reduce_op != reduce_util.ReduceOp.MEAN:
869      raise TypeError("Expected `reduce_op` to be a `tf.distribute.ReduceOp`, "
870                      "not: %r" % reduce_op)
871    # TODO(josh11b): Support list/tuple and tensor axis values.
872    if not isinstance(axis, six.integer_types):
873      raise TypeError("Expected `axis` to be an integer not: %r" % axis)
874
875    def mean_reduce_helper(v, axis=axis):
876      """Computes the numerator and denominator on each replica."""
877      numer = math_ops.reduce_sum(v, axis=axis)
878      if v.shape.rank is not None:
879        # Note(joshl): We support axis < 0 to be consistent with the
880        # tf.math.reduce_* operations.
881        if axis < 0:
882          if axis + v.shape.rank < 0:
883            raise ValueError(
884                "`axis` = %r out of range for `value` with rank %d" %
885                (axis, v.shape.rank))
886          axis += v.shape.rank
887        elif axis >= v.shape.rank:
888          raise ValueError(
889              "`axis` = %r out of range for `value` with rank %d" %
890              (axis, v.shape.rank))
891        # TF v2 returns `None` for unknown dimensions and an integer for
892        # known dimension, whereas TF v1 returns tensor_shape.Dimension(None)
893        # or tensor_shape.Dimension(integer). `dimension_value` hides this
894        # difference, always returning `None` or an integer.
895        dim = tensor_shape.dimension_value(v.shape[axis])
896        if dim is not None:
897          # By returning a python value in the static shape case, we can
898          # maybe get a fast path for reducing the denominator.
899          return numer, array_ops.constant(dim, dtype=dtypes.int64)
900      elif axis < 0:
901        axis = axis + array_ops.rank(v)
902      if v.shape.rank == 1:
903        # TODO(b/139422050): Currently tf.shape is not supported in TPU dynamic
904        # padder, use tf.size instead to workaround if the rank is 1.
905        denom = array_ops.size(v, out_type=dtypes.int64)
906      else:
907        denom = array_ops.shape_v2(v, out_type=dtypes.int64)[axis]
908      # TODO(josh11b): Should we cast denom to v.dtype here instead of after the
909      # reduce is complete?
910      return numer, denom
911
912    numer, denom = self.experimental_run_v2(mean_reduce_helper, args=(value,))
913    # TODO(josh11b): Should batch reduce here instead of doing two.
914    numer = self._extended._reduce(reduce_util.ReduceOp.SUM, numer)  # pylint: disable=protected-access
915    denom = self._extended._reduce(reduce_util.ReduceOp.SUM, denom)  # pylint: disable=protected-access
916    denom = math_ops.cast(denom, numer.dtype)
917    return math_ops.truediv(numer, denom)
918
919  @doc_controls.do_not_doc_inheritable  # DEPRECATED
920  def unwrap(self, value):
921    """Returns the list of all local per-replica values contained in `value`.
922
923    DEPRECATED: Please use `experimental_local_results` instead.
924
925    Note: This only returns values on the workers initiated by this client.
926    When using a `tf.distribute.Strategy` like
927    `tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker
928    will be its own client, and this function will only return values
929    computed on that worker.
930
931    Args:
932      value: A value returned by `experimental_run()`,
933        `extended.call_for_each_replica()`, or a variable created in `scope`.
934
935    Returns:
936      A tuple of values contained in `value`. If `value` represents a single
937      value, this returns `(value,).`
938    """
939    return self._extended._local_results(value)  # pylint: disable=protected-access
940
941  def experimental_local_results(self, value):
942    """Returns the list of all local per-replica values contained in `value`.
943
944    Note: This only returns values on the worker initiated by this client.
945    When using a `tf.distribute.Strategy` like
946    `tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker
947    will be its own client, and this function will only return values
948    computed on that worker.
949
950    Args:
951      value: A value returned by `experimental_run()`, `experimental_run_v2()`,
952        `extended.call_for_each_replica()`, or a variable created in `scope`.
953
954    Returns:
955      A tuple of values contained in `value`. If `value` represents a single
956      value, this returns `(value,).`
957    """
958    return self._extended._local_results(value)  # pylint: disable=protected-access
959
960  @doc_controls.do_not_doc_inheritable  # DEPRECATED: TF v1.x only
961  def group(self, value, name=None):
962    """Shortcut for `tf.group(self.experimental_local_results(value))`."""
963    return self._extended._group(value, name)  # pylint: disable=protected-access
964
965  @property
966  def num_replicas_in_sync(self):
967    """Returns number of replicas over which gradients are aggregated."""
968    return self._extended._num_replicas_in_sync  # pylint: disable=protected-access
969
970  @doc_controls.do_not_doc_inheritable  # DEPRECATED: see doc string
971  def configure(self,
972                session_config=None,
973                cluster_spec=None,
974                task_type=None,
975                task_id=None):
976    # pylint: disable=g-doc-return-or-yield,g-doc-args
977    """DEPRECATED: use `update_config_proto` instead.
978
979    Configures the strategy class.
980
981    DEPRECATED: This method's functionality has been split into the strategy
982    constructor and `update_config_proto`. In the future, we will allow passing
983    cluster and config_proto to the constructor to configure the strategy. And
984    `update_config_proto` can be used to update the config_proto based on the
985    specific strategy.
986    """
987    return self._extended._configure(  # pylint: disable=protected-access
988        session_config, cluster_spec, task_type, task_id)
989
990  @doc_controls.do_not_generate_docs  # DEPRECATED
991  def update_config_proto(self, config_proto):
992    """DEPRECATED TF 1.x ONLY."""
993    return self._extended._update_config_proto(config_proto)  # pylint: disable=protected-access
994
995  def __deepcopy__(self, memo):
996    # First do a regular deepcopy of `self`.
997    cls = self.__class__
998    result = cls.__new__(cls)
999    memo[id(self)] = result
1000    for k, v in self.__dict__.items():
1001      setattr(result, k, copy.deepcopy(v, memo))
1002    # One little fix-up: we want `result._extended` to reference `result`
1003    # instead of `self`.
1004    result._extended._container_strategy_weakref = weakref.ref(result)  # pylint: disable=protected-access
1005    return result
1006
1007  def __copy__(self):
1008    raise RuntimeError("Must only deepcopy DistributionStrategy.")
1009
1010
1011# TF v1.x version has additional deprecated APIs
1012@tf_export(v1=["distribute.Strategy"])
1013class StrategyV1(Strategy):
1014  """A list of devices with a state & compute distribution policy.
1015
1016  See [the guide](https://www.tensorflow.org/guide/distribute_strategy)
1017  for overview and examples.
1018
1019  Note: Not all `tf.distribute.Strategy` implementations currently support
1020  TensorFlow's partitioned variables (where a single variable is split across
1021  multiple devices) at this time.
1022  """
1023
1024  def make_dataset_iterator(self, dataset):
1025    """Makes an iterator for input provided via `dataset`.
1026
1027    DEPRECATED: This method is not available in TF 2.x.
1028
1029    Data from the given dataset will be distributed evenly across all the
1030    compute replicas. We will assume that the input dataset is batched by the
1031    global batch size. With this assumption, we will make a best effort to
1032    divide each batch across all the replicas (one or more workers).
1033    If this effort fails, an error will be thrown, and the user should instead
1034    use `make_input_fn_iterator` which provides more control to the user, and
1035    does not try to divide a batch across replicas.
1036
1037    The user could also use `make_input_fn_iterator` if they want to
1038    customize which input is fed to which replica/worker etc.
1039
1040    Args:
1041      dataset: `tf.data.Dataset` that will be distributed evenly across all
1042        replicas.
1043
1044    Returns:
1045      An `tf.distribute.InputIterator` which returns inputs for each step of the
1046      computation.  User should call `initialize` on the returned iterator.
1047    """
1048    return self._extended._make_dataset_iterator(dataset)  # pylint: disable=protected-access
1049
1050  def make_input_fn_iterator(self,  # pylint: disable=useless-super-delegation
1051                             input_fn,
1052                             replication_mode=InputReplicationMode.PER_WORKER):
1053    """Returns an iterator split across replicas created from an input function.
1054
1055    DEPRECATED: This method is not available in TF 2.x.
1056
1057    The `input_fn` should take an `tf.distribute.InputContext` object where
1058    information about batching and input sharding can be accessed:
1059
1060    ```
1061    def input_fn(input_context):
1062      batch_size = input_context.get_per_replica_batch_size(global_batch_size)
1063      d = tf.data.Dataset.from_tensors([[1.]]).repeat().batch(batch_size)
1064      return d.shard(input_context.num_input_pipelines,
1065                     input_context.input_pipeline_id)
1066    with strategy.scope():
1067      iterator = strategy.make_input_fn_iterator(input_fn)
1068      replica_results = strategy.experimental_run(replica_fn, iterator)
1069    ```
1070
1071    The `tf.data.Dataset` returned by `input_fn` should have a per-replica
1072    batch size, which may be computed using
1073    `input_context.get_per_replica_batch_size`.
1074
1075    Args:
1076      input_fn: A function taking a `tf.distribute.InputContext` object and
1077        returning a `tf.data.Dataset`.
1078      replication_mode: an enum value of `tf.distribute.InputReplicationMode`.
1079        Only `PER_WORKER` is supported currently, which means there will be
1080        a single call to `input_fn` per worker. Replicas will dequeue from the
1081        local `tf.data.Dataset` on their worker.
1082
1083    Returns:
1084      An iterator object that should first be `.initialize()`-ed. It may then
1085      either be passed to `strategy.experimental_run()` or you can
1086      `iterator.get_next()` to get the next value to pass to
1087      `strategy.extended.call_for_each_replica()`.
1088    """
1089    return super(StrategyV1, self).make_input_fn_iterator(
1090        input_fn, replication_mode)
1091
1092  def experimental_make_numpy_dataset(self, numpy_input, session=None):
1093    """Makes a tf.data.Dataset for input provided via a numpy array.
1094
1095    This avoids adding `numpy_input` as a large constant in the graph,
1096    and copies the data to the machine or machines that will be processing
1097    the input.
1098
1099    Note that you will likely need to use
1100    tf.distribute.Strategy.experimental_distribute_dataset
1101    with the returned dataset to further distribute it with the strategy.
1102
1103    Example:
1104    ```
1105    numpy_input = np.ones([10], dtype=np.float32)
1106    dataset = strategy.experimental_make_numpy_dataset(numpy_input)
1107    dist_dataset = strategy.experimental_distribute_dataset(dataset)
1108    ```
1109
1110    Args:
1111      numpy_input: A nest of NumPy input arrays that will be converted into a
1112      dataset. Note that lists of Numpy arrays are stacked, as that is normal
1113      `tf.data.Dataset` behavior.
1114      session: (TensorFlow v1.x graph execution only) A session used for
1115        initialization.
1116
1117    Returns:
1118      A `tf.data.Dataset` representing `numpy_input`.
1119    """
1120    return self.extended.experimental_make_numpy_dataset(
1121        numpy_input, session=session)
1122
1123  def experimental_run(self, fn, input_iterator=None):  # pylint: disable=useless-super-delegation
1124    """Runs ops in `fn` on each replica, with inputs from `input_iterator`.
1125
1126    DEPRECATED: This method is not available in TF 2.x. Please switch
1127    to using `experimental_run_v2` instead.
1128
1129    When eager execution is enabled, executes ops specified by `fn` on each
1130    replica. Otherwise, builds a graph to execute the ops on each replica.
1131
1132    Each replica will take a single, different input from the inputs provided by
1133    one `get_next` call on the input iterator.
1134
1135    `fn` may call `tf.distribute.get_replica_context()` to access members such
1136    as `replica_id_in_sync_group`.
1137
1138    IMPORTANT: Depending on the `tf.distribute.Strategy` implementation being
1139    used, and whether eager execution is enabled, `fn` may be called one or more
1140    times (once for each replica).
1141
1142    Args:
1143      fn: The function to run. The inputs to the function must match the outputs
1144        of `input_iterator.get_next()`. The output must be a `tf.nest` of
1145        `Tensor`s.
1146      input_iterator: (Optional) input iterator from which the inputs are taken.
1147
1148    Returns:
1149      Merged return value of `fn` across replicas. The structure of the return
1150      value is the same as the return value from `fn`. Each element in the
1151      structure can either be `PerReplica` (if the values are unsynchronized),
1152      `Mirrored` (if the values are kept in sync), or `Tensor` (if running on a
1153      single replica).
1154    """
1155    return super(StrategyV1, self).experimental_run(
1156        fn, input_iterator)
1157
1158  def reduce(self, reduce_op, value, axis=None):
1159    return super(StrategyV1, self).reduce(reduce_op, value, axis)
1160
1161  reduce.__doc__ = Strategy.reduce.__doc__
1162
1163  def update_config_proto(self, config_proto):
1164    """Returns a copy of `config_proto` modified for use with this strategy.
1165
1166    DEPRECATED: This method is not available in TF 2.x.
1167
1168    The updated config has something needed to run a strategy, e.g.
1169    configuration to run collective ops, or device filters to improve
1170    distributed training performance.
1171
1172    Args:
1173      config_proto: a `tf.ConfigProto` object.
1174
1175    Returns:
1176      The updated copy of the `config_proto`.
1177    """
1178    return self._extended._update_config_proto(config_proto)  # pylint: disable=protected-access
1179
1180
1181# NOTE(josh11b): For any strategy that needs to support tf.compat.v1,
1182# instead descend from StrategyExtendedV1.
1183@tf_export("distribute.StrategyExtended", v1=[])
1184class StrategyExtendedV2(object):
1185  """Additional APIs for algorithms that need to be distribution-aware.
1186
1187  Note: For most usage of `tf.distribute.Strategy`, there should be no need to
1188  call these methods, since TensorFlow libraries (such as optimizers) already
1189  call these methods when needed on your behalf.
1190
1191  Lower-level concepts:
1192
1193  * Wrapped values: In order to represent values parallel across devices
1194    (either replicas or the devices associated with a particular value), we
1195    wrap them in a "PerReplica" or "Mirrored" object that contains a map
1196    from replica id to values. "PerReplica" is used when the value may be
1197    different across replicas, and "Mirrored" when the value are the same.
1198  * Unwrapping and merging: Consider calling a function `fn` on multiple
1199    replicas, like `experimental_run_v2(fn, args=[w])` with an
1200    argument `w` that is a wrapped value. This means `w` will have a map taking
1201    replica id `0` to `w0`, replica id `11` to `w1`, etc.
1202    `experimental_run_v2()` unwraps `w` before calling `fn`, so
1203    it calls `fn(w0)` on `d0`, `fn(w1)` on `d1`, etc.  It then merges the return
1204    values from `fn()`, which can possibly result in wrapped values. For
1205    example, let's say `fn()` returns a tuple with three components: `(x, a,
1206    v0)` from replica 0, `(x, b, v1)` on replica 1, etc. If the first component
1207    is the same object `x` from every replica, then the first component of the
1208    merged result will also be `x`. If the second component is different (`a`,
1209    `b`, ...)  from each replica, then the merged value will have a wrapped map
1210    from replica device to the different values. If the third component is the
1211    members of a mirrored variable (`v` maps `d0` to `v0`, `d1` to `v1`, etc.),
1212    then the merged result will be that mirrored variable (`v`).
1213  * Worker devices vs. parameter devices: Most replica computations will
1214    happen on worker devices. Since we don't yet support model
1215    parallelism, there will be one worker device per replica. When using
1216    parameter servers or central storage, the set of devices holding
1217    variables may be different, otherwise the parameter devices might
1218    match the worker devices.
1219
1220  *Replica context vs. Cross-replica context*
1221
1222  A _replica context_ applies when we are in some function that is being called
1223  once for each replica.  Otherwise we are in cross-replica context, which is
1224  useful for calling `tf.distribute.Strategy` methods which operate across the
1225  replicas (like `reduce_to()`). By default you start in a replica context
1226  (the "default single replica context") and then some methods can switch you
1227  back and forth. There is a third mode you can be in called _update context_
1228  used when updating variables.
1229
1230  * `tf.distribute.Strategy.scope`: enters cross-replica context when
1231    no other strategy is in scope.
1232  * `tf.distribute.Strategy.experimental_run_v2`: calls a function in
1233    replica context.
1234  * `tf.distribute.ReplicaContext.merge_call`: transitions from replica
1235    context to cross-replica context.
1236  * `tf.distribute.StrategyExtended.update`: calls a function in an update
1237    context from a cross-replica context.
1238
1239  In a replica context, you may freely read the values of variables, but
1240  you may only update their value if they specify a way to aggregate the
1241  update using the `aggregation` parameter in the variable's constructor.
1242  In a cross-replica context, you may read or write variables (writes may
1243  need to be broadcast to all copies of the variable if it is mirrored).
1244
1245  *Sync on read variables*
1246
1247  In some cases, such as a metric, we want to accumulate a bunch of updates on
1248  each replica independently and only aggregate when reading. This can be a big
1249  performance win when the value is read only rarely (maybe the value is only
1250  read at the end of an epoch or when checkpointing).  These are variables
1251  created by passing `synchronization=ON_READ` to the variable's constructor
1252  (and some value for `aggregation`).
1253
1254  The strategy may choose to put the variable on multiple devices, like mirrored
1255  variables, but unlike mirrored variables we don't synchronize the updates to
1256  them to make sure they have the same value. Instead, the synchronization is
1257  performed when reading in cross-replica context.  In a replica context, reads
1258  and writes are performed on the local copy (we allow reads so you can write
1259  code like `v = 0.9*v + 0.1*update`).  We don't allow operations like
1260  `v.assign_add` in a cross-replica context for sync on read variables; right
1261  now we don't have a use case for such updates and depending on the aggregation
1262  mode such updates may not be sensible.
1263
1264  *Locality*
1265
1266  Depending on how a value is produced, it will have a type that will determine
1267  how it may be used.
1268
1269  "Per-replica" values exist on the worker devices, with a different value for
1270  each replica. They are produced by iterating through a "distributed `Dataset`"
1271  returned by `tf.distribute.Strategy.experimental_distribute_dataset` and
1272  `tf.distribute.Strategy.experimental_distribute_datasets_from_function`.  They
1273  are also the typical result returned by
1274  `tf.distribute.Strategy.experimental_run_v2`. You typically can't use a
1275  per-replica value directly in a cross-replica context, without first resolving
1276  how to aggregate the values across replicas, for instance by using
1277  `tf.distribute.Strategy.reduce`.
1278
1279  "Mirrored" values are like per-replica values, except we know that the value
1280  on all replicas are the same. We can safely read a mirrored value in a
1281  cross-replica context by using the value on any replica. You can convert
1282  a per-replica value into a mirrored value by using
1283  `tf.distribute.ReplicaContext.all_reduce`.
1284
1285  Values can also have the same locality as a variable, which is a mirrored
1286  value but residing on the same devices as the variable (as opposed to the
1287  compute devices). Such values may be passed to a call to
1288  `tf.distribute.StrategyExtended.update` to update the value of a variable.
1289  You may use `tf.distribute.StrategyExtended.colocate_vars_with` to give a
1290  variable the same locality as another variable. This is useful, for example,
1291  for "slot" variables used by an optimizer for keeping track of statistics
1292  used to update a primary/model variable. You may convert a per-replica
1293  value to a variable's locality by using
1294  `tf.distribute.StrategyExtended.reduce_to` or
1295  `tf.distribute.StrategyExtended.batch_reduce_to`.
1296
1297  In addition to slot variables which should be colocated with their primary
1298  variables, optimizers also define non-slot variables. These can be things like
1299  "number of step updates performed" or "beta1^t" and "beta2^t".  Each strategy
1300  has some policy for which devices those variables should be copied too, called
1301  the "non-slot devices" (some subset of the parameter devices). We require that
1302  all non-slot variables are allocated on the same device, or mirrored across
1303  the same set of devices. You can use
1304  `tf.distribute.StrategyExtended.non_slot_devices` to pick a consistent set of
1305  devices to pass to both `tf.distribute.StrategyExtended.colocate_vars_with`
1306  and `tf.distribute.StrategyExtended.update_non_slot`.
1307
1308  *How to update a variable*
1309
1310  The standard pattern for updating variables is to:
1311
1312  1. In your function passed to `tf.distribute.Strategy.experimental_run_v2`,
1313     compute a list of (update, variable) pairs. For example, the update might
1314     be a the gradient of the loss with respect to the variable.
1315  2. Switch to cross-replica mode by calling
1316     `tf.distribute.get_replica_context().merge_call()` with the updates and
1317     variables as arguments.
1318  3. Call
1319     `tf.distribute.StrategyExtended.reduce_to(VariableAggregation.SUM, t, v)`
1320     (for one variable) or `tf.distribute.StrategyExtended.batch_reduce_to`
1321     (for a list of variables) to sum the updates.
1322     and broadcast the result to the variable's devices.
1323  4. Call `tf.distribute.StrategyExtended.update(v)` for each variable to update
1324     its value.
1325
1326  Steps 2 through 4 are done automatically by class
1327  `tf.keras.optimizers.Optimizer` if you call its
1328  `tf.keras.optimizers.Optimizer.apply_gradients` method in a replica context.
1329  They are also done automatically if you call an `assign*` method on a (non
1330  sync-on-read) variable that was constructed with an aggregation method (which
1331  is used to determine the reduction used in step 3).
1332
1333  *Distribute-aware layers*
1334
1335  Layers are generally called in a replica context, except when defining a
1336  functional model. `tf.distribute.in_cross_replica_context` will let you
1337  determine which case you are in. If in a replica context,
1338  the `tf.distribute.get_replica_context` function will return a
1339  `tf.distribute.ReplicaContext` object. The `ReplicaContext` object has an
1340  `all_reduce` method for aggregating across all replicas. Alternatively, you
1341  can update variables following steps 2-4 above.
1342
1343  Note: For new `tf.distribute.Strategy` implementations, please put all logic
1344  in a subclass of `tf.distribute.StrategyExtended`. The only code needed for
1345  the `tf.distribute.Strategy` subclass is for instantiating your subclass of
1346  `tf.distribute.StrategyExtended` in the `__init__` method.
1347  """
1348
1349  def __init__(self, container_strategy):
1350    self._container_strategy_weakref = weakref.ref(container_strategy)
1351    self._default_device = None
1352    # This property is used to determine if we should set drop_remainder=True
1353    # when creating Datasets from numpy array inputs.
1354    self._require_static_shapes = False
1355
1356  def _container_strategy(self):
1357    """Get the containing `tf.distribute.Strategy`.
1358
1359    This should not generally be needed except when creating a new
1360    `ReplicaContext` and to validate that the caller is in the correct
1361    `scope()`.
1362
1363    Returns:
1364      The `tf.distribute.Strategy` such that `strategy.extended` is `self`.
1365    """
1366    container_strategy = self._container_strategy_weakref()
1367    assert container_strategy is not None
1368    return container_strategy
1369
1370  def _scope(self, strategy):
1371    """Implementation of tf.distribute.Strategy.scope()."""
1372
1373    def creator_with_resource_vars(next_creator, **kwargs):
1374      """Variable creator to use in `_CurrentDistributionContext`."""
1375      _require_strategy_scope_extended(self)
1376      kwargs["use_resource"] = True
1377      kwargs["distribute_strategy"] = strategy
1378
1379      # Unwrap `initial_value` if it is a `CheckpointInitialValue` to avoid
1380      # dereferencing a `Tensor` that is without a `name`.
1381      # TODO(b/138130844): Revisit the following check once
1382      # `CheckpointInitialValue` class is removed.
1383      if isinstance(kwargs["initial_value"], trackable.CheckpointInitialValue):
1384        kwargs["initial_value"] = kwargs["initial_value"].wrapped_value
1385
1386      return self._create_variable(next_creator, **kwargs)
1387
1388    def distributed_getter(getter, *args, **kwargs):
1389      if not self._allow_variable_partition():
1390        if kwargs.pop("partitioner", None) is not None:
1391          tf_logging.log_first_n(
1392              tf_logging.WARN, "Partitioned variables are disabled when using "
1393              "current tf.distribute.Strategy.", 1)
1394      return getter(*args, **kwargs)
1395
1396    return _CurrentDistributionContext(
1397        strategy,
1398        variable_scope.variable_creator_scope(creator_with_resource_vars),
1399        variable_scope.variable_scope(
1400            variable_scope.get_variable_scope(),
1401            custom_getter=distributed_getter), self._default_device)
1402
1403  def _allow_variable_partition(self):
1404    return False
1405
1406  def _create_variable(self, next_creator, **kwargs):
1407    # Note: should support "colocate_with" argument.
1408    raise NotImplementedError("must be implemented in descendants")
1409
1410  def variable_created_in_scope(self, v):
1411    """Tests whether `v` was created while this strategy scope was active.
1412
1413    Variables created inside the strategy scope are "owned" by it:
1414
1415    ```python
1416    strategy = tf.distribute.StrategyExtended()
1417    with strategy.scope():
1418      v = tf.Variable(1.)
1419    strategy.variable_created_in_scope(v)
1420    True
1421    ```
1422
1423    Variables created outside the strategy are not owned by it:
1424
1425    ```python
1426    v = tf.Variable(1.)
1427    strategy.variable_created_in_scope(v)
1428    False
1429    ```
1430
1431    Args:
1432      v: A `tf.Variable` instance.
1433
1434    Returns:
1435      True if `v` was created inside the scope, False if not.
1436    """
1437    return v._distribute_strategy == self._container_strategy_weakref()  # pylint: disable=protected-access
1438
1439  def colocate_vars_with(self, colocate_with_variable):
1440    """Scope that controls which devices variables will be created on.
1441
1442    No operations should be added to the graph inside this scope, it
1443    should only be used when creating variables (some implementations
1444    work by changing variable creation, others work by using a
1445    tf.compat.v1.colocate_with() scope).
1446
1447    This may only be used inside `self.scope()`.
1448
1449    Example usage:
1450
1451    ```
1452    with strategy.scope():
1453      var1 = tf.Variable(...)
1454      with strategy.extended.colocate_vars_with(var1):
1455        # var2 and var3 will be created on the same device(s) as var1
1456        var2 = tf.Variable(...)
1457        var3 = tf.Variable(...)
1458
1459      def fn(v1, v2, v3):
1460        # operates on v1 from var1, v2 from var2, and v3 from var3
1461
1462      # `fn` runs on every device `var1` is on, `var2` and `var3` will be there
1463      # too.
1464      strategy.extended.update(var1, fn, args=(var2, var3))
1465    ```
1466
1467    Args:
1468      colocate_with_variable: A variable created in this strategy's `scope()`.
1469        Variables created while in the returned context manager will be on the
1470        same set of devices as `colocate_with_variable`.
1471
1472    Returns:
1473      A context manager.
1474    """
1475
1476    def create_colocated_variable(next_creator, **kwargs):
1477      _require_strategy_scope_extended(self)
1478      kwargs["use_resource"] = True
1479      kwargs["colocate_with"] = colocate_with_variable
1480      return next_creator(**kwargs)
1481
1482    _require_strategy_scope_extended(self)
1483    self._validate_colocate_with_variable(colocate_with_variable)
1484    return variable_scope.variable_creator_scope(create_colocated_variable)
1485
1486  def _validate_colocate_with_variable(self, colocate_with_variable):
1487    """Validate `colocate_with_variable` argument to `colocate_vars_with`."""
1488    pass
1489
1490  def _make_dataset_iterator(self, dataset):
1491    raise NotImplementedError("must be implemented in descendants")
1492
1493  def _make_input_fn_iterator(self, input_fn, replication_mode):
1494    raise NotImplementedError("must be implemented in descendants")
1495
1496  def _experimental_distribute_dataset(self, dataset):
1497    raise NotImplementedError("must be implemented in descendants")
1498
1499  def _experimental_distribute_datasets_from_function(self, dataset_fn):
1500    raise NotImplementedError("must be implemented in descendants")
1501
1502  def _reduce(self, reduce_op, value):
1503    # Default implementation until we have an implementation for each strategy.
1504    return self._local_results(
1505        self._reduce_to(reduce_op, value,
1506                        device_util.current() or "/device:CPU:0"))[0]
1507
1508  def reduce_to(self, reduce_op, value, destinations):
1509    """Combine (via e.g. sum or mean) values across replicas.
1510
1511    Args:
1512      reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum.
1513      value: A per-replica value with one value per replica.
1514      destinations: A mirrored variable, a per-replica tensor, or a device
1515        string. The return value will be copied to all destination devices (or
1516        all the devices where the `destinations` value resides). To perform an
1517        all-reduction, pass `value` to `destinations`.
1518
1519    Returns:
1520      A tensor or value mirrored to `destinations`.
1521    """
1522    # TODO(josh11b): More docstring
1523    _require_cross_replica_or_default_context_extended(self)
1524    assert not isinstance(destinations, (list, tuple))
1525    assert not isinstance(reduce_op, variable_scope.VariableAggregation)
1526    if isinstance(reduce_op, six.string_types):
1527      reduce_op = reduce_util.ReduceOp(reduce_op.upper())
1528    assert (reduce_op == reduce_util.ReduceOp.SUM or
1529            reduce_op == reduce_util.ReduceOp.MEAN)
1530    return self._reduce_to(reduce_op, value, destinations)
1531
1532  def _reduce_to(self, reduce_op, value, destinations):
1533    raise NotImplementedError("must be implemented in descendants")
1534
1535  def batch_reduce_to(self, reduce_op, value_destination_pairs):
1536    """Combine multiple `reduce_to` calls into one for faster execution.
1537
1538    Args:
1539      reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum.
1540      value_destination_pairs: A sequence of (value, destinations)
1541        pairs. See `reduce_to()` for a description.
1542
1543    Returns:
1544      A list of mirrored values, one per pair in `value_destination_pairs`.
1545    """
1546    # TODO(josh11b): More docstring
1547    _require_cross_replica_or_default_context_extended(self)
1548    assert not isinstance(reduce_op, variable_scope.VariableAggregation)
1549    if isinstance(reduce_op, six.string_types):
1550      reduce_op = reduce_util.ReduceOp(reduce_op.upper())
1551    return self._batch_reduce_to(reduce_op, value_destination_pairs)
1552
1553  def _batch_reduce_to(self, reduce_op, value_destination_pairs):
1554    return [
1555        self.reduce_to(reduce_op, t, destinations=v)
1556        for t, v in value_destination_pairs
1557    ]
1558
1559  def update(self, var, fn, args=(), kwargs=None, group=True):
1560    """Run `fn` to update `var` using inputs mirrored to the same devices.
1561
1562    If `var` is mirrored across multiple devices, then this implements
1563    logic like:
1564
1565    ```
1566    results = {}
1567    for device, v in var:
1568      with tf.device(device):
1569        # args and kwargs will be unwrapped if they are mirrored.
1570        results[device] = fn(v, *args, **kwargs)
1571    return merged(results)
1572    ```
1573
1574    Otherwise this returns `fn(var, *args, **kwargs)` colocated with `var`.
1575
1576    Neither `args` nor `kwargs` may contain per-replica values.
1577    If they contain mirrored values, they will be unwrapped before
1578    calling `fn`.
1579
1580    Args:
1581      var: Variable, possibly mirrored to multiple devices, to operate on.
1582      fn: Function to call. Should take the variable as the first argument.
1583      args: Tuple or list. Additional positional arguments to pass to `fn()`.
1584      kwargs: Dict with keyword arguments to pass to `fn()`.
1585      group: Boolean. Defaults to True. If False, the return value will be
1586        unwrapped.
1587
1588    Returns:
1589      By default, the merged return value of `fn` across all replicas.  The
1590      merged result has dependencies to make sure that if it is evaluated at
1591      all, the side effects (updates) will happen on every replica. If instead
1592      "group=False" is specified, this function will return a nest of lists
1593      where each list has an element per replica, and the caller is responsible
1594      for ensuring all elements are executed.
1595    """
1596    _require_cross_replica_or_default_context_extended(self)
1597    if kwargs is None:
1598      kwargs = {}
1599    fn = autograph.tf_convert(
1600        fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
1601    with self._container_strategy().scope():
1602      return self._update(var, fn, args, kwargs, group)
1603
1604  def _update(self, var, fn, args, kwargs, group):
1605    raise NotImplementedError("must be implemented in descendants")
1606
1607  def update_non_slot(
1608      self, colocate_with, fn, args=(), kwargs=None, group=True):
1609    """Runs `fn(*args, **kwargs)` on `colocate_with` devices.
1610
1611    Args:
1612      colocate_with: The return value of `non_slot_devices()`.
1613      fn: Function to execute.
1614      args: Tuple or list. Positional arguments to pass to `fn()`.
1615      kwargs: Dict with keyword arguments to pass to `fn()`.
1616      group: Boolean. Defaults to True. If False, the return value will be
1617        unwrapped.
1618
1619    Returns:
1620      Return value of `fn`, possibly merged across devices.
1621    """
1622    _require_cross_replica_or_default_context_extended(self)
1623    if kwargs is None:
1624      kwargs = {}
1625    fn = autograph.tf_convert(
1626        fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
1627    with self._container_strategy().scope():
1628      return self._update_non_slot(colocate_with, fn, args, kwargs, group)
1629
1630  def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
1631    raise NotImplementedError("must be implemented in descendants")
1632
1633  def _local_results(self, distributed_value):
1634    raise NotImplementedError("must be implemented in descendants")
1635
1636  def value_container(self, value):
1637    """Returns the container that this per-replica `value` belongs to.
1638
1639    Args:
1640      value: A value returned by `experimental_run_v2()` or a variable
1641        created in `scope()`.
1642
1643    Returns:
1644      A container that `value` belongs to.
1645      If value does not belong to any container (including the case of
1646      container having been destroyed), returns the value itself.
1647      `value in experimental_local_results(value_container(value))` will
1648      always be true.
1649    """
1650    raise NotImplementedError("must be implemented in descendants")
1651
1652  def _group(self, value, name=None):
1653    """Implementation of `group`."""
1654    value = nest.flatten(self._local_results(value))
1655
1656    if len(value) != 1 or name is not None:
1657      return control_flow_ops.group(value, name=name)
1658    # Special handling for the common case of one op.
1659    v, = value
1660    if hasattr(v, "op"):
1661      v = v.op
1662    return v
1663
1664  @property
1665  def experimental_require_static_shapes(self):
1666    """Returns `True` if static shape is required; `False` otherwise."""
1667    return self._require_static_shapes
1668
1669  @property
1670  def _num_replicas_in_sync(self):
1671    """Returns number of replicas over which gradients are aggregated."""
1672    raise NotImplementedError("must be implemented in descendants")
1673
1674  @property
1675  def worker_devices(self):
1676    """Returns the tuple of all devices used to for compute replica execution.
1677    """
1678    # TODO(josh11b): More docstring
1679    raise NotImplementedError("must be implemented in descendants")
1680
1681  @property
1682  def parameter_devices(self):
1683    """Returns the tuple of all devices used to place variables."""
1684    # TODO(josh11b): More docstring
1685    raise NotImplementedError("must be implemented in descendants")
1686
1687  def non_slot_devices(self, var_list):
1688    """Device(s) for non-slot variables.
1689
1690    Create variables on these devices in a
1691    `with colocate_vars_with(non_slot_devices(...)):` block.
1692    Update those using `update_non_slot()`.
1693
1694    Args:
1695      var_list: The list of variables being optimized, needed with the
1696        default `tf.distribute.Strategy`.
1697    Returns:
1698      A sequence of devices for non-slot variables.
1699    """
1700    raise NotImplementedError("must be implemented in descendants")
1701
1702  def _configure(self,
1703                 session_config=None,
1704                 cluster_spec=None,
1705                 task_type=None,
1706                 task_id=None):
1707    """Configures the strategy class."""
1708    del session_config, cluster_spec, task_type, task_id
1709
1710  def _update_config_proto(self, config_proto):
1711    return copy.deepcopy(config_proto)
1712
1713  def _in_multi_worker_mode(self):
1714    """Whether this strategy indicates working in multi-worker settings.
1715
1716    Multi-worker training refers to the setup where the training is
1717    distributed across multiple workers, as opposed to the case where
1718    only a local process performs the training. This function is
1719    used by higher-level apis such as Keras' `model.fit()` to infer
1720    for example whether or not a distribute coordinator should be run,
1721    and thus TensorFlow servers should be started for communication
1722    with other servers in the cluster, or whether or not saving/restoring
1723    checkpoints is relevant for preemption fault tolerance.
1724
1725    Subclasses should override this to provide whether the strategy is
1726    currently in multi-worker setup.
1727
1728    Experimental. Signature and implementation are subject to change.
1729    """
1730    raise NotImplementedError("must be implemented in descendants")
1731
1732
1733@tf_export(v1=["distribute.StrategyExtended"])  # pylint: disable=missing-docstring
1734class StrategyExtendedV1(StrategyExtendedV2):
1735
1736  __doc__ = StrategyExtendedV2.__doc__
1737
1738  def experimental_make_numpy_dataset(self, numpy_input, session=None):
1739    """Makes a dataset for input provided via a numpy array.
1740
1741    This avoids adding `numpy_input` as a large constant in the graph,
1742    and copies the data to the machine or machines that will be processing
1743    the input.
1744
1745    Args:
1746      numpy_input: A nest of NumPy input arrays that will be distributed evenly
1747        across all replicas. Note that lists of Numpy arrays are stacked, as
1748        that is normal `tf.data.Dataset` behavior.
1749      session: (TensorFlow v1.x graph execution only) A session used for
1750        initialization.
1751
1752    Returns:
1753      A `tf.data.Dataset` representing `numpy_input`.
1754    """
1755    _require_cross_replica_or_default_context_extended(self)
1756    return self._experimental_make_numpy_dataset(numpy_input, session=session)
1757
1758  def _experimental_make_numpy_dataset(self, numpy_input, session):
1759    raise NotImplementedError("must be implemented in descendants")
1760
1761  def broadcast_to(self, tensor, destinations):
1762    """Mirror a tensor on one device to all worker devices.
1763
1764    Args:
1765      tensor: A Tensor value to broadcast.
1766      destinations: A mirrored variable or device string specifying the
1767        destination devices to copy `tensor` to.
1768
1769    Returns:
1770      A value mirrored to `destinations` devices.
1771    """
1772    assert destinations is not None  # from old strategy.broadcast()
1773    # TODO(josh11b): More docstring
1774    _require_cross_replica_or_default_context_extended(self)
1775    assert not isinstance(destinations, (list, tuple))
1776    return self._broadcast_to(tensor, destinations)
1777
1778  def _broadcast_to(self, tensor, destinations):
1779    raise NotImplementedError("must be implemented in descendants")
1780
1781  def experimental_run_steps_on_iterator(self,
1782                                         fn,
1783                                         iterator,
1784                                         iterations=1,
1785                                         initial_loop_values=None):
1786    """DEPRECATED: please use `experimental_run_v2` instead.
1787
1788    Run `fn` with input from `iterator` for `iterations` times.
1789
1790    This method can be used to run a step function for training a number of
1791    times using input from a dataset.
1792
1793    Args:
1794      fn: function to run using this distribution strategy. The function must
1795        have the following signature: `def fn(context, inputs)`. `context` is an
1796          instance of `MultiStepContext` that will be passed when `fn` is run.
1797          `context` can be used to specify the outputs to be returned from `fn`
1798          by calling `context.set_last_step_output`. It can also be used to
1799          capture non tensor outputs by `context.set_non_tensor_output`. See
1800          `MultiStepContext` documentation for more information. `inputs` will
1801          have same type/structure as `iterator.get_next()`. Typically, `fn`
1802          will use `call_for_each_replica` method of the strategy to distribute
1803          the computation over multiple replicas.
1804      iterator: Iterator of a dataset that represents the input for `fn`. The
1805        caller is responsible for initializing the iterator as needed.
1806      iterations: (Optional) Number of iterations that `fn` should be run.
1807        Defaults to 1.
1808      initial_loop_values: (Optional) Initial values to be passed into the
1809        loop that runs `fn`. Defaults to `None`. # TODO(priyag): Remove
1810          initial_loop_values argument when we have a mechanism to infer the
1811          outputs of `fn`.
1812
1813    Returns:
1814      Returns the `MultiStepContext` object which has the following properties,
1815      among other things:
1816        - run_op: An op that runs `fn` `iterations` times.
1817        - last_step_outputs: A dictionary containing tensors set using
1818        `context.set_last_step_output`. Evaluating this returns the value of
1819        the tensors after the last iteration.
1820        - non_tensor_outputs: A dictionatry containing anything that was set by
1821          `fn` by calling `context.set_non_tensor_output`.
1822    """
1823    _require_cross_replica_or_default_context_extended(self)
1824    with self._container_strategy().scope():
1825      return self._experimental_run_steps_on_iterator(fn, iterator, iterations,
1826                                                      initial_loop_values)
1827
1828  def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
1829                                          initial_loop_values):
1830    raise NotImplementedError("must be implemented in descendants")
1831
1832  def call_for_each_replica(self, fn, args=(), kwargs=None):
1833    """Run `fn` once per replica.
1834
1835    `fn` may call `tf.get_replica_context()` to access methods such as
1836    `replica_id_in_sync_group` and `merge_call()`.
1837
1838    `merge_call()` is used to communicate between the replicas and
1839    re-enter the cross-replica context. All replicas pause their execution
1840    having encountered a `merge_call()` call. After that the
1841    `merge_fn`-function is executed. Its results are then unwrapped and
1842    given back to each replica call. After that execution resumes until
1843    `fn` is complete or encounters another `merge_call()`.  Example:
1844
1845    ```python
1846    # Called once in "cross-replica" context.
1847    def merge_fn(distribution, three_plus_replica_id):
1848      # sum the values across replicas
1849      return sum(distribution.experimental_local_results(three_plus_replica_id))
1850
1851    # Called once per replica in `distribution`, in a "replica" context.
1852    def fn(three):
1853      replica_ctx = tf.get_replica_context()
1854      v = three + replica_ctx.replica_id_in_sync_group
1855      # Computes the sum of the `v` values across all replicas.
1856      s = replica_ctx.merge_call(merge_fn, args=(v,))
1857      return s + v
1858
1859    with distribution.scope():
1860      # in "cross-replica" context
1861      ...
1862      merged_results = distribution.experimental_run_v2(fn, args=[3])
1863      # merged_results has the values from every replica execution of `fn`.
1864      # This statement prints a list:
1865      print(distribution.experimental_local_results(merged_results))
1866    ```
1867
1868    Args:
1869      fn: function to run (will be run once per replica).
1870      args: Tuple or list with positional arguments for `fn`.
1871      kwargs: Dict with keyword arguments for `fn`.
1872
1873    Returns:
1874      Merged return value of `fn` across all replicas.
1875    """
1876    _require_cross_replica_or_default_context_extended(self)
1877    if kwargs is None:
1878      kwargs = {}
1879    with self._container_strategy().scope():
1880      return self._call_for_each_replica(fn, args, kwargs)
1881
1882  def _call_for_each_replica(self, fn, args, kwargs):
1883    raise NotImplementedError("must be implemented in descendants")
1884
1885  def read_var(self, v):
1886    """Reads the value of a variable.
1887
1888    Returns the aggregate value of a replica-local variable, or the
1889    (read-only) value of any other variable.
1890
1891    Args:
1892      v: A variable allocated within the scope of this `tf.distribute.Strategy`.
1893
1894    Returns:
1895      A tensor representing the value of `v`, aggregated across replicas if
1896      necessary.
1897    """
1898    raise NotImplementedError("must be implemented in descendants")
1899
1900  @property
1901  def experimental_between_graph(self):
1902    """Whether the strategy uses between-graph replication or not.
1903
1904      This is expected to return a constant value that will not be changed
1905      throughout its life cycle.
1906    """
1907    raise NotImplementedError("must be implemented in descendants")
1908
1909  @property
1910  def experimental_should_init(self):
1911    """Whether initialization is needed."""
1912    raise NotImplementedError("must be implemented in descendants")
1913
1914  @property
1915  def should_checkpoint(self):
1916    """Whether checkpointing is needed."""
1917    raise NotImplementedError("must be implemented in descendants")
1918
1919  @property
1920  def should_save_summary(self):
1921    """Whether saving summaries is needed."""
1922    raise NotImplementedError("must be implemented in descendants")
1923
1924
1925# A note about the difference between the context managers
1926# `ReplicaContext` (defined here) and `_CurrentDistributionContext`
1927# (defined above) used by `tf.distribute.Strategy.scope()`:
1928#
1929# * a ReplicaContext is only present during a `experimental_run_v2()`
1930#   call (except during a `merge_run` call) and in such a scope it
1931#   will be returned by calls to `get_replica_context()`.  Implementers of new
1932#   Strategy descendants will frequently also need to
1933#   define a descendant of ReplicaContext, and are responsible for
1934#   entering and exiting this context.
1935#
1936# * Strategy.scope() sets up a variable_creator scope that
1937#   changes variable creation calls (e.g. to make mirrored
1938#   variables). This is intended as an outer scope that users enter once
1939#   around their model creation and graph definition. There is no
1940#   anticipated need to define descendants of _CurrentDistributionContext.
1941#   It sets the current Strategy for purposes of
1942#   `get_strategy()` and `has_strategy()`
1943#   and switches the thread mode to a "cross-replica context".
1944@tf_export("distribute.ReplicaContext")
1945class ReplicaContext(object):
1946  """`tf.distribute.Strategy` API when in a replica context.
1947
1948  You can use `tf.distribute.get_replica_context` to get an instance of
1949  `ReplicaContext`. This should be inside your replicated step function, such
1950  as in a `tf.distribute.Strategy.experimental_run_v2` call.
1951  """
1952
1953  def __init__(self, strategy, replica_id_in_sync_group):
1954    self._strategy = strategy
1955    self._thread_context = distribution_strategy_context._InReplicaThreadMode(  # pylint: disable=protected-access
1956        self)
1957    self._replica_id_in_sync_group = replica_id_in_sync_group
1958    self._summary_recording_distribution_strategy = None
1959
1960  def __enter__(self):
1961    _push_per_thread_mode(self._thread_context)
1962
1963    def replica_id_is_zero():
1964      return math_ops.equal(self._replica_id_in_sync_group,
1965                            constant_op.constant(0))
1966
1967    summary_state = summary_ops_v2._summary_state  # pylint: disable=protected-access
1968    self._summary_recording_distribution_strategy = (
1969        summary_state.is_recording_distribution_strategy)
1970    summary_state.is_recording_distribution_strategy = replica_id_is_zero
1971
1972  def __exit__(self, exception_type, exception_value, traceback):
1973    summary_state = summary_ops_v2._summary_state  # pylint: disable=protected-access
1974    summary_state.is_recording_distribution_strategy = (
1975        self._summary_recording_distribution_strategy)
1976    _pop_per_thread_mode()
1977
1978  def merge_call(self, merge_fn, args=(), kwargs=None):
1979    """Merge args across replicas and run `merge_fn` in a cross-replica context.
1980
1981    This allows communication and coordination when there are multiple calls
1982    to the step_fn triggered by a call to
1983    `strategy.experimental_run_v2(step_fn, ...)`.
1984
1985    See `tf.distribute.Strategy.experimental_run_v2` for an
1986    explanation.
1987
1988    If not inside a distributed scope, this is equivalent to:
1989
1990    ```
1991    strategy = tf.distribute.get_strategy()
1992    with cross-replica-context(strategy):
1993      return merge_fn(strategy, *args, **kwargs)
1994    ```
1995
1996    Args:
1997      merge_fn: Function that joins arguments from threads that are given as
1998        PerReplica. It accepts `tf.distribute.Strategy` object as
1999        the first argument.
2000      args: List or tuple with positional per-thread arguments for `merge_fn`.
2001      kwargs: Dict with keyword per-thread arguments for `merge_fn`.
2002
2003    Returns:
2004      The return value of `merge_fn`, except for `PerReplica` values which are
2005      unpacked.
2006    """
2007    require_replica_context(self)
2008    if kwargs is None:
2009      kwargs = {}
2010    merge_fn = autograph.tf_convert(
2011        merge_fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
2012    return self._merge_call(merge_fn, args, kwargs)
2013
2014  def _merge_call(self, merge_fn, args, kwargs):
2015    """Default implementation for single replica."""
2016    _push_per_thread_mode(  # thread-local, so not needed with multiple threads
2017        distribution_strategy_context._CrossReplicaThreadMode(self._strategy))  # pylint: disable=protected-access
2018    try:
2019      return merge_fn(self._strategy, *args, **kwargs)
2020    finally:
2021      _pop_per_thread_mode()
2022
2023  @property
2024  def num_replicas_in_sync(self):
2025    """Returns number of replicas over which gradients are aggregated."""
2026    return self._strategy.num_replicas_in_sync
2027
2028  @property
2029  def replica_id_in_sync_group(self):
2030    """Returns the id of the replica being defined.
2031
2032    This identifies the replica that is part of a sync group. Currently we
2033    assume that all sync groups contain the same number of replicas. The value
2034    of the replica id can range from 0 to `num_replica_in_sync` - 1.
2035
2036    NOTE: This is not guaranteed to be the same ID as the XLA replica ID use
2037    for low-level operations such as collective_permute.
2038    """
2039    require_replica_context(self)
2040    return self._replica_id_in_sync_group
2041
2042  @property
2043  def strategy(self):
2044    """The current `tf.distribute.Strategy` object."""
2045    return self._strategy
2046
2047  @property
2048  def devices(self):
2049    """The devices this replica is to be executed on, as a tuple of strings."""
2050    require_replica_context(self)
2051    return (device_util.current(),)
2052
2053  def all_reduce(self, reduce_op, value):
2054    """All-reduces the given `value Tensor` nest across replicas.
2055
2056    If `all_reduce` is called in any replica, it must be called in all replicas.
2057    The nested structure and `Tensor` shapes must be identical in all replicas.
2058
2059    IMPORTANT: The ordering of communications must be identical in all replicas.
2060
2061    Example with two replicas:
2062      Replica 0 `value`: {'a': 1, 'b': [40, 1]}
2063      Replica 1 `value`: {'a': 3, 'b': [ 2, 98]}
2064
2065      If `reduce_op` == `SUM`:
2066        Result (on all replicas): {'a': 4, 'b': [42, 99]}
2067
2068      If `reduce_op` == `MEAN`:
2069        Result (on all replicas): {'a': 2, 'b': [21, 49.5]}
2070
2071    Args:
2072      reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum.
2073      value: The nested structure of `Tensor`s to all-reduce. The structure must
2074        be compatible with `tf.nest`.
2075
2076    Returns:
2077       A `Tensor` nest with the reduced `value`s from each replica.
2078    """
2079    if isinstance(reduce_op, six.string_types):
2080      reduce_op = reduce_util.ReduceOp(reduce_op.upper())
2081
2082    def batch_all_reduce(strategy, *value_flat):
2083      return strategy.extended.batch_reduce_to(
2084          reduce_op, [(v, _batch_reduce_destination(v)) for v in value_flat])
2085
2086    if reduce_op in [reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN]:
2087      # TODO(cjfj): Work out why `batch_reduce` doesn't return the correct grad.
2088      @custom_gradient.custom_gradient
2089      def grad_wrapper(*xs):
2090        ys = self.merge_call(batch_all_reduce, args=xs)
2091        # The gradient of an all-sum is itself an all-sum (all-mean, likewise).
2092        return ys, lambda *dy_s: self.all_reduce(reduce_op, dy_s)
2093      return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value)))
2094    else:
2095      # TODO(cjfj): Implement gradients for other reductions.
2096      reduced = nest.pack_sequence_as(
2097          value, self.merge_call(batch_all_reduce, args=nest.flatten(value)))
2098      return nest.map_structure(array_ops.prevent_gradient, reduced)
2099
2100  # TODO(josh11b): Implement `start_all_reduce(method, t)` for efficient
2101  # all-reduce. It would return a function returning the result of reducing `t`
2102  # across all replicas. The caller would wait to call this function until they
2103  # needed the reduce result, allowing an efficient implementation:
2104  # * With eager execution, the reduction could be performed asynchronously
2105  #   in the background, not blocking until the result was needed.
2106  # * When constructing a graph, it could batch up all reduction requests up
2107  #   to that point that the first result is needed. Most likely this can be
2108  #   implemented in terms of `merge_call()` and `batch_reduce_to()`.
2109
2110
2111def _batch_reduce_destination(x):
2112  """Returns the destinations for batch all-reduce."""
2113  if isinstance(x, ops.Tensor):
2114    # If this is a one device strategy.
2115    return x.device
2116  else:
2117    return x
2118
2119
2120# ------------------------------------------------------------------------------
2121
2122
2123_creating_default_strategy_singleton = False
2124
2125
2126class _DefaultDistributionStrategy(StrategyV1):
2127  """Default `tf.distribute.Strategy` if none is explicitly selected."""
2128
2129  def __init__(self):
2130    if not _creating_default_strategy_singleton:
2131      raise RuntimeError("Should only create a single instance of "
2132                         "_DefaultDistributionStrategy")
2133    super(_DefaultDistributionStrategy, self).__init__(
2134        _DefaultDistributionExtended(self))
2135
2136  def __deepcopy__(self, memo):
2137    del memo
2138    raise RuntimeError("Should only create a single instance of "
2139                       "_DefaultDistributionStrategy")
2140
2141
2142class _DefaultDistributionContext(object):
2143  """Context manager setting the default `tf.distribute.Strategy`."""
2144
2145  def __init__(self, strategy):
2146
2147    def creator(next_creator, **kwargs):
2148      _require_strategy_scope_strategy(strategy)
2149      return next_creator(**kwargs)
2150
2151    self._var_creator_scope = variable_scope.variable_creator_scope(creator)
2152    self._strategy = strategy
2153    self._nested_count = 0
2154
2155  def __enter__(self):
2156    # Allow this scope to be entered if this strategy is already in scope.
2157    if distribution_strategy_context.has_strategy():
2158      raise RuntimeError("Must not nest tf.distribute.Strategy scopes.")
2159    if self._nested_count == 0:
2160      self._var_creator_scope.__enter__()
2161    self._nested_count += 1
2162    return self._strategy
2163
2164  def __exit__(self, exception_type, exception_value, traceback):
2165    self._nested_count -= 1
2166    if self._nested_count == 0:
2167      try:
2168        self._var_creator_scope.__exit__(
2169            exception_type, exception_value, traceback)
2170      except RuntimeError as e:
2171        six.raise_from(
2172            RuntimeError("Variable creator scope nesting error: move call to "
2173                         "tf.distribute.set_strategy() out of `with` scope."),
2174            e)
2175
2176
2177class _DefaultDistributionExtended(StrategyExtendedV1):
2178  """Implementation of _DefaultDistributionStrategy."""
2179
2180  def __init__(self, container_strategy):
2181    super(_DefaultDistributionExtended, self).__init__(container_strategy)
2182    self._retrace_functions_for_each_device = False
2183
2184  def _scope(self, strategy):
2185    """Context manager setting a variable creator and `self` as current."""
2186    return _DefaultDistributionContext(strategy)
2187
2188  def colocate_vars_with(self, colocate_with_variable):
2189    """Does not require `self.scope`."""
2190    _require_strategy_scope_extended(self)
2191    return ops.colocate_with(colocate_with_variable)
2192
2193  def variable_created_in_scope(self, v):
2194    return v._distribute_strategy is None  # pylint: disable=protected-access
2195
2196  def _experimental_distribute_dataset(self, dataset):
2197    return dataset
2198
2199  def _experimental_distribute_datasets_from_function(self, dataset_fn):
2200    return dataset_fn(InputContext())
2201
2202  def _make_dataset_iterator(self, dataset):
2203    return _DefaultDistributionExtended.DefaultInputIterator(dataset)
2204
2205  def _make_input_fn_iterator(self,
2206                              input_fn,
2207                              replication_mode=InputReplicationMode.PER_WORKER):
2208    dataset = input_fn(InputContext())
2209    return _DefaultDistributionExtended.DefaultInputIterator(dataset)
2210
2211  def _experimental_make_numpy_dataset(self, numpy_input, session):
2212    numpy_flat = nest.flatten(numpy_input)
2213    vars_flat = tuple(
2214        variable_scope.variable(array_ops.zeros(i.shape, i.dtype),
2215                                trainable=False, use_resource=True)
2216        for i in numpy_flat
2217    )
2218    for v, i in zip(vars_flat, numpy_flat):
2219      numpy_dataset.init_var_from_numpy(v, i, session)
2220    vars_nested = nest.pack_sequence_as(numpy_input, vars_flat)
2221    return dataset_ops.Dataset.from_tensor_slices(vars_nested)
2222
2223  def _broadcast_to(self, tensor, destinations):
2224    if destinations is None:
2225      return tensor
2226    else:
2227      raise NotImplementedError("TODO")
2228
2229  def _call_for_each_replica(self, fn, args, kwargs):
2230    with ReplicaContext(
2231        self._container_strategy(),
2232        replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)):
2233      return fn(*args, **kwargs)
2234
2235  def _reduce_to(self, reduce_op, value, destinations):
2236    # TODO(josh11b): Use destinations?
2237    del reduce_op, destinations
2238    return value
2239
2240  def _update(self, var, fn, args, kwargs, group):
2241    # The implementations of _update() and _update_non_slot() are identical
2242    # except _update() passes `var` as the first argument to `fn()`.
2243    return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group)
2244
2245  def _update_non_slot(self, colocate_with, fn, args, kwargs, should_group):
2246    # TODO(josh11b): Figure out what we should be passing to UpdateContext()
2247    # once that value is used for something.
2248    with UpdateContext(colocate_with):
2249      result = fn(*args, **kwargs)
2250      if should_group:
2251        return result
2252      else:
2253        return nest.map_structure(self._local_results, result)
2254
2255  def read_var(self, replica_local_var):
2256    return array_ops.identity(replica_local_var)
2257
2258  def _local_results(self, distributed_value):
2259    return (distributed_value,)
2260
2261  def value_container(self, value):
2262    return value
2263
2264  @property
2265  def _num_replicas_in_sync(self):
2266    return 1
2267
2268  @property
2269  def worker_devices(self):
2270    raise RuntimeError("worker_devices() method unsupported by default "
2271                       "tf.distribute.Strategy.")
2272
2273  @property
2274  def parameter_devices(self):
2275    raise RuntimeError("parameter_devices() method unsupported by default "
2276                       "tf.distribute.Strategy.")
2277
2278  def non_slot_devices(self, var_list):
2279    return min(var_list, key=lambda x: x.name)
2280
2281  def _in_multi_worker_mode(self):
2282    """Whether this strategy indicates working in multi-worker settings."""
2283    # Default strategy doesn't indicate multi-worker training.
2284    return False
2285
2286  # TODO(priyag): This should inherit from `InputIterator`, once dependency
2287  # issues have been resolved.
2288  class DefaultInputIterator(object):
2289    """Default implementation of `InputIterator` for default strategy."""
2290
2291    def __init__(self, dataset):
2292      self._dataset = dataset
2293      if eager_context.executing_eagerly():
2294        self._iterator = dataset_ops.make_one_shot_iterator(dataset)
2295      else:
2296        self._iterator = dataset_ops.make_initializable_iterator(dataset)
2297
2298    def get_next(self):
2299      return self._iterator.get_next()
2300
2301    @deprecated(None, "Use the iterator's `initializer` property instead.")
2302    def initialize(self):
2303      """Initialize underlying iterators.
2304
2305      Returns:
2306        A list of any initializer ops that should be run.
2307      """
2308      if eager_context.executing_eagerly():
2309        self._iterator = self._dataset.make_one_shot_iterator()
2310        return []
2311      else:
2312        return [self._iterator.initializer]
2313
2314    @property
2315    def initializer(self):
2316      """Returns a list of ops that initialize the iterator."""
2317      return self.initialize()
2318
2319  # TODO(priyag): Delete this once all strategies use global batch size.
2320  @property
2321  def _global_batch_size(self):
2322    """Global and per-replica batching are equivalent for this strategy."""
2323    return True
2324
2325
2326# ------------------------------------------------------------------------------
2327# We haven't yet implemented deserialization for DistributedVariables.
2328# So here we catch any attempts to deserialize variables
2329# when using distribution strategies.
2330# pylint: disable=protected-access
2331_original_from_proto = resource_variable_ops._from_proto_fn
2332
2333
2334def _from_proto_fn(v, import_scope=None):
2335  if distribution_strategy_context.has_strategy():
2336    raise NotImplementedError(
2337        "Deserialization of variables is not yet supported when using a "
2338        "tf.distribute.Strategy.")
2339  else:
2340    return _original_from_proto(v, import_scope=import_scope)
2341
2342resource_variable_ops._from_proto_fn = _from_proto_fn
2343# pylint: enable=protected-access
2344
2345
2346#-------------------------------------------------------------------------------
2347# Shorthand for some methods from distribution_strategy_context.
2348_push_per_thread_mode = distribution_strategy_context._push_per_thread_mode  # pylint: disable=protected-access
2349_get_per_thread_mode = distribution_strategy_context._get_per_thread_mode  # pylint: disable=protected-access
2350_pop_per_thread_mode = distribution_strategy_context._pop_per_thread_mode  # pylint: disable=protected-access
2351_get_default_replica_mode = (
2352    distribution_strategy_context._get_default_replica_mode)  # pylint: disable=protected-access
2353
2354
2355# ------------------------------------------------------------------------------
2356# Metrics to track which distribution strategy is being called
2357distribution_strategy_gauge = monitoring.StringGauge(
2358    "/tensorflow/api/distribution_strategy",
2359    "Gauge to track the type of distribution strategy used.", "TFVersion")
2360distribution_strategy_replica_gauge = monitoring.IntGauge(
2361    "/tensorflow/api/distribution_strategy/replica",
2362    "Gauge to track the number of replica each distribution strategy used.",
2363    "CountType")
2364