• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TPU Strategy."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import contextlib
23import copy
24import weakref
25
26import numpy as np
27
28from tensorflow.python.autograph.core import ag_ctx as autograph_ctx
29from tensorflow.python.autograph.impl import api as autograph
30from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
31from tensorflow.python.distribute import device_util
32from tensorflow.python.distribute import distribute_lib
33from tensorflow.python.distribute import input_lib
34from tensorflow.python.distribute import numpy_dataset
35from tensorflow.python.distribute import reduce_util
36from tensorflow.python.distribute import values
37from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver
38from tensorflow.python.eager import context
39from tensorflow.python.eager import def_function
40from tensorflow.python.eager import function
41from tensorflow.python.framework import constant_op
42from tensorflow.python.framework import device_spec
43from tensorflow.python.framework import dtypes
44from tensorflow.python.framework import ops
45from tensorflow.python.framework import tensor_shape
46from tensorflow.python.framework import tensor_util
47from tensorflow.python.ops import array_ops
48from tensorflow.python.ops import control_flow_ops
49from tensorflow.python.ops import math_ops
50from tensorflow.python.ops import resource_variable_ops
51from tensorflow.python.tpu import device_assignment as device_assignment_lib  # pylint: disable=unused-import
52from tensorflow.python.tpu import tpu
53from tensorflow.python.tpu import tpu_strategy_util
54from tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
55from tensorflow.python.tpu import training_loop
56from tensorflow.python.tpu.ops import tpu_ops
57from tensorflow.python.util import nest
58from tensorflow.python.util.tf_export import tf_export
59
60
61def get_tpu_system_metadata(tpu_cluster_resolver):
62  """Retrieves TPU system metadata given a TPUClusterResolver."""
63  master = tpu_cluster_resolver.master()
64
65  # pylint: disable=protected-access
66  cluster_spec = tpu_cluster_resolver.cluster_spec()
67  cluster_def = cluster_spec.as_cluster_def() if cluster_spec else None
68  tpu_system_metadata = (
69      tpu_system_metadata_lib._query_tpu_system_metadata(
70          master,
71          cluster_def=cluster_def,
72          query_topology=False))
73
74  return tpu_system_metadata
75
76
77@contextlib.contextmanager
78def maybe_init_scope():
79  if ops.executing_eagerly_outside_functions():
80    yield
81  else:
82    with ops.init_scope():
83      yield
84
85
86def validate_experimental_run_function(fn):
87  """Validate the function passed into strategy.experimental_run_v2."""
88
89  # We allow three types of functions/objects passed into TPUStrategy
90  # experimental_run_v2 in eager mode:
91  #   1. a user annotated tf.function
92  #   2. a ConcreteFunction, this is mostly what you get from loading a saved
93  #      model.
94  #   3. a callable object and the `__call__` method itself is a tf.function.
95  #
96  # Otherwise we return an error, because we don't support eagerly running
97  # experimental_run_v2 in TPUStrategy.
98
99  if context.executing_eagerly() and not isinstance(
100      fn, def_function.Function) and not isinstance(
101          fn, function.ConcreteFunction) and not (callable(fn) and isinstance(
102              fn.__call__, def_function.Function)):
103    raise NotImplementedError(
104        "TPUStrategy.experimental_run_v2(fn, ...) does not support pure eager "
105        "execution. please make sure the function passed into "
106        "`strategy.experimental_run_v2` is a `tf.function` or "
107        "`strategy.experimental_run_v2` is called inside a `tf.function` if "
108        "eager behavior is enabled.")
109
110
111@tf_export("distribute.experimental.TPUStrategy", v1=[])
112class TPUStrategy(distribute_lib.Strategy):
113  """TPU distribution strategy implementation."""
114
115  def __init__(self,
116               tpu_cluster_resolver=None,
117               device_assignment=None):
118    """Synchronous training in TPU donuts or Pods.
119
120    To construct a TPUStrategy object, you need to run the
121    initialization code as below:
122
123    ```python
124    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu)
125    tf.config.experimental_connect_to_cluster(resolver)
126    tf.tpu.experimental.initialize_tpu_system(resolver)
127    strategy = tf.distribute.experimental.TPUStrategy(resolver)
128    ```
129
130    While using distribution strategies, the variables created within strategy's
131    scope will be replicated across all the replicas and can be kept in sync
132    using all-reduce algorithms.
133
134    To run TF2 programs on TPUs, you can either use `.compile` and
135    `.fit` APIs in `tf.keras` with TPUStrategy, or write your own customized
136    training loop by calling `strategy.experimental_run_v2` directly. Note that
137    TPUStrategy doesn't support pure eager execution, so please make sure the
138    function passed into `strategy.experimental_run_v2` is a `tf.function` or
139    `strategy.experimental_run_v2` is called inside a `tf.function` if eager
140    behavior is enabled.
141
142    Args:
143      tpu_cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
144        which provides information about the TPU cluster.
145      device_assignment: Optional `tf.tpu.experimental.DeviceAssignment` to
146        specify the placement of replicas on the TPU cluster. Currently only
147        supports the usecase of using a single core within a TPU cluster.
148    """
149    super(TPUStrategy, self).__init__(TPUExtended(
150        self, tpu_cluster_resolver, device_assignment=device_assignment))
151    distribute_lib.distribution_strategy_gauge.get_cell("V2").set("TPUStrategy")
152    distribute_lib.distribution_strategy_replica_gauge.get_cell(
153        "num_workers").set(self.extended.num_hosts)
154    distribute_lib.distribution_strategy_replica_gauge.get_cell(
155        "num_replicas_per_worker").set(self.extended.num_replicas_per_host)
156
157  # TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this
158  # can use the default implementation.
159  # This implementation runs a single step. It does not use infeed or outfeed.
160  def experimental_run_v2(self, fn, args=(), kwargs=None):
161    """See base class."""
162    validate_experimental_run_function(fn)
163
164    # Note: the target function is converted to graph even when in Eager mode,
165    # so autograph is on by default here.
166    fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
167    return self.extended.tpu_run(fn, args, kwargs)
168
169
170@tf_export(v1=["distribute.experimental.TPUStrategy"])
171class TPUStrategyV1(distribute_lib.StrategyV1):
172  """TPU distribution strategy implementation."""
173
174  def __init__(self,
175               tpu_cluster_resolver=None,
176               steps_per_run=None,
177               device_assignment=None):
178    """Initializes the TPUStrategy object.
179
180    Args:
181      tpu_cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
182          which provides information about the TPU cluster.
183      steps_per_run: Number of steps to run on device before returning to the
184          host. Note that this can have side-effects on performance, hooks,
185          metrics, summaries etc.
186          This parameter is only used when Distribution Strategy is used with
187          estimator or keras.
188      device_assignment: Optional `tf.tpu.experimental.DeviceAssignment` to
189          specify the placement of replicas on the TPU cluster. Currently only
190          supports the usecase of using a single core within a TPU cluster.
191    """
192    super(TPUStrategyV1, self).__init__(TPUExtended(
193        self, tpu_cluster_resolver, steps_per_run, device_assignment))
194    distribute_lib.distribution_strategy_gauge.get_cell("V1").set("TPUStrategy")
195    distribute_lib.distribution_strategy_replica_gauge.get_cell(
196        "num_workers").set(self.extended.num_hosts)
197    distribute_lib.distribution_strategy_replica_gauge.get_cell(
198        "num_replicas_per_worker").set(self.extended.num_replicas_per_host)
199
200  @property
201  def steps_per_run(self):
202    """DEPRECATED: use .extended.steps_per_run instead."""
203    return self._extended.steps_per_run
204
205  # TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this
206  # can use the default implementation.
207  # This implementation runs a single step. It does not use infeed or outfeed.
208  def experimental_run_v2(self, fn, args=(), kwargs=None):
209    """See base class."""
210    validate_experimental_run_function(fn)
211
212    fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
213    return self.extended.tpu_run(fn, args, kwargs)
214
215
216# TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1.
217class TPUExtended(distribute_lib.StrategyExtendedV1):
218  """Implementation of TPUStrategy."""
219
220  def __init__(self,
221               container_strategy,
222               tpu_cluster_resolver=None,
223               steps_per_run=None,
224               device_assignment=None):
225    super(TPUExtended, self).__init__(container_strategy)
226
227    if tpu_cluster_resolver is None:
228      tpu_cluster_resolver = TPUClusterResolver("")
229
230    if steps_per_run is None:
231      # TODO(frankchn): Warn when we are being used by DS/Keras and this is
232      # not specified.
233      steps_per_run = 1
234
235    self._tpu_function_cache = weakref.WeakKeyDictionary()
236    self._tpu_cluster_resolver = tpu_cluster_resolver
237    self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver)
238    self._device_assignment = device_assignment
239
240    tpu_devices_flat = [
241        d.name for d in self._tpu_metadata.devices if "device:TPU:" in d.name]
242
243    # `self._tpu_devices` is a two-dimensional NumPy array of strings. It is
244    # indexed using `[replica_id][logical_device_id]`.
245    if device_assignment is None:
246      self._tpu_devices = np.array(
247          [[d] for d in tpu_devices_flat], dtype=object)
248    else:
249      job_name = device_spec.DeviceSpecV2.from_string(tpu_devices_flat[0]).job
250
251      tpu_devices = []
252      for replica_id in range(device_assignment.num_replicas):
253        replica_devices = []
254
255        for logical_core in range(device_assignment.num_cores_per_replica):
256          replica_devices.append(
257              device_util.canonicalize(
258                  device_assignment.tpu_device(
259                      replica=replica_id,
260                      logical_core=logical_core,
261                      job=job_name)))
262
263        tpu_devices.append(replica_devices)
264      self._tpu_devices = np.array(tpu_devices, dtype=object)
265
266    self._host_device = device_util.get_host_for_device(self._tpu_devices[0][0])
267
268    # Preload the data onto the TPUs. Currently we always preload onto logical
269    # device 0 for each replica.
270    # TODO(cjfj): Create `InputWorkers` lazily, allowing users to place the
271    # input onto a different logical device?
272    input_worker_devices = collections.OrderedDict()
273    for tpu_device in self._tpu_devices[:, 0]:
274      host_device = device_util.get_host_for_device(tpu_device)
275      input_worker_devices.setdefault(host_device, [])
276      input_worker_devices[host_device].append(tpu_device)
277    self._input_worker_devices = tuple(input_worker_devices.items())
278    self._input_workers_obj = None
279
280    # TODO(sourabhbajaj): Remove this once performance of running one step
281    # at a time is comparable to multiple steps.
282    self.steps_per_run = steps_per_run
283    self._require_static_shapes = True
284
285    # TPUStrategy handles the graph replication in TF-XLA bridge, so we don't
286    # need to retrace functions for each device.
287    self._retrace_functions_for_each_device = False
288
289    self.experimental_enable_get_next_as_optional = True
290    self.experimental_enable_dynamic_batch_size = True
291    self._prefetch_on_host = False
292
293    self._logical_device_stack = [0]
294
295  # TODO(bfontain): Remove once a proper dataset API exists for prefetching
296  # a dataset to multiple devices exists.
297  # If value is true, this forces prefetch of data to the host's memeory rather
298  # than the individual TPU device's memory. This is needed when using for TPU
299  # Embeddings as a) sparse tensors cannot be prefetched to the TPU device
300  # memory and b) TPU Embedding enqueue operation are CPU ops and this avoids
301  # a copy back to the host for dense tensors
302  def _set_prefetch_on_host(self, value):
303    if self._prefetch_on_host == value:
304      return
305    if self._input_workers_obj is not None:
306      raise RuntimeError("Unable to change prefetch on host behavior as "
307                         "InputWorkers are already created.")
308    self._prefetch_on_host = value
309    if value:
310      # To prefetch on the host, we must set all the input worker devices to the
311      # corresponding host devices.
312      self._input_worker_devices = tuple([
313          tuple([host,
314                 [device_util.get_host_for_device(d) for d in devices]])
315          for host, devices in self._input_worker_devices])
316      # Force creation of the workers.
317      workers = self._input_workers
318      del workers
319
320  @property
321  def _input_workers(self):
322    if self._input_workers_obj is None:
323      self._input_workers_obj = input_lib.InputWorkers(
324          self._input_worker_devices)
325    return self._input_workers_obj
326
327  def _validate_colocate_with_variable(self, colocate_with_variable):
328    values.validate_colocate(colocate_with_variable, self)
329
330  def _make_dataset_iterator(self, dataset):
331    """Make iterators for each of the TPU hosts."""
332    return input_lib.DatasetIterator(
333        dataset,
334        self._input_workers,
335        self._container_strategy(),
336        split_batch_by=self._num_replicas_in_sync)
337
338  def _make_input_fn_iterator(
339      self,
340      input_fn,
341      replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
342    input_contexts = []
343    num_workers = self._input_workers.num_workers
344    for i in range(num_workers):
345      input_contexts.append(distribute_lib.InputContext(
346          num_input_pipelines=num_workers,
347          input_pipeline_id=i,
348          num_replicas_in_sync=self._num_replicas_in_sync))
349    return input_lib.InputFunctionIterator(
350        input_fn,
351        self._input_workers,
352        input_contexts,
353        self._container_strategy())
354
355  def _experimental_make_numpy_dataset(self, numpy_input, session):
356    return numpy_dataset.one_host_numpy_dataset(
357        numpy_input, numpy_dataset.SingleDevice(self._host_device),
358        session)
359
360  def _experimental_distribute_dataset(self, dataset):
361    return input_lib.get_distributed_dataset(
362        dataset,
363        self._input_workers,
364        self._container_strategy(),
365        split_batch_by=self._num_replicas_in_sync)
366
367  def _experimental_distribute_datasets_from_function(self, dataset_fn):
368    input_contexts = []
369    num_workers = self._input_workers.num_workers
370    for i in range(num_workers):
371      input_contexts.append(distribute_lib.InputContext(
372          num_input_pipelines=num_workers,
373          input_pipeline_id=i,
374          num_replicas_in_sync=self._num_replicas_in_sync))
375
376    return input_lib.get_distributed_datasets_from_function(
377        dataset_fn,
378        self._input_workers,
379        input_contexts,
380        self._container_strategy())
381
382  # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
383  # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have
384  # a mechanism to infer the outputs of `fn`. Pending b/110550782.
385  def _experimental_run_steps_on_iterator(
386      self, fn, multi_worker_iterator, iterations, initial_loop_values=None):
387    # Wrap `fn` for repeat.
388    if initial_loop_values is None:
389      initial_loop_values = {}
390    initial_loop_values = nest.flatten(initial_loop_values)
391    ctx = input_lib.MultiStepContext()
392
393    def run_fn(inputs):
394      """Single step on the TPU device."""
395      fn_result = fn(ctx, inputs)
396      flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
397      if flat_last_step_outputs:
398        with ops.control_dependencies([fn_result]):
399          return [array_ops.identity(f) for f in flat_last_step_outputs]
400      else:
401        return fn_result
402
403    # We capture the control_flow_context at this point, before we run `fn`
404    # inside a while_loop and TPU replicate context. This is useful in cases
405    # where we might need to exit these contexts and get back to the outer
406    # context to do some things, for e.g. create an op which should be
407    # evaluated only once at the end of the loop on the host. One such usage
408    # is in creating metrics' value op.
409    self._outer_control_flow_context = (
410        ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access
411
412    def rewrite_fn(*args):
413      """The rewritten step fn running on TPU."""
414      del args
415
416      per_replica_inputs = multi_worker_iterator.get_next()
417      replicate_inputs = []
418      for replica_id in range(self._num_replicas_in_sync):
419        select_replica = lambda x: values.select_replica(replica_id, x)  # pylint: disable=cell-var-from-loop
420        replicate_inputs.append((nest.map_structure(
421            select_replica, per_replica_inputs),))
422
423      replicate_outputs = tpu.replicate(
424          run_fn, replicate_inputs, device_assignment=self._device_assignment)
425
426      # If run_fn has tensor outputs, tpu.replicate returns a list of list. We
427      # will flatten it in this case. If run_fn has no tensor outputs,
428      # tpu.replicate returns a list of no_ops, we will keep the output as it
429      # is.
430      if isinstance(replicate_outputs[0], list):
431        replicate_outputs = nest.flatten(replicate_outputs)
432
433      return replicate_outputs
434
435    # TODO(sourabhbajaj): The input to while loop should be based on the
436    # output type of the step_fn
437    assert isinstance(initial_loop_values, list)
438    initial_loop_values = initial_loop_values * self._num_replicas_in_sync
439
440    # Put the while loop op on TPU host 0.
441    with ops.device(self._host_device):
442      if self.steps_per_run == 1:
443        replicate_outputs = rewrite_fn()
444      else:
445        replicate_outputs = training_loop.repeat(iterations, rewrite_fn,
446                                                 initial_loop_values)
447
448    del self._outer_control_flow_context
449    ctx.run_op = control_flow_ops.group(replicate_outputs)
450
451    if isinstance(replicate_outputs, list):
452      # Filter out any ops from the outputs, typically this would be the case
453      # when there were no tensor outputs.
454      last_step_tensor_outputs = [
455          x for x in replicate_outputs if not isinstance(x, ops.Operation)
456      ]
457
458      # Outputs are currently of the structure (flattened)
459      # [output0_device0, output1_device0, output2_device0,
460      #  output0_device1, output1_device1, output2_device1,
461      #  ...]
462      # Convert this to the following structure instead: (grouped by output)
463      # [[output0_device0, output0_device1],
464      #  [output1_device0, output1_device1],
465      #  [output2_device0, output2_device1]]
466      output_num = len(last_step_tensor_outputs) // self._num_replicas_in_sync
467      last_step_tensor_outputs = [
468          last_step_tensor_outputs[i::output_num] for i in range(output_num)
469      ]
470    else:
471      # no tensors returned.
472      last_step_tensor_outputs = []
473
474    _set_last_step_outputs(ctx, last_step_tensor_outputs)
475    return ctx
476
477  def _call_for_each_replica(self, fn, args, kwargs):
478    # TODO(jhseu): Consider making it so call_for_each_replica implies that
479    # we're in a tpu.rewrite(), and update TPUMirroredVariable accordingly.
480    with _TPUReplicaContext(self._container_strategy()):
481      return fn(*args, **kwargs)
482
483  @contextlib.contextmanager
484  def experimental_logical_device(self, logical_device_id):
485    """Places variables and ops on the specified logical device."""
486    num_logical_devices_per_replica = self._tpu_devices.shape[1]
487    if logical_device_id >= num_logical_devices_per_replica:
488      raise ValueError(
489          "`logical_device_id` not in range (was {}, but there are only {} "
490          "logical devices per replica).".format(
491              logical_device_id, num_logical_devices_per_replica))
492
493    self._logical_device_stack.append(logical_device_id)
494    try:
495      if values._enclosing_tpu_context() is None:  # pylint: disable=protected-access
496        yield
497      else:
498        with ops.device(tpu.core(logical_device_id)):
499          yield
500    finally:
501      self._logical_device_stack.pop()
502
503  def _experimental_initialize_system(self):
504    """Experimental method added to be used by Estimator.
505
506    This is a private method only to be used by Estimator. Other frameworks
507    should directly be calling `tf.tpu.experimental.initialize_tpu_system`
508    """
509    tpu_strategy_util.initialize_tpu_system(self._tpu_cluster_resolver)
510
511  def _create_variable(self, next_creator, **kwargs):
512    """Create a TPUMirroredVariable. See `DistributionStrategy.scope`."""
513    if kwargs.pop("skip_mirrored_creator", False):
514      return next_creator(**kwargs)
515
516    colocate_with = kwargs.pop("colocate_with", None)
517    if colocate_with is None:
518      devices = self._tpu_devices[:, self._logical_device_stack[-1]]
519    elif isinstance(colocate_with, numpy_dataset.SingleDevice):
520      with ops.device(colocate_with.device):
521        return next_creator(**kwargs)
522    else:
523      devices = colocate_with.devices
524
525    def _real_mirrored_creator(**kwargs):  # pylint: disable=g-missing-docstring
526      initial_value = None
527      value_list = []
528      for i, d in enumerate(devices):
529        with ops.device(d):
530          if i == 0:
531            initial_value = kwargs["initial_value"]
532            # Note: some v1 code expects variable initializer creation to happen
533            # inside a init_scope.
534            with maybe_init_scope():
535              initial_value = initial_value() if callable(
536                  initial_value) else initial_value
537
538          if i > 0:
539            # Give replicas meaningful distinct names:
540            var0name = value_list[0].name.split(":")[0]
541            # We append a / to variable names created on replicas with id > 0 to
542            # ensure that we ignore the name scope and instead use the given
543            # name as the absolute name of the variable.
544            kwargs["name"] = "%s/replica_%d/" % (var0name, i)
545          kwargs["initial_value"] = initial_value
546
547          with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
548            v = next_creator(**kwargs)
549
550          assert not isinstance(v, values.TPUMirroredVariable)
551          value_list.append(v)
552      return value_list
553
554    return values.create_mirrored_variable(self._container_strategy(),
555                                           _real_mirrored_creator,
556                                           values.TPUMirroredVariable,
557                                           values.TPUSyncOnReadVariable,
558                                           **kwargs)
559
560  def _reduce_to(self, reduce_op, value, destinations):
561    if (isinstance(value, values.DistributedValues) or
562        tensor_util.is_tensor(value)
563       ) and values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
564      if reduce_op == reduce_util.ReduceOp.MEAN:
565        # TODO(jhseu):  Revisit once we support model-parallelism.
566        value *= (1. / self._num_replicas_in_sync)
567      elif reduce_op != reduce_util.ReduceOp.SUM:
568        raise NotImplementedError(
569            "Currently only support sum & mean in TPUStrategy.")
570      return tpu_ops.cross_replica_sum(value)
571
572    if not isinstance(value, values.DistributedValues):
573      # This function handles reducing values that are not PerReplica or
574      # Mirrored values. For example, the same value could be present on all
575      # replicas in which case `value` would be a single value or value could
576      # be 0.
577      return cross_device_ops_lib.reduce_non_distributed_value(
578          reduce_op, value, destinations, self._num_replicas_in_sync)
579
580    # TODO(cjfj): Detect when it is possible to use `cross_replica_sum`.
581    # Always performs the reduction on the TPU host.
582    with ops.device(self._host_device):
583      output = math_ops.add_n(value.values)
584      if reduce_op == reduce_util.ReduceOp.MEAN:
585        output *= (1. / len(value.values))
586
587    devices = cross_device_ops_lib.get_devices_from(destinations)
588
589    if len(devices) == 1:
590      # If necessary, copy to requested destination.
591      dest_canonical = device_util.canonicalize(devices[0])
592      host_canonical = device_util.canonicalize(self._host_device)
593
594      if dest_canonical != host_canonical:
595        with ops.device(dest_canonical):
596          output = array_ops.identity(output)
597    else:
598      output = cross_device_ops_lib.simple_broadcast(output, destinations)
599
600    return output
601
602  def _update(self, var, fn, args, kwargs, group):
603    assert isinstance(var, values.TPUVariableMixin) or isinstance(
604        var, resource_variable_ops.BaseResourceVariable)
605    if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
606      if group:
607        return fn(var, *args, **kwargs)
608      else:
609        return (fn(var, *args, **kwargs),)
610
611    # Otherwise, we revert to MirroredStrategy behavior and update each variable
612    # directly.
613    updates = []
614    for i, v in enumerate(var.values):
615      name = "update_%d" % i
616      with ops.device(v.device), \
617           distribute_lib.UpdateContext(i), \
618           ops.name_scope(name):
619        # If args and kwargs are not mirrored, the value is returned as is.
620        updates.append(fn(v,
621                          *values.select_replica_mirrored(i, args),
622                          **values.select_replica_mirrored(i, kwargs)))
623    return values.update_regroup(self, updates, group)
624
625  def read_var(self, var):
626    assert isinstance(var, values.TPUVariableMixin) or isinstance(
627        var, resource_variable_ops.BaseResourceVariable)
628    return var.read_value()
629
630  def _local_results(self, val):
631    if isinstance(val, values.DistributedValues):
632      return val.values
633    return (val,)
634
635  def value_container(self, value):
636    return value
637
638  def _broadcast_to(self, tensor, destinations):
639    del destinations
640    # This is both a fast path for Python constants, and a way to delay
641    # converting Python values to a tensor until we know what type it
642    # should be converted to. Otherwise we have trouble with:
643    #   global_step.assign_add(1)
644    # since the `1` gets broadcast as an int32 but global_step is int64.
645    if isinstance(tensor, (float, int)):
646      return tensor
647    if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
648      broadcast_tensor = [tensor for _ in range(self._num_replicas_in_sync)]
649      result = tpu_ops.all_to_all(
650          broadcast_tensor,
651          concat_dimension=0,
652          split_dimension=0,
653          split_count=self._num_replicas_in_sync)
654
655      # This uses the broadcasted value from the first replica because the only
656      # caller of this is for ONLY_FIRST_REPLICA variables aggregation.
657      return result[0]
658    return tensor
659
660  @property
661  def num_hosts(self):
662    if self._device_assignment is None:
663      return self._tpu_metadata.num_hosts
664
665    return len(set([self._device_assignment.host_device(r)
666                    for r in range(self._device_assignment.num_replicas)]))
667
668  @property
669  def num_replicas_per_host(self):
670    if self._device_assignment is None:
671      return self._tpu_metadata.num_of_cores_per_host
672
673    # TODO(sourabhbajaj): Remove this method we use inputs and remove infeed
674    # as the computation of num_replicas_per_host is not a constant
675    # when using device_assignment. This is a temporary workaround to support
676    # StatefulRNN as everything is 1 in that case.
677    # This method needs to take host_id as input for correct computation.
678    max_models_per_host = (self._tpu_metadata.num_of_cores_per_host //
679                           self._device_assignment.num_cores_per_replica)
680    return min(self._device_assignment.num_replicas, max_models_per_host)
681
682  @property
683  def _num_replicas_in_sync(self):
684    if self._device_assignment is None:
685      return self._tpu_metadata.num_cores
686    return self._device_assignment.num_replicas
687
688  @property
689  def experimental_between_graph(self):
690    return False
691
692  @property
693  def experimental_should_init(self):
694    return True
695
696  @property
697  def should_checkpoint(self):
698    return True
699
700  @property
701  def should_save_summary(self):
702    return True
703
704  @property
705  def worker_devices(self):
706    return tuple(self._tpu_devices[:, self._logical_device_stack[-1]])
707
708  @property
709  def parameter_devices(self):
710    return self.worker_devices
711
712  def non_slot_devices(self, var_list):
713    return self._host_device
714
715  def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
716    del colocate_with
717    with ops.device(self._host_device), distribute_lib.UpdateContext(None):
718      result = fn(*args, **kwargs)
719      if group:
720        return result
721      else:
722        return nest.map_structure(self._local_results, result)
723
724  def _configure(self,
725                 session_config=None,
726                 cluster_spec=None,
727                 task_type=None,
728                 task_id=None):
729    del cluster_spec, task_type, task_id
730    if session_config:
731      session_config.CopyFrom(self._update_config_proto(session_config))
732
733  def _update_config_proto(self, config_proto):
734    updated_config = copy.deepcopy(config_proto)
735    updated_config.isolate_session_state = True
736    cluster_spec = self._tpu_cluster_resolver.cluster_spec()
737    if cluster_spec:
738      updated_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
739    return updated_config
740
741  # TODO(priyag): Delete this once all strategies use global batch size.
742  @property
743  def _global_batch_size(self):
744    """`make_dataset_iterator` and `make_numpy_iterator` use global batch size.
745
746    `make_input_fn_iterator` assumes per-replica batching.
747
748    Returns:
749      Boolean.
750    """
751    return True
752
753  def tpu_run(self, fn, args, kwargs):
754    func = self._tpu_function_creator(fn)
755    return func(args, kwargs)
756
757  def _tpu_function_creator(self, fn):
758    if fn in self._tpu_function_cache:
759      return self._tpu_function_cache[fn]
760
761    strategy = self._container_strategy()
762
763    def tpu_function(args, kwargs):
764      """TF Function used to replicate the user computation."""
765      if kwargs is None:
766        kwargs = {}
767
768      # Remove None at the end of args as they are not replicatable
769      # If there are None in the middle we can't do anything about it
770      # so let those cases fail.
771      # For example when Keras model predict is used they pass the targets as
772      # None. We want to handle it here so all client libraries don't have to
773      # do this as other strategies can handle None values better.
774      while args and args[-1] is None:
775        args = args[:-1]
776
777      # Used to re-structure flattened output tensors from `tpu.replicate()`
778      # into a structured format.
779      result = [[]]
780
781      def replicated_fn(replica_id, replica_args, replica_kwargs):
782        """Wraps user function to provide replica ID and `Tensor` inputs."""
783        with _TPUReplicaContext(strategy, replica_id_in_sync_group=replica_id):
784          result[0] = fn(*replica_args, **replica_kwargs)
785        return result[0]
786
787      replicate_inputs = []  # By replica.
788      for i in range(strategy.num_replicas_in_sync):
789        replicate_inputs.append(
790            [constant_op.constant(i, dtype=dtypes.int32),
791             values.select_replica(i, args),
792             values.select_replica(i, kwargs)])
793
794      # Construct and pass `maximum_shapes` so that we could support dynamic
795      # shapes using dynamic padder.
796      if self.experimental_enable_dynamic_batch_size and replicate_inputs:
797        maximum_shapes = []
798        flattened_list = nest.flatten(replicate_inputs[0])
799        for input_tensor in flattened_list:
800          if tensor_util.is_tensor(input_tensor):
801            rank = input_tensor.get_shape().rank
802          else:
803            rank = np.rank(input_tensor)
804          maximum_shape = tensor_shape.TensorShape([None] * rank)
805          maximum_shapes.append(maximum_shape)
806        maximum_shapes = nest.pack_sequence_as(replicate_inputs[0],
807                                               maximum_shapes)
808      else:
809        maximum_shapes = None
810
811      with strategy.scope():
812        replicate_outputs = tpu.replicate(
813            replicated_fn,
814            replicate_inputs,
815            device_assignment=self._device_assignment,
816            maximum_shapes=maximum_shapes)
817
818      # Remove all no ops that may have been added during 'tpu.replicate()'
819      if isinstance(result[0], list):
820        result[0] = [
821            output for output in result[0] if not isinstance(
822                output, ops.Operation)
823        ]
824
825      # Workaround for `tpu.replicate` behaviour when single `Tensor` returned.
826      if result[0] is None or isinstance(result[0], ops.Operation):
827        replicate_outputs = [None] * len(replicate_outputs)
828      else:
829        replicate_outputs = [
830            nest.pack_sequence_as(result[0], nest.flatten(replica_output))
831            for replica_output in replicate_outputs
832        ]
833      return values.regroup(replicate_outputs)
834
835    if context.executing_eagerly():
836      tpu_function = def_function.function(tpu_function)
837
838    self._tpu_function_cache[fn] = tpu_function
839    return tpu_function
840
841  def _in_multi_worker_mode(self):
842    """Whether this strategy indicates working in multi-worker settings."""
843    # TPUStrategy has different distributed training structure that the whole
844    # cluster should be treated as single worker from higher-level (e.g. Keras)
845    # library's point of view.
846    # TODO(rchao): Revisit this as we design a fault-tolerance solution for
847    # TPUStrategy.
848    return False
849
850
851class _TPUReplicaContext(distribute_lib.ReplicaContext):
852  """Replication Context class for TPU Strategy."""
853
854  # TODO(sourabhbajaj): Call for each replica should be updating this.
855  # TODO(b/118385803): Always properly initialize replica_id.
856  def __init__(self, strategy, replica_id_in_sync_group=None):
857    if replica_id_in_sync_group is None:
858      replica_id_in_sync_group = constant_op.constant(0, dtypes.int32)
859    distribute_lib.ReplicaContext.__init__(
860        self, strategy, replica_id_in_sync_group=replica_id_in_sync_group)
861
862  @property
863  def devices(self):
864    distribute_lib.require_replica_context(self)
865    ds = self._strategy
866    replica_id = tensor_util.constant_value(self._replica_id_in_sync_group)
867
868    if replica_id is None:  # Non-constant `Tensor` inside `tpu.replicate`.
869      # TODO(cjfj): Return other devices when model parallelism is supported.
870      return (tpu.core(0),)
871    else:
872      return (ds.extended.worker_devices[replica_id],)
873
874  def experimental_logical_device(self, logical_device_id):
875    """Places variables and ops on the specified logical device."""
876    return self.strategy.extended.experimental_logical_device(logical_device_id)
877
878
879def _set_last_step_outputs(ctx, last_step_tensor_outputs):
880  """Sets the last step outputs on the given context."""
881  # Convert replicate_outputs to the original dict structure of
882  # last_step_outputs.
883  last_step_tensor_outputs_dict = nest.pack_sequence_as(
884      ctx.last_step_outputs, last_step_tensor_outputs)
885
886  for name, reduce_op in ctx._last_step_outputs_reduce_ops.items():  # pylint: disable=protected-access
887    output = last_step_tensor_outputs_dict[name]
888    # For outputs that have already been reduced, take the first value
889    # from the list as each value should be the same. Else return the full
890    # list of values.
891    # TODO(josh11b): If reduce_op is NONE, we should return a PerReplica
892    # value.
893    if reduce_op is not None:
894      # TODO(priyag): Should this return the element or a list with 1 element
895      last_step_tensor_outputs_dict[name] = output[0]
896  ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access
897