• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TPU Strategy."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import atexit
22import collections
23import contextlib
24import copy
25import functools
26import weakref
27
28from absl import logging
29import numpy as np
30
31from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
32from tensorflow.python.autograph.core import ag_ctx as autograph_ctx
33from tensorflow.python.autograph.impl import api as autograph
34from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
35from tensorflow.python.distribute import device_util
36from tensorflow.python.distribute import distribute_lib
37from tensorflow.python.distribute import distribute_utils
38from tensorflow.python.distribute import input_lib
39from tensorflow.python.distribute import numpy_dataset
40from tensorflow.python.distribute import reduce_util
41from tensorflow.python.distribute import tpu_util
42from tensorflow.python.distribute import tpu_values
43from tensorflow.python.distribute import values
44from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver
45from tensorflow.python.eager import context
46from tensorflow.python.eager import def_function
47from tensorflow.python.eager import function
48from tensorflow.python.framework import constant_op
49from tensorflow.python.framework import device_spec
50from tensorflow.python.framework import dtypes
51from tensorflow.python.framework import ops
52from tensorflow.python.framework import sparse_tensor
53from tensorflow.python.framework import tensor_shape
54from tensorflow.python.framework import tensor_util
55from tensorflow.python.ops import array_ops
56from tensorflow.python.ops import control_flow_ops
57from tensorflow.python.ops import math_ops
58from tensorflow.python.ops import resource_variable_ops
59from tensorflow.python.ops import variables as variables_lib
60from tensorflow.python.ops.ragged import ragged_tensor
61from tensorflow.python.tpu import device_assignment as device_assignment_lib  # pylint: disable=unused-import
62from tensorflow.python.tpu import tpu
63from tensorflow.python.tpu import tpu_strategy_util
64from tensorflow.python.tpu import training_loop
65from tensorflow.python.tpu.ops import tpu_ops
66from tensorflow.python.util import deprecation
67from tensorflow.python.util import nest
68from tensorflow.python.util import tf_inspect
69from tensorflow.python.util.tf_export import tf_export
70
71_XLA_OP_BY_OP_INPUTS_LIMIT = 200
72
73
74@contextlib.contextmanager
75def maybe_init_scope():
76  if ops.executing_eagerly_outside_functions():
77    yield
78  else:
79    with ops.init_scope():
80      yield
81
82
83def validate_run_function(fn):
84  """Validate the function passed into strategy.run."""
85
86  # We allow three types of functions/objects passed into TPUStrategy
87  # run in eager mode:
88  #   1. a user annotated tf.function
89  #   2. a ConcreteFunction, this is mostly what you get from loading a saved
90  #      model.
91  #   3. a callable object and the `__call__` method itself is a tf.function.
92  #
93  # Otherwise we return an error, because we don't support eagerly running
94  # run in TPUStrategy.
95
96  if context.executing_eagerly() \
97      and not isinstance(fn, def_function.Function) \
98      and not isinstance(fn, function.ConcreteFunction) \
99      and not (callable(fn) and isinstance(fn.__call__, def_function.Function)):
100    raise NotImplementedError(
101        "TPUStrategy.run(fn, ...) does not support pure eager "
102        "execution. please make sure the function passed into "
103        "`strategy.run` is a `tf.function` or "
104        "`strategy.run` is called inside a `tf.function` if "
105        "eager behavior is enabled.")
106
107
108def _maybe_partial_apply_variables(fn, args, kwargs):
109  """Inspects arguments to partially apply any DistributedVariable.
110
111  This avoids an automatic cast of the current variable value to tensor.
112
113  Note that a variable may be captured implicitly with Python scope instead of
114  passing it to run(), but supporting run() keeps behavior consistent
115  with MirroredStrategy.
116
117  Since positional arguments must be applied from left to right, this function
118  does some tricky function inspection to move variable positional arguments
119  into kwargs. As a result of this, we can't support passing Variables as *args,
120  nor as args to functions which combine both explicit positional arguments and
121  *args.
122
123  Args:
124    fn: The function to run, as passed to run().
125    args: Positional arguments to fn, as passed to run().
126    kwargs: Keyword arguments to fn, as passed to run().
127
128  Returns:
129    A tuple of the function (possibly wrapped), args, kwargs (both
130    possibly filtered, with members of args possibly moved to kwargs).
131    If no variables are found, this function is a noop.
132
133  Raises:
134    ValueError: If the function signature makes unsupported use of *args, or if
135      too many arguments are passed.
136  """
137
138  def is_distributed_var(x):
139    flat = nest.flatten(x)
140    return flat and isinstance(flat[0], values.DistributedVariable)
141
142  # We will split kwargs into two dicts, one of which will be applied now.
143  var_kwargs = {}
144  nonvar_kwargs = {}
145
146  if kwargs:
147    var_kwargs = {k: v for k, v in kwargs.items() if is_distributed_var(v)}
148  if var_kwargs:
149    nonvar_kwargs = {
150        k: v for k, v in kwargs.items() if not is_distributed_var(v)
151    }
152
153  # Dump the argument names of `fn` to a list. This will include both positional
154  # and keyword arguments, but since positional arguments come first we can
155  # look up names of positional arguments by index.
156  positional_args = []
157  index_of_star_args = None
158  for i, p in enumerate(tf_inspect.signature(fn).parameters.values()):
159    # Class methods define "self" as first argument, but we don't pass "self".
160    # Note that this is a heuristic, as a method can name its first argument
161    # something else, and a function can define a first argument "self" as well.
162    # In both of these cases, using a Variable will fail with an unfortunate
163    # error about the number of arguments.
164    # inspect.is_method() seems not to work here, possibly due to the use of
165    # tf.function().
166    if i == 0 and p.name == "self":
167      continue
168
169    if p.kind == tf_inspect.Parameter.POSITIONAL_OR_KEYWORD:
170      positional_args.append(p.name)
171
172    elif p.kind == tf_inspect.Parameter.VAR_POSITIONAL:
173      # We'll raise an error later if a variable is passed to *args, since we
174      # can neither pass it by name nor partially apply it. This case only
175      # happens once at most.
176      index_of_star_args = i
177
178    elif p.kind == tf_inspect.Parameter.POSITIONAL_ONLY:
179      # This is a rare Python feature, indicating a / in the arg list.
180      if var_kwargs or any(is_distributed_var(a) for a in args):
181        raise ValueError(
182            "Mixing Variables and positional-only parameters not supported by "
183            "TPUStrategy.")
184      return fn, args, kwargs
185
186  star_args = []
187  have_seen_var_arg = False
188
189  for i, a in enumerate(args):
190    if is_distributed_var(a):
191      if index_of_star_args is not None and i >= index_of_star_args:
192        raise ValueError(
193            "TPUStrategy.run() cannot handle Variables passed to *args. "
194            "Either name the function argument, or capture the Variable "
195            "implicitly.")
196      if len(positional_args) <= i:
197        raise ValueError(
198            "Too many positional arguments passed to call to TPUStrategy.run()."
199        )
200      var_kwargs[positional_args[i]] = a
201      have_seen_var_arg = True
202    else:
203      if index_of_star_args is not None and i >= index_of_star_args:
204        if have_seen_var_arg:
205          raise ValueError(
206              "TPUStrategy.run() cannot handle both Variables and a mix of "
207              "positional args and *args. Either remove the *args, or capture "
208              "the Variable implicitly.")
209        else:
210          star_args.append(a)
211          continue
212
213      if len(positional_args) <= i:
214        raise ValueError(
215            "Too many positional arguments passed to call to TPUStrategy.run()."
216        )
217      nonvar_kwargs[positional_args[i]] = a
218
219  if var_kwargs:
220    return functools.partial(fn, **var_kwargs), star_args, nonvar_kwargs
221  return fn, args, kwargs
222
223
224@tf_export("distribute.TPUStrategy", v1=[])
225class TPUStrategyV2(distribute_lib.Strategy):
226  """Synchronous training on TPUs and TPU Pods.
227
228  To construct a TPUStrategy object, you need to run the
229  initialization code as below:
230
231  >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
232  >>> tf.config.experimental_connect_to_cluster(resolver)
233  >>> tf.tpu.experimental.initialize_tpu_system(resolver)
234  >>> strategy = tf.distribute.TPUStrategy(resolver)
235
236  While using distribution strategies, the variables created within the
237  strategy's scope will be replicated across all the replicas and can be kept in
238  sync using all-reduce algorithms.
239
240  To run TF2 programs on TPUs, you can either use `.compile` and
241  `.fit` APIs in `tf.keras` with TPUStrategy, or write your own customized
242  training loop by calling `strategy.run` directly. Note that
243  TPUStrategy doesn't support pure eager execution, so please make sure the
244  function passed into `strategy.run` is a `tf.function` or
245  `strategy.run` is called inside a `tf.function` if eager
246  behavior is enabled. See more details in https://www.tensorflow.org/guide/tpu.
247
248  `distribute_datasets_from_function` and
249  `experimental_distribute_dataset` APIs can be used to distribute the dataset
250  across the TPU workers when writing your own training loop. If you are using
251  `fit` and `compile` methods available in `tf.keras.Model`, then Keras will
252  handle the distribution for you.
253
254  An example of writing customized training loop on TPUs:
255
256  >>> with strategy.scope():
257  ...   model = tf.keras.Sequential([
258  ...     tf.keras.layers.Dense(2, input_shape=(5,)),
259  ...   ])
260  ...   optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)
261
262  >>> def dataset_fn(ctx):
263  ...   x = np.random.random((2, 5)).astype(np.float32)
264  ...   y = np.random.randint(2, size=(2, 1))
265  ...   dataset = tf.data.Dataset.from_tensor_slices((x, y))
266  ...   return dataset.repeat().batch(1, drop_remainder=True)
267  >>> dist_dataset = strategy.distribute_datasets_from_function(
268  ...     dataset_fn)
269  >>> iterator = iter(dist_dataset)
270
271  >>> @tf.function()
272  ... def train_step(iterator):
273  ...
274  ...   def step_fn(inputs):
275  ...     features, labels = inputs
276  ...     with tf.GradientTape() as tape:
277  ...       logits = model(features, training=True)
278  ...       loss = tf.keras.losses.sparse_categorical_crossentropy(
279  ...           labels, logits)
280  ...
281  ...     grads = tape.gradient(loss, model.trainable_variables)
282  ...     optimizer.apply_gradients(zip(grads, model.trainable_variables))
283  ...
284  ...   strategy.run(step_fn, args=(next(iterator),))
285
286  >>> train_step(iterator)
287
288  For the advanced use cases like model parallelism, you can set
289  `experimental_device_assignment` argument when creating TPUStrategy to specify
290  number of replicas and number of logical devices. Below is an example to
291  initialize TPU system with 2 logical devices and 1 replica.
292
293  >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
294  >>> tf.config.experimental_connect_to_cluster(resolver)
295  >>> topology = tf.tpu.experimental.initialize_tpu_system(resolver)
296  >>> device_assignment = tf.tpu.experimental.DeviceAssignment.build(
297  ...     topology,
298  ...     computation_shape=[1, 1, 1, 2],
299  ...     num_replicas=1)
300  >>> strategy = tf.distribute.TPUStrategy(
301  ...     resolver, experimental_device_assignment=device_assignment)
302
303  Then you can run a `tf.add` operation only on logical device 0.
304
305  >>> @tf.function()
306  ... def step_fn(inputs):
307  ...   features, _ = inputs
308  ...   output = tf.add(features, features)
309  ...
310  ...   # Add operation will be executed on logical device 0.
311  ...   output = strategy.experimental_assign_to_logical_device(output, 0)
312  ...   return output
313  >>> dist_dataset = strategy.distribute_datasets_from_function(
314  ...     dataset_fn)
315  >>> iterator = iter(dist_dataset)
316  >>> strategy.run(step_fn, args=(next(iterator),))
317  """
318
319  def __init__(self,
320               tpu_cluster_resolver=None,
321               experimental_device_assignment=None):
322    """Synchronous training in TPU donuts or Pods.
323
324    Args:
325      tpu_cluster_resolver: A
326        `tf.distribute.cluster_resolver.TPUClusterResolver` instance, which
327        provides information about the TPU cluster. If None, it will assume
328        running on a local TPU worker.
329      experimental_device_assignment: Optional
330        `tf.tpu.experimental.DeviceAssignment` to specify the placement of
331        replicas on the TPU cluster.
332    """
333    super(TPUStrategyV2, self).__init__(TPUExtended(
334        self, tpu_cluster_resolver,
335        device_assignment=experimental_device_assignment))
336    distribute_lib.distribution_strategy_gauge.get_cell("V2").set("TPUStrategy")
337    distribute_lib.distribution_strategy_replica_gauge.get_cell(
338        "num_workers").set(self.extended.num_hosts)
339    distribute_lib.distribution_strategy_replica_gauge.get_cell(
340        "num_replicas_per_worker").set(self.extended.num_replicas_per_host)
341    # Packed variable is used to reduce the overhead of function execution.
342    # For a DistributedVariable, only one variable handle is captured into a
343    # function graph. It's only supported in eager mode.
344    self._enable_packed_variable_in_eager_mode = True
345
346  def run(self, fn, args=(), kwargs=None, options=None):
347    """Run the computation defined by `fn` on each TPU replica.
348
349    Executes ops specified by `fn` on each replica. If `args` or `kwargs` have
350    `tf.distribute.DistributedValues`, such as those produced by a
351    `tf.distribute.DistributedDataset` from
352    `tf.distribute.Strategy.experimental_distribute_dataset` or
353    `tf.distribute.Strategy.distribute_datasets_from_function`,
354    when `fn` is executed on a particular replica, it will be executed with the
355    component of `tf.distribute.DistributedValues` that correspond to that
356    replica.
357
358    `fn` may call `tf.distribute.get_replica_context()` to access members such
359    as `all_reduce`.
360
361    All arguments in `args` or `kwargs` should either be nest of tensors or
362    `tf.distribute.DistributedValues` containing tensors or composite tensors.
363
364    Example usage:
365
366    >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
367    >>> tf.config.experimental_connect_to_cluster(resolver)
368    >>> tf.tpu.experimental.initialize_tpu_system(resolver)
369    >>> strategy = tf.distribute.TPUStrategy(resolver)
370    >>> @tf.function
371    ... def run():
372    ...   def value_fn(value_context):
373    ...     return value_context.num_replicas_in_sync
374    ...   distributed_values = (
375    ...       strategy.experimental_distribute_values_from_function(value_fn))
376    ...   def replica_fn(input):
377    ...     return input * 2
378    ...   return strategy.run(replica_fn, args=(distributed_values,))
379    >>> result = run()
380
381    Args:
382      fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
383      args: (Optional) Positional arguments to `fn`.
384      kwargs: (Optional) Keyword arguments to `fn`.
385      options: (Optional) An instance of `tf.distribute.RunOptions` specifying
386        the options to run `fn`.
387
388    Returns:
389      Merged return value of `fn` across replicas. The structure of the return
390      value is the same as the return value from `fn`. Each element in the
391      structure can either be `tf.distribute.DistributedValues`, `Tensor`
392      objects, or `Tensor`s (for example, if running on a single replica).
393    """
394    validate_run_function(fn)
395
396    fn, args, kwargs = _maybe_partial_apply_variables(fn, args, kwargs)
397
398    # Note: the target function is converted to graph even when in Eager mode,
399    # so autograph is on by default here.
400    fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
401    options = options or distribute_lib.RunOptions()
402    return self.extended.tpu_run(fn, args, kwargs, options)
403
404  def experimental_assign_to_logical_device(self, tensor, logical_device_id):
405    """Adds annotation that `tensor` will be assigned to a logical device.
406
407    This adds an annotation to `tensor` specifying that operations on
408    `tensor` will be invoked on logical core device id `logical_device_id`.
409    When model parallelism is used, the default behavior is that all ops
410    are placed on zero-th logical device.
411
412    ```python
413
414    # Initializing TPU system with 2 logical devices and 4 replicas.
415    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
416    tf.config.experimental_connect_to_cluster(resolver)
417    topology = tf.tpu.experimental.initialize_tpu_system(resolver)
418    device_assignment = tf.tpu.experimental.DeviceAssignment.build(
419        topology,
420        computation_shape=[1, 1, 1, 2],
421        num_replicas=4)
422    strategy = tf.distribute.TPUStrategy(
423        resolver, experimental_device_assignment=device_assignment)
424    iterator = iter(inputs)
425
426    @tf.function()
427    def step_fn(inputs):
428      output = tf.add(inputs, inputs)
429
430      # Add operation will be executed on logical device 0.
431      output = strategy.experimental_assign_to_logical_device(output, 0)
432      return output
433
434    strategy.run(step_fn, args=(next(iterator),))
435    ```
436
437    Args:
438      tensor: Input tensor to annotate.
439      logical_device_id: Id of the logical core to which the tensor will be
440        assigned.
441
442    Raises:
443      ValueError: The logical device id presented is not consistent with total
444      number of partitions specified by the device assignment.
445
446    Returns:
447      Annotated tensor with identical value as `tensor`.
448    """
449    num_logical_devices_per_replica = self.extended._tpu_devices.shape[1]  # pylint: disable=protected-access
450    if (logical_device_id < 0 or
451        logical_device_id >= num_logical_devices_per_replica):
452      raise ValueError("`logical_core_id` to assign must be lower then total "
453                       "number of logical devices per replica. Received "
454                       "logical device id {} but there are only total of {} "
455                       "logical devices in replica.".format(
456                           logical_device_id, num_logical_devices_per_replica))
457    return xla_sharding.assign_device(
458        tensor, logical_device_id, use_sharding_op=True)
459
460  def experimental_split_to_logical_devices(self, tensor, partition_dimensions):
461    """Adds annotation that `tensor` will be split across logical devices.
462
463    This adds an annotation to tensor `tensor` specifying that operations on
464    `tensor` will be split among multiple logical devices. Tensor `tensor` will
465    be split across dimensions specified by `partition_dimensions`.
466    The dimensions of `tensor` must be divisible by corresponding value in
467    `partition_dimensions`.
468
469    For example, for system with 8 logical devices, if `tensor` is an image
470    tensor with shape (batch_size, width, height, channel) and
471    `partition_dimensions` is [1, 2, 4, 1], then `tensor` will be split
472    2 in width dimension and 4 way in height dimension and the split
473    tensor values will be fed into 8 logical devices.
474
475    ```python
476    # Initializing TPU system with 8 logical devices and 1 replica.
477    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
478    tf.config.experimental_connect_to_cluster(resolver)
479    topology = tf.tpu.experimental.initialize_tpu_system(resolver)
480    device_assignment = tf.tpu.experimental.DeviceAssignment.build(
481        topology,
482        computation_shape=[1, 2, 2, 2],
483        num_replicas=1)
484    strategy = tf.distribute.TPUStrategy(
485        resolver, experimental_device_assignment=device_assignment)
486
487    iterator = iter(inputs)
488
489    @tf.function()
490    def step_fn(inputs):
491      inputs = strategy.experimental_split_to_logical_devices(
492        inputs, [1, 2, 4, 1])
493
494      # model() function will be executed on 8 logical devices with `inputs`
495      # split 2 * 4  ways.
496      output = model(inputs)
497      return output
498
499    strategy.run(step_fn, args=(next(iterator),))
500    ```
501    Args:
502      tensor: Input tensor to annotate.
503      partition_dimensions: An unnested list of integers with the size equal to
504        rank of `tensor` specifying how `tensor` will be partitioned. The
505        product of all elements in `partition_dimensions` must be equal to the
506        total number of logical devices per replica.
507
508    Raises:
509      ValueError: 1) If the size of partition_dimensions does not equal to rank
510        of `tensor` or 2) if product of elements of `partition_dimensions` does
511        not match the number of logical devices per replica defined by the
512        implementing DistributionStrategy's device specification or
513        3) if a known size of `tensor` is not divisible by corresponding
514        value in `partition_dimensions`.
515
516    Returns:
517      Annotated tensor with identical value as `tensor`.
518    """
519    num_logical_devices_per_replica = self.extended._tpu_devices.shape[1]  # pylint: disable=protected-access
520    num_partition_splits = np.prod(partition_dimensions)
521    input_shape = tensor.shape
522    tensor_rank = len(input_shape)
523
524    if tensor_rank != len(partition_dimensions):
525      raise ValueError("Length of `partition_dimensions` ({}) must be  "
526                       "equal to the rank of `x` ({}).".format(
527                           len(partition_dimensions), tensor_rank))
528
529    for dim_index, dim_size in enumerate(input_shape):
530      if dim_size is None:
531        continue
532
533      split_size = partition_dimensions[dim_index]
534      if dim_size % split_size != 0:
535        raise ValueError("Tensor shape at dimension ({}) must be "
536                         "divisible by corresponding value specified "
537                         "by `partition_dimensions` ({}).".format(
538                             dim_index, split_size))
539
540    if num_partition_splits != num_logical_devices_per_replica:
541      raise ValueError("Number of logical devices ({}) does not match the "
542                       "number of partition splits specified ({}).".format(
543                           num_logical_devices_per_replica,
544                           num_partition_splits))
545
546    tile_assignment = np.arange(num_partition_splits).reshape(
547        partition_dimensions)
548    return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True)
549
550  def experimental_replicate_to_logical_devices(self, tensor):
551    """Adds annotation that `tensor` will be replicated to all logical devices.
552
553    This adds an annotation to tensor `tensor` specifying that operations on
554    `tensor` will be invoked on all logical devices.
555
556    ```python
557    # Initializing TPU system with 2 logical devices and 4 replicas.
558    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
559    tf.config.experimental_connect_to_cluster(resolver)
560    topology = tf.tpu.experimental.initialize_tpu_system(resolver)
561    device_assignment = tf.tpu.experimental.DeviceAssignment.build(
562        topology,
563        computation_shape=[1, 1, 1, 2],
564        num_replicas=4)
565    strategy = tf.distribute.TPUStrategy(
566        resolver, experimental_device_assignment=device_assignment)
567
568    iterator = iter(inputs)
569
570    @tf.function()
571    def step_fn(inputs):
572      images, labels = inputs
573      images = strategy.experimental_split_to_logical_devices(
574        inputs, [1, 2, 4, 1])
575
576      # model() function will be executed on 8 logical devices with `inputs`
577      # split 2 * 4  ways.
578      output = model(inputs)
579
580      # For loss calculation, all logical devices share the same logits
581      # and labels.
582      labels = strategy.experimental_replicate_to_logical_devices(labels)
583      output = strategy.experimental_replicate_to_logical_devices(output)
584      loss = loss_fn(labels, output)
585
586      return loss
587
588    strategy.run(step_fn, args=(next(iterator),))
589    ```
590    Args:
591      tensor: Input tensor to annotate.
592
593    Returns:
594      Annotated tensor with identical value as `tensor`.
595    """
596    return xla_sharding.replicate(tensor, use_sharding_op=True)
597
598
599@tf_export("distribute.experimental.TPUStrategy", v1=[])
600@deprecation.deprecated_endpoints("distribute.experimental.TPUStrategy")
601class TPUStrategy(distribute_lib.Strategy):
602  """Synchronous training on TPUs and TPU Pods.
603
604  To construct a TPUStrategy object, you need to run the
605  initialization code as below:
606
607  >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
608  >>> tf.config.experimental_connect_to_cluster(resolver)
609  >>> tf.tpu.experimental.initialize_tpu_system(resolver)
610  >>> strategy = tf.distribute.experimental.TPUStrategy(resolver)
611
612  While using distribution strategies, the variables created within the
613  strategy's scope will be replicated across all the replicas and can be kept in
614  sync using all-reduce algorithms.
615
616  To run TF2 programs on TPUs, you can either use `.compile` and
617  `.fit` APIs in `tf.keras` with TPUStrategy, or write your own customized
618  training loop by calling `strategy.run` directly. Note that
619  TPUStrategy doesn't support pure eager execution, so please make sure the
620  function passed into `strategy.run` is a `tf.function` or
621  `strategy.run` is called inside a `tf.function` if eager
622  behavior is enabled.
623  """
624
625  def __init__(self,
626               tpu_cluster_resolver=None,
627               device_assignment=None):
628    """Synchronous training in TPU donuts or Pods.
629
630    Args:
631      tpu_cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
632        which provides information about the TPU cluster.
633      device_assignment: Optional `tf.tpu.experimental.DeviceAssignment` to
634        specify the placement of replicas on the TPU cluster.
635    """
636    logging.warning(
637        "`tf.distribute.experimental.TPUStrategy` is deprecated, please use "
638        " the non experimental symbol `tf.distribute.TPUStrategy` instead.")
639
640    super(TPUStrategy, self).__init__(TPUExtended(
641        self, tpu_cluster_resolver, device_assignment=device_assignment))
642    distribute_lib.distribution_strategy_gauge.get_cell("V2").set("TPUStrategy")
643    distribute_lib.distribution_strategy_replica_gauge.get_cell(
644        "num_workers").set(self.extended.num_hosts)
645    distribute_lib.distribution_strategy_replica_gauge.get_cell(
646        "num_replicas_per_worker").set(self.extended.num_replicas_per_host)
647    # Packed variable is used to reduce the overhead of function execution.
648    # For a DistributedVariable, only one variable handle is captured into a
649    # function graph. It's only supported in eager mode.
650    self._enable_packed_variable_in_eager_mode = True
651
652  # TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this
653  # can use the default implementation.
654  # This implementation runs a single step. It does not use infeed or outfeed.
655  def run(self, fn, args=(), kwargs=None, options=None):
656    """See base class."""
657    validate_run_function(fn)
658
659    fn, args, kwargs = _maybe_partial_apply_variables(fn, args, kwargs)
660
661    # Note: the target function is converted to graph even when in Eager mode,
662    # so autograph is on by default here.
663    fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
664    options = options or distribute_lib.RunOptions()
665    return self.extended.tpu_run(fn, args, kwargs, options)
666
667  @property
668  def cluster_resolver(self):
669    """Returns the cluster resolver associated with this strategy.
670
671    `tf.distribute.experimental.TPUStrategy` provides the
672    associated `tf.distribute.cluster_resolver.ClusterResolver`. If the user
673    provides one in `__init__`, that instance is returned; if the user does
674    not, a default
675    `tf.distribute.cluster_resolver.TPUClusterResolver` is provided.
676    """
677    return self.extended._tpu_cluster_resolver  # pylint: disable=protected-access
678
679
680@tf_export(v1=["distribute.experimental.TPUStrategy"])
681class TPUStrategyV1(distribute_lib.StrategyV1):
682  """TPU distribution strategy implementation."""
683
684  def __init__(self,
685               tpu_cluster_resolver=None,
686               steps_per_run=None,
687               device_assignment=None):
688    """Initializes the TPUStrategy object.
689
690    Args:
691      tpu_cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
692          which provides information about the TPU cluster.
693      steps_per_run: Number of steps to run on device before returning to the
694          host. Note that this can have side-effects on performance, hooks,
695          metrics, summaries etc.
696          This parameter is only used when Distribution Strategy is used with
697          estimator or keras.
698      device_assignment: Optional `tf.tpu.experimental.DeviceAssignment` to
699          specify the placement of replicas on the TPU cluster. Currently only
700          supports the usecase of using a single core within a TPU cluster.
701    """
702    super(TPUStrategyV1, self).__init__(TPUExtended(
703        self, tpu_cluster_resolver, steps_per_run, device_assignment))
704    distribute_lib.distribution_strategy_gauge.get_cell("V1").set("TPUStrategy")
705    distribute_lib.distribution_strategy_replica_gauge.get_cell(
706        "num_workers").set(self.extended.num_hosts)
707    distribute_lib.distribution_strategy_replica_gauge.get_cell(
708        "num_replicas_per_worker").set(self.extended.num_replicas_per_host)
709    # Packed variable is used to reduce the overhead of function execution.
710    # For a DistributedVariable, only one variable handle is captured into a
711    # function graph. It's only supported in eager mode.
712    self._enable_packed_variable_in_eager_mode = True
713
714  @property
715  def steps_per_run(self):
716    """DEPRECATED: use .extended.steps_per_run instead."""
717    return self._extended.steps_per_run
718
719  # TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this
720  # can use the default implementation.
721  # This implementation runs a single step. It does not use infeed or outfeed.
722  def run(self, fn, args=(), kwargs=None, options=None):
723    """Run `fn` on each replica, with the given arguments.
724
725    Executes ops specified by `fn` on each replica. If `args` or `kwargs` have
726    "per-replica" values, such as those produced by a "distributed `Dataset`",
727    when `fn` is executed on a particular replica, it will be executed with the
728    component of those "per-replica" values that correspond to that replica.
729
730    `fn` may call `tf.distribute.get_replica_context()` to access members such
731    as `all_reduce`.
732
733    All arguments in `args` or `kwargs` should either be nest of tensors or
734    per-replica objects containing tensors or composite tensors.
735
736    Users can pass strategy specific options to `options` argument. An example
737    to enable bucketizing dynamic shapes in `TPUStrategy.run`
738    is:
739
740    >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
741    >>> tf.config.experimental_connect_to_cluster(resolver)
742    >>> tf.tpu.experimental.initialize_tpu_system(resolver)
743    >>> strategy = tf.distribute.experimental.TPUStrategy(resolver)
744
745    >>> options = tf.distribute.RunOptions(
746    ...     experimental_bucketizing_dynamic_shape=True)
747
748    >>> dataset = tf.data.Dataset.range(
749    ...    strategy.num_replicas_in_sync, output_type=dtypes.float32).batch(
750    ...        strategy.num_replicas_in_sync, drop_remainder=True)
751    >>> input_iterator = iter(strategy.experimental_distribute_dataset(dataset))
752
753    >>> @tf.function()
754    ... def step_fn(inputs):
755    ...  output = tf.reduce_sum(inputs)
756    ...  return output
757
758    >>> strategy.run(step_fn, args=(next(input_iterator),), options=options)
759
760    Args:
761      fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
762      args: (Optional) Positional arguments to `fn`.
763      kwargs: (Optional) Keyword arguments to `fn`.
764      options: (Optional) An instance of `tf.distribute.RunOptions` specifying
765        the options to run `fn`.
766
767    Returns:
768      Merged return value of `fn` across replicas. The structure of the return
769      value is the same as the return value from `fn`. Each element in the
770      structure can either be "per-replica" `Tensor` objects or `Tensor`s
771      (for example, if running on a single replica).
772    """
773    validate_run_function(fn)
774
775    fn, args, kwargs = _maybe_partial_apply_variables(fn, args, kwargs)
776
777    fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
778    options = options or distribute_lib.RunOptions()
779    return self.extended.tpu_run(fn, args, kwargs, options)
780
781
782# TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1.
783class TPUExtended(distribute_lib.StrategyExtendedV1):
784  """Implementation of TPUStrategy."""
785
786  def __init__(self,
787               container_strategy,
788               tpu_cluster_resolver=None,
789               steps_per_run=None,
790               device_assignment=None):
791    super(TPUExtended, self).__init__(container_strategy)
792
793    if tpu_cluster_resolver is None:
794      tpu_cluster_resolver = TPUClusterResolver("")
795
796    if steps_per_run is None:
797      # TODO(frankchn): Warn when we are being used by DS/Keras and this is
798      # not specified.
799      steps_per_run = 1
800
801    # `self._tpu_function_cache` is a dict of `tf.function`s, thus if a
802    # `tf.function` is passed into `strategy.run` in eager mode, the
803    # `tf.function` won't get retraced.
804    self._tpu_function_cache = weakref.WeakKeyDictionary()
805
806    self._tpu_cluster_resolver = tpu_cluster_resolver
807    self._tpu_metadata = self._tpu_cluster_resolver.get_tpu_system_metadata()
808    self._device_assignment = device_assignment
809
810    tpu_devices_flat = [
811        d.name for d in self._tpu_metadata.devices if "device:TPU:" in d.name]
812
813    # `self._tpu_devices` is a two-dimensional NumPy array of strings. It is
814    # indexed using `[replica_id][logical_device_id]`.
815    if device_assignment is None:
816      self._tpu_devices = np.array(
817          [[d] for d in tpu_devices_flat], dtype=object)
818    else:
819      job_name = device_spec.DeviceSpecV2.from_string(tpu_devices_flat[0]).job
820
821      tpu_devices = []
822      for replica_id in range(device_assignment.num_replicas):
823        replica_devices = []
824
825        for logical_core in range(device_assignment.num_cores_per_replica):
826          replica_devices.append(
827              device_util.canonicalize(
828                  device_assignment.tpu_device(
829                      replica=replica_id,
830                      logical_core=logical_core,
831                      job=job_name)))
832
833        tpu_devices.append(replica_devices)
834      self._tpu_devices = np.array(tpu_devices, dtype=object)
835
836    self._host_device = device_util.get_host_for_device(self._tpu_devices[0][0])
837
838    # Preload the data onto the TPUs. Currently we always preload onto logical
839    # device 0 for each replica.
840    # TODO(cjfj): Create `InputWorkers` lazily, allowing users to place the
841    # input onto a different logical device?
842    self._device_input_worker_devices = collections.OrderedDict()
843    self._host_input_worker_devices = collections.OrderedDict()
844    for tpu_device in self._tpu_devices[:, 0]:
845      host_device = device_util.get_host_for_device(tpu_device)
846      self._device_input_worker_devices.setdefault(host_device, [])
847      self._device_input_worker_devices[host_device].append(tpu_device)
848      self._host_input_worker_devices.setdefault(host_device, [])
849      self._host_input_worker_devices[host_device].append(host_device)
850
851    # TODO(sourabhbajaj): Remove this once performance of running one step
852    # at a time is comparable to multiple steps.
853    self.steps_per_run = steps_per_run
854    self._require_static_shapes = True
855
856    self.experimental_enable_get_next_as_optional = True
857
858    self._logical_device_stack = [0]
859
860    if context.executing_eagerly():
861      # In async remote eager, we want to sync the executors before exiting the
862      # program.
863      def async_wait():
864        if context.context()._context_handle is not None:  # pylint: disable=protected-access
865          context.async_wait()
866      atexit.register(async_wait)
867
868    # Flag to turn on VariablePolicy
869    self._use_var_policy = True
870
871    # Flag to enable TF2 SPMD
872    self._use_spmd_for_xla_partitioning = False
873
874  def _validate_colocate_with_variable(self, colocate_with_variable):
875    distribute_utils. validate_colocate(colocate_with_variable, self)
876
877  def _make_dataset_iterator(self, dataset):
878    """Make iterators for each of the TPU hosts."""
879    input_workers = input_lib.InputWorkers(
880        tuple(self._device_input_worker_devices.items()))
881    return input_lib.DatasetIterator(
882        dataset,
883        input_workers,
884        self._container_strategy(),
885        num_replicas_in_sync=self._num_replicas_in_sync)
886
887  def _make_input_fn_iterator(
888      self,
889      input_fn,
890      replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
891    input_contexts = []
892    input_workers = input_lib.InputWorkers(
893        tuple(self._device_input_worker_devices.items()))
894    num_workers = input_workers.num_workers
895    for i in range(num_workers):
896      input_contexts.append(distribute_lib.InputContext(
897          num_input_pipelines=num_workers,
898          input_pipeline_id=i,
899          num_replicas_in_sync=self._num_replicas_in_sync))
900    return input_lib.InputFunctionIterator(
901        input_fn,
902        input_workers,
903        input_contexts,
904        self._container_strategy())
905
906  def _experimental_make_numpy_dataset(self, numpy_input, session):
907    return numpy_dataset.one_host_numpy_dataset(
908        numpy_input, numpy_dataset.SingleDevice(self._host_device),
909        session)
910
911  def _get_input_workers(self, options):
912    if not options or options.experimental_prefetch_to_device:
913      return input_lib.InputWorkers(
914          tuple(self._device_input_worker_devices.items()))
915    else:
916      return input_lib.InputWorkers(
917          tuple(self._host_input_worker_devices.items()))
918
919  def _check_spec(self, element_spec):
920    if isinstance(element_spec, values.PerReplicaSpec):
921      element_spec = element_spec._component_specs  # pylint: disable=protected-access
922    specs = nest.flatten_with_joined_string_paths(element_spec)
923    for path, spec in specs:
924      if isinstance(spec, (sparse_tensor.SparseTensorSpec,
925                           ragged_tensor.RaggedTensorSpec)):
926        raise ValueError(
927            "Found tensor {} with spec {}. TPUStrategy does not support "
928            "distributed datasets with device prefetch when using sparse or "
929            "ragged tensors. If you intend to use sparse or ragged tensors, "
930            "please pass a tf.distribute.InputOptions object with "
931            "experimental_prefetch_to_device set to False to your dataset "
932            "distribution function.".format(path, type(spec)))
933
934  def _experimental_distribute_dataset(self, dataset, options):
935    if (options and options.experimental_replication_mode ==
936        distribute_lib.InputReplicationMode.PER_REPLICA):
937      raise NotImplementedError(
938          "InputReplicationMode.PER_REPLICA "
939          "is only supported in "
940          "`experimental_distribute_datasets_from_function`."
941      )
942    if options is None or options.experimental_prefetch_to_device:
943      self._check_spec(dataset.element_spec)
944
945    return input_lib.get_distributed_dataset(
946        dataset,
947        self._get_input_workers(options),
948        self._container_strategy(),
949        num_replicas_in_sync=self._num_replicas_in_sync)
950
951  def _distribute_datasets_from_function(self, dataset_fn, options):
952    if (options and options.experimental_replication_mode ==
953        distribute_lib.InputReplicationMode.PER_REPLICA):
954      raise NotImplementedError(
955          "InputReplicationMode.PER_REPLICA "
956          "is only supported in "
957          " `experimental_distribute_datasets_from_function` "
958          "of tf.distribute.MirroredStrategy")
959    input_workers = self._get_input_workers(options)
960    input_contexts = []
961    num_workers = input_workers.num_workers
962    for i in range(num_workers):
963      input_contexts.append(distribute_lib.InputContext(
964          num_input_pipelines=num_workers,
965          input_pipeline_id=i,
966          num_replicas_in_sync=self._num_replicas_in_sync))
967
968    distributed_dataset = input_lib.get_distributed_datasets_from_function(
969        dataset_fn,
970        input_workers,
971        input_contexts,
972        self._container_strategy())
973
974    # We can only check after the dataset_fn is called.
975    if options is None or options.experimental_prefetch_to_device:
976      self._check_spec(distributed_dataset.element_spec)
977    return distributed_dataset
978
979  def _experimental_distribute_values_from_function(self, value_fn):
980    per_replica_values = []
981    for replica_id in range(self._num_replicas_in_sync):
982      per_replica_values.append(
983          value_fn(distribute_lib.ValueContext(replica_id,
984                                               self._num_replicas_in_sync)))
985    return distribute_utils.regroup(per_replica_values, always_wrap=True)
986
987  # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
988  # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have
989  # a mechanism to infer the outputs of `fn`. Pending b/110550782.
990  def _experimental_run_steps_on_iterator(
991      self, fn, multi_worker_iterator, iterations, initial_loop_values=None):
992    # Wrap `fn` for repeat.
993    if initial_loop_values is None:
994      initial_loop_values = {}
995    initial_loop_values = nest.flatten(initial_loop_values)
996    ctx = input_lib.MultiStepContext()
997
998    def run_fn(inputs):
999      """Single step on the TPU device."""
1000      fn_result = fn(ctx, inputs)
1001      flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
1002      if flat_last_step_outputs:
1003        with ops.control_dependencies([fn_result]):
1004          return [array_ops.identity(f) for f in flat_last_step_outputs]
1005      else:
1006        return fn_result
1007
1008    # We capture the control_flow_context at this point, before we run `fn`
1009    # inside a while_loop and TPU replicate context. This is useful in cases
1010    # where we might need to exit these contexts and get back to the outer
1011    # context to do some things, for e.g. create an op which should be
1012    # evaluated only once at the end of the loop on the host. One such usage
1013    # is in creating metrics' value op.
1014    self._outer_control_flow_context = (
1015        ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access
1016
1017    def rewrite_fn(*args):
1018      """The rewritten step fn running on TPU."""
1019      del args
1020
1021      per_replica_inputs = multi_worker_iterator.get_next()
1022      replicate_inputs = []
1023      for replica_id in range(self._num_replicas_in_sync):
1024        select_replica = lambda x: distribute_utils.select_replica(  # pylint: disable=g-long-lambda
1025            replica_id, x)   # pylint: disable=cell-var-from-loop
1026        replicate_inputs.append((nest.map_structure(
1027            select_replica, per_replica_inputs),))
1028
1029      replicate_outputs = tpu.replicate(
1030          run_fn,
1031          replicate_inputs,
1032          device_assignment=self._device_assignment,
1033          xla_options=tpu.XLAOptions(use_spmd_for_xla_partitioning=self
1034                                     ._use_spmd_for_xla_partitioning))
1035      # If run_fn has tensor outputs, tpu.replicate returns a list of list. We
1036      # will flatten it in this case. If run_fn has no tensor outputs,
1037      # tpu.replicate returns a list of no_ops, we will keep the output as it
1038      # is.
1039      if isinstance(replicate_outputs[0], list):
1040        replicate_outputs = nest.flatten(replicate_outputs)
1041
1042      return replicate_outputs
1043
1044    # TODO(sourabhbajaj): The input to while loop should be based on the
1045    # output type of the step_fn
1046    assert isinstance(initial_loop_values, list)
1047    initial_loop_values = initial_loop_values * self._num_replicas_in_sync
1048
1049    # Put the while loop op on TPU host 0.
1050    with ops.device(self._host_device):
1051      if self.steps_per_run == 1:
1052        replicate_outputs = rewrite_fn()
1053      else:
1054        replicate_outputs = training_loop.repeat(iterations, rewrite_fn,
1055                                                 initial_loop_values)
1056
1057    del self._outer_control_flow_context
1058    ctx.run_op = control_flow_ops.group(replicate_outputs)
1059
1060    if isinstance(replicate_outputs, list):
1061      # Filter out any ops from the outputs, typically this would be the case
1062      # when there were no tensor outputs.
1063      last_step_tensor_outputs = [
1064          x for x in replicate_outputs if not isinstance(x, ops.Operation)
1065      ]
1066
1067      # Outputs are currently of the structure (flattened)
1068      # [output0_device0, output1_device0, output2_device0,
1069      #  output0_device1, output1_device1, output2_device1,
1070      #  ...]
1071      # Convert this to the following structure instead: (grouped by output)
1072      # [[output0_device0, output0_device1],
1073      #  [output1_device0, output1_device1],
1074      #  [output2_device0, output2_device1]]
1075      output_num = len(last_step_tensor_outputs) // self._num_replicas_in_sync
1076      last_step_tensor_outputs = [
1077          last_step_tensor_outputs[i::output_num] for i in range(output_num)
1078      ]
1079    else:
1080      # no tensors returned.
1081      last_step_tensor_outputs = []
1082
1083    _set_last_step_outputs(ctx, last_step_tensor_outputs)
1084    return ctx
1085
1086  def _call_for_each_replica(self, fn, args, kwargs):
1087    # TODO(jhseu): Consider making it so call_for_each_replica implies that
1088    # we're in a tpu.rewrite(), and update TPUMirroredVariable accordingly.
1089    with _TPUReplicaContext(self._container_strategy()):
1090      return fn(*args, **kwargs)
1091
1092  @contextlib.contextmanager
1093  def experimental_logical_device(self, logical_device_id):
1094    """Places variables and ops on the specified logical device."""
1095    num_logical_devices_per_replica = self._tpu_devices.shape[1]
1096    if logical_device_id >= num_logical_devices_per_replica:
1097      raise ValueError(
1098          "`logical_device_id` not in range (was {}, but there are only {} "
1099          "logical devices per replica).".format(
1100              logical_device_id, num_logical_devices_per_replica))
1101
1102    self._logical_device_stack.append(logical_device_id)
1103    try:
1104      if tpu_util.enclosing_tpu_context() is None:
1105        yield
1106      else:
1107        with ops.device(tpu.core(logical_device_id)):
1108          yield
1109    finally:
1110      self._logical_device_stack.pop()
1111
1112  def _experimental_initialize_system(self):
1113    """Experimental method added to be used by Estimator.
1114
1115    This is a private method only to be used by Estimator. Other frameworks
1116    should directly be calling `tf.tpu.experimental.initialize_tpu_system`
1117    """
1118    tpu_strategy_util.initialize_tpu_system(self._tpu_cluster_resolver)
1119
1120  def _create_variable(self, next_creator, **kwargs):
1121    """Create a TPUMirroredVariable. See `DistributionStrategy.scope`."""
1122    if kwargs.pop("skip_mirrored_creator", False):
1123      return next_creator(**kwargs)
1124
1125    colocate_with = kwargs.pop("colocate_with", None)
1126    if colocate_with is None:
1127      devices = self._tpu_devices[:, self._logical_device_stack[-1]]
1128    elif isinstance(colocate_with, numpy_dataset.SingleDevice):
1129      with ops.device(colocate_with.device):
1130        return next_creator(**kwargs)
1131    else:
1132      devices = colocate_with._devices  # pylint: disable=protected-access
1133
1134    def _real_mirrored_creator(**kwargs):  # pylint: disable=g-missing-docstring
1135      initial_value = None
1136      value_list = []
1137      for i, d in enumerate(devices):
1138        with ops.device(d):
1139          if i == 0:
1140            initial_value = kwargs["initial_value"]
1141            # Note: some v1 code expects variable initializer creation to happen
1142            # inside a init_scope.
1143            with maybe_init_scope():
1144              initial_value = initial_value() if callable(
1145                  initial_value) else initial_value
1146
1147          if i > 0:
1148            # Give replicas meaningful distinct names:
1149            var0name = value_list[0].name.split(":")[0]
1150            # We append a / to variable names created on replicas with id > 0 to
1151            # ensure that we ignore the name scope and instead use the given
1152            # name as the absolute name of the variable.
1153            kwargs["name"] = "%s/replica_%d/" % (var0name, i)
1154          kwargs["initial_value"] = initial_value
1155
1156          with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
1157            v = next_creator(**kwargs)
1158
1159          assert not isinstance(v, tpu_values.TPUMirroredVariable)
1160          value_list.append(v)
1161      return value_list
1162
1163    return distribute_utils.create_mirrored_variable(
1164        self._container_strategy(), _real_mirrored_creator,
1165        distribute_utils.TPU_VARIABLE_CLASS_MAPPING,
1166        distribute_utils.TPU_VARIABLE_POLICY_MAPPING, **kwargs)
1167
1168  def _gather_to_implementation(self, value, destinations, axis, options):
1169    if not isinstance(value, values.DistributedValues):
1170      return value
1171
1172    value_list = value.values
1173    # pylint: disable=protected-access
1174    if isinstance(
1175        value,
1176        values.DistributedVariable) and value._packed_variable is not None:
1177      value_list = tuple(
1178          value._packed_variable.on_device(d)
1179          for d in value._packed_variable.devices)
1180    # pylint: enable=protected-access
1181
1182    # Currently XLA op by op mode has a limit for the number of inputs for a
1183    # single op, thus we break one `add_n` op into a group of `add_n` ops to
1184    # work around the constraint.
1185    if len(value.values) <= _XLA_OP_BY_OP_INPUTS_LIMIT:
1186      output = array_ops.concat(value_list, axis=axis)
1187    else:
1188      output = array_ops.concat(
1189          value_list[:_XLA_OP_BY_OP_INPUTS_LIMIT], axis=axis)
1190      for i in range(_XLA_OP_BY_OP_INPUTS_LIMIT, len(value_list),
1191                     _XLA_OP_BY_OP_INPUTS_LIMIT - 1):
1192        output = array_ops.concat(
1193            [output] + value_list[i:i + _XLA_OP_BY_OP_INPUTS_LIMIT - 1],
1194            axis=axis)
1195
1196    output = self._broadcast_output(destinations, output)
1197    return output
1198
1199  def _broadcast_output(self, destinations, output):
1200    devices = cross_device_ops_lib.get_devices_from(destinations)
1201
1202    if len(devices) == 1:
1203      # If necessary, copy to requested destination.
1204      dest_canonical = device_util.canonicalize(devices[0])
1205      host_canonical = device_util.canonicalize(self._host_device)
1206
1207      if dest_canonical != host_canonical:
1208        with ops.device(dest_canonical):
1209          output = array_ops.identity(output)
1210    else:
1211      output = cross_device_ops_lib.simple_broadcast(output, destinations)
1212
1213    return output
1214
1215  def _reduce_to(self, reduce_op, value, destinations, options):
1216    if (isinstance(value, values.DistributedValues) or
1217        tensor_util.is_tf_type(value)
1218       ) and tpu_util.enclosing_tpu_context() is not None:
1219      if reduce_op == reduce_util.ReduceOp.MEAN:
1220        # TODO(jhseu):  Revisit once we support model-parallelism.
1221        value *= (1. / self._num_replicas_in_sync)
1222      elif reduce_op != reduce_util.ReduceOp.SUM:
1223        raise NotImplementedError(
1224            "Currently only support sum & mean in TPUStrategy.")
1225      return tpu_ops.cross_replica_sum(value)
1226
1227    if not isinstance(value, values.DistributedValues):
1228      # This function handles reducing values that are not PerReplica or
1229      # Mirrored values. For example, the same value could be present on all
1230      # replicas in which case `value` would be a single value or value could
1231      # be 0.
1232      return cross_device_ops_lib.reduce_non_distributed_value(
1233          reduce_op, value, destinations, self._num_replicas_in_sync)
1234
1235    value_list = value.values
1236    # pylint: disable=protected-access
1237    if isinstance(
1238        value,
1239        values.DistributedVariable) and value._packed_variable is not None:
1240      value_list = tuple(
1241          value._packed_variable.on_device(d)
1242          for d in value._packed_variable.devices)
1243    # pylint: enable=protected-access
1244
1245    # Currently XLA op by op mode has a limit for the number of inputs for a
1246    # single op, thus we break one `add_n` op into a group of `add_n` ops to
1247    # work around the constraint.
1248    # TODO(cjfj): Detect when it is possible to use `cross_replica_sum`.
1249    if len(value.values) <= _XLA_OP_BY_OP_INPUTS_LIMIT:
1250      output = math_ops.add_n(value_list)
1251    else:
1252      output = array_ops.zeros_like(value_list[0], dtype=value_list[0].dtype)
1253      for i in range(0, len(value_list), _XLA_OP_BY_OP_INPUTS_LIMIT):
1254        output += math_ops.add_n(value_list[i:i + _XLA_OP_BY_OP_INPUTS_LIMIT])
1255
1256    if reduce_op == reduce_util.ReduceOp.MEAN:
1257      output *= (1. / len(value_list))
1258
1259    output = self._broadcast_output(destinations, output)
1260    return output
1261
1262  def _update(self, var, fn, args, kwargs, group):
1263    assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance(
1264        var, resource_variable_ops.BaseResourceVariable)
1265    if tpu_util.enclosing_tpu_context() is not None:
1266      if group:
1267        return fn(var, *args, **kwargs)
1268      else:
1269        return (fn(var, *args, **kwargs),)
1270
1271    # Otherwise, we revert to MirroredStrategy behavior and update the variable
1272    # on each replica directly.
1273    updates = []
1274    values_and_devices = []
1275    packed_var = var._packed_variable  # pylint: disable=protected-access
1276    if packed_var is not None:
1277      for device in packed_var.devices:
1278        values_and_devices.append((packed_var, device))
1279    else:
1280      for value in var.values:
1281        values_and_devices.append((value, value.device))
1282
1283    if (var.synchronization != variables_lib.VariableSynchronization.ON_READ and
1284        var.aggregation != variables_lib.VariableAggregation.NONE):
1285      distribute_utils.assert_mirrored(args)
1286      distribute_utils.assert_mirrored(kwargs)
1287    for i, value_and_device in enumerate(values_and_devices):
1288      value = value_and_device[0]
1289      device = value_and_device[1]
1290      name = "update_%d" % i
1291      with ops.device(device), \
1292           distribute_lib.UpdateContext(i), \
1293           ops.name_scope(name):
1294        # If args and kwargs are not mirrored, the value is returned as is.
1295        updates.append(
1296            fn(value, *distribute_utils.select_replica(i, args),
1297               **distribute_utils.select_replica(i, kwargs)))
1298    return distribute_utils.update_regroup(self, updates, group)
1299
1300  def read_var(self, var):
1301    assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance(
1302        var, resource_variable_ops.BaseResourceVariable)
1303    return var.read_value()
1304
1305  def _local_results(self, val):
1306    if isinstance(val, values.DistributedValues):
1307      return val.values
1308    return (val,)
1309
1310  def value_container(self, value):
1311    return value
1312
1313  def _broadcast_to(self, tensor, destinations):
1314    del destinations
1315    # This is both a fast path for Python constants, and a way to delay
1316    # converting Python values to a tensor until we know what type it
1317    # should be converted to. Otherwise we have trouble with:
1318    #   global_step.assign_add(1)
1319    # since the `1` gets broadcast as an int32 but global_step is int64.
1320    if isinstance(tensor, (float, int)):
1321      return tensor
1322    if tpu_util.enclosing_tpu_context() is not None:
1323      broadcast_tensor = [tensor for _ in range(self._num_replicas_in_sync)]
1324      result = tpu_ops.all_to_all(
1325          broadcast_tensor,
1326          concat_dimension=0,
1327          split_dimension=0,
1328          split_count=self._num_replicas_in_sync)
1329
1330      # This uses the broadcasted value from the first replica because the only
1331      # caller of this is for ONLY_FIRST_REPLICA variables aggregation.
1332      return result[0]
1333    return tensor
1334
1335  @property
1336  def num_hosts(self):
1337    if self._device_assignment is None:
1338      return self._tpu_metadata.num_hosts
1339
1340    return len(set([self._device_assignment.host_device(r)
1341                    for r in range(self._device_assignment.num_replicas)]))
1342
1343  @property
1344  def num_replicas_per_host(self):
1345    if self._device_assignment is None:
1346      return self._tpu_metadata.num_of_cores_per_host
1347
1348    # TODO(sourabhbajaj): Remove this method we use inputs and remove infeed
1349    # as the computation of num_replicas_per_host is not a constant
1350    # when using device_assignment. This is a temporary workaround to support
1351    # StatefulRNN as everything is 1 in that case.
1352    # This method needs to take host_id as input for correct computation.
1353    max_models_per_host = (self._tpu_metadata.num_of_cores_per_host //
1354                           self._device_assignment.num_cores_per_replica)
1355    return min(self._device_assignment.num_replicas, max_models_per_host)
1356
1357  @property
1358  def _num_replicas_in_sync(self):
1359    if self._device_assignment is None:
1360      return self._tpu_metadata.num_cores
1361    return self._device_assignment.num_replicas
1362
1363  @property
1364  def experimental_between_graph(self):
1365    return False
1366
1367  @property
1368  def experimental_should_init(self):
1369    return True
1370
1371  @property
1372  def should_checkpoint(self):
1373    return True
1374
1375  @property
1376  def should_save_summary(self):
1377    return True
1378
1379  @property
1380  def worker_devices(self):
1381    return tuple(self._tpu_devices[:, self._logical_device_stack[-1]])
1382
1383  @property
1384  def parameter_devices(self):
1385    return self.worker_devices
1386
1387  def non_slot_devices(self, var_list):
1388    return self._host_device
1389
1390  def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
1391    del colocate_with
1392    with ops.device(self._host_device), distribute_lib.UpdateContext(None):
1393      result = fn(*args, **kwargs)
1394      if group:
1395        return result
1396      else:
1397        return nest.map_structure(self._local_results, result)
1398
1399  def _configure(self,
1400                 session_config=None,
1401                 cluster_spec=None,
1402                 task_type=None,
1403                 task_id=None):
1404    del cluster_spec, task_type, task_id
1405    if session_config:
1406      session_config.CopyFrom(self._update_config_proto(session_config))
1407
1408  def _update_config_proto(self, config_proto):
1409    updated_config = copy.deepcopy(config_proto)
1410    updated_config.isolate_session_state = True
1411    cluster_spec = self._tpu_cluster_resolver.cluster_spec()
1412    if cluster_spec:
1413      updated_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
1414    return updated_config
1415
1416  # TODO(priyag): Delete this once all strategies use global batch size.
1417  @property
1418  def _global_batch_size(self):
1419    """`make_dataset_iterator` and `make_numpy_iterator` use global batch size.
1420
1421    `make_input_fn_iterator` assumes per-replica batching.
1422
1423    Returns:
1424      Boolean.
1425    """
1426    return True
1427
1428  def tpu_run(self, fn, args, kwargs, options=None):
1429    func = self._tpu_function_creator(fn, options)
1430    return func(args, kwargs)
1431
1432  def _tpu_function_creator(self, fn, options):
1433    if context.executing_eagerly() and fn in self._tpu_function_cache:
1434      return self._tpu_function_cache[fn]
1435
1436    strategy = self._container_strategy()
1437
1438    def tpu_function(args, kwargs):
1439      """TF Function used to replicate the user computation."""
1440      if kwargs is None:
1441        kwargs = {}
1442
1443      # Used to re-structure flattened output tensors from `tpu.replicate()`
1444      # into a structured format.
1445      result = [[]]
1446
1447      def replicated_fn(replica_id, replica_args, replica_kwargs):
1448        """Wraps user function to provide replica ID and `Tensor` inputs."""
1449        with _TPUReplicaContext(strategy, replica_id_in_sync_group=replica_id):
1450          result[0] = fn(*replica_args, **replica_kwargs)
1451        return result[0]
1452
1453      replicate_inputs = []  # By replica.
1454      for i in range(strategy.num_replicas_in_sync):
1455        replicate_inputs.append(
1456            [constant_op.constant(i, dtype=dtypes.int32),
1457             distribute_utils.select_replica(i, args),
1458             distribute_utils.select_replica(i, kwargs)])
1459
1460      # Construct and pass `maximum_shapes` so that we could support dynamic
1461      # shapes using dynamic padder.
1462      if options.experimental_enable_dynamic_batch_size and replicate_inputs:
1463        maximum_shapes = []
1464        flattened_list = nest.flatten(replicate_inputs[0])
1465        for input_tensor in flattened_list:
1466          if tensor_util.is_tf_type(input_tensor):
1467            rank = input_tensor.shape.rank
1468          else:
1469            rank = np.ndim(input_tensor)
1470          maximum_shape = tensor_shape.TensorShape([None] * rank)
1471          maximum_shapes.append(maximum_shape)
1472        maximum_shapes = nest.pack_sequence_as(replicate_inputs[0],
1473                                               maximum_shapes)
1474      else:
1475        maximum_shapes = None
1476
1477      if options.experimental_bucketizing_dynamic_shape:
1478        padding_spec = tpu.PaddingSpec.POWER_OF_TWO
1479      else:
1480        padding_spec = None
1481
1482      with strategy.scope():
1483        replicate_outputs = tpu.replicate(
1484            replicated_fn,
1485            replicate_inputs,
1486            device_assignment=self._device_assignment,
1487            maximum_shapes=maximum_shapes,
1488            padding_spec=padding_spec,
1489            xla_options=tpu.XLAOptions(use_spmd_for_xla_partitioning=self
1490                                       ._use_spmd_for_xla_partitioning))
1491
1492      # Remove all no ops that may have been added during 'tpu.replicate()'
1493      if isinstance(result[0], list):
1494        result[0] = [
1495            output for output in result[0] if not isinstance(
1496                output, ops.Operation)
1497        ]
1498
1499      # Workaround for `tpu.replicate` behaviour when single `Tensor` returned.
1500      if result[0] is None or isinstance(result[0], ops.Operation):
1501        replicate_outputs = [None] * len(replicate_outputs)
1502      else:
1503        replicate_outputs = [
1504            nest.pack_sequence_as(result[0], nest.flatten(replica_output))
1505            for replica_output in replicate_outputs
1506        ]
1507      return distribute_utils.regroup(replicate_outputs)
1508
1509    if context.executing_eagerly():
1510      tpu_function = def_function.function(tpu_function)
1511      self._tpu_function_cache[fn] = tpu_function
1512    return tpu_function
1513
1514  def _in_multi_worker_mode(self):
1515    """Whether this strategy indicates working in multi-worker settings."""
1516    # TPUStrategy has different distributed training structure that the whole
1517    # cluster should be treated as single worker from higher-level (e.g. Keras)
1518    # library's point of view.
1519    # TODO(rchao): Revisit this as we design a fault-tolerance solution for
1520    # TPUStrategy.
1521    return False
1522
1523  def _get_local_replica_id(self, replica_id_in_sync_group):
1524    return replica_id_in_sync_group
1525
1526
1527class _TPUReplicaContext(distribute_lib.ReplicaContext):
1528  """Replication Context class for TPU Strategy."""
1529
1530  # TODO(sourabhbajaj): Call for each replica should be updating this.
1531  # TODO(b/118385803): Always properly initialize replica_id.
1532  def __init__(self, strategy, replica_id_in_sync_group=0):
1533    distribute_lib.ReplicaContext.__init__(
1534        self, strategy, replica_id_in_sync_group=replica_id_in_sync_group)
1535
1536  @property
1537  def devices(self):
1538    distribute_lib.require_replica_context(self)
1539    ds = self._strategy
1540    replica_id = tensor_util.constant_value(self.replica_id_in_sync_group)
1541
1542    if replica_id is None:  # Non-constant `Tensor` inside `tpu.replicate`.
1543      # TODO(cjfj): Return other devices when model parallelism is supported.
1544      return (tpu.core(0),)
1545    else:
1546      return (ds.extended.worker_devices[replica_id],)
1547
1548  def experimental_logical_device(self, logical_device_id):
1549    """Places variables and ops on the specified logical device."""
1550    return self.strategy.extended.experimental_logical_device(logical_device_id)
1551
1552  # TODO(wxinyi): Investigate whether to use cross_replica_sum to optimize it.
1553  def all_gather(self, value, axis, experimental_hints=None):
1554    del experimental_hints
1555    for v in nest.flatten(value):
1556      if isinstance(v, ops.IndexedSlices):
1557        raise NotImplementedError("all_gather does not support IndexedSlices")
1558
1559    def _all_to_all(value, axis):
1560      # The underlying AllToAllOp first do a split of the input value and then
1561      # cross-replica communication and concatenation of the result. So we
1562      # concatenate the local tensor here first.
1563      inputs = array_ops.concat(
1564          [value for _ in range(self.num_replicas_in_sync)], axis=0)
1565      unordered_output = tpu_ops.all_to_all(
1566          inputs,
1567          concat_dimension=axis,
1568          split_dimension=0,
1569          split_count=self.num_replicas_in_sync)
1570
1571      # Re-order since xla.replica_id and ReplicaContext.replica_id mismatch.
1572      # xla_id = xla.replica_id()
1573      concat_replica_id = array_ops.concat([
1574          array_ops.expand_dims_v2(self.replica_id_in_sync_group, 0)
1575          for _ in range(self.num_replicas_in_sync)
1576      ],
1577                                           axis=0)
1578      replica_ids = tpu_ops.all_to_all(
1579          concat_replica_id,
1580          concat_dimension=0,
1581          split_dimension=0,
1582          split_count=self.num_replicas_in_sync)
1583
1584      splited_unordered = array_ops.split(
1585          unordered_output,
1586          num_or_size_splits=self.num_replicas_in_sync,
1587          axis=axis)
1588      sorted_with_extra_dim = math_ops.unsorted_segment_sum(
1589          array_ops.concat([
1590              array_ops.expand_dims(replica, axis=0)
1591              for replica in splited_unordered
1592          ],
1593                           axis=0),
1594          replica_ids,
1595          num_segments=self.num_replicas_in_sync)
1596
1597      splited_with_extra_dim = array_ops.split(
1598          sorted_with_extra_dim,
1599          num_or_size_splits=self.num_replicas_in_sync,
1600          axis=0)
1601      squeezed = [
1602          array_ops.squeeze(replica, axis=0)
1603          for replica in splited_with_extra_dim
1604      ]
1605      result = array_ops.concat(squeezed, axis=axis)
1606      return result
1607
1608    ys = [_all_to_all(t, axis=axis) for t in nest.flatten(value)]
1609    return nest.pack_sequence_as(value, ys)
1610
1611
1612def _set_last_step_outputs(ctx, last_step_tensor_outputs):
1613  """Sets the last step outputs on the given context."""
1614  # Convert replicate_outputs to the original dict structure of
1615  # last_step_outputs.
1616  last_step_tensor_outputs_dict = nest.pack_sequence_as(
1617      ctx.last_step_outputs, last_step_tensor_outputs)
1618
1619  for name, reduce_op in ctx._last_step_outputs_reduce_ops.items():  # pylint: disable=protected-access
1620    output = last_step_tensor_outputs_dict[name]
1621    # For outputs that aren't reduced, return a PerReplica of all values. Else
1622    # take the first value from the list as each value should be the same.
1623    if reduce_op is None:
1624      last_step_tensor_outputs_dict[name] = values.PerReplica(output)
1625    else:
1626      # TODO(priyag): Should this return the element or a list with 1 element
1627      last_step_tensor_outputs_dict[name] = output[0]
1628  ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access
1629