• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Library for running a computation across multiple devices."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import copy
22import threading
23import weakref
24import enum
25
26from tensorflow.python.data.ops import dataset_ops
27from tensorflow.python.distribute import device_util
28from tensorflow.python.distribute import distribution_strategy_context
29from tensorflow.python.distribute import numpy_dataset
30from tensorflow.python.distribute import reduce_util
31from tensorflow.python.eager import context as eager_context
32from tensorflow.python.framework import constant_op
33from tensorflow.python.framework import dtypes
34from tensorflow.python.framework import ops
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import control_flow_ops
37from tensorflow.python.ops import custom_gradient
38from tensorflow.python.ops import math_ops
39from tensorflow.python.ops import resource_variable_ops
40from tensorflow.python.ops import variable_scope
41from tensorflow.python.platform import tf_logging
42from tensorflow.python.util import nest
43from tensorflow.python.util.tf_export import tf_export
44from tensorflow.tools.docs import doc_controls
45
46
47# ------------------------------------------------------------------------------
48# Context tracking whether in a strategy.update() or .update_non_slot() call.
49
50
51_update_device = threading.local()
52
53
54def get_update_device():
55  """Get the current device if in a `tf.distribute.Strategy.update()` call."""
56  try:
57    return _update_device.current
58  except AttributeError:
59    return None
60
61
62class UpdateContext(object):
63  """Context manager when you are in `update()` or `update_non_slot()`."""
64
65  def __init__(self, device):
66    self._device = device
67    self._old_device = None
68
69  def __enter__(self):
70    self._old_device = get_update_device()
71    _update_device.current = self._device
72
73  def __exit__(self, exception_type, exception_value, traceback):
74    del exception_type, exception_value, traceback
75    _update_device.current = self._old_device
76
77
78# ------------------------------------------------------------------------------
79# Public utility functions.
80
81
82@tf_export(v1=["distribute.get_loss_reduction"])
83def get_loss_reduction():
84  """DEPRECATED: Now always returns `tf.distribute.ReduceOp.SUM`.
85
86  We now always make the complete adjustment when computing the loss, so
87  code should always add gradients/losses across replicas, never average.
88  """
89  return reduce_util.ReduceOp.SUM
90
91
92# ------------------------------------------------------------------------------
93# Internal API for validating the current thread mode
94
95
96def _require_cross_replica_or_default_context_extended(extended):
97  """Verify in cross-replica context."""
98  context = _get_per_thread_mode()
99  cross_replica = context.cross_replica_context
100  if cross_replica is not None and cross_replica.extended is extended:
101    return
102  if context is _get_default_replica_mode():
103    return
104  strategy = extended._container_strategy()  # pylint: disable=protected-access
105  # We have an error to report, figure out the right message.
106  if context.strategy is not strategy:
107    _wrong_strategy_scope(strategy, context)
108  assert cross_replica is None
109  raise RuntimeError("Method requires being in cross-replica context, use "
110                     "get_replica_context().merge_call()")
111
112
113def _wrong_strategy_scope(strategy, context):
114  # Figure out the right error message.
115  if not distribution_strategy_context.has_strategy():
116    raise RuntimeError(
117        'Need to be inside "with strategy.scope()" for %s' %
118        (strategy,))
119  else:
120    raise RuntimeError(
121        "Mixing different tf.distribute.Strategy objects: %s is not %s" %
122        (context.strategy, strategy))
123
124
125def require_replica_context(replica_ctx):
126  """Verify in `replica_ctx` replica context."""
127  context = _get_per_thread_mode()
128  if context.replica_context is replica_ctx: return
129  # We have an error to report, figure out the right message.
130  if context.replica_context is None:
131    raise RuntimeError("Need to be inside `call_for_each_replica()`")
132  if context.strategy is replica_ctx.strategy:
133    # Two different ReplicaContexts with the same tf.distribute.Strategy.
134    raise RuntimeError("Mismatching ReplicaContext.")
135  raise RuntimeError(
136      "Mismatching tf.distribute.Strategy objects: %s is not %s." %
137      (context.strategy, replica_ctx.strategy))
138
139
140def _require_strategy_scope_strategy(strategy):
141  """Verify in a `strategy.scope()` in this thread."""
142  context = _get_per_thread_mode()
143  if context.strategy is strategy: return
144  _wrong_strategy_scope(strategy, context)
145
146
147def _require_strategy_scope_extended(extended):
148  """Verify in a `distribution_strategy.scope()` in this thread."""
149  context = _get_per_thread_mode()
150  if context.strategy.extended is extended: return
151  # Report error.
152  strategy = extended._container_strategy()  # pylint: disable=protected-access
153  _wrong_strategy_scope(strategy, context)
154
155
156# ------------------------------------------------------------------------------
157# Internal context managers used to implement the DistributionStrategy
158# base class
159
160
161class _CurrentDistributionContext(object):
162  """Context manager setting the current `tf.distribute.Strategy`.
163
164  Also: overrides the variable creator and optionally the current device.
165  """
166
167  def __init__(self,
168               strategy,
169               var_creator_scope,
170               var_scope=None,
171               default_device=None):
172    self._context = distribution_strategy_context._CrossReplicaThreadMode(  # pylint: disable=protected-access
173        strategy)
174    self._var_creator_scope = var_creator_scope
175    self._var_scope = var_scope
176    if default_device:
177      self._device_scope = ops.device(default_device)
178    else:
179      self._device_scope = None
180
181  def __enter__(self):
182    _push_per_thread_mode(self._context)
183    if self._var_scope:
184      self._var_scope.__enter__()
185    self._var_creator_scope.__enter__()
186    if self._device_scope:
187      self._device_scope.__enter__()
188    return self._context.strategy
189
190  def __exit__(self, exception_type, exception_value, traceback):
191    if self._device_scope:
192      self._device_scope.__exit__(exception_type, exception_value, traceback)
193    self._var_creator_scope.__exit__(exception_type, exception_value, traceback)
194    if self._var_scope:
195      self._var_scope.__exit__(exception_type, exception_value, traceback)
196    _pop_per_thread_mode()
197
198
199class _SameScopeAgainContext(object):
200  """Trivial context manager when you are already in `scope()`."""
201
202  def __init__(self, strategy):
203    self._strategy = strategy
204
205  def __enter__(self):
206    return self._strategy
207
208  def __exit__(self, exception_type, exception_value, traceback):
209    del exception_type, exception_value, traceback
210
211
212# TODO(yuefengz): add more replication modes.
213@tf_export("distribute.InputReplicationMode")
214class InputReplicationMode(enum.Enum):
215  """Replication mode for input function.
216
217  * `PER_WORKER`: The input function will be called on each worker
218    independently, creating as many input pipelines as number of workers.
219    Replicas will dequeue from the local Dataset on their worker.
220    `tf.distribute.Strategy` doesn't manage any state sharing between such
221    separate input pipelines.
222  """
223  PER_WORKER = "PER_WORKER"
224
225
226@tf_export("distribute.InputContext")
227class InputContext(object):
228  """A class wrapping information needed by an input function.
229
230  This is a context class that is passed to the user's input fn and contains
231  information about the compute replicas and input pipelines. The number of
232  compute replicas (in sync training) helps compute per input pipeline batch
233  size from the desired global batch size. Input pipeline information can be
234  used to return a different subset of the input in each input pipeline (for
235  e.g. shard the input pipeline, use a different input source etc).
236  """
237
238  def __init__(self,
239               num_input_pipelines=1,
240               input_pipeline_id=0,
241               num_replicas_in_sync=1):
242    """Initializes an InputContext object.
243
244    Args:
245      num_input_pipelines: the number of input pipelines in a cluster.
246      input_pipeline_id: the current input pipeline id, should be an int in
247        [0,`num_input_pipelines`).
248      num_replicas_in_sync: the number of replicas that are in sync.
249    """
250    self._num_input_pipelines = num_input_pipelines
251    self._input_pipeline_id = input_pipeline_id
252    self._num_replicas_in_sync = num_replicas_in_sync
253
254  @property
255  def num_replicas_in_sync(self):
256    """Returns the number of compute replicas in sync."""
257    return self._num_replicas_in_sync
258
259  @property
260  def input_pipeline_id(self):
261    """Returns the input pipeline ID."""
262    return self._input_pipeline_id
263
264  @property
265  def num_input_pipelines(self):
266    """Returns the number of input pipelines."""
267    return self._num_input_pipelines
268
269  def get_per_replica_batch_size(self, global_batch_size):
270    """Returns the per-replica batch size.
271
272    Args:
273      global_batch_size: the global batch size which should be divisible by
274        `num_replicas_in_sync`.
275
276    Returns:
277      the per-replica batch size.
278
279    Raises:
280      ValueError: if `global_batch_size` not divisible by
281        `num_replicas_in_sync`.
282    """
283    if global_batch_size % self._num_replicas_in_sync != 0:
284      raise ValueError("The `global_batch_size` %r is not divisible by "
285                       "`num_replicas_in_sync` %r " %
286                       (global_batch_size, self._num_replicas_in_sync))
287    return global_batch_size // self._num_replicas_in_sync
288
289
290# ------------------------------------------------------------------------------
291# Base classes for all distribution strategies.
292
293
294@tf_export("distribute.Strategy")
295class DistributionStrategy(object):
296  """A list of devices with a state & compute distribution policy.
297
298  See [tensorflow/contrib/distribute/README.md](
299  https://www.tensorflow.org/code/tensorflow/contrib/distribute/README.md)
300  for overview and examples.
301  """
302
303  # TODO(josh11b): Raise an exception if variable partitioning requested before
304  #   we add support.
305  # TODO(josh11b): Also `parameter_device_index` property?
306  # TODO(josh11b): `map()`
307  # TODO(josh11b): ClusterSpec/ClusterResolver
308  # TODO(josh11b): Partitioned computations, state; sharding
309  # TODO(josh11b): Model parallelism: "replicas" with multiple devices; shuffling
310  # TODO(josh11b): List of replicas with their worker and parameter devices
311  #   (where the parameter devices may overlap in the ps case).
312
313  def __init__(self, extended):
314    self._extended = extended
315
316  @property
317  def extended(self):
318    """`tf.distribute.StrategyExtended` with additional methods."""
319    return self._extended
320
321  def scope(self):
322    """Returns a context manager selecting this Strategy as current.
323
324    Inside a `with strategy.scope():` code block, this thread
325    will use a variable creator set by `strategy`, and will
326    enter its "cross-replica context".
327
328    Returns:
329      A context manager.
330    """
331    return self._extended._scope(self)  # pylint: disable=protected-access
332
333  @doc_controls.do_not_generate_docs  # DEPRECATED, moving to `extended`
334  def colocate_vars_with(self, colocate_with_variable):
335    """DEPRECATED: use extended.colocate_vars_with() instead."""
336    return self._extended.colocate_vars_with(colocate_with_variable)
337
338  def make_dataset_iterator(self, dataset):
339    """Makes an iterator for input provided via `dataset`.
340
341    Data from the given dataset will be distributed evenly across all the
342    compute replicas. We will assume that the input dataset is batched by the
343    global batch size. With this assumption, we will make a best effort to
344    divide each batch across all the replicas (one or more workers).
345    If this effort fails, an error will be thrown, and the user should instead
346    use `make_input_fn_iterator` which provides more control to the user, and
347    does not try to divide a batch across replicas.
348
349    The user could also use `make_input_fn_iterator` if they want to
350    customize which input is fed to which replica/worker etc.
351
352    Args:
353      dataset: `tf.data.Dataset` that will be distributed evenly across all
354        replicas.
355
356    Returns:
357      An `tf.distribute.InputIterator` which returns inputs for each step of the
358      computation.  User should call `initialize` on the returned iterator.
359    """
360    return self._extended._make_dataset_iterator(dataset)  # pylint: disable=protected-access
361
362  def make_input_fn_iterator(self,
363                             input_fn,
364                             replication_mode=InputReplicationMode.PER_WORKER):
365    """Returns an iterator split across replicas created from an input function.
366
367    The `input_fn` should take an `tf.distribute.InputContext` object where
368    information about batching and input sharding can be accessed:
369
370    ```
371    def input_fn(input_context):
372      batch_size = input_context.get_per_replica_batch_size(global_batch_size)
373      d = tf.data.Dataset.from_tensors([[1.]]).repeat().batch(batch_size)
374      return d.shard(input_context.num_input_pipelines,
375                     input_context.input_pipeline_id)
376    with strategy.scope():
377      iterator = strategy.make_input_fn_iterator(input_fn)
378      replica_results = strategy.experimental_run(replica_fn, iterator)
379    ```
380
381    The `tf.data.Dataset` returned by `input_fn` should have a per-replica
382    batch size, which may be computed using
383    `input_context.get_per_replica_batch_size`.
384
385    Args:
386      input_fn: A function taking a `tf.distribute.InputContext` object and
387        returning a `tf.data.Dataset`.
388      replication_mode: an enum value of `tf.distribute.InputReplicationMode`.
389        Only `PER_WORKER` is supported currently, which means there will be
390        a single call to `input_fn` per worker. Replicas will dequeue from the
391        local `tf.data.Dataset` on their worker.
392
393    Returns:
394      An iterator object that should first be `.initialize()`-ed. It may then
395      either be passed to `strategy.experimental_run()` or you can
396      `iterator.get_next()` to get the next value to pass to
397      `strategy.extended.call_for_each_replica()`.
398    """
399    if replication_mode != InputReplicationMode.PER_WORKER:
400      raise ValueError(
401          "Input replication mode not supported: %r" % replication_mode)
402    with self.scope():
403      return self.extended._make_input_fn_iterator(  # pylint: disable=protected-access
404          input_fn, replication_mode=replication_mode)
405
406  def experimental_make_numpy_iterator(
407      self, numpy_input, batch_size, num_epochs=1, shuffle=1024, session=None):
408    """Makes an iterator for input provided via a nest of numpy arrays.
409
410    Args:
411      numpy_input: A nest of NumPy input arrays that will be distributed evenly
412        across all replicas. Note that lists of Numpy arrays are stacked,
413        as that is normal `tf.data.Dataset` behavior.
414      batch_size: The number of entries from the array we should consume in one
415        step of the computation, across all replicas. This is the global batch
416        size. It should be divisible by `num_replicas_in_sync`.
417      num_epochs: The number of times to iterate through the examples. A value
418        of `None` means repeat forever.
419      shuffle: Size of buffer to use for shuffling the input examples.
420        Use `None` to disable shuffling.
421      session: (TensorFlow v1.x graph execution only) A session used for
422        initialization.
423
424    Returns:
425      An `tf.distribute.InputIterator` which returns inputs for each step of the
426      computation.  User should call `initialize` on the returned iterator.
427    """
428    ds = self.extended.experimental_make_numpy_dataset(
429        numpy_input, session=session)
430    if shuffle:
431      ds = ds.shuffle(shuffle)
432    if num_epochs != 1:
433      ds = ds.repeat(num_epochs)
434    # We need to use the drop_remainder argument to get a known static
435    # input shape which is required for TPUs.
436    drop_remainder = self.extended.experimental_require_static_shapes
437    ds = ds.batch(batch_size, drop_remainder=drop_remainder)
438    return self.make_dataset_iterator(ds)
439
440  def experimental_run(self, fn, input_iterator=None):
441    """Runs ops in `fn` on each replica, with inputs from `input_iterator`.
442
443    When eager execution is enabled, executes ops specified by `fn` on each
444    replica. Otherwise, builds a graph to execute the ops on each replica.
445
446    Each replica will take a single, different input from the inputs provided by
447    one `get_next` call on the input iterator.
448
449    `fn` may call `tf.distribute.get_replica_context()` to access members such
450    as `replica_id_in_sync_group`.
451
452    IMPORTANT: Depending on the `tf.distribute.Strategy` implementation being
453    used, and whether eager execution is enabled, `fn` may be called one or more
454    times (once for each replica).
455
456    Args:
457      fn: The function to run. The inputs to the function must match the outputs
458        of `input_iterator.get_next()`. The output must be a `tf.nest` of
459        `Tensor`s.
460      input_iterator: (Optional) input iterator from which the inputs are taken.
461
462    Returns:
463      Merged return value of `fn` across replicas. The structure of the return
464      value is the same as the return value from `fn`. Each element in the
465      structure can either be `PerReplica` (if the values are unsynchronized),
466      `Mirrored` (if the values are kept in sync), or `Tensor` (if running on a
467      single replica).
468    """
469    with self.scope():
470      args = (input_iterator.get_next(),) if input_iterator is not None else ()
471    return self.experimental_run_v2(fn, args=args)
472
473  def experimental_run_v2(self, fn, args=(), kwargs=None):
474    """Runs ops in `fn` on each replica, with the given arguments.
475
476    When eager execution is enabled, executes ops specified by `fn` on each
477    replica. Otherwise, builds a graph to execute the ops on each replica.
478
479    `fn` may call `tf.distribute.get_replica_context()` to access members such
480    as `replica_id_in_sync_group`.
481
482    IMPORTANT: Depending on the `tf.distribute.Strategy` implementation being
483    used, and whether eager execution is enabled, `fn` may be called one or more
484    times (once for each replica).
485
486    Args:
487      fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
488      args: (Optional) Positional arguments to `fn`.
489      kwargs: (Optional) Keyword arguments to `fn`.
490
491    Returns:
492      Merged return value of `fn` across replicas. The structure of the return
493      value is the same as the return value from `fn`. Each element in the
494      structure can either be `PerReplica` (if the values are unsynchronized),
495      `Mirrored` (if the values are kept in sync), or `Tensor` (if running on a
496      single replica).
497    """
498    with self.scope():
499      return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
500
501  def reduce(self, reduce_op, value):
502    """Reduce `value` across replicas.
503
504    Args:
505      reduce_op: A `tf.distribute.ReduceOp` value specifying how values should
506        be combined.
507      value: A "per replica" value to be combined into a single tensor.
508
509    Returns:
510      A `Tensor`.
511    """
512    _require_cross_replica_or_default_context_extended(self._extended)
513    return self._extended._reduce(reduce_op, value)  # pylint: disable=protected-access
514
515  @doc_controls.do_not_generate_docs  # DEPRECATED
516  def unwrap(self, value):
517    """Returns the list of all local per-replica values contained in `value`.
518
519    DEPRECATED: Please use `experimental_local_results` instead.
520
521    Note: This only returns values on the workers initiated by this client.
522    When using a `Strategy` like
523    `tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker
524    will be its own client, and this function will only return values
525    computed on that worker.
526
527    Args:
528      value: A value returned by `experimental_run()`,
529        `extended.call_for_each_replica()`, or a variable created in `scope`.
530
531    Returns:
532      A tuple of values contained in `value`. If `value` represents a single
533      value, this returns `(value,).`
534    """
535    return self._extended._local_results(value)  # pylint: disable=protected-access
536
537  def experimental_local_results(self, value):
538    """Returns the list of all local per-replica values contained in `value`.
539
540    Note: This only returns values on the workers initiated by this client.
541    When using a `Strategy` like
542    `tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker
543    will be its own client, and this function will only return values
544    computed on that worker.
545
546    Args:
547      value: A value returned by `experimental_run()`, `experimental_run_v2()`,
548        `extended.call_for_each_replica()`, or a variable created in `scope`.
549
550    Returns:
551      A tuple of values contained in `value`. If `value` represents a single
552      value, this returns `(value,).`
553    """
554    return self._extended._local_results(value)  # pylint: disable=protected-access
555
556  @doc_controls.do_not_generate_docs  # DEPRECATED: TF v1.x only
557  def group(self, value, name=None):
558    """Shortcut for `tf.group(self.experimental_local_results(value))`."""
559    return self._extended._group(value, name)  # pylint: disable=protected-access
560
561  @property
562  def num_replicas_in_sync(self):
563    """Returns number of replicas over which gradients are aggregated."""
564    return self._extended._num_replicas_in_sync  # pylint: disable=protected-access
565
566  @doc_controls.do_not_generate_docs  # DEPRECATED, being replaced by a new API.
567  def configure(self,
568                session_config=None,
569                cluster_spec=None,
570                task_type=None,
571                task_id=None):
572    # pylint: disable=g-doc-return-or-yield,g-doc-args
573    """DEPRECATED: use `update_config_proto` instead.
574
575    Configures the strategy class.
576
577    DEPRECATED: This method's functionality has been split into the strategy
578    constructor and `update_config_proto`. In the future, we will allow passing
579    cluster and config_proto to the constructor to configure the strategy. And
580    `update_config_proto` can be used to update the config_proto based on the
581    specific strategy.
582    """
583    return self._extended._configure(  # pylint: disable=protected-access
584        session_config, cluster_spec, task_type, task_id)
585
586  def update_config_proto(self, config_proto):
587    """Returns a copy of `config_proto` modified for use with this strategy.
588
589    The updated config has something needed to run a strategy, e.g.
590    configuration to run collective ops, or device filters to improve
591    distributed training performance.
592
593    Args:
594      config_proto: a `tf.ConfigProto` object.
595
596    Returns:
597      The updated copy of the `config_proto`.
598    """
599    return self._extended._update_config_proto(config_proto)  # pylint: disable=protected-access
600
601  def __deepcopy__(self, memo):
602    # First do a regular deepcopy of `self`.
603    cls = self.__class__
604    result = cls.__new__(cls)
605    memo[id(self)] = result
606    for k, v in self.__dict__.items():
607      setattr(result, k, copy.deepcopy(v, memo))
608    # One little fix-up: we want `result._extended` to reference `result`
609    # instead of `self`.
610    result._extended._container_strategy_weakref = weakref.ref(result)  # pylint: disable=protected-access
611    return result
612
613  def __copy__(self):
614    raise RuntimeError("Must only deepcopy DistributionStrategy.")
615
616
617@tf_export("distribute.StrategyExtended")
618class DistributionStrategyExtended(object):
619  """Additional APIs for algorithms that need to be distribution-aware.
620
621  The intent is that you can write an algorithm in a stylized way and
622  it will be usable with a variety of different
623  `tf.distribute.Strategy`
624  implementations. Each descendant will implement a different strategy
625  for distributing the algorithm across multiple devices/machines.
626  Furthermore, these changes can be hidden inside the specific layers
627  and other library classes that need special treatment to run in a
628  distributed setting, so that most users' model definition code can
629  run unchanged. The `tf.distribute.Strategy` API works the same way
630  with eager and graph execution.
631
632  First let's introduce a few high-level concepts:
633
634  * _Data parallelism_ is where we run multiple copies of the model
635    on different slices of the input data. This is in contrast to
636    _model parallelism_ where we divide up a single copy of a model
637    across multiple devices.
638    Note: we only support data parallelism for now, but
639    hope to add support for model parallelism in the future.
640  * A _replica_ is one copy of the model, running on one slice of the
641    input data.
642  * _Synchronous_, or more commonly _sync_, training is where the
643    updates from each replica are aggregated together before updating
644    the model variables. This is in contrast to _asynchronous_, or
645    _async_ training, where each replica updates the model variables
646    independently.
647  * Furthermore you might run your computation on multiple devices
648    on one machine (or "host"), or on multiple machines/hosts.
649    If you are running on multiple machines, you might have a
650    single master host that drives computation across all of them,
651    or you might have multiple clients driving the computation
652    asynchronously.
653
654  To distribute an algorithm, we might use some of these ingredients:
655
656  * Parameter servers: These are hosts that hold a single copy of
657    parameters/variables. All replicas that want to operate on a variable
658    retrieve it at the beginning of a step and send an update to be
659    applied at the end of the step. Can support either sync or async
660    training.
661  * Mirrored variables: These are variables that are copied to multiple
662    devices, where we keep the copies in sync by applying the same
663    updates to every copy. Normally would only be used with sync training.
664  * Reductions and Allreduce: A _reduction_ is some method of
665    aggregating multiple values into one value, like "sum" or
666    "mean". If doing sync training, we will perform a reduction on the
667    gradients to a parameter from all replicas before applying the
668    update. Allreduce is an algorithm for performing a reduction on
669    values from multiple devices and making the result available on
670    all of those devices.
671  * In the future we will have support for TensorFlow's partitioned
672    variables, where a single variable is split across multiple
673    devices.
674
675  We have then a few approaches we want to support:
676
677  * Code written (as if) with no knowledge of class `tf.distribute.Strategy`.
678    This code should work as before, even if some of the layers, etc.
679    used by that code are written to be distribution-aware. This is done
680    by having a default `tf.distribute.Strategy` that gives ordinary behavior,
681    and by default being in a single replica context.
682  * Ordinary model code that you want to run using a specific
683    `tf.distribute.Strategy`. This can be as simple as:
684
685    ```
686    with my_strategy.scope():
687      iterator = my_strategy.make_dataset_iterator(dataset)
688      session.run(iterator.initialize())
689      replica_train_ops = my_strategy.extended.call_for_each_replica(
690          replica_fn, args=(iterator.get_next(),))
691      train_op = my_strategy.group(replica_train_ops)
692    ```
693
694    This takes an ordinary `dataset` and `replica_fn` and runs it
695    distributed using a particular `tf.distribute.Strategy` in
696    `my_strategy`. Any variables created in `replica_fn` are created
697    using `my_strategy`'s policy, and library functions called by
698    `replica_fn` can use the `get_replica_context()` API to get enhanced
699    behavior in this case.
700
701  * If you want to write a distributed algorithm, you may use any of
702    the `tf.distribute.Strategy` APIs inside a
703    `with my_strategy.scope():` block of code.
704
705  Lower-level concepts:
706
707  * Wrapped values: In order to represent values parallel across devices
708    (either replicas or the devices associated with a particular value), we
709    wrap them in a "PerReplica" or "Mirrored" object that contains a map
710    from device to values. "PerReplica" is used when the value may be
711    different across replicas, and "Mirrored" when the value are the same.
712  * Unwrapping and merging: Consider calling a function `fn` on multiple
713    replicas, like `extended.call_for_each_replica(fn, args=[w])` with an
714    argument `w` that is a wrapped value. This means `w` will have a map taking
715    replica device `d0` to `w0`, replica device `d1` to `w1`,
716    etc. `extended.call_for_each_replica()` unwraps `w` before calling `fn`, so
717    it calls `fn(w0)` on `d0`, `fn(w1)` on `d1`, etc.  It then merges the return
718    values from `fn()`, which can possibly result in wrapped values. For
719    example, let's say `fn()` returns a tuple with three components: `(x, a,
720    v0)` from replica 0, `(x, b, v1)` on replica 1, etc. If the first component
721    is the same object `x` from every replica, then the first component of the
722    merged result will also be `x`. If the second component is different (`a`,
723    `b`, ...)  from each replica, then the merged value will have a wrapped map
724    from replica device to the different values. If the third component is the
725    members of a mirrored variable (`v` maps `d0` to `v0`, `d1` to `v1`, etc.),
726    then the merged result will be that mirrored variable (`v`).
727  * Replica context vs. Cross-replica context: _replica context_ is when we
728    are in some function that is being called once for each replica.
729    Otherwise we are in cross-replica context, which is useful for
730    calling `tf.distribute.Strategy` methods which operate across the
731    replicas (like `reduce_to()`). By default you start in a replica context
732    (the default "single replica context") and then some methods can
733    switch you back and forth, as described below.
734  * Worker devices vs. parameter devices: Most replica computations will
735    happen on worker devices. Since we don't yet support model
736    parallelism, there will be one worker device per replica. When using
737    parameter servers (see above), the set of devices holding
738    variables may be different, otherwise the parameter devices might
739    match the worker devices.
740  * Non-slot devices are some subset of the parameter devices where we
741    put all the non-slot variables. We need to ensure that all
742    non-slot variables are allocated on the same device, or mirrored
743    across the same set of devices. If you have some variable you want
744    to colocate all the non-slot variables with, you can use
745    `colocate_vars_with()` to get the remaining non-slot variables on
746    the same device.  Otherwise you can use `non_slot_devices()` to
747    pick a consistent set of devices to pass to both
748    `colocate_vars_with()` and `update_non_slot()`.
749
750  When using a `tf.distribute.Strategy`, we have a new type dimension
751  called _locality_ that says what values are compatible with which
752  APIs:
753
754  * T: different value for each replica (e.g. a PerReplica-wrapped value).
755  * M: value is "mirrored" across replicas, i.e. there are copies with the
756    same value on each replica (e.g. a Mirrored-wrapped value).
757  * V(`v`): value is "mirrored" across all the devices which have a
758    copy of variable `v` (also a Mirrored-wrapped value, but over
759    parameter devices instead of worker devices).
760  * N: value is "mirrored" across all the "non-slot" devices
761
762  Rules for methods with respect to locality and single-replica vs.
763  cross-replica context:
764
765  * `with d.scope()`: default single-replica context -> cross-replica context
766    for `d`
767  * `with d.extended.colocate_vars_with(v)`: in replica/cross-replica context,
768    variables will be created with locality V(`v`). That is, if we write
769    `with d.extended.colocate_vars_with(v1): v2 = tf.get_variable(...)`,
770    then `v2` will have locality V(`v1`), i.e. locality V(`v2`) will equal
771    V(`v1`).
772  * `with d.extended.colocate_vars_with(d.extended.non_slot_devices(...))`: in
773    replica/cross-replica context, variables will be created with locality N
774  * `v = tf.get_variable(...)`: in replica/cross-replica context, creates
775    a variable (which by definition will have locality V(`v`), though
776    will match another locality if inside a `colocate_vars_with`
777    scope).
778  * `d.make_dataset_iterator(dataset)`: in cross-replica
779    context, produces an iterator with locality T
780  * `d.extended.broadcast_to(t, v)`: in cross-replica context, produces a value
781    with locality V(`v`)
782  * `d.extended.call_for_each_replica(fn, ...)`: in cross-replica context, runs
783    `fn()` in a replica context (and so may call `get_replica_context()` and
784    use its API, including `merge_call()` to get back to cross-replica
785    context), once for each replica. May use values with locality T or
786    M, and any variable.
787  * `d.extended.reduce_to(m, t, t)`: in cross-replica context, accepts t with
788    locality T and produces a value with locality M.
789  * `d.extended.reduce_to(m, t, v)`: in cross-replica context, accepts t with
790    locality T and produces a value with locality V(`v`).
791  * `d.extended.batch_reduce_to(m, [(t, v)]): see `d.extended.reduce_to()`
792  * `d.extended.update(v, fn, ...)`: in cross-replica context, runs `fn()` once
793    for each device `v` is copied to, all inputs should have locality
794    V(`v`), output will have locality V(`v`) as well.
795  * `d.extended.update_non_slot(d.extended.non_slot_devices(), fn)`: in
796    cross-replica context, like `d.extended.update()` except with locality N.
797  * `d.extended.read_var(v)`: Gets the (read-only) value of the variable `v` (on
798    the device determined by the current device scope), aggregating
799    across replicas for replica-local variables. Frequently, this will be
800    done automatically when using `v` in an expression or fetching it in
801    a cross-replica context, but this function can be used to force that
802    conversion happens at a particular point in time (for example, to
803    add the result of the conversion to a graph collection).
804
805  The standard pattern for updating variables is to:
806
807  1. Create an input iterator with `d.make_dataset_iterator()`.
808  2. Define each replica `d.extended.call_for_each_replica()` up to the point of
809     getting a list of gradient, variable pairs.
810  3. Call `d.extended.reduce_to(VariableAggregation.SUM, t, v)` or
811     `d.extended.batch_reduce_to()` to sum the gradients (with locality T)
812     into values with locality V(`v`).
813  4. Call `d.extended.update(v)` for each variable to update its value.
814
815  Steps 3 and 4 are done automatically by class `Optimizer` if you call
816  its `apply_gradients` method in a replica context. Otherwise you can
817  manually call its `_distributed_apply` method in a cross-replica context.
818
819  Another thing you might want to do in the middle of your replica function is
820  an all-reduce of some intermediate value, using `d.extended.reduce_to()` or
821  `d.extended.batch_reduce_to()`. You simply provide the same tensor as the
822  input and destination.
823
824  Layers should expect to be called in a replica context, and can use
825  the `tf.distribute.get_replica_context` function to get a
826  `tf.distribute.ReplicaContext` object. The
827  `ReplicaContext` object has a `merge_call()` method for entering
828  cross-replica context where you can use `reduce_to()` (or
829  `batch_reduce_to()`) and then optionally `update()` to update state.
830
831  You may use this API whether or not a `tf.distribute.Strategy` is
832  being used, since there is a default implementation of
833  `ReplicaContext` and `tf.distribute.Strategy`.
834
835  NOTE for new `tf.distribute.Strategy` implementations: Please put all logic
836  in a subclass of `tf.distribute.StrategyExtended`. The only code needed for
837  the `tf.distribute.Strategy` subclass is for instantiating your subclass of
838  `tf.distribute.StrategyExtended` in the `__init__` method.
839  """
840
841  def __init__(self, container_strategy):
842    self._container_strategy_weakref = weakref.ref(container_strategy)
843    self._default_device = None
844    # This property is used to determine if we should set drop_remainder=True
845    # when creating Datasets from numpy array inputs.
846    self._require_static_shapes = False
847
848  def _container_strategy(self):
849    """Get the containing `DistributionStrategy`.
850
851    This should not generally be needed except when creating a new
852    `ReplicaContext` and to validate that the caller is in the correct
853    `scope()`.
854
855    Returns:
856      The `DistributionStrategy` such that `strategy.extended` is `self`.
857    """
858    container_strategy = self._container_strategy_weakref()
859    assert container_strategy is not None
860    return container_strategy
861
862  def _scope(self, strategy):
863    """Implementation of DistributionStrategy.scope()."""
864    if distribution_strategy_context.has_strategy():
865      _require_cross_replica_or_default_context_extended(self)
866      return _SameScopeAgainContext(strategy)
867
868    def creator_with_resource_vars(*args, **kwargs):
869      _require_strategy_scope_extended(self)
870      kwargs["use_resource"] = True
871      kwargs["distribute_strategy"] = strategy
872      return self._create_variable(*args, **kwargs)
873
874    def distributed_getter(getter, *args, **kwargs):
875      if not self._allow_variable_partition():
876        if kwargs.pop("partitioner", None) is not None:
877          tf_logging.log_first_n(
878              tf_logging.WARN, "Partitioned variables are disabled when using "
879              "current tf.distribute.Strategy.", 1)
880      return getter(*args, **kwargs)
881
882    return _CurrentDistributionContext(
883        strategy,
884        variable_scope.variable_creator_scope(creator_with_resource_vars),
885        variable_scope.variable_scope(
886            variable_scope.get_variable_scope(),
887            custom_getter=distributed_getter), self._default_device)
888
889  def _allow_variable_partition(self):
890    return False
891
892  def _create_variable(self, next_creator, *args, **kwargs):
893    # Note: should support "colocate_with" argument.
894    raise NotImplementedError("must be implemented in descendants")
895
896  def variable_created_in_scope(self, v):
897    """Tests whether `v` was created while this strategy scope was active.
898
899    Variables created inside the strategy scope are "owned" by it:
900
901    >>> with strategy.scope():
902    ...   v = tf.Variable(1.)
903    >>> strategy.variable_created_in_scope(v)
904    True
905
906    Variables created outside the strategy are not owned by it:
907
908    >>> v = tf.Variable(1.)
909    >>> strategy.variable_created_in_scope(v)
910    False
911
912    Args:
913      v: A `tf.Variable` instance.
914
915    Returns:
916      True if `v` was created inside the scope, False if not.
917    """
918    return v._distribute_strategy == self._container_strategy_weakref()  # pylint: disable=protected-access
919
920  def read_var(self, v):
921    """Reads the value of a variable.
922
923    Returns the aggregate value of a replica-local variable, or the
924    (read-only) value of any other variable.
925
926    Args:
927      v: A variable allocated within the scope of this `tf.distribute.Strategy`.
928
929    Returns:
930      A tensor representing the value of `v`, aggregated across replicas if
931      necessary.
932    """
933    raise NotImplementedError("must be implemented in descendants")
934
935  def colocate_vars_with(self, colocate_with_variable):
936    """Scope that controls which devices variables will be created on.
937
938    No operations should be added to the graph inside this scope, it
939    should only be used when creating variables (some implementations
940    work by changing variable creation, others work by using a
941    tf.colocate_with() scope).
942
943    This may only be used inside `self.scope()`.
944
945    Example usage:
946
947    ```
948    with strategy.scope():
949      var1 = tf.get_variable(...)
950      with strategy.extended.colocate_vars_with(var1):
951        # var2 and var3 will be created on the same device(s) as var1
952        var2 = tf.get_variable(...)
953        var3 = tf.get_variable(...)
954
955      def fn(v1, v2, v3):
956        # operates on v1 from var1, v2 from var2, and v3 from var3
957
958      # `fn` runs on every device `var1` is on, `var2` and `var3` will be there
959      # too.
960      strategy.extended.update(var1, fn, args=(var2, var3))
961    ```
962
963    Args:
964      colocate_with_variable: A variable created in this strategy's `scope()`.
965        Variables created while in the returned context manager will be on the
966        same set of devices as `colocate_with_variable`.
967
968    Returns:
969      A context manager.
970    """
971    def create_colocated_variable(next_creator, *args, **kwargs):
972      _require_strategy_scope_extended(self)
973      kwargs["use_resource"] = True
974      kwargs["colocate_with"] = colocate_with_variable
975      return next_creator(*args, **kwargs)
976
977    _require_strategy_scope_extended(self)
978    self._validate_colocate_with_variable(colocate_with_variable)
979    return variable_scope.variable_creator_scope(create_colocated_variable)
980
981  def _validate_colocate_with_variable(self, colocate_with_variable):
982    """Validate `colocate_with_variable` argument to `colocate_vars_with`."""
983    pass
984
985  def _make_dataset_iterator(self, dataset):
986    raise NotImplementedError("must be implemented in descendants")
987
988  def _make_input_fn_iterator(self, input_fn, replication_mode):
989    raise NotImplementedError("must be implemented in descendants")
990
991  def experimental_make_numpy_dataset(self, numpy_input, session=None):
992    """Makes a dataset for input provided via a numpy array.
993
994    This avoids adding `numpy_input` as a large constant in the graph,
995    and copies the data to the machine or machines that will be processing
996    the input.
997
998    Args:
999      numpy_input: A nest of NumPy input arrays that will be distributed evenly
1000        across all replicas. Note that lists of Numpy arrays are stacked,
1001        as that is normal `tf.data.Dataset` behavior.
1002      session: (TensorFlow v1.x graph execution only) A session used for
1003        initialization.
1004
1005    Returns:
1006      A `tf.data.Dataset` representing `numpy_input`.
1007    """
1008    _require_cross_replica_or_default_context_extended(self)
1009    return self._experimental_make_numpy_dataset(numpy_input, session=session)
1010
1011  def _experimental_make_numpy_dataset(self, numpy_input, session):
1012    raise NotImplementedError("must be implemented in descendants")
1013
1014  def broadcast_to(self, tensor, destinations):
1015    """Mirror a tensor on one device to all worker devices.
1016
1017    Args:
1018      tensor: A Tensor value to broadcast.
1019      destinations: A mirrored variable or device string specifying the
1020        destination devices to copy `tensor` to.
1021
1022    Returns:
1023      A value mirrored to `destinations` devices.
1024    """
1025    assert destinations is not None  # from old strategy.broadcast()
1026    # TODO(josh11b): More docstring
1027    _require_cross_replica_or_default_context_extended(self)
1028    assert not isinstance(destinations, (list, tuple))
1029    return self._broadcast_to(tensor, destinations)
1030
1031  def _broadcast_to(self, tensor, destinations):
1032    raise NotImplementedError("must be implemented in descendants")
1033
1034  def experimental_run_steps_on_iterator(self, fn, iterator, iterations=1,
1035                                         initial_loop_values=None):
1036    """Run `fn` with input from `iterator` for `iterations` times.
1037
1038    This method can be used to run a step function for training a number of
1039    times using input from a dataset.
1040
1041    Args:
1042      fn: function to run using this distribution strategy. The function must
1043        have the following signature: `def fn(context, inputs)`.
1044        `context` is an instance of `MultiStepContext` that will be passed when
1045        `fn` is run. `context` can be used to specify the outputs to be returned
1046        from `fn` by calling `context.set_last_step_output`. It can also be used
1047        to capture non tensor outputs by `context.set_non_tensor_output`.
1048        See `MultiStepContext` documentation for more information.
1049        `inputs` will have same type/structure as `iterator.get_next()`.
1050        Typically, `fn` will use `call_for_each_replica` method of the strategy
1051        to distribute the computation over multiple replicas.
1052      iterator: Iterator of a dataset that represents the input for `fn`. The
1053        caller is responsible for initializing the iterator as needed.
1054      iterations: (Optional) Number of iterations that `fn` should be run.
1055        Defaults to 1.
1056      initial_loop_values: (Optional) Initial values to be passed into the
1057        loop that runs `fn`. Defaults to `None`. # TODO(priyag): Remove
1058        initial_loop_values argument when we have a mechanism to infer the
1059        outputs of `fn`.
1060
1061    Returns:
1062      Returns the `MultiStepContext` object which has the following properties,
1063      among other things:
1064        - run_op: An op that runs `fn` `iterations` times.
1065        - last_step_outputs: A dictionary containing tensors set using
1066        `context.set_last_step_output`. Evaluating this returns the value of
1067        the tensors after the last iteration.
1068        - non_tensor_outputs: A dictionatry containing anything that was set by
1069          `fn` by calling `context.set_non_tensor_output`.
1070    """
1071    _require_cross_replica_or_default_context_extended(self)
1072    with self._container_strategy().scope():
1073      return self._experimental_run_steps_on_iterator(
1074          fn, iterator, iterations, initial_loop_values)
1075
1076  def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
1077                                          initial_loop_values):
1078    raise NotImplementedError("must be implemented in descendants")
1079
1080  def call_for_each_replica(self, fn, args=(), kwargs=None):
1081    """Run `fn` once per replica.
1082
1083    `fn` may call `tf.get_replica_context()` to access methods such as
1084    `replica_id_in_sync_group` and `merge_call()`.
1085
1086    `merge_call()` is used to communicate between the replicas and
1087    re-enter the cross-replica context. All replicas pause their execution
1088    having encountered a `merge_call()` call. After that the
1089    `merge_fn`-function is executed. Its results are then unwrapped and
1090    given back to each replica call. After that execution resumes until
1091    `fn` is complete or encounters another `merge_call()`.  Example:
1092
1093    ```python
1094    # Called once in "cross-replica" context.
1095    def merge_fn(distribution, three_plus_replica_id):
1096      # sum the values across replicas
1097      return sum(distribution.experimental_local_results(three_plus_replica_id))
1098
1099    # Called once per replica in `distribution`, in a "replica" context.
1100    def fn(three):
1101      replica_ctx = tf.get_replica_context()
1102      v = three + replica_ctx.replica_id_in_sync_group
1103      # Computes the sum of the `v` values across all replicas.
1104      s = replica_ctx.merge_call(merge_fn, args=(v,))
1105      return s + v
1106
1107    with distribution.scope():
1108      # in "cross-replica" context
1109      ...
1110      merged_results = distribution.call_for_each_replica(fn, args=[3])
1111      # merged_results has the values from every replica execution of `fn`.
1112      # This statement prints a list:
1113      print(distribution.experimental_local_results(merged_results))
1114    ```
1115
1116    Args:
1117      fn: function to run (will be run once per replica).
1118      args: Tuple or list with positional arguments for `fn`.
1119      kwargs: Dict with keyword arguments for `fn`.
1120
1121    Returns:
1122      Merged return value of `fn` across all replicas.
1123    """
1124    _require_cross_replica_or_default_context_extended(self)
1125    if kwargs is None:
1126      kwargs = {}
1127    with self._container_strategy().scope():
1128      return self._call_for_each_replica(fn, args, kwargs)
1129
1130  def _call_for_each_replica(self, fn, args, kwargs):
1131    raise NotImplementedError("must be implemented in descendants")
1132
1133  def _reduce(self, reduce_op, value):
1134    # Default implementation until we have an implementation for each strategy.
1135    return self._local_results(
1136        self._reduce_to(reduce_op, value,
1137                        device_util.current() or "/device:CPU:0"))[0]
1138
1139  def reduce_to(self, reduce_op, value, destinations):
1140    """Combine (via e.g. sum or mean) values across replicas.
1141
1142    Args:
1143      reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum.
1144      value: A per-replica value with one value per replica.
1145      destinations: A mirrored variable, a per-replica tensor, or a device
1146        string. The return value will be copied to all destination devices (or
1147        all the devices where the `destinations` value resides). To perform an
1148        all-reduction, pass `value` to `destinations`.
1149
1150    Returns:
1151      A value mirrored to `destinations`.
1152    """
1153    # TODO(josh11b): More docstring
1154    _require_cross_replica_or_default_context_extended(self)
1155    assert not isinstance(destinations, (list, tuple))
1156    assert not isinstance(reduce_op, variable_scope.VariableAggregation)
1157    assert (reduce_op == reduce_util.ReduceOp.SUM or
1158            reduce_op == reduce_util.ReduceOp.MEAN)
1159    return self._reduce_to(reduce_op, value, destinations)
1160
1161  def _reduce_to(self, reduce_op, value, destinations):
1162    raise NotImplementedError("must be implemented in descendants")
1163
1164  def batch_reduce_to(self, reduce_op, value_destination_pairs):
1165    """Combine multiple `reduce_to` calls into one for faster execution.
1166
1167    Args:
1168      reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum.
1169      value_destination_pairs: A sequence of (value, destinations)
1170        pairs. See `reduce_to()` for a description.
1171
1172    Returns:
1173      A list of mirrored values, one per pair in `value_destination_pairs`.
1174    """
1175    # TODO(josh11b): More docstring
1176    _require_cross_replica_or_default_context_extended(self)
1177    assert not isinstance(reduce_op, variable_scope.VariableAggregation)
1178    return self._batch_reduce_to(reduce_op, value_destination_pairs)
1179
1180  def _batch_reduce_to(self, reduce_op, value_destination_pairs):
1181    return [
1182        self.reduce_to(reduce_op, t, destinations=v)
1183        for t, v in value_destination_pairs
1184    ]
1185
1186  def update(self, var, fn, args=(), kwargs=None, group=True):
1187    """Run `fn` to update `var` using inputs mirrored to the same devices.
1188
1189    If `var` is mirrored across multiple devices, then this implements
1190    logic like:
1191
1192    ```
1193    results = {}
1194    for device, v in var:
1195      with tf.device(device):
1196        # args and kwargs will be unwrapped if they are mirrored.
1197        results[device] = fn(v, *args, **kwargs)
1198    return merged(results)
1199    ```
1200
1201    Otherwise this returns `fn(var, *args, **kwargs)` colocated with `var`.
1202
1203    Neither `args` nor `kwargs` may contain per-replica values.
1204    If they contain mirrored values, they will be unwrapped before
1205    calling `fn`.
1206
1207    Args:
1208      var: Variable, possibly mirrored to multiple devices, to operate on.
1209      fn: Function to call. Should take the variable as the first argument.
1210      args: Tuple or list. Additional positional arguments to pass to `fn()`.
1211      kwargs: Dict with keyword arguments to pass to `fn()`.
1212      group: Boolean. Defaults to True. If False, the return value will be
1213        unwrapped.
1214
1215    Returns:
1216      By default, the merged return value of `fn` across all replicas.  The
1217      merged result has dependencies to make sure that if it is evaluated at
1218      all, the side effects (updates) will happen on every replica. If instead
1219      "group=False" is specified, this function will return a nest of lists
1220      where each list has an element per replica, and the caller is responsible
1221      for ensuring all elements are executed.
1222    """
1223    _require_cross_replica_or_default_context_extended(self)
1224    if kwargs is None:
1225      kwargs = {}
1226    with self._container_strategy().scope():
1227      return self._update(var, fn, args, kwargs, group)
1228
1229  def _update(self, var, fn, args, kwargs, group):
1230    raise NotImplementedError("must be implemented in descendants")
1231
1232  def update_non_slot(
1233      self, colocate_with, fn, args=(), kwargs=None, group=True):
1234    """Runs `fn(*args, **kwargs)` on `colocate_with` devices.
1235
1236    Args:
1237      colocate_with: The return value of `non_slot_devices()`.
1238      fn: Function to execute.
1239      args: Tuple or list. Positional arguments to pass to `fn()`.
1240      kwargs: Dict with keyword arguments to pass to `fn()`.
1241      group: Boolean. Defaults to True. If False, the return value will be
1242        unwrapped.
1243
1244    Returns:
1245      Return value of `fn`, possibly merged across devices.
1246    """
1247    _require_cross_replica_or_default_context_extended(self)
1248    if kwargs is None:
1249      kwargs = {}
1250    with self._container_strategy().scope():
1251      return self._update_non_slot(colocate_with, fn, args, kwargs, group)
1252
1253  def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
1254    raise NotImplementedError("must be implemented in descendants")
1255
1256  def _local_results(self, distributed_value):
1257    raise NotImplementedError("must be implemented in descendants")
1258
1259  def value_container(self, value):
1260    """Returns the container that this per-replica `value` belongs to.
1261
1262    Args:
1263      value: A value returned by `call_for_each_replica()` or a variable
1264        created in `scope()`.
1265
1266    Returns:
1267      A container that `value` belongs to.
1268      If value does not belong to any container (including the case of
1269      container having been destroyed), returns the value itself.
1270      `value in experimental_local_results(value_container(value))` will
1271      always be true.
1272    """
1273    raise NotImplementedError("must be implemented in descendants")
1274
1275  def _group(self, value, name=None):
1276    """Implementation of `group`."""
1277    value = nest.flatten(self._local_results(value))
1278
1279    if len(value) != 1 or name is not None:
1280      return control_flow_ops.group(value, name=name)
1281    # Special handling for the common case of one op.
1282    v, = value
1283    if hasattr(v, "op"):
1284      v = v.op
1285    return v
1286
1287  @property
1288  def experimental_require_static_shapes(self):
1289    return self._require_static_shapes
1290
1291  @property
1292  def _num_replicas_in_sync(self):
1293    """Returns number of replicas over which gradients are aggregated."""
1294    raise NotImplementedError("must be implemented in descendants")
1295
1296  @property
1297  def worker_devices(self):
1298    """Returns the tuple of all devices used to for compute replica execution.
1299    """
1300    # TODO(josh11b): More docstring
1301    raise NotImplementedError("must be implemented in descendants")
1302
1303  @property
1304  def parameter_devices(self):
1305    """Returns the tuple of all devices used to place variables."""
1306    # TODO(josh11b): More docstring
1307    raise NotImplementedError("must be implemented in descendants")
1308
1309  def non_slot_devices(self, var_list):
1310    """Device(s) for non-slot variables.
1311
1312    Create variables on these devices in a
1313    `with colocate_vars_with(non_slot_devices(...)):` block.
1314    Update those using `update_non_slot()`.
1315
1316    Args:
1317      var_list: The list of variables being optimized, needed with the
1318        default `tf.distribute.Strategy`.
1319    """
1320    raise NotImplementedError("must be implemented in descendants")
1321
1322  @property
1323  def experimental_between_graph(self):
1324    """Whether the strategy uses between-graph replication or not.
1325
1326      This is expected to return a constant value that will not be changed
1327      throughout its life cycle.
1328    """
1329    raise NotImplementedError("must be implemented in descendants")
1330
1331  def _configure(self,
1332                 session_config=None,
1333                 cluster_spec=None,
1334                 task_type=None,
1335                 task_id=None):
1336    """Configures the strategy class."""
1337    del session_config, cluster_spec, task_type, task_id
1338
1339  def _update_config_proto(self, config_proto):
1340    return copy.deepcopy(config_proto)
1341
1342  @property
1343  def experimental_should_init(self):
1344    """Whether initialization is needed."""
1345    raise NotImplementedError("must be implemented in descendants")
1346
1347  @property
1348  def should_checkpoint(self):
1349    """Whether checkpointing is needed."""
1350    raise NotImplementedError("must be implemented in descendants")
1351
1352  @property
1353  def should_save_summary(self):
1354    """Whether saving summaries is needed."""
1355    raise NotImplementedError("must be implemented in descendants")
1356
1357
1358# A note about the difference between the context managers
1359# `ReplicaContext` (defined here) and `_CurrentDistributionContext`
1360# (defined above) used by `DistributionStrategy.scope()`:
1361#
1362# * a ReplicaContext is only present during a `call_for_each_replica()`
1363#   call (except during a `merge_run` call) and in such a scope it
1364#   will be returned by calls to `get_replica_context()`.  Implementers of new
1365#   DistributionStrategy descendants will frequently also need to
1366#   define a descendant of ReplicaContext, and are responsible for
1367#   entering and exiting this context.
1368#
1369# * DistributionStrategy.scope() sets up a variable_creator scope that
1370#   changes variable creation calls (e.g. to make mirrored
1371#   variables). This is intended as an outer scope that users enter once
1372#   around their model creation and graph definition. There is no
1373#   anticipated need to define descendants of _CurrentDistributionContext.
1374#   It sets the current DistributionStrategy for purposes of
1375#   `get_strategy()` and `has_strategy()`
1376#   and switches the thread mode to a "cross-replica context".
1377@tf_export("distribute.ReplicaContext")
1378class ReplicaContext(object):
1379  """`tf.distribute.Strategy` API when in a replica context.
1380
1381  To be used inside your replicated step function, such as in a
1382  `tf.distribute.StrategyExtended.call_for_each_replica` call.
1383  """
1384
1385  def __init__(self, strategy, replica_id_in_sync_group):
1386    self._strategy = strategy
1387    self._thread_context = distribution_strategy_context._InReplicaThreadMode(  # pylint: disable=protected-access
1388        self)
1389    self._replica_id_in_sync_group = replica_id_in_sync_group
1390    self._summary_recording_distribution_strategy = None
1391
1392  def __enter__(self):
1393    _push_per_thread_mode(self._thread_context)
1394    ctx = eager_context.context()
1395
1396    def replica_id_is_zero():
1397      return math_ops.equal(self._replica_id_in_sync_group,
1398                            constant_op.constant(0))
1399
1400    self._summary_recording_distribution_strategy = (
1401        ctx.summary_recording_distribution_strategy)
1402    ctx.summary_recording_distribution_strategy = replica_id_is_zero
1403
1404  def __exit__(self, exception_type, exception_value, traceback):
1405    ctx = eager_context.context()
1406    ctx.summary_recording_distribution_strategy = (
1407        self._summary_recording_distribution_strategy)
1408    _pop_per_thread_mode()
1409
1410  def merge_call(self, merge_fn, args=(), kwargs=None):
1411    """Merge args across replicas and run `merge_fn` in a cross-replica context.
1412
1413    This allows communication and coordination when there are multiple calls
1414    to a model function triggered by a call to
1415    `strategy.extended.call_for_each_replica(model_fn, ...)`.
1416
1417    See `tf.distribute.StrategyExtended.call_for_each_replica` for an
1418    explanation.
1419
1420    If not inside a distributed scope, this is equivalent to:
1421
1422    ```
1423    strategy = tf.distribute.get_strategy()
1424    with cross-replica-context(strategy):
1425      return merge_fn(strategy, *args, **kwargs)
1426    ```
1427
1428    Args:
1429      merge_fn: function that joins arguments from threads that are given as
1430        PerReplica. It accepts `tf.distribute.Strategy` object as
1431        the first argument.
1432      args: List or tuple with positional per-thread arguments for `merge_fn`.
1433      kwargs: Dict with keyword per-thread arguments for `merge_fn`.
1434
1435    Returns:
1436      The return value of `merge_fn`, except for `PerReplica` values which are
1437      unpacked.
1438    """
1439    require_replica_context(self)
1440    if kwargs is None:
1441      kwargs = {}
1442    return self._merge_call(merge_fn, args, kwargs)
1443
1444  def _merge_call(self, merge_fn, args, kwargs):
1445    """Default implementation for single replica."""
1446    _push_per_thread_mode(  # thread-local, so not needed with multiple threads
1447        distribution_strategy_context._CrossReplicaThreadMode(self._strategy))  # pylint: disable=protected-access
1448    try:
1449      return merge_fn(self._strategy, *args, **kwargs)
1450    finally:
1451      _pop_per_thread_mode()
1452
1453  @property
1454  def num_replicas_in_sync(self):
1455    """Returns number of replicas over which gradients are aggregated."""
1456    return self._strategy.num_replicas_in_sync
1457
1458  @property
1459  def replica_id_in_sync_group(self):
1460    """Which replica is being defined, from 0 to `num_replicas_in_sync - 1`."""
1461    require_replica_context(self)
1462    return self._replica_id_in_sync_group
1463
1464  @property
1465  def strategy(self):
1466    """The current `tf.distribute.Strategy` object."""
1467    return self._strategy
1468
1469  @property
1470  def devices(self):
1471    """The devices this replica is to be executed on, as a tuple of strings."""
1472    require_replica_context(self)
1473    return (device_util.current(),)
1474
1475  def all_reduce(self, reduce_op, value):
1476    """All-reduces the given `Tensor` nest across replicas.
1477
1478    If `all_reduce` is called in any replica, it must be called in all replicas.
1479    The nested structure and `Tensor` shapes must be identical in all replicas.
1480
1481    IMPORTANT: The ordering of communications must be identical in all replicas.
1482
1483    Example with two replicas:
1484      Replica 0 `value`: {'a': 1, 'b': [40,  1]}
1485      Replica 1 `value`: {'a': 3, 'b': [ 2, 98]}
1486
1487      If `reduce_op` == `SUM`:
1488        Result (on all replicas): {'a': 4, 'b': [42, 99]}
1489
1490      If `reduce_op` == `MEAN`:
1491        Result (on all replicas): {'a': 2, 'b': [21, 49.5]}
1492
1493    Args:
1494      reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum.
1495      value: The nested structure of `Tensor`s to all-reduced.
1496        The structure must be compatible with `tf.nest`.
1497
1498    Returns:
1499       A `Tensor` nest with the reduced `value`s from each replica.
1500    """
1501    def batch_all_reduce(strategy, *value_flat):
1502      return strategy.extended.batch_reduce_to(
1503          reduce_op, [(v, _batch_reduce_destination(v)) for v in value_flat])
1504
1505    if reduce_op in [reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN]:
1506      # TODO(cjfj): Work out why `batch_reduce` doesn't return the correct grad.
1507      @custom_gradient.custom_gradient
1508      def grad_wrapper(*xs):
1509        ys = self.merge_call(batch_all_reduce, args=xs)
1510        # The gradient of an all-sum is itself an all-sum (all-mean, likewise).
1511        return ys, lambda *dy_s: self.all_reduce(reduce_op, dy_s)
1512      return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value)))
1513    else:
1514      # TODO(cjfj): Implement gradients for other reductions.
1515      reduced = nest.pack_sequence_as(
1516          value, self.merge_call(batch_all_reduce, args=nest.flatten(value)))
1517      return nest.map_structure(array_ops.prevent_gradient, reduced)
1518
1519  # TODO(josh11b): Implement `start_all_reduce(method, t)` for efficient
1520  # all-reduce. It would return a function returning the result of reducing `t`
1521  # across all replicas. The caller would wait to call this function until they
1522  # needed the reduce result, allowing an efficient implementation:
1523  # * With eager execution, the reduction could be performed asynchronously
1524  #   in the background, not blocking until the result was needed.
1525  # * When constructing a graph, it could batch up all reduction requests up
1526  #   to that point that the first result is needed. Most likely this can be
1527  #   implemented in terms of `merge_call()` and `batch_reduce_to()`.
1528
1529
1530def _batch_reduce_destination(x):
1531  """Returns the destinations for batch all-reduce."""
1532  if isinstance(x, ops.Tensor):  # One device strategies.
1533    return x.device
1534  else:
1535    return x
1536
1537
1538# ------------------------------------------------------------------------------
1539
1540
1541class _DefaultDistributionStrategy(DistributionStrategy):
1542  """Default `tf.distribute.Strategy` if none is explicitly selected."""
1543
1544  def __init__(self):
1545    super(_DefaultDistributionStrategy, self).__init__(
1546        _DefaultDistributionExtended(self))
1547
1548
1549class _DefaultDistributionExtended(DistributionStrategyExtended):
1550  """Implementation of _DefaultDistributionStrategy."""
1551
1552  def _scope(self, strategy):
1553    """Context manager setting a variable creator and `self` as current."""
1554    if distribution_strategy_context.has_strategy():
1555      raise RuntimeError("Must not nest tf.distribute.Strategy scopes.")
1556
1557    def creator(next_creator, *args, **kwargs):
1558      _require_strategy_scope_strategy(strategy)
1559      return next_creator(*args, **kwargs)
1560
1561    return _CurrentDistributionContext(
1562        strategy, variable_scope.variable_creator_scope(creator))
1563
1564  def colocate_vars_with(self, colocate_with_variable):
1565    """Does not require `self.scope`."""
1566    _require_strategy_scope_extended(self)
1567    return ops.colocate_with(colocate_with_variable)
1568
1569  def variable_created_in_scope(self, v):
1570    return v._distribute_strategy is None  # pylint: disable=protected-access
1571
1572  def _make_dataset_iterator(self, dataset):
1573    return _DefaultDistributionExtended.DefaultInputIterator(dataset)
1574
1575  def _make_input_fn_iterator(self,
1576                              input_fn,
1577                              replication_mode=InputReplicationMode.PER_WORKER):
1578    dataset = input_fn(InputContext())
1579    return _DefaultDistributionExtended.DefaultInputIterator(dataset)
1580
1581  def _experimental_make_numpy_dataset(self, numpy_input, session):
1582    numpy_flat = nest.flatten(numpy_input)
1583    vars_flat = tuple(
1584        variable_scope.variable(array_ops.zeros(i.shape, i.dtype),
1585                                trainable=False, use_resource=True)
1586        for i in numpy_flat
1587    )
1588    for v, i in zip(vars_flat, numpy_flat):
1589      numpy_dataset.init_var_from_numpy(v, i, session)
1590    vars_nested = nest.pack_sequence_as(numpy_input, vars_flat)
1591    return dataset_ops.Dataset.from_tensor_slices(vars_nested)
1592
1593  def _broadcast_to(self, tensor, destinations):
1594    if destinations is None:
1595      return tensor
1596    else:
1597      raise NotImplementedError("TODO")
1598
1599  def _call_for_each_replica(self, fn, args, kwargs):
1600    with ReplicaContext(
1601        self._container_strategy(),
1602        replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)):
1603      return fn(*args, **kwargs)
1604
1605  def _reduce_to(self, reduce_op, value, destinations):
1606    # TODO(josh11b): Use destinations?
1607    del reduce_op, destinations
1608    return value
1609
1610  def _update(self, var, fn, args, kwargs, group):
1611    # The implementations of _update() and _update_non_slot() are identical
1612    # except _update() passes `var` as the first argument to `fn()`.
1613    return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group)
1614
1615  def _update_non_slot(self, colocate_with, fn, args, kwargs, should_group):
1616    # TODO(josh11b): Figure out what we should be passing to UpdateContext()
1617    # once that value is used for something.
1618    with ops.colocate_with(colocate_with), UpdateContext(colocate_with):
1619      result = fn(*args, **kwargs)
1620      if should_group:
1621        return result
1622      else:
1623        return nest.map_structure(self._local_results, result)
1624
1625  def read_var(self, replica_local_var):
1626    return array_ops.identity(replica_local_var)
1627
1628  def _local_results(self, distributed_value):
1629    return (distributed_value,)
1630
1631  def value_container(self, value):
1632    return value
1633
1634  @property
1635  def _num_replicas_in_sync(self):
1636    return 1
1637
1638  @property
1639  def worker_devices(self):
1640    raise RuntimeError("worker_devices() method unsupported by default "
1641                       "tf.distribute.Strategy.")
1642
1643  @property
1644  def parameter_devices(self):
1645    raise RuntimeError("parameter_devices() method unsupported by default "
1646                       "tf.distribute.Strategy.")
1647
1648  def non_slot_devices(self, var_list):
1649    return min(var_list, key=lambda x: x.name)
1650
1651  # TODO(priyag): This should inherit from `InputIterator`, once dependency
1652  # issues have been resolved.
1653  class DefaultInputIterator(object):
1654    """Default implementation of `InputIterator` for default strategy."""
1655
1656    def __init__(self, dataset):
1657      self._dataset = dataset
1658      if eager_context.executing_eagerly():
1659        self._iterator = dataset.make_one_shot_iterator()
1660      else:
1661        self._iterator = dataset.make_initializable_iterator()
1662
1663    def get_next(self):
1664      return self._iterator.get_next()
1665
1666    def initialize(self):
1667      if eager_context.executing_eagerly():
1668        self._iterator = self._dataset.make_one_shot_iterator()
1669        return []
1670      else:
1671        return [self._iterator.initializer]
1672
1673  # TODO(priyag): Delete this once all strategies use global batch size.
1674  @property
1675  def _global_batch_size(self):
1676    """Global and per-replica batching are equivalent for this strategy."""
1677    return True
1678
1679
1680# ------------------------------------------------------------------------------
1681# We haven't yet implemented deserialization for DistributedVariables.
1682# So here we catch any attempts to deserialize variables
1683# when using distribution strategies.
1684# pylint: disable=protected-access
1685_original_from_proto = resource_variable_ops._from_proto_fn
1686
1687
1688def _from_proto_fn(v, import_scope=None):
1689  if distribution_strategy_context.has_strategy():
1690    raise NotImplementedError(
1691        "Deserialization of variables is not yet supported when using a "
1692        "tf.distribute.Strategy.")
1693  else:
1694    return _original_from_proto(v, import_scope=import_scope)
1695
1696resource_variable_ops._from_proto_fn = _from_proto_fn
1697# pylint: enable=protected-access
1698
1699
1700#-------------------------------------------------------------------------------
1701# Shorthand for some methods from distribution_strategy_context.
1702_push_per_thread_mode = distribution_strategy_context._push_per_thread_mode  # pylint: disable=protected-access
1703_get_per_thread_mode = distribution_strategy_context._get_per_thread_mode  # pylint: disable=protected-access
1704_pop_per_thread_mode = distribution_strategy_context._pop_per_thread_mode  # pylint: disable=protected-access
1705_get_default_replica_mode = (
1706    distribution_strategy_context._get_default_replica_mode)  # pylint: disable=protected-access
1707