• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TPU Strategy."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import atexit
22import collections
23import contextlib
24import copy
25import functools
26import weakref
27
28from absl import logging
29import numpy as np
30
31from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
32from tensorflow.python.autograph.core import ag_ctx as autograph_ctx
33from tensorflow.python.autograph.impl import api as autograph
34from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
35from tensorflow.python.distribute import device_util
36from tensorflow.python.distribute import distribute_lib
37from tensorflow.python.distribute import distribute_utils
38from tensorflow.python.distribute import input_lib
39from tensorflow.python.distribute import numpy_dataset
40from tensorflow.python.distribute import reduce_util
41from tensorflow.python.distribute import tpu_util
42from tensorflow.python.distribute import tpu_values
43from tensorflow.python.distribute import values
44from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver
45from tensorflow.python.eager import context
46from tensorflow.python.eager import def_function
47from tensorflow.python.eager import function
48from tensorflow.python.framework import constant_op
49from tensorflow.python.framework import device_spec
50from tensorflow.python.framework import dtypes
51from tensorflow.python.framework import ops
52from tensorflow.python.framework import sparse_tensor
53from tensorflow.python.framework import tensor_shape
54from tensorflow.python.framework import tensor_util
55from tensorflow.python.ops import array_ops
56from tensorflow.python.ops import control_flow_ops
57from tensorflow.python.ops import math_ops
58from tensorflow.python.ops import resource_variable_ops
59from tensorflow.python.ops import variables as variables_lib
60from tensorflow.python.ops.ragged import ragged_tensor
61from tensorflow.python.tpu import device_assignment as device_assignment_lib  # pylint: disable=unused-import
62from tensorflow.python.tpu import tpu
63from tensorflow.python.tpu import tpu_strategy_util
64from tensorflow.python.tpu import training_loop
65from tensorflow.python.tpu.ops import tpu_ops
66from tensorflow.python.util import deprecation
67from tensorflow.python.util import nest
68from tensorflow.python.util import tf_inspect
69from tensorflow.python.util.tf_export import tf_export
70
71
72_XLA_OP_BY_OP_INPUTS_LIMIT = 200
73
74
75@contextlib.contextmanager
76def maybe_init_scope():
77  if ops.executing_eagerly_outside_functions():
78    yield
79  else:
80    with ops.init_scope():
81      yield
82
83
84def validate_run_function(fn):
85  """Validate the function passed into strategy.run."""
86
87  # We allow three types of functions/objects passed into TPUStrategy
88  # run in eager mode:
89  #   1. a user annotated tf.function
90  #   2. a ConcreteFunction, this is mostly what you get from loading a saved
91  #      model.
92  #   3. a callable object and the `__call__` method itself is a tf.function.
93  #
94  # Otherwise we return an error, because we don't support eagerly running
95  # run in TPUStrategy.
96
97  if context.executing_eagerly() \
98      and not isinstance(fn, def_function.Function) \
99      and not isinstance(fn, function.ConcreteFunction) \
100      and not (callable(fn) and isinstance(fn.__call__, def_function.Function)):
101    raise NotImplementedError(
102        "TPUStrategy.run(fn, ...) does not support pure eager "
103        "execution. please make sure the function passed into "
104        "`strategy.run` is a `tf.function` or "
105        "`strategy.run` is called inside a `tf.function` if "
106        "eager behavior is enabled.")
107
108
109def _maybe_partial_apply_variables(fn, args, kwargs):
110  """Inspects arguments to partially apply any DistributedVariable.
111
112  This avoids an automatic cast of the current variable value to tensor.
113
114  Note that a variable may be captured implicitly with Python scope instead of
115  passing it to run(), but supporting run() keeps behavior consistent
116  with MirroredStrategy.
117
118  Since positional arguments must be applied from left to right, this function
119  does some tricky function inspection to move variable positional arguments
120  into kwargs. As a result of this, we can't support passing Variables as *args,
121  nor as args to functions which combine both explicit positional arguments and
122  *args.
123
124  Args:
125    fn: The function to run, as passed to run().
126    args: Positional arguments to fn, as passed to run().
127    kwargs: Keyword arguments to fn, as passed to run().
128
129  Returns:
130    A tuple of the function (possibly wrapped), args, kwargs (both
131    possibly filtered, with members of args possibly moved to kwargs).
132    If no variables are found, this function is a noop.
133
134  Raises:
135    ValueError: If the function signature makes unsupported use of *args, or if
136      too many arguments are passed.
137  """
138
139  def is_distributed_var(x):
140    flat = nest.flatten(x)
141    return flat and isinstance(flat[0], values.DistributedVariable)
142
143  # We will split kwargs into two dicts, one of which will be applied now.
144  var_kwargs = {}
145  nonvar_kwargs = {}
146
147  if kwargs:
148    var_kwargs = {k: v for k, v in kwargs.items() if is_distributed_var(v)}
149  if var_kwargs:
150    nonvar_kwargs = {
151        k: v for k, v in kwargs.items() if not is_distributed_var(v)
152    }
153
154  # Dump the argument names of `fn` to a list. This will include both positional
155  # and keyword arguments, but since positional arguments come first we can
156  # look up names of positional arguments by index.
157  positional_args = []
158  index_of_star_args = None
159  for i, p in enumerate(tf_inspect.signature(fn).parameters.values()):
160    # Class methods define "self" as first argument, but we don't pass "self".
161    # Note that this is a heuristic, as a method can name its first argument
162    # something else, and a function can define a first argument "self" as well.
163    # In both of these cases, using a Variable will fail with an unfortunate
164    # error about the number of arguments.
165    # inspect.is_method() seems not to work here, possibly due to the use of
166    # tf.function().
167    if i == 0 and p.name == "self":
168      continue
169
170    if p.kind == tf_inspect.Parameter.POSITIONAL_OR_KEYWORD:
171      positional_args.append(p.name)
172
173    elif p.kind == tf_inspect.Parameter.VAR_POSITIONAL:
174      # We'll raise an error later if a variable is passed to *args, since we
175      # can neither pass it by name nor partially apply it. This case only
176      # happens once at most.
177      index_of_star_args = i
178
179    elif p.kind == tf_inspect.Parameter.POSITIONAL_ONLY:
180      # This is a rare Python feature, indicating a / in the arg list.
181      if var_kwargs or any(is_distributed_var(a) for a in args):
182        raise ValueError(
183            "Mixing Variables and positional-only parameters not supported by "
184            "TPUStrategy.")
185      return fn, args, kwargs
186
187  star_args = []
188  have_seen_var_arg = False
189
190  for i, a in enumerate(args):
191    if is_distributed_var(a):
192      if index_of_star_args is not None and i >= index_of_star_args:
193        raise ValueError(
194            "TPUStrategy.run() cannot handle Variables passed to *args. "
195            "Either name the function argument, or capture the Variable "
196            "implicitly.")
197      if len(positional_args) <= i:
198        raise ValueError(
199            "Too many positional arguments passed to call to TPUStrategy.run()."
200        )
201      var_kwargs[positional_args[i]] = a
202      have_seen_var_arg = True
203    else:
204      if index_of_star_args is not None and i >= index_of_star_args:
205        if have_seen_var_arg:
206          raise ValueError(
207              "TPUStrategy.run() cannot handle both Variables and a mix of "
208              "positional args and *args. Either remove the *args, or capture "
209              "the Variable implicitly.")
210        else:
211          star_args.append(a)
212          continue
213
214      if len(positional_args) <= i:
215        raise ValueError(
216            "Too many positional arguments passed to call to TPUStrategy.run()."
217        )
218      nonvar_kwargs[positional_args[i]] = a
219
220  if var_kwargs:
221    return functools.partial(fn, **var_kwargs), star_args, nonvar_kwargs
222  return fn, args, kwargs
223
224
225@tf_export("distribute.TPUStrategy", v1=[])
226class TPUStrategyV2(distribute_lib.Strategy):
227  """Synchronous training on TPUs and TPU Pods.
228
229  To construct a TPUStrategy object, you need to run the
230  initialization code as below:
231
232  >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
233  >>> tf.config.experimental_connect_to_cluster(resolver)
234  >>> tf.tpu.experimental.initialize_tpu_system(resolver)
235  >>> strategy = tf.distribute.TPUStrategy(resolver)
236
237  While using distribution strategies, the variables created within the
238  strategy's scope will be replicated across all the replicas and can be kept in
239  sync using all-reduce algorithms.
240
241  To run TF2 programs on TPUs, you can either use `.compile` and
242  `.fit` APIs in `tf.keras` with TPUStrategy, or write your own customized
243  training loop by calling `strategy.run` directly. Note that
244  TPUStrategy doesn't support pure eager execution, so please make sure the
245  function passed into `strategy.run` is a `tf.function` or
246  `strategy.run` is called inside a `tf.function` if eager
247  behavior is enabled. See more details in https://www.tensorflow.org/guide/tpu.
248
249  `distribute_datasets_from_function` and
250  `experimental_distribute_dataset` APIs can be used to distribute the dataset
251  across the TPU workers when writing your own training loop. If you are using
252  `fit` and `compile` methods available in `tf.keras.Model`, then Keras will
253  handle the distribution for you.
254
255  An example of writing customized training loop on TPUs:
256
257  >>> with strategy.scope():
258  ...   model = tf.keras.Sequential([
259  ...     tf.keras.layers.Dense(2, input_shape=(5,)),
260  ...   ])
261  ...   optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)
262
263  >>> def dataset_fn(ctx):
264  ...   x = np.random.random((2, 5)).astype(np.float32)
265  ...   y = np.random.randint(2, size=(2, 1))
266  ...   dataset = tf.data.Dataset.from_tensor_slices((x, y))
267  ...   return dataset.repeat().batch(1, drop_remainder=True)
268  >>> dist_dataset = strategy.distribute_datasets_from_function(
269  ...     dataset_fn)
270  >>> iterator = iter(dist_dataset)
271
272  >>> @tf.function()
273  ... def train_step(iterator):
274  ...
275  ...   def step_fn(inputs):
276  ...     features, labels = inputs
277  ...     with tf.GradientTape() as tape:
278  ...       logits = model(features, training=True)
279  ...       loss = tf.keras.losses.sparse_categorical_crossentropy(
280  ...           labels, logits)
281  ...
282  ...     grads = tape.gradient(loss, model.trainable_variables)
283  ...     optimizer.apply_gradients(zip(grads, model.trainable_variables))
284  ...
285  ...   strategy.run(step_fn, args=(next(iterator),))
286
287  >>> train_step(iterator)
288
289  For the advanced use cases like model parallelism, you can set
290  `experimental_device_assignment` argument when creating TPUStrategy to specify
291  number of replicas and number of logical devices. Below is an example to
292  initialize TPU system with 2 logical devices and 1 replica.
293
294  >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
295  >>> tf.config.experimental_connect_to_cluster(resolver)
296  >>> topology = tf.tpu.experimental.initialize_tpu_system(resolver)
297  >>> device_assignment = tf.tpu.experimental.DeviceAssignment.build(
298  ...     topology,
299  ...     computation_shape=[1, 1, 1, 2],
300  ...     num_replicas=1)
301  >>> strategy = tf.distribute.TPUStrategy(
302  ...     resolver, experimental_device_assignment=device_assignment)
303
304  Then you can run a `tf.add` operation only on logical device 0.
305
306  >>> @tf.function()
307  ... def step_fn(inputs):
308  ...   features, _ = inputs
309  ...   output = tf.add(features, features)
310  ...
311  ...   # Add operation will be executed on logical device 0.
312  ...   output = strategy.experimental_assign_to_logical_device(output, 0)
313  ...   return output
314  >>> dist_dataset = strategy.distribute_datasets_from_function(
315  ...     dataset_fn)
316  >>> iterator = iter(dist_dataset)
317  >>> strategy.run(step_fn, args=(next(iterator),))
318  """
319
320  def __init__(self,
321               tpu_cluster_resolver=None,
322               experimental_device_assignment=None):
323    """Synchronous training in TPU donuts or Pods.
324
325    Args:
326      tpu_cluster_resolver: A
327        `tf.distribute.cluster_resolver.TPUClusterResolver` instance, which
328        provides information about the TPU cluster. If None, it will assume
329        running on a local TPU worker.
330      experimental_device_assignment: Optional
331        `tf.tpu.experimental.DeviceAssignment` to specify the placement of
332        replicas on the TPU cluster.
333    """
334    super(TPUStrategyV2, self).__init__(TPUExtended(
335        self, tpu_cluster_resolver,
336        device_assignment=experimental_device_assignment))
337    distribute_lib.distribution_strategy_gauge.get_cell("V2").set("TPUStrategy")
338    distribute_lib.distribution_strategy_replica_gauge.get_cell(
339        "num_workers").set(self.extended.num_hosts)
340    distribute_lib.distribution_strategy_replica_gauge.get_cell(
341        "num_replicas_per_worker").set(self.extended.num_replicas_per_host)
342    # Packed variable is used to reduce the overhead of function execution.
343    # For a DistributedVariable, only one variable handle is captured into a
344    # function graph. It's only supported in eager mode.
345    self._enable_packed_variable_in_eager_mode = True
346
347  def run(self, fn, args=(), kwargs=None, options=None):
348    """Run the computation defined by `fn` on each TPU replica.
349
350    Executes ops specified by `fn` on each replica. If `args` or `kwargs` have
351    `tf.distribute.DistributedValues`, such as those produced by a
352    `tf.distribute.DistributedDataset` from
353    `tf.distribute.Strategy.experimental_distribute_dataset` or
354    `tf.distribute.Strategy.distribute_datasets_from_function`,
355    when `fn` is executed on a particular replica, it will be executed with the
356    component of `tf.distribute.DistributedValues` that correspond to that
357    replica.
358
359    `fn` may call `tf.distribute.get_replica_context()` to access members such
360    as `all_reduce`.
361
362    All arguments in `args` or `kwargs` should either be nest of tensors or
363    `tf.distribute.DistributedValues` containing tensors or composite tensors.
364
365    Example usage:
366
367    >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
368    >>> tf.config.experimental_connect_to_cluster(resolver)
369    >>> tf.tpu.experimental.initialize_tpu_system(resolver)
370    >>> strategy = tf.distribute.TPUStrategy(resolver)
371    >>> @tf.function
372    ... def run():
373    ...   def value_fn(value_context):
374    ...     return value_context.num_replicas_in_sync
375    ...   distributed_values = (
376    ...       strategy.experimental_distribute_values_from_function(value_fn))
377    ...   def replica_fn(input):
378    ...     return input * 2
379    ...   return strategy.run(replica_fn, args=(distributed_values,))
380    >>> result = run()
381
382    Args:
383      fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
384      args: (Optional) Positional arguments to `fn`.
385      kwargs: (Optional) Keyword arguments to `fn`.
386      options: (Optional) An instance of `tf.distribute.RunOptions` specifying
387        the options to run `fn`.
388
389    Returns:
390      Merged return value of `fn` across replicas. The structure of the return
391      value is the same as the return value from `fn`. Each element in the
392      structure can either be `tf.distribute.DistributedValues`, `Tensor`
393      objects, or `Tensor`s (for example, if running on a single replica).
394    """
395    validate_run_function(fn)
396
397    fn, args, kwargs = _maybe_partial_apply_variables(fn, args, kwargs)
398
399    # Note: the target function is converted to graph even when in Eager mode,
400    # so autograph is on by default here.
401    fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
402    options = options or distribute_lib.RunOptions()
403    return self.extended.tpu_run(fn, args, kwargs, options)
404
405  def experimental_assign_to_logical_device(self, tensor, logical_device_id):
406    """Adds annotation that `tensor` will be assigned to a logical device.
407
408    This adds an annotation to `tensor` specifying that operations on
409    `tensor` will be invoked on logical core device id `logical_device_id`.
410    When model parallelism is used, the default behavior is that all ops
411    are placed on zero-th logical device.
412
413    ```python
414
415    # Initializing TPU system with 2 logical devices and 4 replicas.
416    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
417    tf.config.experimental_connect_to_cluster(resolver)
418    topology = tf.tpu.experimental.initialize_tpu_system(resolver)
419    device_assignment = tf.tpu.experimental.DeviceAssignment.build(
420        topology,
421        computation_shape=[1, 1, 1, 2],
422        num_replicas=4)
423    strategy = tf.distribute.TPUStrategy(
424        resolver, experimental_device_assignment=device_assignment)
425    iterator = iter(inputs)
426
427    @tf.function()
428    def step_fn(inputs):
429      output = tf.add(inputs, inputs)
430
431      # Add operation will be executed on logical device 0.
432      output = strategy.experimental_assign_to_logical_device(output, 0)
433      return output
434
435    strategy.run(step_fn, args=(next(iterator),))
436    ```
437
438    Args:
439      tensor: Input tensor to annotate.
440      logical_device_id: Id of the logical core to which the tensor will be
441        assigned.
442
443    Raises:
444      ValueError: The logical device id presented is not consistent with total
445      number of partitions specified by the device assignment.
446
447    Returns:
448      Annotated tensor with identical value as `tensor`.
449    """
450    num_logical_devices_per_replica = self.extended._tpu_devices.shape[1]  # pylint: disable=protected-access
451    if (logical_device_id < 0 or
452        logical_device_id >= num_logical_devices_per_replica):
453      raise ValueError("`logical_core_id` to assign must be lower then total "
454                       "number of logical devices per replica. Received "
455                       "logical device id {} but there are only total of {} "
456                       "logical devices in replica.".format(
457                           logical_device_id, num_logical_devices_per_replica))
458    return xla_sharding.assign_device(
459        tensor, logical_device_id, use_sharding_op=True)
460
461  def experimental_split_to_logical_devices(self, tensor, partition_dimensions):
462    """Adds annotation that `tensor` will be split across logical devices.
463
464    This adds an annotation to tensor `tensor` specifying that operations on
465    `tensor` will be split among multiple logical devices. Tensor `tensor` will
466    be split across dimensions specified by `partition_dimensions`.
467    The dimensions of `tensor` must be divisible by corresponding value in
468    `partition_dimensions`.
469
470    For example, for system with 8 logical devices, if `tensor` is an image
471    tensor with shape (batch_size, width, height, channel) and
472    `partition_dimensions` is [1, 2, 4, 1], then `tensor` will be split
473    2 in width dimension and 4 way in height dimension and the split
474    tensor values will be fed into 8 logical devices.
475
476    ```python
477    # Initializing TPU system with 8 logical devices and 1 replica.
478    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
479    tf.config.experimental_connect_to_cluster(resolver)
480    topology = tf.tpu.experimental.initialize_tpu_system(resolver)
481    device_assignment = tf.tpu.experimental.DeviceAssignment.build(
482        topology,
483        computation_shape=[1, 2, 2, 2],
484        num_replicas=1)
485    strategy = tf.distribute.TPUStrategy(
486        resolver, experimental_device_assignment=device_assignment)
487
488    iterator = iter(inputs)
489
490    @tf.function()
491    def step_fn(inputs):
492      inputs = strategy.experimental_split_to_logical_devices(
493        inputs, [1, 2, 4, 1])
494
495      # model() function will be executed on 8 logical devices with `inputs`
496      # split 2 * 4  ways.
497      output = model(inputs)
498      return output
499
500    strategy.run(step_fn, args=(next(iterator),))
501    ```
502    Args:
503      tensor: Input tensor to annotate.
504      partition_dimensions: An unnested list of integers with the size equal to
505        rank of `tensor` specifying how `tensor` will be partitioned. The
506        product of all elements in `partition_dimensions` must be equal to the
507        total number of logical devices per replica.
508
509    Raises:
510      ValueError: 1) If the size of partition_dimensions does not equal to rank
511        of `tensor` or 2) if product of elements of `partition_dimensions` does
512        not match the number of logical devices per replica defined by the
513        implementing DistributionStrategy's device specification or
514        3) if a known size of `tensor` is not divisible by corresponding
515        value in `partition_dimensions`.
516
517    Returns:
518      Annotated tensor with identical value as `tensor`.
519    """
520    num_logical_devices_per_replica = self.extended._tpu_devices.shape[1]  # pylint: disable=protected-access
521    num_partition_splits = np.prod(partition_dimensions)
522    input_shape = tensor.shape
523    tensor_rank = len(input_shape)
524
525    if tensor_rank != len(partition_dimensions):
526      raise ValueError("Length of `partition_dimensions` ({}) must be  "
527                       "equal to the rank of `x` ({}).".format(
528                           len(partition_dimensions), tensor_rank))
529
530    for dim_index, dim_size in enumerate(input_shape):
531      if dim_size is None:
532        continue
533
534      split_size = partition_dimensions[dim_index]
535      if dim_size % split_size != 0:
536        raise ValueError("Tensor shape at dimension {} ({}) must be "
537                         "divisible by corresponding value specified "
538                         "by `partition_dimensions` ({}).".format(
539                             dim_index, dim_size, split_size))
540
541    if num_partition_splits != num_logical_devices_per_replica:
542      raise ValueError("Number of logical devices ({}) does not match the "
543                       "number of partition splits specified ({}).".format(
544                           num_logical_devices_per_replica,
545                           num_partition_splits))
546
547    tile_assignment = np.arange(num_partition_splits).reshape(
548        partition_dimensions)
549    return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True)
550
551  def experimental_replicate_to_logical_devices(self, tensor):
552    """Adds annotation that `tensor` will be replicated to all logical devices.
553
554    This adds an annotation to tensor `tensor` specifying that operations on
555    `tensor` will be invoked on all logical devices.
556
557    ```python
558    # Initializing TPU system with 2 logical devices and 4 replicas.
559    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
560    tf.config.experimental_connect_to_cluster(resolver)
561    topology = tf.tpu.experimental.initialize_tpu_system(resolver)
562    device_assignment = tf.tpu.experimental.DeviceAssignment.build(
563        topology,
564        computation_shape=[1, 1, 1, 2],
565        num_replicas=4)
566    strategy = tf.distribute.TPUStrategy(
567        resolver, experimental_device_assignment=device_assignment)
568
569    iterator = iter(inputs)
570
571    @tf.function()
572    def step_fn(inputs):
573      images, labels = inputs
574      images = strategy.experimental_split_to_logical_devices(
575        inputs, [1, 2, 4, 1])
576
577      # model() function will be executed on 8 logical devices with `inputs`
578      # split 2 * 4  ways.
579      output = model(inputs)
580
581      # For loss calculation, all logical devices share the same logits
582      # and labels.
583      labels = strategy.experimental_replicate_to_logical_devices(labels)
584      output = strategy.experimental_replicate_to_logical_devices(output)
585      loss = loss_fn(labels, output)
586
587      return loss
588
589    strategy.run(step_fn, args=(next(iterator),))
590    ```
591    Args:
592      tensor: Input tensor to annotate.
593
594    Returns:
595      Annotated tensor with identical value as `tensor`.
596    """
597    return xla_sharding.replicate(tensor, use_sharding_op=True)
598
599
600@tf_export("distribute.experimental.TPUStrategy", v1=[])
601@deprecation.deprecated_endpoints("distribute.experimental.TPUStrategy")
602class TPUStrategy(distribute_lib.Strategy):
603  """Synchronous training on TPUs and TPU Pods.
604
605  To construct a TPUStrategy object, you need to run the
606  initialization code as below:
607
608  >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
609  >>> tf.config.experimental_connect_to_cluster(resolver)
610  >>> tf.tpu.experimental.initialize_tpu_system(resolver)
611  >>> strategy = tf.distribute.experimental.TPUStrategy(resolver)
612
613  While using distribution strategies, the variables created within the
614  strategy's scope will be replicated across all the replicas and can be kept in
615  sync using all-reduce algorithms.
616
617  To run TF2 programs on TPUs, you can either use `.compile` and
618  `.fit` APIs in `tf.keras` with TPUStrategy, or write your own customized
619  training loop by calling `strategy.run` directly. Note that
620  TPUStrategy doesn't support pure eager execution, so please make sure the
621  function passed into `strategy.run` is a `tf.function` or
622  `strategy.run` is called inside a `tf.function` if eager
623  behavior is enabled.
624  """
625
626  def __init__(self,
627               tpu_cluster_resolver=None,
628               device_assignment=None):
629    """Synchronous training in TPU donuts or Pods.
630
631    Args:
632      tpu_cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
633        which provides information about the TPU cluster.
634      device_assignment: Optional `tf.tpu.experimental.DeviceAssignment` to
635        specify the placement of replicas on the TPU cluster.
636    """
637    logging.warning(
638        "`tf.distribute.experimental.TPUStrategy` is deprecated, please use "
639        " the non experimental symbol `tf.distribute.TPUStrategy` instead.")
640
641    super(TPUStrategy, self).__init__(TPUExtended(
642        self, tpu_cluster_resolver, device_assignment=device_assignment))
643    distribute_lib.distribution_strategy_gauge.get_cell("V2").set("TPUStrategy")
644    distribute_lib.distribution_strategy_replica_gauge.get_cell(
645        "num_workers").set(self.extended.num_hosts)
646    distribute_lib.distribution_strategy_replica_gauge.get_cell(
647        "num_replicas_per_worker").set(self.extended.num_replicas_per_host)
648    # Packed variable is used to reduce the overhead of function execution.
649    # For a DistributedVariable, only one variable handle is captured into a
650    # function graph. It's only supported in eager mode.
651    self._enable_packed_variable_in_eager_mode = True
652
653  # TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this
654  # can use the default implementation.
655  # This implementation runs a single step. It does not use infeed or outfeed.
656  def run(self, fn, args=(), kwargs=None, options=None):
657    """See base class."""
658    validate_run_function(fn)
659
660    fn, args, kwargs = _maybe_partial_apply_variables(fn, args, kwargs)
661
662    # Note: the target function is converted to graph even when in Eager mode,
663    # so autograph is on by default here.
664    fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
665    options = options or distribute_lib.RunOptions()
666    return self.extended.tpu_run(fn, args, kwargs, options)
667
668  @property
669  def cluster_resolver(self):
670    """Returns the cluster resolver associated with this strategy.
671
672    `tf.distribute.experimental.TPUStrategy` provides the
673    associated `tf.distribute.cluster_resolver.ClusterResolver`. If the user
674    provides one in `__init__`, that instance is returned; if the user does
675    not, a default
676    `tf.distribute.cluster_resolver.TPUClusterResolver` is provided.
677    """
678    return self.extended._tpu_cluster_resolver  # pylint: disable=protected-access
679
680
681@tf_export(v1=["distribute.experimental.TPUStrategy"])
682class TPUStrategyV1(distribute_lib.StrategyV1):
683  """TPU distribution strategy implementation."""
684
685  def __init__(self,
686               tpu_cluster_resolver=None,
687               steps_per_run=None,
688               device_assignment=None):
689    """Initializes the TPUStrategy object.
690
691    Args:
692      tpu_cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
693          which provides information about the TPU cluster.
694      steps_per_run: Number of steps to run on device before returning to the
695          host. Note that this can have side-effects on performance, hooks,
696          metrics, summaries etc.
697          This parameter is only used when Distribution Strategy is used with
698          estimator or keras.
699      device_assignment: Optional `tf.tpu.experimental.DeviceAssignment` to
700          specify the placement of replicas on the TPU cluster. Currently only
701          supports the usecase of using a single core within a TPU cluster.
702    """
703    super(TPUStrategyV1, self).__init__(TPUExtended(
704        self, tpu_cluster_resolver, steps_per_run, device_assignment))
705    distribute_lib.distribution_strategy_gauge.get_cell("V1").set("TPUStrategy")
706    distribute_lib.distribution_strategy_replica_gauge.get_cell(
707        "num_workers").set(self.extended.num_hosts)
708    distribute_lib.distribution_strategy_replica_gauge.get_cell(
709        "num_replicas_per_worker").set(self.extended.num_replicas_per_host)
710    # Packed variable is used to reduce the overhead of function execution.
711    # For a DistributedVariable, only one variable handle is captured into a
712    # function graph. It's only supported in eager mode.
713    self._enable_packed_variable_in_eager_mode = True
714
715  @property
716  def steps_per_run(self):
717    """DEPRECATED: use .extended.steps_per_run instead."""
718    return self._extended.steps_per_run
719
720  # TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this
721  # can use the default implementation.
722  # This implementation runs a single step. It does not use infeed or outfeed.
723  def run(self, fn, args=(), kwargs=None, options=None):
724    """Run `fn` on each replica, with the given arguments.
725
726    Executes ops specified by `fn` on each replica. If `args` or `kwargs` have
727    "per-replica" values, such as those produced by a "distributed `Dataset`",
728    when `fn` is executed on a particular replica, it will be executed with the
729    component of those "per-replica" values that correspond to that replica.
730
731    `fn` may call `tf.distribute.get_replica_context()` to access members such
732    as `all_reduce`.
733
734    All arguments in `args` or `kwargs` should either be nest of tensors or
735    per-replica objects containing tensors or composite tensors.
736
737    Users can pass strategy specific options to `options` argument. An example
738    to enable bucketizing dynamic shapes in `TPUStrategy.run`
739    is:
740
741    >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
742    >>> tf.config.experimental_connect_to_cluster(resolver)
743    >>> tf.tpu.experimental.initialize_tpu_system(resolver)
744    >>> strategy = tf.distribute.experimental.TPUStrategy(resolver)
745
746    >>> options = tf.distribute.RunOptions(
747    ...     experimental_bucketizing_dynamic_shape=True)
748
749    >>> dataset = tf.data.Dataset.range(
750    ...    strategy.num_replicas_in_sync, output_type=dtypes.float32).batch(
751    ...        strategy.num_replicas_in_sync, drop_remainder=True)
752    >>> input_iterator = iter(strategy.experimental_distribute_dataset(dataset))
753
754    >>> @tf.function()
755    ... def step_fn(inputs):
756    ...  output = tf.reduce_sum(inputs)
757    ...  return output
758
759    >>> strategy.run(step_fn, args=(next(input_iterator),), options=options)
760
761    Args:
762      fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
763      args: (Optional) Positional arguments to `fn`.
764      kwargs: (Optional) Keyword arguments to `fn`.
765      options: (Optional) An instance of `tf.distribute.RunOptions` specifying
766        the options to run `fn`.
767
768    Returns:
769      Merged return value of `fn` across replicas. The structure of the return
770      value is the same as the return value from `fn`. Each element in the
771      structure can either be "per-replica" `Tensor` objects or `Tensor`s
772      (for example, if running on a single replica).
773    """
774    validate_run_function(fn)
775
776    fn, args, kwargs = _maybe_partial_apply_variables(fn, args, kwargs)
777
778    fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
779    options = options or distribute_lib.RunOptions()
780    return self.extended.tpu_run(fn, args, kwargs, options)
781
782
783# TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1.
784class TPUExtended(distribute_lib.StrategyExtendedV1):
785  """Implementation of TPUStrategy."""
786
787  def __init__(self,
788               container_strategy,
789               tpu_cluster_resolver=None,
790               steps_per_run=None,
791               device_assignment=None):
792    super(TPUExtended, self).__init__(container_strategy)
793
794    if tpu_cluster_resolver is None:
795      tpu_cluster_resolver = TPUClusterResolver("")
796
797    if steps_per_run is None:
798      # TODO(frankchn): Warn when we are being used by DS/Keras and this is
799      # not specified.
800      steps_per_run = 1
801
802    # `self._tpu_function_cache` is a dict of `tf.function`s, thus if a
803    # `tf.function` is passed into `strategy.run` in eager mode, the
804    # `tf.function` won't get retraced.
805    self._tpu_function_cache = weakref.WeakKeyDictionary()
806
807    self._tpu_cluster_resolver = tpu_cluster_resolver
808    self._tpu_metadata = self._tpu_cluster_resolver.get_tpu_system_metadata()
809    self._device_assignment = device_assignment
810
811    tpu_devices_flat = [
812        d.name for d in self._tpu_metadata.devices if "device:TPU:" in d.name]
813
814    # `self._tpu_devices` is a two-dimensional NumPy array of strings. It is
815    # indexed using `[replica_id][logical_device_id]`.
816    if device_assignment is None:
817      self._tpu_devices = np.array(
818          [[d] for d in tpu_devices_flat], dtype=object)
819    else:
820      job_name = device_spec.DeviceSpecV2.from_string(tpu_devices_flat[0]).job
821
822      tpu_devices = []
823      for replica_id in range(device_assignment.num_replicas):
824        replica_devices = []
825
826        for logical_core in range(device_assignment.num_cores_per_replica):
827          replica_devices.append(
828              device_util.canonicalize(
829                  device_assignment.tpu_device(
830                      replica=replica_id,
831                      logical_core=logical_core,
832                      job=job_name)))
833
834        tpu_devices.append(replica_devices)
835      self._tpu_devices = np.array(tpu_devices, dtype=object)
836
837    self._host_device = device_util.get_host_for_device(self._tpu_devices[0][0])
838
839    # Preload the data onto the TPUs. Currently we always preload onto logical
840    # device 0 for each replica.
841    # TODO(cjfj): Create `InputWorkers` lazily, allowing users to place the
842    # input onto a different logical device?
843    self._device_input_worker_devices = collections.OrderedDict()
844    self._host_input_worker_devices = collections.OrderedDict()
845    for tpu_device in self._tpu_devices[:, 0]:
846      host_device = device_util.get_host_for_device(tpu_device)
847      self._device_input_worker_devices.setdefault(host_device, [])
848      self._device_input_worker_devices[host_device].append(tpu_device)
849      self._host_input_worker_devices.setdefault(host_device, [])
850      self._host_input_worker_devices[host_device].append(host_device)
851
852    # TODO(sourabhbajaj): Remove this once performance of running one step
853    # at a time is comparable to multiple steps.
854    self.steps_per_run = steps_per_run
855    self._require_static_shapes = True
856
857    self.experimental_enable_get_next_as_optional = True
858
859    self._logical_device_stack = [0]
860
861    if context.executing_eagerly():
862      # In async remote eager, we want to sync the executors before exiting the
863      # program.
864      atexit.register(context.async_wait)
865
866    # Flag to turn on VariablePolicy.
867    self._use_var_policy = True
868
869    # Flag to enable XLA SPMD partitioning.
870    # TODO(b/170873313): Enable XLA SPMD partitioning in TPUStrategy.
871    self._use_spmd_for_xla_partitioning = False
872
873  def _validate_colocate_with_variable(self, colocate_with_variable):
874    distribute_utils. validate_colocate(colocate_with_variable, self)
875
876  def _make_dataset_iterator(self, dataset):
877    """Make iterators for each of the TPU hosts."""
878    input_workers = input_lib.InputWorkers(
879        tuple(self._device_input_worker_devices.items()))
880    return input_lib.DatasetIterator(
881        dataset,
882        input_workers,
883        self._container_strategy(),
884        num_replicas_in_sync=self._num_replicas_in_sync)
885
886  def _make_input_fn_iterator(
887      self,
888      input_fn,
889      replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
890    input_contexts = []
891    input_workers = input_lib.InputWorkers(
892        tuple(self._device_input_worker_devices.items()))
893    num_workers = input_workers.num_workers
894    for i in range(num_workers):
895      input_contexts.append(distribute_lib.InputContext(
896          num_input_pipelines=num_workers,
897          input_pipeline_id=i,
898          num_replicas_in_sync=self._num_replicas_in_sync))
899    return input_lib.InputFunctionIterator(
900        input_fn,
901        input_workers,
902        input_contexts,
903        self._container_strategy())
904
905  def _experimental_make_numpy_dataset(self, numpy_input, session):
906    return numpy_dataset.one_host_numpy_dataset(
907        numpy_input, numpy_dataset.SingleDevice(self._host_device),
908        session)
909
910  def _get_input_workers(self, options):
911    if not options or options.experimental_fetch_to_device:
912      return input_lib.InputWorkers(
913          tuple(self._device_input_worker_devices.items()))
914    else:
915      return input_lib.InputWorkers(
916          tuple(self._host_input_worker_devices.items()))
917
918  def _check_spec(self, element_spec):
919    if isinstance(element_spec, values.PerReplicaSpec):
920      element_spec = element_spec._component_specs  # pylint: disable=protected-access
921    specs = nest.flatten_with_joined_string_paths(element_spec)
922    for path, spec in specs:
923      if isinstance(spec, (sparse_tensor.SparseTensorSpec,
924                           ragged_tensor.RaggedTensorSpec)):
925        raise ValueError(
926            "Found tensor {} with spec {}. TPUStrategy does not support "
927            "distributed datasets with device prefetch when using sparse or "
928            "ragged tensors. If you intend to use sparse or ragged tensors, "
929            "please pass a tf.distribute.InputOptions object with "
930            "experimental_fetch_to_device set to False to your dataset "
931            "distribution function.".format(path, type(spec)))
932
933  def _experimental_distribute_dataset(self, dataset, options):
934    if (options and options.experimental_replication_mode ==
935        distribute_lib.InputReplicationMode.PER_REPLICA):
936      raise NotImplementedError(
937          "InputReplicationMode.PER_REPLICA "
938          "is only supported in "
939          "`experimental_distribute_datasets_from_function`."
940      )
941    if options is None or options.experimental_fetch_to_device:
942      self._check_spec(dataset.element_spec)
943
944    return input_lib.get_distributed_dataset(
945        dataset,
946        self._get_input_workers(options),
947        self._container_strategy(),
948        num_replicas_in_sync=self._num_replicas_in_sync,
949        options=options)
950
951  def _distribute_datasets_from_function(self, dataset_fn, options):
952    if (options and options.experimental_replication_mode ==
953        distribute_lib.InputReplicationMode.PER_REPLICA):
954      raise NotImplementedError(
955          "InputReplicationMode.PER_REPLICA "
956          "is only supported in "
957          " `experimental_distribute_datasets_from_function` "
958          "of tf.distribute.MirroredStrategy")
959    input_workers = self._get_input_workers(options)
960    input_contexts = []
961    num_workers = input_workers.num_workers
962    for i in range(num_workers):
963      input_contexts.append(distribute_lib.InputContext(
964          num_input_pipelines=num_workers,
965          input_pipeline_id=i,
966          num_replicas_in_sync=self._num_replicas_in_sync))
967
968    distributed_dataset = input_lib.get_distributed_datasets_from_function(
969        dataset_fn,
970        input_workers,
971        input_contexts,
972        self._container_strategy(),
973        options=options)
974
975    # We can only check after the dataset_fn is called.
976    if options is None or options.experimental_fetch_to_device:
977      self._check_spec(distributed_dataset.element_spec)
978    return distributed_dataset
979
980  def _experimental_distribute_values_from_function(self, value_fn):
981    per_replica_values = []
982    for replica_id in range(self._num_replicas_in_sync):
983      per_replica_values.append(
984          value_fn(distribute_lib.ValueContext(replica_id,
985                                               self._num_replicas_in_sync)))
986    return distribute_utils.regroup(per_replica_values, always_wrap=True)
987
988  # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
989  # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have
990  # a mechanism to infer the outputs of `fn`. Pending b/110550782.
991  def _experimental_run_steps_on_iterator(
992      self, fn, multi_worker_iterator, iterations, initial_loop_values=None):
993    # Wrap `fn` for repeat.
994    if initial_loop_values is None:
995      initial_loop_values = {}
996    initial_loop_values = nest.flatten(initial_loop_values)
997    ctx = input_lib.MultiStepContext()
998
999    def run_fn(inputs):
1000      """Single step on the TPU device."""
1001      fn_result = fn(ctx, inputs)
1002      flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
1003      if flat_last_step_outputs:
1004        with ops.control_dependencies([fn_result]):
1005          return [array_ops.identity(f) for f in flat_last_step_outputs]
1006      else:
1007        return fn_result
1008
1009    # We capture the control_flow_context at this point, before we run `fn`
1010    # inside a while_loop and TPU replicate context. This is useful in cases
1011    # where we might need to exit these contexts and get back to the outer
1012    # context to do some things, for e.g. create an op which should be
1013    # evaluated only once at the end of the loop on the host. One such usage
1014    # is in creating metrics' value op.
1015    self._outer_control_flow_context = (
1016        ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access
1017
1018    def rewrite_fn(*args):
1019      """The rewritten step fn running on TPU."""
1020      del args
1021
1022      per_replica_inputs = multi_worker_iterator.get_next()
1023      replicate_inputs = []
1024      for replica_id in range(self._num_replicas_in_sync):
1025        select_replica = lambda x: distribute_utils.select_replica(  # pylint: disable=g-long-lambda
1026            replica_id, x)   # pylint: disable=cell-var-from-loop
1027        replicate_inputs.append((nest.map_structure(
1028            select_replica, per_replica_inputs),))
1029
1030      replicate_outputs = tpu.replicate(
1031          run_fn,
1032          replicate_inputs,
1033          device_assignment=self._device_assignment,
1034          xla_options=tpu.XLAOptions(use_spmd_for_xla_partitioning=self
1035                                     ._use_spmd_for_xla_partitioning))
1036      # If run_fn has tensor outputs, tpu.replicate returns a list of list. We
1037      # will flatten it in this case. If run_fn has no tensor outputs,
1038      # tpu.replicate returns a list of no_ops, we will keep the output as it
1039      # is.
1040      if isinstance(replicate_outputs[0], list):
1041        replicate_outputs = nest.flatten(replicate_outputs)
1042
1043      return replicate_outputs
1044
1045    # TODO(sourabhbajaj): The input to while loop should be based on the
1046    # output type of the step_fn
1047    assert isinstance(initial_loop_values, list)
1048    initial_loop_values = initial_loop_values * self._num_replicas_in_sync
1049
1050    # Put the while loop op on TPU host 0.
1051    with ops.device(self._host_device):
1052      if self.steps_per_run == 1:
1053        replicate_outputs = rewrite_fn()
1054      else:
1055        replicate_outputs = training_loop.repeat(iterations, rewrite_fn,
1056                                                 initial_loop_values)
1057
1058    del self._outer_control_flow_context
1059    ctx.run_op = control_flow_ops.group(replicate_outputs)
1060
1061    if isinstance(replicate_outputs, list):
1062      # Filter out any ops from the outputs, typically this would be the case
1063      # when there were no tensor outputs.
1064      last_step_tensor_outputs = [
1065          x for x in replicate_outputs if not isinstance(x, ops.Operation)
1066      ]
1067
1068      # Outputs are currently of the structure (flattened)
1069      # [output0_device0, output1_device0, output2_device0,
1070      #  output0_device1, output1_device1, output2_device1,
1071      #  ...]
1072      # Convert this to the following structure instead: (grouped by output)
1073      # [[output0_device0, output0_device1],
1074      #  [output1_device0, output1_device1],
1075      #  [output2_device0, output2_device1]]
1076      output_num = len(last_step_tensor_outputs) // self._num_replicas_in_sync
1077      last_step_tensor_outputs = [
1078          last_step_tensor_outputs[i::output_num] for i in range(output_num)
1079      ]
1080    else:
1081      # no tensors returned.
1082      last_step_tensor_outputs = []
1083
1084    _set_last_step_outputs(ctx, last_step_tensor_outputs)
1085    return ctx
1086
1087  def _call_for_each_replica(self, fn, args, kwargs):
1088    # TODO(jhseu): Consider making it so call_for_each_replica implies that
1089    # we're in a tpu.rewrite(), and update TPUMirroredVariable accordingly.
1090    with _TPUReplicaContext(self._container_strategy()):
1091      return fn(*args, **kwargs)
1092
1093  @contextlib.contextmanager
1094  def experimental_logical_device(self, logical_device_id):
1095    """Places variables and ops on the specified logical device."""
1096    num_logical_devices_per_replica = self._tpu_devices.shape[1]
1097    if logical_device_id >= num_logical_devices_per_replica:
1098      raise ValueError(
1099          "`logical_device_id` not in range (was {}, but there are only {} "
1100          "logical devices per replica).".format(
1101              logical_device_id, num_logical_devices_per_replica))
1102
1103    self._logical_device_stack.append(logical_device_id)
1104    try:
1105      if tpu_util.enclosing_tpu_context() is None:
1106        yield
1107      else:
1108        with ops.device(tpu.core(logical_device_id)):
1109          yield
1110    finally:
1111      self._logical_device_stack.pop()
1112
1113  def _experimental_initialize_system(self):
1114    """Experimental method added to be used by Estimator.
1115
1116    This is a private method only to be used by Estimator. Other frameworks
1117    should directly be calling `tf.tpu.experimental.initialize_tpu_system`
1118    """
1119    tpu_strategy_util.initialize_tpu_system(self._tpu_cluster_resolver)
1120
1121  def _create_variable(self, next_creator, **kwargs):
1122    """Create a TPUMirroredVariable. See `DistributionStrategy.scope`."""
1123    if kwargs.pop("skip_mirrored_creator", False):
1124      return next_creator(**kwargs)
1125
1126    colocate_with = kwargs.pop("colocate_with", None)
1127    if colocate_with is None:
1128      devices = self._tpu_devices[:, self._logical_device_stack[-1]]
1129    elif isinstance(colocate_with, numpy_dataset.SingleDevice):
1130      with ops.device(colocate_with.device):
1131        return next_creator(**kwargs)
1132    else:
1133      devices = colocate_with._devices  # pylint: disable=protected-access
1134
1135    def _real_mirrored_creator(**kwargs):  # pylint: disable=g-missing-docstring
1136      initial_value = None
1137      value_list = []
1138      for i, d in enumerate(devices):
1139        with ops.device(d):
1140          if i == 0:
1141            initial_value = kwargs["initial_value"]
1142            # Note: some v1 code expects variable initializer creation to happen
1143            # inside a init_scope.
1144            with maybe_init_scope():
1145              initial_value = initial_value() if callable(
1146                  initial_value) else initial_value
1147
1148          if i > 0:
1149            # Give replicas meaningful distinct names:
1150            var0name = value_list[0].name.split(":")[0]
1151            # We append a / to variable names created on replicas with id > 0 to
1152            # ensure that we ignore the name scope and instead use the given
1153            # name as the absolute name of the variable.
1154            kwargs["name"] = "%s/replica_%d/" % (var0name, i)
1155          kwargs["initial_value"] = initial_value
1156
1157          with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
1158            v = next_creator(**kwargs)
1159
1160          assert not isinstance(v, tpu_values.TPUMirroredVariable)
1161          value_list.append(v)
1162      return value_list
1163
1164    return distribute_utils.create_mirrored_variable(
1165        self._container_strategy(), _real_mirrored_creator,
1166        distribute_utils.TPU_VARIABLE_CLASS_MAPPING,
1167        distribute_utils.TPU_VARIABLE_POLICY_MAPPING, **kwargs)
1168
1169  def _resource_creator_scope(self):
1170
1171    def lookup_creator(next_creator, *args, **kwargs):
1172      host_to_table = collections.OrderedDict()
1173      for host_device in self._device_input_worker_devices.keys():
1174        with ops.device(host_device):
1175          host_to_table[host_device] = next_creator(*args, **kwargs)
1176
1177      return values.PerWorkerResource(self._container_strategy(), host_to_table)
1178
1179    # TODO(b/194362531): Define creator(s) for other resources.
1180    return ops.resource_creator_scope("StaticHashTable", lookup_creator)
1181
1182  def _gather_to_implementation(self, value, destinations, axis, options):
1183    if not isinstance(value, values.DistributedValues):
1184      return value
1185
1186    value_list = list(value.values)
1187    # pylint: disable=protected-access
1188    if isinstance(
1189        value,
1190        values.DistributedVariable) and value._packed_variable is not None:
1191      value_list = list(
1192          value._packed_variable.on_device(d)
1193          for d in value._packed_variable.devices)
1194    # pylint: enable=protected-access
1195
1196    # Currently XLA op by op mode has a limit for the number of inputs for a
1197    # single op, thus we break one `add_n` op into a group of `add_n` ops to
1198    # work around the constraint.
1199    if len(value.values) <= _XLA_OP_BY_OP_INPUTS_LIMIT:
1200      output = array_ops.concat(value_list, axis=axis)
1201    else:
1202      output = array_ops.concat(
1203          value_list[:_XLA_OP_BY_OP_INPUTS_LIMIT], axis=axis)
1204      for i in range(_XLA_OP_BY_OP_INPUTS_LIMIT, len(value_list),
1205                     _XLA_OP_BY_OP_INPUTS_LIMIT - 1):
1206        output = array_ops.concat(
1207            [output] + value_list[i:i + _XLA_OP_BY_OP_INPUTS_LIMIT - 1],
1208            axis=axis)
1209
1210    output = self._broadcast_output(destinations, output)
1211    return output
1212
1213  def _broadcast_output(self, destinations, output):
1214    devices = cross_device_ops_lib.get_devices_from(destinations)
1215
1216    if len(devices) == 1:
1217      # If necessary, copy to requested destination.
1218      dest_canonical = device_util.canonicalize(devices[0])
1219      host_canonical = device_util.canonicalize(self._host_device)
1220
1221      if dest_canonical != host_canonical:
1222        with ops.device(dest_canonical):
1223          output = array_ops.identity(output)
1224    else:
1225      output = cross_device_ops_lib.simple_broadcast(output, destinations)
1226
1227    return output
1228
1229  def _reduce_to(self, reduce_op, value, destinations, options):
1230    if (isinstance(value, values.DistributedValues) or
1231        tensor_util.is_tf_type(value)
1232       ) and tpu_util.enclosing_tpu_context() is not None:
1233      if reduce_op == reduce_util.ReduceOp.MEAN:
1234        # TODO(jhseu):  Revisit once we support model-parallelism.
1235        # scalar_mul maintains the type of value: tensor or IndexedSlices.
1236        value = math_ops.scalar_mul((1./self._num_replicas_in_sync), value)
1237      elif reduce_op != reduce_util.ReduceOp.SUM:
1238        raise NotImplementedError(
1239            "Currently only support sum & mean in TPUStrategy.")
1240      return tpu_ops.cross_replica_sum(value)
1241
1242    if not isinstance(value, values.DistributedValues):
1243      # This function handles reducing values that are not PerReplica or
1244      # Mirrored values. For example, the same value could be present on all
1245      # replicas in which case `value` would be a single value or value could
1246      # be 0.
1247      return cross_device_ops_lib.reduce_non_distributed_value(
1248          reduce_op, value, destinations, self._num_replicas_in_sync)
1249
1250    value_list = value.values
1251    # pylint: disable=protected-access
1252    if isinstance(
1253        value,
1254        values.DistributedVariable) and value._packed_variable is not None:
1255      value_list = tuple(
1256          value._packed_variable.on_device(d)
1257          for d in value._packed_variable.devices)
1258    # pylint: enable=protected-access
1259
1260    # Currently XLA op by op mode has a limit for the number of inputs for a
1261    # single op, thus we break one `add_n` op into a group of `add_n` ops to
1262    # work around the constraint.
1263    # TODO(cjfj): Detect when it is possible to use `cross_replica_sum`.
1264    if len(value.values) <= _XLA_OP_BY_OP_INPUTS_LIMIT:
1265      output = math_ops.add_n(value_list)
1266    else:
1267      output = array_ops.zeros_like(value_list[0], dtype=value_list[0].dtype)
1268      for i in range(0, len(value_list), _XLA_OP_BY_OP_INPUTS_LIMIT):
1269        output += math_ops.add_n(value_list[i:i + _XLA_OP_BY_OP_INPUTS_LIMIT])
1270
1271    if reduce_op == reduce_util.ReduceOp.MEAN:
1272      output *= (1. / len(value_list))
1273
1274    output = self._broadcast_output(destinations, output)
1275    return output
1276
1277  def _update(self, var, fn, args, kwargs, group):
1278    assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance(
1279        var, resource_variable_ops.BaseResourceVariable)
1280    if tpu_util.enclosing_tpu_context() is not None:
1281      if group:
1282        return fn(var, *args, **kwargs)
1283      else:
1284        return (fn(var, *args, **kwargs),)
1285
1286    # Otherwise, we revert to MirroredStrategy behavior and update the variable
1287    # on each replica directly.
1288    updates = []
1289    values_and_devices = []
1290    packed_var = var._packed_variable  # pylint: disable=protected-access
1291    if packed_var is not None:
1292      for device in packed_var.devices:
1293        values_and_devices.append((packed_var, device))
1294    else:
1295      for value in var.values:
1296        values_and_devices.append((value, value.device))
1297
1298    if (var.synchronization != variables_lib.VariableSynchronization.ON_READ and
1299        var.aggregation != variables_lib.VariableAggregation.NONE):
1300      distribute_utils.assert_mirrored(args)
1301      distribute_utils.assert_mirrored(kwargs)
1302    for i, value_and_device in enumerate(values_and_devices):
1303      value = value_and_device[0]
1304      device = value_and_device[1]
1305      name = "update_%d" % i
1306      with ops.device(device), \
1307           distribute_lib.UpdateContext(i), \
1308           ops.name_scope(name):
1309        # If args and kwargs are not mirrored, the value is returned as is.
1310        updates.append(
1311            fn(value, *distribute_utils.select_replica(i, args),
1312               **distribute_utils.select_replica(i, kwargs)))
1313    return distribute_utils.update_regroup(self, updates, group)
1314
1315  def read_var(self, var):
1316    assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance(
1317        var, resource_variable_ops.BaseResourceVariable)
1318    return var.read_value()
1319
1320  def value_container(self, value):
1321    return value
1322
1323  def _broadcast_to(self, tensor, destinations):
1324    del destinations
1325    # This is both a fast path for Python constants, and a way to delay
1326    # converting Python values to a tensor until we know what type it
1327    # should be converted to. Otherwise we have trouble with:
1328    #   global_step.assign_add(1)
1329    # since the `1` gets broadcast as an int32 but global_step is int64.
1330    if isinstance(tensor, (float, int)):
1331      return tensor
1332    if tpu_util.enclosing_tpu_context() is not None:
1333      broadcast_tensor = [tensor for _ in range(self._num_replicas_in_sync)]
1334      result = tpu_ops.all_to_all(
1335          broadcast_tensor,
1336          concat_dimension=0,
1337          split_dimension=0,
1338          split_count=self._num_replicas_in_sync)
1339
1340      # This uses the broadcasted value from the first replica because the only
1341      # caller of this is for ONLY_FIRST_REPLICA variables aggregation.
1342      return result[0]
1343    return tensor
1344
1345  @property
1346  def num_hosts(self):
1347    if self._device_assignment is None:
1348      return self._tpu_metadata.num_hosts
1349
1350    return len(set([self._device_assignment.host_device(r)
1351                    for r in range(self._device_assignment.num_replicas)]))
1352
1353  @property
1354  def num_replicas_per_host(self):
1355    if self._device_assignment is None:
1356      return self._tpu_metadata.num_of_cores_per_host
1357
1358    # TODO(sourabhbajaj): Remove this method we use inputs and remove infeed
1359    # as the computation of num_replicas_per_host is not a constant
1360    # when using device_assignment. This is a temporary workaround to support
1361    # StatefulRNN as everything is 1 in that case.
1362    # This method needs to take host_id as input for correct computation.
1363    max_models_per_host = (self._tpu_metadata.num_of_cores_per_host //
1364                           self._device_assignment.num_cores_per_replica)
1365    return min(self._device_assignment.num_replicas, max_models_per_host)
1366
1367  @property
1368  def _num_replicas_in_sync(self):
1369    if self._device_assignment is None:
1370      return self._tpu_metadata.num_cores
1371    return self._device_assignment.num_replicas
1372
1373  @property
1374  def experimental_between_graph(self):
1375    return False
1376
1377  @property
1378  def experimental_should_init(self):
1379    return True
1380
1381  @property
1382  def should_checkpoint(self):
1383    return True
1384
1385  @property
1386  def should_save_summary(self):
1387    return True
1388
1389  @property
1390  def worker_devices(self):
1391    return tuple(self._tpu_devices[:, self._logical_device_stack[-1]])
1392
1393  @property
1394  def parameter_devices(self):
1395    return self.worker_devices
1396
1397  def non_slot_devices(self, var_list):
1398    return self._host_device
1399
1400  def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
1401    del colocate_with
1402    with ops.device(self._host_device), distribute_lib.UpdateContext(None):
1403      result = fn(*args, **kwargs)
1404      if group:
1405        return result
1406      else:
1407        return nest.map_structure(self._local_results, result)
1408
1409  def _configure(self,
1410                 session_config=None,
1411                 cluster_spec=None,
1412                 task_type=None,
1413                 task_id=None):
1414    del cluster_spec, task_type, task_id
1415    if session_config:
1416      session_config.CopyFrom(self._update_config_proto(session_config))
1417
1418  def _update_config_proto(self, config_proto):
1419    updated_config = copy.deepcopy(config_proto)
1420    updated_config.isolate_session_state = True
1421    cluster_spec = self._tpu_cluster_resolver.cluster_spec()
1422    if cluster_spec:
1423      updated_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
1424    return updated_config
1425
1426  # TODO(priyag): Delete this once all strategies use global batch size.
1427  @property
1428  def _global_batch_size(self):
1429    """`make_dataset_iterator` and `make_numpy_iterator` use global batch size.
1430
1431    `make_input_fn_iterator` assumes per-replica batching.
1432
1433    Returns:
1434      Boolean.
1435    """
1436    return True
1437
1438  def tpu_run(self, fn, args, kwargs, options=None):
1439    func = self._tpu_function_creator(fn, options)
1440    return func(args, kwargs)
1441
1442  def _tpu_function_creator(self, fn, options):
1443    if context.executing_eagerly() and fn in self._tpu_function_cache:
1444      return self._tpu_function_cache[fn]
1445
1446    strategy = self._container_strategy()
1447
1448    def tpu_function(args, kwargs):
1449      """TF Function used to replicate the user computation."""
1450      if kwargs is None:
1451        kwargs = {}
1452
1453      # Used to re-structure flattened output tensors from `tpu.replicate()`
1454      # into a structured format.
1455      result = [[]]
1456
1457      def replicated_fn(replica_id, replica_args, replica_kwargs):
1458        """Wraps user function to provide replica ID and `Tensor` inputs."""
1459        with _TPUReplicaContext(strategy, replica_id_in_sync_group=replica_id):
1460          result[0] = fn(*replica_args, **replica_kwargs)
1461        return result[0]
1462
1463      replicate_inputs = []  # By replica.
1464      for i in range(strategy.num_replicas_in_sync):
1465        replicate_inputs.append(
1466            [constant_op.constant(i, dtype=dtypes.int32),
1467             distribute_utils.select_replica(i, args),
1468             distribute_utils.select_replica(i, kwargs)])
1469
1470      # Construct and pass `maximum_shapes` so that we could support dynamic
1471      # shapes using dynamic padder.
1472      if options.experimental_enable_dynamic_batch_size and replicate_inputs:
1473        maximum_shapes = []
1474        flattened_list = nest.flatten(replicate_inputs[0])
1475        for input_tensor in flattened_list:
1476          if tensor_util.is_tf_type(input_tensor):
1477            rank = input_tensor.shape.rank
1478          else:
1479            rank = np.ndim(input_tensor)
1480          if rank is None:
1481            raise ValueError(
1482                "input tensor {} to TPUStrategy.run() has unknown rank, "
1483                "which is not allowed".format(input_tensor))
1484          maximum_shape = tensor_shape.TensorShape([None] * rank)
1485          maximum_shapes.append(maximum_shape)
1486        maximum_shapes = nest.pack_sequence_as(replicate_inputs[0],
1487                                               maximum_shapes)
1488      else:
1489        maximum_shapes = None
1490
1491      if options.experimental_bucketizing_dynamic_shape:
1492        padding_spec = tpu.PaddingSpec.POWER_OF_TWO
1493      else:
1494        padding_spec = None
1495
1496      with strategy.scope():
1497        xla_options = options.experimental_xla_options or tpu.XLAOptions(
1498            use_spmd_for_xla_partitioning=self._use_spmd_for_xla_partitioning)
1499        replicate_outputs = tpu.replicate(
1500            replicated_fn,
1501            replicate_inputs,
1502            device_assignment=self._device_assignment,
1503            maximum_shapes=maximum_shapes,
1504            padding_spec=padding_spec,
1505            xla_options=xla_options)
1506
1507      # Remove all no ops that may have been added during 'tpu.replicate()'
1508      filter_ops = lambda x: [o for o in x if not isinstance(o, ops.Operation)]
1509      if isinstance(result[0], list):
1510        result[0] = filter_ops(result[0])
1511
1512      # Workaround for `tpu.replicate` behaviour when single `Tensor` returned.
1513      if result[0] is None or isinstance(result[0], ops.Operation):
1514        replicate_outputs = [None] * len(replicate_outputs)
1515      else:
1516        replicate_outputs = [
1517            nest.pack_sequence_as(result[0], filter_ops(nest.flatten(output)))
1518            for output in replicate_outputs
1519        ]
1520      return distribute_utils.regroup(replicate_outputs)
1521
1522    if context.executing_eagerly():
1523      tpu_function = def_function.function(tpu_function)
1524      self._tpu_function_cache[fn] = tpu_function
1525    return tpu_function
1526
1527  def _in_multi_worker_mode(self):
1528    """Whether this strategy indicates working in multi-worker settings."""
1529    # TPUStrategy has different distributed training structure that the whole
1530    # cluster should be treated as single worker from higher-level (e.g. Keras)
1531    # library's point of view.
1532    # TODO(rchao): Revisit this as we design a fault-tolerance solution for
1533    # TPUStrategy.
1534    return False
1535
1536  def _get_local_replica_id(self, replica_id_in_sync_group):
1537    return replica_id_in_sync_group
1538
1539
1540class _TPUReplicaContext(distribute_lib.ReplicaContext):
1541  """Replication Context class for TPU Strategy."""
1542
1543  # TODO(sourabhbajaj): Call for each replica should be updating this.
1544  # TODO(b/118385803): Always properly initialize replica_id.
1545  def __init__(self, strategy, replica_id_in_sync_group=0):
1546    distribute_lib.ReplicaContext.__init__(
1547        self, strategy, replica_id_in_sync_group=replica_id_in_sync_group)
1548
1549  @property
1550  def devices(self):
1551    distribute_lib.require_replica_context(self)
1552    ds = self._strategy
1553    replica_id = tensor_util.constant_value(self.replica_id_in_sync_group)
1554
1555    if replica_id is None:  # Non-constant `Tensor` inside `tpu.replicate`.
1556      # TODO(cjfj): Return other devices when model parallelism is supported.
1557      return (tpu.core(0),)
1558    else:
1559      return (ds.extended.worker_devices[replica_id],)
1560
1561  def experimental_logical_device(self, logical_device_id):
1562    """Places variables and ops on the specified logical device."""
1563    return self.strategy.extended.experimental_logical_device(logical_device_id)
1564
1565  # TODO(wxinyi): Investigate whether to use cross_replica_sum to optimize it.
1566  def all_gather(self, value, axis, experimental_hints=None):
1567    del experimental_hints
1568    for v in nest.flatten(value):
1569      if isinstance(v, ops.IndexedSlices):
1570        raise NotImplementedError("all_gather does not support IndexedSlices")
1571
1572    def _all_to_all(value, axis):
1573      # The underlying AllToAllOp first do a split of the input value and then
1574      # cross-replica communication and concatenation of the result. So we
1575      # concatenate the local tensor here first.
1576      inputs = array_ops.concat(
1577          [value for _ in range(self.num_replicas_in_sync)], axis=0)
1578      unordered_output = tpu_ops.all_to_all(
1579          inputs,
1580          concat_dimension=axis,
1581          split_dimension=0,
1582          split_count=self.num_replicas_in_sync)
1583
1584      # Re-order since xla.replica_id and ReplicaContext.replica_id mismatch.
1585      # xla_id = xla.replica_id()
1586      concat_replica_id = array_ops.concat([
1587          array_ops.expand_dims_v2(self.replica_id_in_sync_group, 0)
1588          for _ in range(self.num_replicas_in_sync)
1589      ],
1590                                           axis=0)
1591      replica_ids = tpu_ops.all_to_all(
1592          concat_replica_id,
1593          concat_dimension=0,
1594          split_dimension=0,
1595          split_count=self.num_replicas_in_sync)
1596
1597      splited_unordered = array_ops.split(
1598          unordered_output,
1599          num_or_size_splits=self.num_replicas_in_sync,
1600          axis=axis)
1601      sorted_with_extra_dim = math_ops.unsorted_segment_sum(
1602          array_ops.concat([
1603              array_ops.expand_dims(replica, axis=0)
1604              for replica in splited_unordered
1605          ],
1606                           axis=0),
1607          replica_ids,
1608          num_segments=self.num_replicas_in_sync)
1609
1610      splited_with_extra_dim = array_ops.split(
1611          sorted_with_extra_dim,
1612          num_or_size_splits=self.num_replicas_in_sync,
1613          axis=0)
1614      squeezed = [
1615          array_ops.squeeze(replica, axis=0)
1616          for replica in splited_with_extra_dim
1617      ]
1618      result = array_ops.concat(squeezed, axis=axis)
1619      return result
1620
1621    ys = [_all_to_all(t, axis=axis) for t in nest.flatten(value)]
1622    return nest.pack_sequence_as(value, ys)
1623
1624
1625def _set_last_step_outputs(ctx, last_step_tensor_outputs):
1626  """Sets the last step outputs on the given context."""
1627  # Convert replicate_outputs to the original dict structure of
1628  # last_step_outputs.
1629  last_step_tensor_outputs_dict = nest.pack_sequence_as(
1630      ctx.last_step_outputs, last_step_tensor_outputs)
1631
1632  for name, reduce_op in ctx._last_step_outputs_reduce_ops.items():  # pylint: disable=protected-access
1633    output = last_step_tensor_outputs_dict[name]
1634    # For outputs that aren't reduced, return a PerReplica of all values. Else
1635    # take the first value from the list as each value should be the same.
1636    if reduce_op is None:
1637      last_step_tensor_outputs_dict[name] = values.PerReplica(output)
1638    else:
1639      # TODO(priyag): Should this return the element or a list with 1 element
1640      last_step_tensor_outputs_dict[name] = output[0]
1641  ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access
1642