• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ======================================
15
16"""Library of TPU helper functions."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23import enum
24import typing
25from typing import Any, Callable, Iterable, List, Optional, Text, Tuple, Union
26
27from absl import logging
28import numpy as np
29from six.moves import xrange  # pylint: disable=redefined-builtin
30
31from tensorflow.compiler.tf2xla.python import xla as tf2xla
32from tensorflow.core.framework import attr_value_pb2
33from tensorflow.core.protobuf.tpu import dynamic_padding_pb2 as dynamic_padding
34from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 as embedding_pb2
35from tensorflow.python.compiler.xla import xla
36from tensorflow.python.distribute import device_util
37from tensorflow.python.distribute import distribution_strategy_context
38from tensorflow.python.framework import auto_control_deps
39from tensorflow.python.framework import c_api_util
40from tensorflow.python.framework import composite_tensor
41from tensorflow.python.framework import config
42from tensorflow.python.framework import constant_op
43from tensorflow.python.framework import device as pydev
44from tensorflow.python.framework import dtypes
45from tensorflow.python.framework import errors
46from tensorflow.python.framework import func_graph
47from tensorflow.python.framework import function
48from tensorflow.python.framework import ops
49from tensorflow.python.framework import tensor_shape
50from tensorflow.python.ops import array_ops
51from tensorflow.python.ops import control_flow_ops
52from tensorflow.python.ops import math_ops
53from tensorflow.python.ops import variable_scope
54from tensorflow.python.ops import variables
55from tensorflow.python.tpu import device_assignment as device_assignment_lib
56from tensorflow.python.tpu import tpu_feed
57from tensorflow.python.tpu import tpu_function
58from tensorflow.python.tpu import tpu_name_util
59from tensorflow.python.tpu.ops import tpu_ops
60from tensorflow.python.types import core as core_types
61from tensorflow.python.util import compat
62from tensorflow.python.util import nest
63from tensorflow.python.util import object_identity
64from tensorflow.python.util.tf_export import tf_export
65
66ops.NotDifferentiable("TPUReplicatedInput")
67
68# Operations that indicate some error in the users graph, e.g. a placeholder
69# that's introduced outside of the infeed.
70_DENYLISTED_OPS = set([
71    "Placeholder",
72])
73
74# XLA doesn't currently support reading of intermediate tensors, thus some ops
75# are not supported.
76_UNSUPPORTED_OPS = set([
77    "AudioSummary",
78    "AudioSummaryV2",
79    "HistogramSummary",
80    "ImageSummary",
81    "MergeSummary",
82    "Print",
83    "ScalarSummary",
84    "TensorSummary",
85    "TensorSummaryV2",
86    ])
87
88# Ops which can be safely pruned from XLA compile if they have no consumers.
89#  These ops should also have no inputs.
90_UNCONNECTED_OPS_TO_PRUNE = set(["Placeholder", "VarHandleOp"])
91
92_MAX_WARNING_LINES = 5
93
94_TPU_REPLICATE_ATTR = "_tpu_replicate"
95_POST_DEVICE_REWRITE_ATTR = "_post_device_rewrite"
96_TPU_COMPILATION_STATUS_ATTR = "_tpu_compilation_status"
97_OUTSIDE_COMPILATION_ATTR = "_xla_outside_compilation"
98_PIVOT_FOR_CLUSTER = "_pivot_for_cluster"
99
100
101core = tpu_name_util.core
102
103
104def _tpu_system_device_name(job: Optional[Text]) -> Text:
105  """Returns the device name for the TPU_SYSTEM device of `job`."""
106  if job is None:
107    return "/device:TPU_SYSTEM:0"
108  else:
109    return "/job:%s/device:TPU_SYSTEM:0" % job
110
111
112@tf_export(v1=["tpu.initialize_system"])
113def initialize_system(
114    embedding_config: Optional[embedding_pb2.TPUEmbeddingConfiguration] = None,
115    job: Optional[Text] = None,
116    compilation_failure_closes_chips: bool = True
117) -> core_types.Tensor:
118  """Initializes a distributed TPU system for use with TensorFlow.
119
120  Args:
121    embedding_config: If not None, a `TPUEmbeddingConfiguration` proto
122      describing the desired configuration of the hardware embedding lookup
123      tables. If embedding_config is None, no hardware embeddings can be used.
124    job: The job (the XXX in TensorFlow device specification /job:XXX) that
125      contains the TPU devices that will be initialized. If job=None it is
126      assumed there is only one job in the TensorFlow flock, and an error will
127      be returned if this assumption does not hold.
128    compilation_failure_closes_chips: Set the configuration whether
129      we want to close TPU chips when there is a compilation failure.
130  Returns:
131    A serialized `TopologyProto` that describes the TPU system. Note:
132      the topology must be evaluated using `Session.run` before it can be used.
133  """
134  config_string = ("" if embedding_config is None else
135                   embedding_config.SerializeToString())
136  with ops.device(_tpu_system_device_name(job)):
137    topology = tpu_ops.configure_distributed_tpu(
138        compilation_failure_closes_chips=compilation_failure_closes_chips)
139
140    if embedding_config is None:
141      return topology
142
143    # This set of control dependencies is needed as this function is expected to
144    # return an op which will return the topology when executed, but we need to
145    # call the embedding initialization op between initializing the TPU and
146    # returning the topology.
147    with ops.control_dependencies([topology]):
148      embedding_init = tpu_ops.configure_tpu_embedding(config=config_string)
149    with ops.control_dependencies([embedding_init]):
150      return array_ops.identity(topology, name="tpu_init_identity")
151
152
153def initialize_system_for_tpu_embedding(
154    embedding_config: embedding_pb2.TPUEmbeddingConfiguration,
155    job: Optional[Text] = None,
156) -> ops.Operation:
157  """Initializes a distributed TPU Embedding system for use with TensorFlow.
158
159  The following two are equivalent:
160  1. initialize_system() with embedding_config.
161  2. initialize_system() without embedding_config, then
162     initialize_system_for_tpu_embedding().
163  initialize_system() should not be called with embedding_config if
164  initialize_system_for_tpu_embedding() is meant to be called later.
165
166  Args:
167    embedding_config: a `TPUEmbeddingConfiguration` proto describing the desired
168      configuration of the hardware embedding lookup tables.
169    job: The job (the XXX in TensorFlow device specification /job:XXX) that
170      contains the TPU devices that will be initialized. If job=None it is
171      assumed there is only one job in the TensorFlow flock, and an error will
172      be returned if this assumption does not hold.
173
174  Returns:
175    A no-op.
176  """
177  config_string = embedding_config.SerializeToString()
178  with ops.device(_tpu_system_device_name(job)):
179    return tpu_ops.configure_tpu_embedding(config=config_string)
180
181
182@tf_export(v1=["tpu.shutdown_system"])
183def shutdown_system(job: Optional[Text] = None) -> ops.Operation:
184  """Shuts down a running a distributed TPU system.
185
186  Args:
187    job: The job (the XXX in TensorFlow device specification /job:XXX) that
188      contains the TPU devices that will be shutdown. If job=None it is
189      assumed there is only one job in the TensorFlow flock, and an error will
190      be returned if this assumption does not hold.
191  """
192  with ops.device(_tpu_system_device_name(job)):
193    shutdown_distributed_tpu = tpu_ops.shutdown_distributed_tpu()
194  return shutdown_distributed_tpu
195
196
197def _enclosing_tpu_context_and_graph() -> Tuple[Any, Any]:
198  """Returns the TPUReplicateContext and its associated graph."""
199  graph = ops.get_default_graph()
200  while graph is not None:
201    # pylint: disable=protected-access
202    context_ = graph._get_control_flow_context()
203    # pylint: enable=protected-access
204    while context_ is not None:
205      if isinstance(context_, TPUReplicateContext):
206        return context_, graph
207      context_ = context_.outer_context
208    graph = getattr(graph, "outer_graph", None)
209  raise ValueError("get_replicated_var_handle() called without "
210                   "TPUReplicateContext. This shouldn't happen. Please file "
211                   "a bug.")
212
213
214def is_tpu_strategy(strategy: Any) -> bool:
215  is_tpu_strat = lambda k: k.__name__.startswith("TPUStrategy")
216  clz = strategy.__class__
217  return is_tpu_strat(clz) or any(map(is_tpu_strat, clz.__bases__))
218
219
220def _enclosing_tpu_device_assignment(
221) -> Optional[device_assignment_lib.DeviceAssignment]:
222  if not distribution_strategy_context.has_strategy():
223    return None
224  strategy = distribution_strategy_context.get_strategy()
225  if not is_tpu_strategy(strategy):
226    return None
227  return strategy.extended._device_assignment  # pylint: disable=protected-access
228
229
230@auto_control_deps.register_acd_resource_resolver
231def tpu_replicated_input_resolver(
232    op: ops.Operation,
233    resource_reads: object_identity.ObjectIdentitySet,
234    resource_writes: object_identity.ObjectIdentitySet) -> bool:
235  """Replaces TPUReplicatedInput outputs with its inputs in resource_inputs."""
236  # Ignore TPUReplicatedInput for ACD purposes since we will be directly adding
237  # control deps on the replicated inputs.
238  if op.type == "TPUReplicatedInput":
239    if resource_reads or resource_writes:
240      resource_reads.clear()
241      resource_writes.clear()
242      return True
243    else:
244      return False
245  # Replace tensors in `resource_inputs` which are outputs of TPUReplicatedInput
246  # with the actual replicated inputs. This allows ACD to correct add control
247  # deps when there are multiple calls to `run` in a
248  # `tf.function`.
249  def replace_with_unreplicated_resources(resource_inputs):
250    """Replaces handles in `resource_inputs` with their unreplicated inputs."""
251    to_remove = []
252    to_add = []
253    for resource in resource_inputs:
254      if resource.op.type == "TPUReplicatedInput":
255        to_remove.append(resource)
256        to_add.extend(resource.op.inputs)
257    for t in to_remove:
258      resource_inputs.discard(t)
259    resource_inputs.update(to_add)
260    return to_add or to_remove
261
262  return bool(replace_with_unreplicated_resources(resource_reads) or
263              replace_with_unreplicated_resources(resource_writes))
264
265
266class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
267  """A `ControlFlowContext` for nodes inside a TPU computation.
268
269  The primary role of `TPUReplicateContext` is to mark operators inside a
270  tpu.replicate() computation with the attribute "_tpu_replicate=XYZ", where XYZ
271  is a unique name.
272
273  We use a `ControlFlowContext` to perform the annotation since it integrates
274  with Tensorflow constructs like ResourceVariables. For example, if a
275  `ResourceVariable` is constructed inside a tpu.replicate() block, the
276  `ResourceVariable` implementation can use
277  `with ops.control_dependencies(None)` to build the variable's definition
278  outside the replicated computation.
279  """
280
281  def __init__(self, name: Text, num_replicas: int, pivot: ops.Operation):
282    """Builds a new TPUReplicateContext.
283
284    Args:
285      name: a unique name for the context, used to populate the `_tpu_replicate`
286        attribute.
287      num_replicas: an integer that gives the number of replicas for the
288        computation.
289      pivot: a pivot node. Nodes in the TPUReplicateContext that do not have any
290        inputs will have a control dependency on the pivot node. This ensures
291        that nodes are correctly included in any enclosing control flow
292        contexts.
293    """
294    super(TPUReplicateContext, self).__init__()
295    self._num_replicas = num_replicas
296    self._outer_device_function_stack = None
297    self._oc_dev_fn_stack = None
298    self._outside_compilation_cluster = None
299    self._outside_compilation_v2_context = None
300    self._outside_compilation_counter = 0
301    self._in_gradient_colocation = None
302    self._gradient_colocation_stack = []
303    self._host_compute_core = []
304    self._name = name
305    self._name_as_bytes = compat.as_bytes(name)
306    self._tpu_relicate_attr_buf = c_api_util.ScopedTFBuffer(
307        attr_value_pb2.AttrValue(s=self._name_as_bytes).SerializeToString())
308    self._unsupported_ops = []
309    self._pivot = pivot
310    self._replicated_vars = {}
311
312  def get_replicated_var_handle(
313      self,
314      name: Text,
315      vars_: List[variables.Variable],
316      is_mirrored: bool = False,
317      is_packed: bool = False) -> core_types.Tensor:
318    """Returns a variable handle for replicated TPU variable 'var'.
319
320    This is a method used by an experimental replicated variable implementation
321    and is not intended as a public API.
322
323    Args:
324      name: The common name of the variable.
325      vars_: The replicated TPU variables.
326      is_mirrored: Whether the variables are mirrored, which guarantees the
327        values in each replica are always the same.
328      is_packed: Whether the replicated variables are packed into one variable.
329
330    Returns:
331      The handle of the TPU replicated input node.
332    """
333    device_assignment = _enclosing_tpu_device_assignment()
334    # We don't need to put device assignment as part of the replicated_vars key
335    # because each TPUReplicateContext will only have one device assignment.
336    handle = self._replicated_vars.get(name)
337    if handle is not None:
338      return handle
339
340    if device_assignment is not None and not is_packed:
341      # Find a variable copy for each replica in the device assignment.
342      # Note that the order of devices for replicas for the variable and the
343      # device assignment might not match.
344      job_name = pydev.DeviceSpec.from_string(vars_[0].device).job
345      devices_to_vars = {device_util.canonicalize(v.device): v for v in vars_}
346      replicated_vars = []
347      for replica_id in range(device_assignment.num_replicas):
348        for logical_core in range(device_assignment.num_cores_per_replica):
349          device = device_util.canonicalize(
350              device_assignment.tpu_device(
351                  replica=replica_id, logical_core=logical_core, job=job_name))
352          if device in devices_to_vars:
353            replicated_vars.append(devices_to_vars[device])
354            break
355        else:
356          raise ValueError(
357              "Failed to find a variable on any device in replica {} for "
358              "current device assignment".format(replica_id))
359    else:
360      replicated_vars = vars_
361
362    # Builds a TPUReplicatedInput node for the variable, if one does not already
363    # exist. The TPUReplicatedInput node must belong to the enclosing
364    # control-flow scope of the TPUReplicateContext.
365    # TODO(phawkins): consider changing the contract of the TPU encapsulation
366    # so the TPUReplicatedInput nodes go inside the TPUReplicateContext scope
367    # instead.
368
369    _, graph = _enclosing_tpu_context_and_graph()
370    with graph.as_default():
371      # pylint: disable=protected-access
372      saved_context = graph._get_control_flow_context()
373      graph._set_control_flow_context(self.outer_context)
374      handle = tpu_ops.tpu_replicated_input([v.handle for v in replicated_vars],
375                                            name=name + "/handle",
376                                            is_mirrored_variable=is_mirrored,
377                                            is_packed=is_packed)
378      graph._set_control_flow_context(saved_context)
379      # pylint: enable=protected-access
380    self._replicated_vars[name] = handle
381    return handle
382
383  def report_unsupported_operations(self) -> None:
384    if self._unsupported_ops:
385      op_str = "\n".join("  %s (%s)" % (op.type, op.name)
386                         for op in self._unsupported_ops[:_MAX_WARNING_LINES])
387      logging.warning("%d unsupported operations found: \n%s",
388                      len(self._unsupported_ops), op_str)
389      if len(self._unsupported_ops) > _MAX_WARNING_LINES:
390        logging.warning("... and %d more" %
391                        (len(self._unsupported_ops) - _MAX_WARNING_LINES))
392
393  def EnterGradientColocation(self, op: ops.Operation, gradient_uid: Text):
394    if op is not None:
395      if ops.get_default_graph()._control_flow_context is None:  # pylint: disable=protected-access
396        # If we are in TF 2 functions (control flow V2 functions, or
397        # tf.function()), we need to attach _xla_outside_compilation attribute
398        # directly because we are not in TPUReplicateContext.
399        try:
400          outside_attr = op.get_attr(_OUTSIDE_COMPILATION_ATTR).decode("ascii")
401        except ValueError:
402          # The attr was not present: do nothing.
403          return
404        parts = outside_attr.split(".")
405        cluster = parts[0] + "." + gradient_uid
406        self._outside_compilation_v2_context = OutsideCompilationV2Context(
407            cluster)
408        self._outside_compilation_v2_context.Enter()
409        return
410      self._gradient_colocation_stack.append(op)
411      if not self._outside_compilation_cluster:
412        try:
413          outside_attr = op.get_attr(_OUTSIDE_COMPILATION_ATTR).decode("ascii")
414          if self._in_gradient_colocation:
415            raise NotImplementedError(
416                "Cannot nest gradient colocation operations outside compilation"
417            )
418          if gradient_uid == "__unsupported__":
419            raise NotImplementedError(
420                "No gradient_uid calling gradient within outside_compilation")
421          # When we take the gradient of an op X in an outside_compilation
422          # cluster C in a forward computation we would like to put the ops
423          # corresponding to the gradient of X into a new outside_compilation
424          # cluster C'. However, if we take the gradient of X twice, the second
425          # one should get yet another new outside_compilation cluster C''.
426          #
427          # The mechanism we adopt is to use a 'root_cluster' which is the
428          # cluster that X was in before we took gradients, and a 'gradient_uid'
429          # which is different for every invocation of gradients, and put the
430          # gradient of X in cluster 'root_cluster.gradient_uid'.
431          #
432          # When taking a gradient of a gradient, some ops will be colocated
433          # with Op in the forward pass (e.g., cluster root_cluster) and some in
434          # the backward pass (e.g., cluster root_cluster.initial_gradient_uid).
435          # We need all of the grad-of-grad ops to be in the same cluster to
436          # avoid cyclic dependencies between clusters. We adopt a heuristic
437          # that puts any op clustered with root_cluster.<xxx> in
438          # root_cluster.gradient_uid, even if xxx was initial_gradient_uid.
439          self._in_gradient_colocation = op
440          parts = outside_attr.split(".")
441          cluster = parts[0] + "." + gradient_uid
442          self._EnterOutsideCompilationScope(cluster=cluster)
443        except ValueError:
444          # The attr was not present: do nothing.
445          pass
446
447  def ExitGradientColocation(self, op: ops.Operation, gradient_uid: Text):
448    if op is not None:
449      if ops.get_default_graph()._control_flow_context is None:  # pylint: disable=protected-access
450        # Inside a TF2 tf.function or control flow graph and `op` was not
451        # marked to be outside compiled.
452        assert self._outside_compilation_v2_context is None
453        return
454      if self._outside_compilation_v2_context is not None:
455        # Inside a TF2 tf.function or control flow graph and `op` was
456        # marked to be outside compiled.
457        self._outside_compilation_v2_context.Exit()
458        self._outside_compilation_v2_context = None
459        return
460      if not self._gradient_colocation_stack:
461        raise errors.InternalError(
462            op.node_def, op,
463            "Badly nested gradient colocation: empty stack when popping Op " +
464            op.name)
465      last_op = self._gradient_colocation_stack.pop()
466      if op is last_op:
467        if op is self._in_gradient_colocation:
468          self._in_gradient_colocation = None
469          self._ExitOutsideCompilationScope()
470      else:
471        raise errors.InternalError(
472            op.node_def, op, "Badly nested gradient colocation, expected " +
473            last_op + ", got " + op.name)
474
475  def _EnterOutsideCompilationScope(self, cluster: Optional[Text] = None):
476
477    class FakeOp(object):
478      """A helper class to determine the current device.
479
480      Supports only the type and device set/get methods needed to run the
481      graph's _apply_device_function method.
482      """
483
484      def __init__(self):
485        self._device = ""
486
487      @property
488      def type(self):
489        return "FakeOp"
490
491      @property
492      def device(self):
493        return self._device
494
495      def _set_device(self, device):
496        if isinstance(device, pydev.DeviceSpec):
497          self._device = device.to_string()
498        else:
499          self._device = device
500
501      def _set_device_from_string(self, device_str):
502        self._device = device_str
503
504    if self._outside_compilation_cluster:
505      raise NotImplementedError("Cannot nest outside_compilation clusters")
506    if cluster:
507      self._outside_compilation_cluster = cluster
508    else:
509      self._outside_compilation_cluster = str(self._outside_compilation_counter)
510      self._outside_compilation_counter += 1
511    graph = ops.get_default_graph()
512    fake_op = FakeOp()
513    graph._apply_device_functions(fake_op)  # pylint: disable=protected-access
514    device = pydev.DeviceSpec.from_string(fake_op.device)
515    if (device.device_type == "TPU_REPLICATED_CORE" and
516        device.device_index is not None):
517      self._host_compute_core.append(self._outside_compilation_cluster + ":" +
518                                     str(device.device_index))
519    self._oc_dev_fn_stack = graph._device_function_stack  # pylint: disable=protected-access
520    graph._device_function_stack = self._outer_device_function_stack  # pylint: disable=protected-access
521
522  def _ExitOutsideCompilationScope(self):
523    if not self._outside_compilation_cluster:
524      raise NotImplementedError(
525          "Attempted to exit outside_compilation scope when not in scope")
526    self._outside_compilation_cluster = None
527    graph = ops.get_default_graph()
528    graph._device_function_stack = self._oc_dev_fn_stack  # pylint: disable=protected-access
529
530  def Enter(self) -> None:
531    if not self._outer_device_function_stack:
532      # Capture the device function stack at the time of first entry
533      # since that is the stack that will be used outside_compilation.
534      graph = ops.get_default_graph()
535      # pylint: disable=protected-access
536      self._outer_device_function_stack = graph._device_function_stack.copy()
537      # pylint: enable=protected-access
538    super(TPUReplicateContext, self).Enter()
539
540  def HostComputeCore(self) -> List[Text]:
541    return self._host_compute_core
542
543  def _RemoveExternalControlEdges(
544      self, op: ops.Operation
545      ) -> Tuple[List[ops.Operation], List[ops.Operation]]:
546    """Remove any external control dependency on this op."""
547    internal_control_inputs = []
548    external_control_inputs = []
549    for x in op.control_inputs:
550      # pylint: disable=protected-access
551      is_internal_op = False
552      ctxt = x._get_control_flow_context()
553      while ctxt is not None:
554        if ctxt == self:
555          is_internal_op = True
556          break
557        ctxt = ctxt._outer_context
558      if is_internal_op:
559        internal_control_inputs.append(x)
560      else:
561        external_control_inputs.append(x)
562      # pylint: enable=protected-access
563    # pylint: disable=protected-access
564    op._remove_all_control_inputs()
565    op._add_control_inputs(internal_control_inputs)
566    # pylint: enable=protected-access
567    return internal_control_inputs, external_control_inputs
568
569  def AddOp(self, op: ops.Operation) -> None:
570    # pylint: disable=protected-access
571    if op.type in _DENYLISTED_OPS:
572      logging.error("Operation of type %s (%s) is not supported on the TPU. "
573                    "Execution will fail if this op is used in the graph. ",
574                    op.type, op.name)
575
576    if op.type in _UNSUPPORTED_OPS:
577      self._unsupported_ops.append(op)
578
579    if any(x.dtype._is_ref_dtype for x in op.inputs):
580      raise NotImplementedError(
581          "Non-resource Variables are not supported inside TPU computations "
582          "(operator name: %s)" % op.name)
583
584    # TensorFlowOpLayer may clone nodes that are in tpu.rewrite()s. It'll add
585    # the "_cloned" attribute and we should continue in that case.
586    if (_TPU_REPLICATE_ATTR in op.node_def.attr and
587        "_cloned" not in op.node_def.attr):
588      raise ValueError("TPU computations cannot be nested on op (%s)" %
589                       op)
590    op._set_attr_with_buf(_TPU_REPLICATE_ATTR,
591                          self._tpu_relicate_attr_buf.buffer)
592    if self._outside_compilation_cluster:
593      op._set_attr(
594          _OUTSIDE_COMPILATION_ATTR,
595          attr_value_pb2.AttrValue(
596              s=compat.as_bytes(self._outside_compilation_cluster)))
597    if self._num_replicas > 1 or not self._outside_compilation_cluster:
598      # Prevent feeding or fetching anything that is being compiled,
599      # and any replicated outside_compilation Op.
600      op.graph.prevent_feeding(op)
601      op.graph.prevent_fetching(op)
602
603    # Remove any control edges from outer control flow contexts. These may cause
604    # mismatched frame errors.
605    (internal_control_inputs,
606     external_control_inputs) = self._RemoveExternalControlEdges(op)
607
608    if not op.inputs:
609      # Add a control edge from the control pivot to this op.
610      if not internal_control_inputs:
611        # pylint: disable=protected-access
612        op._add_control_input(self.GetControlPivot())
613        # pylint: enable=protected-access
614    else:
615      for index in xrange(len(op.inputs)):
616        x = op.inputs[index]
617        real_x = self.AddValue(x)
618        if real_x is not x:
619          op._update_input(index, real_x)  # pylint: disable=protected-access
620
621    if external_control_inputs:
622      # Use an identity to pull control inputs as data inputs. Note that we
623      # ignore ops which don't have outputs. TODO(phawkins): fix that.
624      with ops.control_dependencies(None):
625        self.Enter()
626        external_control_inputs = [
627            array_ops.identity(x.outputs[0]).op
628            for x in external_control_inputs
629            if x.outputs
630        ]
631        self.Exit()
632      # pylint: disable=protected-access
633      op._add_control_inputs(external_control_inputs)
634      # pylint: enable=protected-access
635
636    # Mark op's outputs as seen by this context and any outer contexts.
637    output_names = [x.name for x in op.outputs]
638    context = self
639    while context is not None:
640      # pylint: disable=protected-access
641      context._values.update(output_names)
642      context = context._outer_context
643      # pylint: enable=protected-access
644
645    if self._outer_context:
646      self._outer_context.AddInnerOp(op)
647
648  def AddValue(self, val: core_types.Tensor) -> core_types.Tensor:
649    """Add `val` to the current context and its outer context recursively."""
650    if not self._outer_context:
651      return val
652
653    if val.name in self._values:
654      # Use the real value if it comes from outer context.
655      result = self._external_values.get(val.name)
656      return val if result is None else result
657
658    result = val
659    self._values.add(val.name)
660    if self._outer_context:
661      result = self._outer_context.AddValue(val)
662      self._values.add(result.name)
663
664    self._external_values[val.name] = result
665
666    return result
667
668  def AddInnerOp(self, op: ops.Operation):
669    self.AddOp(op)
670    if self._outer_context:
671      self._outer_context.AddInnerOp(op)
672
673  @property
674  def grad_state(self):
675    # Define the gradient loop state associated with the TPUReplicateContext to
676    # be None as the TPUReplicateContext does not get nested nor does the
677    # grad_state outside the TPUReplicateContext affect the graph inside so the
678    # grad_state should be as if this is the top-level gradient state.
679    return None
680
681  @property
682  def back_prop(self):
683    """Forwards to the enclosing while context, if any."""
684    if self.GetWhileContext():
685      return self.GetWhileContext().back_prop
686    return False
687
688  def GetControlPivot(self) -> ops.Operation:
689    return self._pivot
690
691  def RequiresUniqueFunctionRetracing(self):
692    # More context: b/158152827. TPU stack uses the TPUReplicateContext to
693    # create replicated variable handles and cluster TPU computations, thus we
694    # always retrace a tf.function when the wrapped TPUReplicateContext changes.
695    return True
696
697
698class OutsideCompilationV2Context(control_flow_ops.ControlFlowContext):
699  """The context for outside compilation in Tensorflow 2.0.
700
701  Every op added in this context will be assigned an _xla_outside_compilation
702  attribute.
703  """
704
705  def __init__(self, name: Text):
706    control_flow_ops.ControlFlowContext.__init__(self)
707    self._name = name
708
709  def AddOp(self, op: ops.Operation) -> None:
710    if self._outer_context:
711      self._outer_context.AddOp(op)
712    # pylint: disable=protected-access
713    op._set_attr("_xla_outside_compilation",
714                 attr_value_pb2.AttrValue(s=compat.as_bytes(self._name)))
715    # pylint: enable=protected-access
716
717  def AddInnerOp(self, op: ops.Operation) -> None:
718    if self._outer_context:
719      self._outer_context.AddInnerOp(op)
720    # pylint: disable=protected-access
721    op._set_attr("_xla_outside_compilation",
722                 attr_value_pb2.AttrValue(s=compat.as_bytes(self._name)))
723    # pylint: enable=protected-access
724
725  def to_control_flow_context_def(self, context_def, export_scope=None):
726    raise NotImplementedError("to_control_flow_context_def not implemented")
727
728
729@tf_export(v1=["tpu.outside_compilation"])
730def outside_compilation(
731    computation: Callable[..., Any], *args, **kwargs
732    ) -> Any:
733  """Builds part of a computation outside any current TPU replicate scope.
734
735  `tf.tpu.outside_compilation()` is used to run ops in `computation` on CPU
736  instead of running on TPU. For example, users can run ops that are not
737  supported on TPU's (e.g. tf.summary.write()) by explicitly placing those
738  ops on CPU's. Below usage of outside compilation will place ops in
739  `computation_with_string_ops` on CPU.
740
741  Example usage:
742
743  ```python
744  def computation_with_string_ops(x):
745    # strings types are not supported on TPU's and below ops must
746    # run on CPU instead.
747    output = tf.strings.format('1{}', x)
748    return tf.strings.to_number(output)
749
750  def tpu_computation():
751    # Expected output is 11.
752    output = tf.tpu.outside_compilation(computation_with_string_ops, 1)
753  ```
754
755  Outside compilation should be called inside TPUReplicateContext. That is,
756  `tf.tpu.outside_compilation()` should be called inside a function that is
757  passed to `tpu.split_compile_and_replicate()` -- this is implied when
758  outside compilation is invoked inside a function passed to TPUStrategy
759  `run()`. If invoked outside of TPUReplicateContext,
760  then this simply returns the result of `computation`, and therefore,
761  would be a no-op. Note that outside compilation is different from
762  `tf.distribute.experimental.TPUStrategy.merge_call()` as logic in
763  outside compilation is replicated and executed separately for each
764  replica. On the other hand, `merge_call()` requires a `merge_fn`
765  to aggregate the inputs from different replicas and is executed only
766  once.
767
768  For variables placed in TPU device, which includes variables created inside
769  TPUStrategy scope, outside compilation logic must not include variable
770  read/write. For variables placed on host, which is the case when variables
771  created via TPUEstimator, variable read/write is only allowed if the variable
772  is not accessed by any other ops in the TPU computation. Variable read/write
773  from outside compilation cluster is not visible from TPU computation and
774  vice versa. Therefore, if outside compilation logic contains such host
775  variables read/write ops and if the variables are accessed by TPU
776  computation as well, then this may lead to deadlock.
777
778  Internally, `tf.tpu.outside_compilation()` adds outside compilation
779  attributes to all ops in `computation`. During later graph pass, these
780  ops with outside compilation attribute is extracted out and replicated
781  into a host-side graph. Inputs to this extract host-side graph is sent
782  from TPU computation graph to host graph via a pair of XlaSendToHost and
783  XlaRecvFromHost ops. Note that using `tf.tpu.outside_compilation()`
784  may result in tensor transfer between TPU and CPU, leading to non-trivial
785  performance impact.
786
787  Args:
788    computation: A Python function that builds the computation to
789      place on the host.
790    *args: the positional arguments for the computation.
791    **kwargs: the keyword arguments for the computation.
792
793  Returns:
794    The Tensors returned by computation.
795  """
796  args = [] if args is None else args
797  graph = ops.get_default_graph()
798
799  # If we are in TF 2 functions (control flow V2 functions, or tf.function()),
800  # we need to attach _xla_outside_compilation attribute directly because we are
801  # not in TPUReplicateContext.
802  if isinstance(graph, func_graph.FuncGraph):
803    try:
804      tpu_context, _ = _enclosing_tpu_context_and_graph()
805    except ValueError:
806      logging.warning(
807          "Outside compilation attempted outside TPUReplicateContext "
808          "scope. As no enclosing TPUReplicateContext can be found, "
809          "returning the result of `computation` as is.")
810      return computation(*args, **kwargs)
811
812    # pylint: disable=protected-access
813    outside_compilation_name = str(tpu_context._outside_compilation_counter)
814    tpu_context._outside_compilation_counter = (
815        tpu_context._outside_compilation_counter + 1)
816    # pylint: enable=protected-access
817
818    outside_compilation_context = OutsideCompilationV2Context(
819        outside_compilation_name)
820    outside_compilation_context.Enter()
821    args = [] if args is None else args
822    retval = computation(*args, **kwargs)
823    outside_compilation_context.Exit()
824    return retval
825
826  # If we are in a TPUReplicateContext, signal that we are now
827  # outside_compilation
828  initial_context = graph._get_control_flow_context()  # pylint: disable=protected-access
829  context = initial_context
830  while context:
831    if isinstance(context, TPUReplicateContext):
832      context._EnterOutsideCompilationScope()  # pylint: disable=protected-access
833    context = context.outer_context
834
835  retval = computation(*args, **kwargs)
836
837  # If we are in a TPUReplicateContext, signal that we are no longer
838  # outside_compilation
839  final_context = graph._get_control_flow_context()  # pylint: disable=protected-access
840  if initial_context is not final_context:
841    raise NotImplementedError(
842        "Control-flow context cannot be different at start and end of an "
843        "outside_compilation scope")
844  context = initial_context
845  while context:
846    if isinstance(context, TPUReplicateContext):
847      context._ExitOutsideCompilationScope()  # pylint: disable=protected-access
848    context = context.outer_context
849
850  return retval
851
852
853@tf_export(v1=["tpu.PaddingSpec"])
854class PaddingSpec(enum.IntEnum):
855  """Represents the type of padding policies for tpu.replicate."""
856  # By default the policy is set to AUTO, the dynamic input shape dimension will
857  # be pad to maximum of all the replicas.
858  AUTO = 0
859  # Bucketize the dynamic input shape dimension into a power of 2.
860  POWER_OF_TWO = 1
861
862
863@tf_export(v1=["tpu.XLAOptions"])
864class XLAOptions(
865    collections.namedtuple("XLAOptions", [
866        "use_spmd_for_xla_partitioning",
867    ])):
868  """XLA compilation options.
869
870  Attributes:
871    use_spmd_for_xla_partitioning: Boolean. Whether to use XLA's SPMD
872      partitioner instead of MPMD partitioner when compiler partitioning is
873      requested.
874  """
875
876  def __new__(cls, use_spmd_for_xla_partitioning=True):
877    return super(XLAOptions, cls).__new__(cls, use_spmd_for_xla_partitioning)
878
879
880@tf_export(v1=["tpu.replicate"])
881def replicate(
882    computation: Callable[..., Any],
883    inputs: Optional[List[List[core_types.Tensor]]] = None,
884    infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
885    device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
886    name: Optional[Text] = None,
887    maximum_shapes: Any = None,
888    padding_spec: Optional[PaddingSpec] = None,
889    xla_options: Optional[XLAOptions] = None) -> List[Any]:
890  """Builds a graph operator that runs a replicated TPU computation.
891
892  Example for the basic usage that `inputs` has static shape:
893
894  ```python
895
896  def computation(x):
897    x = x + 1
898    return tf.math.reduce_mean(x)
899
900  x = tf.convert_to_tensor([1., 2., 3.])
901  y = tf.convert_to_tensor([4., 5., 6.])
902  tf.compat.v1.tpu.replicate(computation, inputs=[[x], [y]])
903  ```
904
905  If the `inputs` has dynamic shapes and you would like to automatically
906  bucketize the inputs to avoid XLA recompilation. See the advanced example
907  below:
908
909  ```python
910
911  def computation(x):
912    x = x + 1
913    return tf.math.reduce_mean(x)
914
915  # Assume input tensors in two replicas `x` and `y` both have dynamic shape
916  # ([None, 2]).
917  tf.compat.v1.tpu.replicate(
918    computation,
919    inputs=[x, y],
920    maximum_shapes=[tf.TensorShape([None, None])],
921    padding_spec=tf.compat.v1.tpu.PaddingSpec.POWER_OF_TWO)
922  ```
923
924  Args:
925    computation: A Python function that builds the computation to replicate.
926    inputs: A list of lists of input tensors or `None` (equivalent to
927      `[[]]`), indexed by `[replica_num][input_num]`. All replicas must
928      have the same number of inputs. Each input can be a nested structure
929      containing values that are convertible to tensors. Note that passing an
930      N-dimension list of compatible values will result in a N-dimension list of
931      scalar tensors rather than a single Rank-N tensors. If you need different
932      behavior, convert part of inputs to tensors with `tf.convert_to_tensor`.
933    infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
934      of arguments as inputs to computation.
935    device_assignment: If not `None`, a `DeviceAssignment` describing the
936      mapping between logical cores in the computation with physical cores in
937      the TPU topology. Uses a default device assignment if `None`. The
938      `DeviceAssignment` may be omitted if each replica of the computation uses
939      only one core, and there is either only one replica, or the number of
940      replicas is equal to the number of cores in the TPU system.
941    name: (Deprecated) Does nothing.
942    maximum_shapes: A nested structure of tf.TensorShape representing the shape
943      to which the respective component of each input element in each replica
944      should be padded. Any unknown dimensions (e.g.
945      tf.compat.v1.Dimension(None) in a tf.TensorShape or -1 in a tensor-like
946      object) will be padded to the maximum size of that dimension over all
947      replicas. The structure of `maximum_shapes` needs to be the same as
948      `inputs[0]`.
949    padding_spec: An enum specified by `tpu.PaddingSpec`. This describes the
950      padding policy when the `inputs` to `tpu.replicate` is dynamic.
951      One usage is to enable automatic bucketizing on the inputs by setting the
952      value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the
953      recompilation in the XLA side.
954    xla_options: An instance of `tpu.XLAOptions` which indicates the options
955      passed to XLA compiler. Use `None` for default options.
956  Returns:
957    A list of outputs, indexed by `[replica_num]` each output can be a nested
958    structure same as what computation() returns with a few exceptions.
959
960    Exceptions include:
961      1) None output: a NoOp would be returned which control-depends on
962         computation.
963      2) Single value output: A tuple containing the value would be returned.
964      3) Operation-only outputs: a NoOp would be returned which
965         control-depends on computation.
966      TODO(b/121383831): Investigate into removing these special cases.
967
968  Raises:
969    ValueError: If all replicas do not have equal numbers of input tensors.
970    ValueError: If the number of inputs per replica does not match
971      the number of formal parameters to `computation`.
972    ValueError: If the static `inputs` dimensions don't match with the values
973      given in `maximum_shapes`.
974    ValueError: If the structure of inputs per replica does not match
975      the structure of `maximum_shapes`.
976  """
977  return split_compile_and_replicate(
978      computation,
979      inputs,
980      infeed_queue,
981      device_assignment,
982      name,
983      maximum_shapes=maximum_shapes,
984      padding_spec=padding_spec,
985      xla_options=xla_options)[1]
986
987
988def _ceil_to_pow_of_n(x, n):
989  """Ceil input `x` to power of `n`."""
990  x = math_ops.cast(x, dtypes.float32)
991  lognx = math_ops.log(x) / math_ops.log(n * 1.0)
992  lognx = math_ops.ceil(lognx)
993  result = math_ops.pow(n * 1.0, lognx)
994  result = math_ops.cast(result, dtypes.int32)
995  return result
996
997
998def _pad_all_input(
999    inputs: Iterable[core_types.Tensor],
1000    padded_shapes: List[Optional[tensor_shape.TensorShape]],
1001    padding_spec: PaddingSpec
1002) -> Tuple[List[List[Any]], List[dynamic_padding.PaddingMap]]:
1003  """Pad all input tensors given padded_shapes.
1004
1005  The real shape tensors will be concatenated with the padded original inputs.
1006
1007  Args:
1008    inputs: The original inputs.
1009    padded_shapes: A list of padded shapes for each input. If an entry is None,
1010      no padding is performed.
1011    padding_spec: An enum specified by `tpu.PaddingSpec`. This describes the
1012      padding policy when the `inputs` to `tf.tpu.replicate` is dynamic.
1013      One usage is to enable automatic bucketizing on the inputs by setting the
1014      value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the
1015      recompilation in the XLA side.
1016
1017  Returns:
1018    The padded inputs and a PaddingMap list which maps the padded input
1019    dimension to the real shape argument index.
1020  """
1021  # maximum_static_shapes[idx][i] indicates the maximum static size of ith
1022  # dimension of the idx input among all the replicas.
1023  maximum_static_shapes = []
1024  # need_padding[idx][i] indicates whether the ith dimension of the idx input
1025  # needs padding.
1026  need_padding = []
1027  input_shape_tensors = []
1028  for core_idx, inputs_per_core in enumerate(inputs):
1029    for idx, input_tensor in enumerate(inputs_per_core):
1030      input_shape = input_tensor.get_shape().as_list()
1031      if core_idx == 0:
1032        input_shape_tensors.append([])
1033        maximum_static_shapes.append(input_shape)
1034        need_padding.append(np.full_like(input_shape, False, dtype=bool))
1035      else:
1036        for i, s in enumerate(input_shape):
1037          if s is None or s != maximum_static_shapes[idx][i]:
1038            need_padding[idx][i] = True
1039        maximum_static_shapes[idx] = max(input_shape,
1040                                         maximum_static_shapes[idx])
1041
1042      # Append _POST_DEVICE_REWRITE_ATTR attributes to the real shape ops.
1043      real_input_shape = array_ops.shape(input_tensor)
1044      real_input_shape.op._set_attr(  # pylint: disable=protected-access
1045          _POST_DEVICE_REWRITE_ATTR,
1046          attr_value_pb2.AttrValue(b=True))
1047      input_shape_tensors[idx].append(real_input_shape)
1048
1049  maximum_shapes = []
1050  for shapes_per_input in input_shape_tensors:
1051    maximum_shapes.append(
1052        math_ops.reduce_max(array_ops.stack(shapes_per_input), axis=0))
1053
1054  padded_inputs = []
1055  real_shapes = []
1056  padding_maps = []
1057  for core_idx, inputs_per_core in enumerate(inputs):
1058    padded_inputs.append([])
1059    real_shapes.append([])
1060    real_shape_idx = len(inputs_per_core) - 1
1061    for idx, input_tensor in enumerate(inputs_per_core):
1062      input_shape_tensor = input_shape_tensors[idx][core_idx]
1063      input_shape = input_tensor.get_shape().as_list()
1064      padded_shape = padded_shapes[idx]
1065
1066      # If we have no padded_shape, then skip padding.
1067      if any(need_padding[idx]) and padded_shape is not None:
1068        for i, s in enumerate(input_shape):
1069          if need_padding[idx][i]:
1070            if core_idx == 0:
1071              real_shape_idx += 1
1072              padding_map = dynamic_padding.PaddingMap()
1073              padding_map.arg_index = idx
1074              padding_map.shape_index = i
1075              padding_map.padding_arg_index = real_shape_idx
1076              padding_maps.append(padding_map)
1077            real_shapes[core_idx].append(
1078                math_ops.cast(input_shape_tensor[i], dtypes.int32))
1079
1080        paddings = []
1081        for i, s in enumerate(padded_shape.dims):
1082          if need_padding[idx][i]:
1083            # The minimum padded dimension size is 2 as XLA doesn't support size
1084            # 1 dynamic size.
1085            minimum_dynamic_dim_size = 2
1086            if s.value is not None:
1087              # Pad to the given maximum value.
1088              max_dim_size = max(s.value, minimum_dynamic_dim_size)
1089            else:
1090              # If maximum value is not given, then pad to the maximum dimension
1091              # among all the cores.
1092              max_dim_size = math_ops.maximum(maximum_shapes[idx][i],
1093                                              minimum_dynamic_dim_size)
1094              if padding_spec == PaddingSpec.POWER_OF_TWO:
1095                max_dim_size = _ceil_to_pow_of_n(max_dim_size, 2)
1096            # Pad to the given maximum value.
1097            padding = [0, max_dim_size - input_shape_tensor[i]]
1098          else:
1099            padding = [0, 0]
1100          paddings.append(padding)
1101
1102        if input_tensor.get_shape().is_fully_defined():
1103          # TODO(rxsang): This is a hack to make sure padded_input has dynamic
1104          # shapes, so any tf.size/tf.shape op performed on it won't be constant
1105          # folded. Do we have better ways to do it?
1106          padded_input = control_flow_ops.cond(
1107              array_ops.constant(True),
1108              lambda: array_ops.pad(input_tensor, paddings),  # pylint: disable=cell-var-from-loop
1109              lambda: input_tensor)
1110        else:
1111          padded_input = array_ops.pad(input_tensor, paddings)
1112
1113        # Append _POST_DEVICE_REWRITE_ATTR attributes to all padded inputs.
1114        padded_input.op._set_attr(  # pylint: disable=protected-access
1115            _POST_DEVICE_REWRITE_ATTR,
1116            attr_value_pb2.AttrValue(b=True))
1117
1118        padded_inputs[core_idx].append(padded_input)
1119      else:
1120        padded_inputs[core_idx].append(input_tensor)
1121
1122  num_replicas = len(padded_inputs)
1123  for i in range(num_replicas):
1124    padded_inputs[i].extend(real_shapes[i])
1125
1126  return padded_inputs, padding_maps
1127
1128
1129def _flatten_and_filter_composite(maybe_composite, non_composite_output,
1130                                  composite_output=None):
1131  """For an input, replaced the input by a tuple if the input is composite.
1132
1133  If `maybe_composite` is not composite, return the parameter
1134  `non_composite_output` otherwise return a tuple which consists of the value of
1135  the parameter `composite_output` the same number of times as there are
1136  components of the composite tensor.
1137
1138  This is useful for computing a mask when flattening nested data with
1139  `expand_composites=True`. For example
1140
1141  ```python
1142  nest.flatten(data, expand_composites=True)
1143  ```
1144
1145  and
1146
1147  ```python
1148  nest.flatten(nest.map(
1149      data, lambda x: _flatten_and_filter_composite(x, False, True)))
1150  ```
1151
1152  will have the same length and second will be True if the tensor in the first
1153  is derived from a expanding a composite tensor.
1154
1155  Args:
1156    maybe_composite: A value to test for being a composite tensor.
1157    non_composite_output: The value to return when `maybe_composite` is not a
1158      composite.
1159    composite_output: the value to fill the output tuple with if
1160      `maybe_composite` is a composite.
1161
1162  Returns:
1163    `non_composite_output` or a tuple with multiple copies of
1164    `composite_output`.
1165  """
1166
1167  if isinstance(maybe_composite, composite_tensor.CompositeTensor):
1168    num_components = len(
1169        maybe_composite._type_spec._to_components(maybe_composite))  # pylint: disable=protected-access
1170    return (composite_output,) * num_components
1171  return non_composite_output
1172
1173
1174def split_compile_and_replicate(
1175    computation: Callable[..., Any],
1176    inputs: Optional[List[List[core_types.Tensor]]] = None,
1177    infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
1178    device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
1179    name: Optional[Text] = None,
1180    use_tpu: bool = True,
1181    maximum_shapes: Any = None,
1182    padding_spec: Optional[PaddingSpec] = None,
1183    xla_options: Optional[XLAOptions] = None,
1184) -> List[List[core_types.Tensor]]:
1185  """Builds graph operators that runs compilation and replicated computation.
1186
1187  This is a lower level interface than replicate that returns a separate compile
1188  and execute output tensor. In the generated graph the compile op feeds into
1189  the execute op and no additional compilation is incurred when running the
1190  compile op before the execute op. The compile op returns additional
1191  information about the compilation but does not return the compiled program.
1192
1193  Args:
1194    computation: A Python function that builds the computation to replicate.
1195    inputs: A list of lists of input tensors or `None` (equivalent to
1196      `[[]]`), indexed by `[replica_num][input_num]`. All replicas must
1197      have the same number of inputs. Each input can be a nested structure
1198      containing values that are convertible to tensors. Note that passing an
1199      N-dimension list of compatible values will result in a N-dimension list of
1200      scalar tensors rather than a single Rank-N tensors. If you need different
1201      behavior, convert part of inputs to tensors with `tf.convert_to_tensor`.
1202    infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
1203      of arguments as inputs to computation.
1204    device_assignment: If not `None`, a `DeviceAssignment` describing the
1205      mapping between logical cores in the computation with physical cores in
1206      the TPU topology. Uses a default device assignment if `None`. The
1207      `DeviceAssignment` may be omitted if each replica of the computation uses
1208      only one core, and there is either only one replica, or the number of
1209      replicas is equal to the number of cores in the TPU system.
1210    name: (Deprecated) Does nothing.
1211    use_tpu: When false, the input `computation` is executed on the XLA CPU/GPU
1212      backends. Currently, only supports a default placement (computation is
1213      placed on GPU if one is available, and on CPU if not).
1214    maximum_shapes: A nested structure of tf.TensorShape representing the shape
1215      to which the respective component of each input element in each replica
1216      should be padded. Any unknown dimensions (e.g.
1217      tf.compat.v1.Dimension(None) in a tf.TensorShape or -1 in a tensor-like
1218      object) will be padded to the maximum size of that dimension over all
1219      replicas. The structure of `maximum_shapes` needs to be the same as
1220      `inputs[0]`.
1221    padding_spec: An enum specified by `tf.tpu.PaddingSpec`. This describes the
1222      padding policy when the `inputs` to `tf.tpu.replicate` is dynamic.
1223      One usage is to enable automatic bucketizing on the inputs by setting the
1224      value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the
1225      recompilation in the XLA side.
1226    xla_options: An instance of `tpu.XLAOptions` which indicates the options
1227      passed to XLA compiler. Use `None` for default options.
1228
1229  Returns:
1230    A list of lists with the first list corresponding to the compile op and the
1231    second a list of output tensors, indexed by `[replica_num][output_num]`.
1232  Raises:
1233    ValueError: If all replicas do not have equal numbers of input tensors.
1234    ValueError: If the number of inputs per replica does not match
1235      the number of formal parameters to `computation`.
1236    ValueError: If the static `inputs` dimensions don't match with the values
1237      given in `maximum_shapes`.
1238    ValueError: If the structure of inputs per replica does not match
1239      the structure of `maximum_shapes`.
1240  """
1241  del name
1242  inputs = [[]] if inputs is None else inputs
1243  xla_options = xla_options or XLAOptions()
1244
1245  metadata_kwargs = {}
1246  if device_assignment is not None:
1247    # Turn the Numpy array into a flattened list so we can pass it as an
1248    # operator attribute.
1249    metadata_kwargs = {
1250        "topology":
1251            device_assignment.topology.serialized(),
1252        "device_assignment":
1253            device_assignment.core_assignment.flatten().tolist()
1254    }
1255    metadata_kwargs["num_cores_per_replica"] = (
1256        device_assignment.num_cores_per_replica)
1257
1258  # This entry is used for enabling automatic outside compilation.
1259  metadata_kwargs["allow_soft_placement"] = config.get_soft_device_placement()
1260  if config.get_soft_device_placement():
1261    logging.info("Automatic outside compilation is enabled. "
1262                 "Ops without XLA kernels will be automatically "
1263                 "placed on CPU.")
1264
1265  if ((not isinstance(inputs, list)) or
1266      any(not isinstance(inp, (list, tuple)) for inp in inputs)):
1267    raise TypeError("tpu.replicate() inputs must be a list of lists/tuples")
1268
1269  num_replicas = len(inputs)
1270
1271  # No replicas? Nothing to do.
1272  if num_replicas == 0:
1273    return []
1274
1275  # Checks all replicas have the same structure.
1276  for i in xrange(1, num_replicas):
1277    nest.assert_same_structure(inputs[0], inputs[i])
1278
1279  # Flatten inputs. This structure may contain None values, which will be
1280  # handled later.
1281  flat_inputs_with_nones = [
1282      nest.flatten(per_replica_input, expand_composites=True)
1283      for per_replica_input in inputs
1284  ]
1285  # Mask parallel to one replica's inputs with True for tensors coming from
1286  # composites.
1287  is_composite = nest.flatten(nest.map_structure(
1288      lambda x: _flatten_and_filter_composite(x, False, True), inputs[0]))
1289
1290  # Converts inputs to Tensors, replacing Nones with a placeholder 0 since
1291  # tpu_ops.tpu_replicated_input() can't handle non-Tensor values.
1292  flat_inputs = []
1293  for inp in flat_inputs_with_nones:
1294    flat_inputs.append([
1295        constant_op.constant(0) if x is None else ops.convert_to_tensor(x)
1296        for x in inp
1297    ])
1298
1299  # Verifies that all replicas have matching numbers and types of inputs
1300  flat_input_types = [x.dtype for x in flat_inputs[0]]
1301  input_arity = len(inputs[0])
1302  flat_input_arity = len(flat_input_types)
1303  for i in range(num_replicas):
1304    if len(inputs[i]) != input_arity:
1305      raise ValueError("Replicas must have the same number of inputs. "
1306                       "Replica 0 had {} inputs, replica {} had {} "
1307                       "inputs.".format(input_arity, i, len(inputs[i])))
1308
1309    types = [x.dtype for x in flat_inputs[i]]
1310    if types != flat_input_types:
1311      raise ValueError("Replicas must have matching input types. Replica 0 had "
1312                       "input types {}, replica {} had input types {}".format(
1313                           flat_input_types, i, types))
1314
1315  arg_error = xla.check_function_argument_count(
1316      computation, input_arity, infeed_queue)
1317  if arg_error is not None:
1318    if infeed_queue is None:
1319      raise TypeError(
1320          "Supplied computation cannot be called with the specified inputs. "
1321          "You specified %d inputs: %s, but the computation needs %s" % (
1322              input_arity, str([i.name for i in inputs[0]]), arg_error))
1323    else:
1324      raise TypeError(
1325          "Supplied computation cannot be called with the specified inputs. "
1326          "You specified %d inputs: %s and %d additional inputs from infeed,"
1327          " but the computation needs %s" % (input_arity, str(
1328              [i.name
1329               for i in inputs[0]]), infeed_queue.number_of_tuple_elements,
1330                                             arg_error))
1331
1332  dynamic_shape_inputs = False
1333  if maximum_shapes:
1334    if infeed_queue:
1335      raise ValueError(
1336          "Dynamic input shapes are not supported with infeed queues")
1337
1338    # Make sure maximum_shapes has the same structure as inputs.
1339    nest.assert_same_structure(inputs[0], maximum_shapes, check_types=False)
1340
1341    # Flatten padded shapes:
1342    # For composite tensor components, we don't want to pad them. For each
1343    # entry of maximum_shapes that corresponds to a composite tensor, replace it
1344    # by a tuple of Nones of the same length as the number of components of the
1345    # composite tensor. When we flatten a second time, this makes
1346    # flat_maximum_shapes have the same length as flat_inputs[i]. We can then
1347    # avoid padding these tensors. The assumption is that they will be used by
1348    # outside compilation or that the components are statically shaped and will
1349    # be used by tpu compatible ops.
1350    flat_maximum_shapes = nest.flatten(
1351        [_flatten_and_filter_composite(x, y)
1352         for x, y in zip(nest.flatten(inputs[0]),
1353                         nest.flatten(maximum_shapes))])
1354    flat_maximum_shapes = [
1355        tensor_shape.TensorShape(s) if s is not None else None
1356        for s in flat_maximum_shapes
1357    ]
1358    nest.assert_same_structure(flat_inputs[0], flat_maximum_shapes,
1359                               check_types=False)
1360
1361    unpadded_inputs = flat_inputs
1362    flat_inputs, padding_maps = _pad_all_input(unpadded_inputs,
1363                                               flat_maximum_shapes,
1364                                               padding_spec)
1365    if padding_maps:
1366      dynamic_shape_inputs = True
1367      logging.info("TPU has inputs with dynamic shapes: %s", unpadded_inputs[0])
1368
1369  metadata_kwargs["step_marker_location"] = getattr(
1370      computation, "step_marker_location", "STEP_MARK_AT_ENTRY")
1371  metadata_kwargs["use_spmd_for_xla_partitioning"] = \
1372      xla_options.use_spmd_for_xla_partitioning
1373
1374  graph = ops.get_default_graph()
1375
1376  # Fan-in: Builds a TPUReplicatedInput node for each input.
1377  flat_replicated_inputs = []
1378  for i in range(0, len(flat_inputs[0])):
1379    replicas = [flat_inputs[replica][i] for replica in xrange(num_replicas)]
1380    flat_replicated_inputs.append(
1381        tpu_ops.tpu_replicated_input(
1382            replicas, name="input{}".format(i), index=i))
1383  if isinstance(graph, func_graph.FuncGraph):
1384    # When we are in Tensorflow 2.0 function, 'graph' will be a FuncGraph
1385    # object. If both outside graph and this function have a TPU cluster,
1386    # they will have the same cluster name and it will cause problems (because
1387    # we lower functional ops in Tensorflow 2.0). Append function name to
1388    # 'cluster_name' to avoid cluster name collision.
1389    cluster_name = graph.unique_name("cluster_" + graph.name)
1390  else:
1391    cluster_name = graph.unique_name("cluster")
1392  pivot = control_flow_ops.no_op(name=cluster_name + "/pivot")
1393  pivot._set_attr(_PIVOT_FOR_CLUSTER,  # pylint: disable=protected-access
1394                  attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name)))
1395  context = TPUReplicateContext(
1396      name=cluster_name, num_replicas=num_replicas, pivot=pivot)
1397  try:
1398    context.Enter()
1399
1400    metadata = tpu_ops.tpu_replicate_metadata(
1401        num_replicas=num_replicas, use_tpu=use_tpu, **metadata_kwargs)
1402
1403    with tpu_function.tpu_shard_context(
1404        num_replicas), ops.control_dependencies([metadata]):
1405
1406      if dynamic_shape_inputs:
1407        for padding_map in padding_maps:
1408          input_shape = flat_replicated_inputs[padding_map.arg_index].shape
1409          flat_replicated_inputs[
1410              padding_map.arg_index] = tf2xla.set_dynamic_dimension_size(
1411                  flat_replicated_inputs[padding_map.arg_index],
1412                  padding_map.shape_index,
1413                  flat_replicated_inputs[padding_map.padding_arg_index])
1414          flat_replicated_inputs[padding_map.arg_index].set_shape(input_shape)
1415
1416      # Add identity ops so even unused inputs are "consumed" by the
1417      # computation. This is to avoid orphaned TPUReplicatedInput nodes.
1418      # TODO(phawkins): consider instead pruning unused TPUReplicatedInput
1419      # and eliding trivial TPUReplicatedInput/TPUReplicatedOutput pairs.
1420      flat_replicated_inputs = [
1421          array_ops.identity(x, name="replicated_input_{}".format(i))
1422          for i, x in enumerate(flat_replicated_inputs)
1423      ]
1424      for i, composite in zip(flat_replicated_inputs, is_composite):
1425        # pylint: disable=protected-access
1426        # Add an attribute to the identity node so that they could be removed in
1427        # encapsulate TPU computation pass if unused. However we don't remove
1428        # inputs when dynamic padding is enabled.
1429        # TODO(rxsang): Use other ways except argument index in padding_map so
1430        # outside compilation can work with dynamic padding correctly.
1431        if not dynamic_shape_inputs or composite:
1432          i.op._set_attr("_tpu_input_identity",
1433                         attr_value_pb2.AttrValue(b=True))
1434        # pylint: enable=protected-access
1435
1436      # Clobber replicated placeholders with Nones.
1437      computation_inputs = [
1438          None if inp is None else replicated for replicated, inp in zip(
1439              flat_replicated_inputs, flat_inputs_with_nones[0])
1440      ]
1441
1442      # Unflatten the computation inputs to match original input structure.
1443      computation_inputs = nest.pack_sequence_as(
1444          structure=inputs[0],
1445          flat_sequence=computation_inputs[:flat_input_arity],
1446          expand_composites=True)
1447
1448      # If there is an infeed queue, adds the dequeued values to the
1449      # computation's inputs.
1450      if infeed_queue is not None:
1451        infeed_queue.set_number_of_shards(num_replicas)
1452        for t in infeed_queue.generate_dequeue_op():
1453          computation_inputs.append(t)
1454
1455      # Only resource variables work inside a TPU computation, so turn on
1456      # resource variables for the computation.
1457      # TODO(phawkins): consider removing this code. It will
1458      # be less confusing to clients if they knowingly choose to use resource
1459      # variables.
1460      # Partitioned variables is not supported (b/112311320).
1461      vscope = variable_scope.get_variable_scope()
1462      saved_use_resource = vscope.use_resource
1463      saved_custom_getter = vscope.custom_getter
1464
1465      def custom_getter(getter, name, *args, **kwargs):
1466        """Variables on TPU have a few restrictions."""
1467        partitioner = kwargs.get("partitioner", None)
1468        if partitioner is not None:
1469          kwargs["partitioner"] = None
1470          logging.warning(
1471              "Partitioned variables are not supported on TPU. Got "
1472              "`partitioner` that is %s for variable %s. "
1473              "Setting `partitioner` to `None`.", partitioner, name)
1474        if saved_custom_getter is None:
1475          return getter(name, *args, **kwargs)
1476        else:
1477          return saved_custom_getter(getter, name, *args, **kwargs)
1478
1479      vscope.set_use_resource(True)
1480      vscope.set_custom_getter(custom_getter)
1481
1482      outputs = computation(*computation_inputs)
1483
1484      vscope.set_use_resource(saved_use_resource)
1485      vscope.set_custom_getter(saved_custom_getter)
1486
1487    outputs_is_flat = xla.is_flat(outputs)
1488    if outputs_is_flat:
1489      output_tensors, control_deps, pack_template = _postprocess_flat_outputs(
1490          outputs)
1491    else:
1492      output_tensors, control_deps, pack_template = (
1493          _postprocess_non_flat_outputs(outputs))
1494
1495    # tensor_tracer imports tpu.py. Local import to tensor_tracer to avoid
1496    # import-cycle
1497    if typing.TYPE_CHECKING:
1498      tensor_tracer = Any
1499    else:
1500      # pylint: disable=g-import-not-at-top
1501      from tensorflow.python.tpu import tensor_tracer
1502      # pylint: enable=g-import-not-at-top
1503    if tensor_tracer.TensorTracer.is_enabled():
1504      tt = tensor_tracer.TensorTracer()
1505      output_tensors = tt.trace_tpu(ops.get_default_graph(),
1506                                    output_tensors, control_deps,
1507                                    num_replicas)
1508
1509    context.ExitResult(output_tensors)
1510  finally:
1511    context.report_unsupported_operations()
1512    context.Exit()
1513    host_compute_core = context.HostComputeCore()
1514
1515  if host_compute_core:
1516    attr_value = attr_value_pb2.AttrValue()
1517    attr_value.list.s.extend(compat.as_bytes(x) for x in host_compute_core)
1518    metadata._set_attr("host_compute_core", attr_value)  # pylint: disable=protected-access
1519
1520  with ops.control_dependencies([metadata]):
1521    if use_tpu:
1522      compile_status = tpu_ops.tpu_compilation_result()
1523      op = compile_status.op
1524      attr_value = attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name))
1525      op._set_attr(_TPU_COMPILATION_STATUS_ATTR, attr_value)  # pylint: disable=protected-access
1526    else:
1527      compile_status = control_flow_ops.no_op(name="compilation_status")
1528
1529  if not output_tensors:
1530    # Returns a list of NoOps dependent on the replication Op, indexed by
1531    # [replica_num].
1532    return [
1533        compile_status,
1534        [
1535            control_flow_ops.group(control_deps, name="shard_%d" % i)
1536            for i in range(num_replicas)
1537        ]
1538    ]
1539
1540  # Fan-out: Builds a TPUReplicatedOutput node for each output.
1541  replicated_outputs = [[] for i in range(num_replicas)]
1542  for i, t in enumerate(output_tensors):
1543
1544    # None values returned by the computation can't be sent to
1545    # tpu_ops.tpu_replicated_output(), we handle them specially here. We can
1546    # avoid the placeholder 0 routine required on the inputs since outputs are
1547    # replicated per-tensor, not per-replica, so we can skip replication.
1548    if t is None:
1549      for replica in range(num_replicas):
1550        replicated_outputs[replica].append(None)
1551      continue
1552
1553    # Fan-out: Builds a TPUReplicatedOutput node for each output.
1554    ys = tpu_ops.tpu_replicated_output(
1555        t, num_replicas, name="output{}".format(i))
1556
1557    # Wraps the outputs in identity operators so the names of any possible
1558    # `fetch` nodes are preserved by the replication rewrite.
1559    with ops.control_dependencies(control_deps):
1560      for replica in range(num_replicas):
1561        replicated_outputs[replica].append(
1562            array_ops.identity(
1563                ys[replica], name="output_%d_shard_%d" % (i, replica)))
1564
1565  replicated_outputs = [
1566      nest.pack_sequence_as(pack_template, replica_outs, expand_composites=True)
1567      for replica_outs in replicated_outputs
1568  ]
1569
1570  return [compile_status, replicated_outputs]
1571
1572
1573def _postprocess_flat_outputs(
1574    outputs: Any
1575) -> Tuple[List[Optional[core_types.Tensor]], List[ops.Operation], List[Any]]:
1576  """Validates non-flat outputs, add backs device assignments and other attrs.
1577
1578  Args:
1579    outputs: Output from `computation` inside `tpu.rewrite`.
1580
1581  Returns:
1582    - Tensors extracted from outputs.
1583    - Operations extracted from outputs.
1584    - A pack template for use with nest.pack_sequence_as to pack the tensors.
1585  """
1586  # Following code segment is to preserve legacy behavior. Previously we only
1587  # supported flat outputs and thus for consistency it was nice to convert even
1588  # single element into a tuple. But now that we support arbitrary output
1589  # structure, this is no longer necessary.
1590  # TODO(b/121383831): Migrate all legacy use cases and delete this special
1591  # case.
1592  # If the computation returns `None`, make it an empty tuple.
1593  if outputs is None:
1594    outputs = tuple()
1595
1596  # For legacy / backwards compatibility reasons we return a list for "flat"
1597  # output values (even if the user's flat return value was a different type or
1598  # even just a scalar value) so use nest.flatten to compute a flat list pack
1599  # template.
1600  pack_template = nest.flatten(outputs, expand_composites=False)
1601
1602  # Even though outputs is already "flat", we flatten any composites so their
1603  # component tensors can be tagged and replicated. The pack_template will be
1604  # used by the caller to repack the composite tensors.
1605  outputs = nest.flatten(outputs, expand_composites=True)
1606
1607  # Append `no_op` here so that fetching any return value of this function
1608  # will trigger TPUExecute node.
1609  outputs += (control_flow_ops.no_op(),)
1610
1611  maybe_convert = lambda x: None if x is None else ops.convert_to_tensor(x)
1612  try:
1613    with ops.device(core(0)):
1614      outputs = [
1615          o if isinstance(o, ops.Operation) else maybe_convert(o)
1616          for o in outputs
1617      ]
1618  except Exception as e:
1619    raise ValueError(
1620        "TPU function return values must all either be Operations or "
1621        "convertible to Tensors. Got '%s'" % str(e))
1622
1623  # Separates the returned Operations and Tensors.
1624  output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
1625  output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)]
1626
1627  if outputs != output_tensors + output_operations:
1628    raise ValueError(
1629        "TPU functions must return zero-or more Tensor values followed by "
1630        "zero or more Operations.")
1631
1632  # Trim operations off the end of the pack template. output_operations has 1
1633  # extra element due to the no-op that is added.
1634  if len(output_operations) > 1:
1635    pack_template = pack_template[:1 - len(output_operations)]
1636
1637  # Wraps outputs in Identity ops. Otherwise a replicated input copied
1638  # straight to an output would bypass the replicate(). This would be bad
1639  # because the TPUReplicatedInput/TPUReplicatedOutput operator would not
1640  # be rewritten away, leading to a runtime error.
1641  # TODO(phawkins): extend the rewrite to elide these nodes instead.
1642  new_output_tensors = []
1643  for t in output_tensors:
1644    if t is None:
1645      new_output_tensors.append(None)
1646    with ops.device(t.device if t.device else core(0)):
1647      o = array_ops.identity(t)
1648      # pylint: disable=protected-access
1649      o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True))
1650      # pylint: enable=protected-access
1651      new_output_tensors.append(o)
1652  return new_output_tensors, output_operations, pack_template
1653
1654
1655def _postprocess_non_flat_outputs(
1656    outputs: Any
1657) -> Tuple[List[Optional[core_types.Tensor]], List[ops.Operation], List[Any]]:
1658  """Validates non-flat outputs, add backs device assignments and other attrs.
1659
1660  Args:
1661    outputs: Output from `computation` inside `tpu.rewrite`.
1662
1663  Returns:
1664    - Tensors extracted from outputs.
1665    - An empty Operations list because Operations are not allowed in non-flat
1666      outputs.
1667    - A pack template for use with nest.pack_sequence_as to pack the tensors.
1668  """
1669
1670  # Flatten output items.
1671  flat_outputs = nest.flatten(outputs, expand_composites=True)
1672
1673  # Convert all non-None non-Operation outputs to Tensors.
1674  for i, o in enumerate(flat_outputs):
1675    if o is None:
1676      flat_outputs[i] = None
1677      continue
1678
1679    if isinstance(o, ops.Operation):
1680      raise ValueError(
1681          "tpu.rewrite does not support Operation as return value in non-flat "
1682          "output structure. You can set returned Operations as control "
1683          "dependencies of returned Tensors so Operations are triggered when "
1684          'Tensors are evaluated. Operation found: "%s"' % o.name)
1685
1686    try:
1687      o = ops.convert_to_tensor(o)
1688    except Exception as e:
1689      raise ValueError(
1690          "TPU function return values must all either be Operations or "
1691          'convertible to Tensors. Got error: "%s"' % str(e))
1692
1693    # Wraps outputs in Identity ops. Otherwise a replicated input copied
1694    # straight to an output would bypass the replicate(). This would be bad
1695    # because the TPUReplicatedInput/TPUReplicatedOutput operator would not
1696    # be rewritten away, leading to a runtime error.
1697    # TODO(phawkins): extend the rewrite to elide these nodes instead.
1698    with ops.device(o.device if o.device else core(0)):
1699      o = array_ops.identity(o)
1700      # pylint: disable=protected-access
1701      o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True))
1702      # pylint: enable=protected-access
1703      flat_outputs[i] = array_ops.identity(o)
1704
1705  # All flat_outputs are Tensors, and no Operations.
1706  return flat_outputs, [], outputs
1707
1708
1709def split_compile_and_shard(
1710    computation: Callable[..., Any],
1711    inputs: List[List[Optional[core_types.Tensor]]] = None,
1712    num_shards: int = 1,
1713    input_shard_axes: Optional[List[int]] = None,
1714    outputs_from_all_shards: Union[bool, List[bool]] = True,
1715    output_shard_axes: Optional[List[int]] = None,
1716    infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
1717    device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
1718    name: Optional[Text] = None,
1719    xla_options: Optional[XLAOptions] = None,
1720    ) -> Tuple[ops.Operation, List[core_types.Tensor]]:
1721  """Shards `computation` for parallel execution.
1722
1723  `inputs` must be a list of Tensors or None (equivalent to an empty list), each
1724  of which has a corresponding split axis (from `input_shard_axes`). Each input
1725  is split into `num_shards` pieces along the corresponding axis, and
1726  computation is applied to each shard in parallel.
1727
1728  Tensors are broadcast to all shards if they are lexically captured by
1729  `computation`. e.g.,
1730
1731  x = tf.constant(7)
1732  def computation():
1733    return x + 3
1734  ... = shard(computation, ...)
1735
1736  If `outputs_from_all_shards` is true, the outputs from all shards of
1737  `computation` are concatenated back together along their `output_shard_axes`.
1738  Otherwise, each output is taken from an arbitrary shard.
1739
1740  Inputs and outputs of the computation must be at least rank-1 Tensors.
1741
1742  Args:
1743    computation: A Python function that builds a computation to apply to each
1744      shard of the input.
1745    inputs: A list of input tensors or None (equivalent to an empty list). Each
1746      input tensor has a corresponding shard axes, given by `input_shard_axes`,
1747      which must have size divisible by `num_shards`.
1748    num_shards: The number of shards.
1749    input_shard_axes: A list of dimensions along which to shard `inputs`, or
1750      `None`. `None` means "shard all inputs along dimension 0". If not `None`,
1751      there must be one dimension per input.
1752    outputs_from_all_shards: Boolean or list of boolean. For each output, if
1753      `True`, outputs from all shards are concatenated along the corresponding
1754      `output_shard_axes` entry. Otherwise, each output is taken
1755      from an arbitrary shard. If the argument is a boolean, the argument's
1756      value is used for each output.
1757    output_shard_axes: A list of dimensions along which to concatenate the
1758      outputs of `computation`, or `None`. `None` means "concatenate all outputs
1759      along dimension 0". If not `None`, there must be one dimension per output.
1760      Ignored if `outputs_from_all_shards` is False.
1761    infeed_queue: If not `None`, the `InfeedQueue` to use to augment the inputs
1762      of `computation`.
1763    device_assignment: If not `None`, a `DeviceAssignment` describing the
1764      mapping between logical cores in the computation with physical cores in
1765      the TPU topology. Uses a default device assignment if `None`. The
1766      `DeviceAssignment` may be omitted if each shard of the computation uses
1767      only one core, and there is either only one shard, or the number of shards
1768      is equal to the number of cores in the TPU system.
1769    name: (Deprecated) Does nothing.
1770    xla_options: An instance of `tpu.XLAOptions` which indicates the options
1771      passed to XLA compiler. Use `None` for default options.
1772  Returns:
1773    A tuple of (compile op, [output tensors]).
1774  Raises:
1775    ValueError: If num_shards <= 0
1776    ValueError: If len(input_shard_axes) != len(inputs)
1777    ValueError: If len(output_shard_axes) != len(outputs from `computation`)
1778  """
1779  # TODO(phawkins): consider adding support for broadcasting Tensors passed as
1780  # inputs.
1781
1782  if num_shards <= 0:
1783    raise ValueError("num_shards must be a positive integer.")
1784
1785  inputs = [] if inputs is None else inputs
1786  if not isinstance(inputs, list):
1787    raise TypeError("tpu.shard()'s inputs must be a list of Tensors or None.")
1788
1789  # Converts inputs to Tensors.
1790  inputs = [ops.convert_to_tensor(x) for x in inputs]
1791
1792  if input_shard_axes is None:
1793    input_shard_axes = [0] * len(inputs)
1794  if len(inputs) != len(input_shard_axes):
1795    raise ValueError("Length of input_shard_axes must be equal to the number "
1796                     "of inputs.")
1797
1798  if inputs:
1799    # Splits the `inputs` along the corresponding `input_shard_axes`, giving
1800    # lists with layout [input][shard]
1801    split_inputs = [
1802        array_ops.split(x, num_shards, axis=axis)
1803        for (axis, x) in zip(input_shard_axes, inputs)]
1804
1805    # Transposes the input lists to have layout [shard][input]
1806    transposed_inputs = [list(i) for i in zip(*split_inputs)]
1807  else:
1808    transposed_inputs = [[]] * num_shards
1809
1810  compile_op, outputs = split_compile_and_replicate(
1811      computation,
1812      transposed_inputs,
1813      infeed_queue=infeed_queue,
1814      device_assignment=device_assignment,
1815      name=name,
1816      xla_options=xla_options)
1817
1818  # There must be at least one shard since num_shards > 0.
1819  # TODO(b/36647078) remove disable when pylint bug is fixed.
1820  # pylint: disable=indexing-exception
1821  if isinstance(outputs[0], ops.Operation):
1822    # pylint: enable=indexing-exception
1823    # There were no outputs from the computation and replicate returned a list
1824    # of NoOps with control dependencies on the computation. Return the first
1825    # one so it can be used as a control dependency or fetch node.
1826    # TODO(b/36647078) remove disable when pylint bug is fixed.
1827    # pylint: disable=indexing-exception
1828    return compile_op, [outputs[0]]
1829    # pylint: enable=indexing-exception
1830
1831  # TODO(b/36647078) remove disable when pylint bug is fixed.
1832  # pylint: disable=indexing-exception
1833  num_outputs = len(outputs[0])
1834  # pylint: enable=indexing-exception
1835
1836  if output_shard_axes is None:
1837    output_shard_axes = [0] * num_outputs
1838  if num_outputs != len(output_shard_axes):
1839    raise ValueError("Length of output_shard_axes must be equal to the number "
1840                     "of outputs.")
1841
1842  if isinstance(outputs_from_all_shards, bool):
1843    outputs_from_all_shards = [outputs_from_all_shards] * num_outputs
1844
1845  if num_outputs != len(outputs_from_all_shards):
1846    raise ValueError("Length of outputs_from_all_shards must be equal to the "
1847                     "number of outputs.")
1848
1849  results = []
1850  for (axis, all_shards, x) in zip(output_shard_axes, outputs_from_all_shards,
1851                                   zip(*outputs)):
1852    if all_shards:
1853      # Concatenate all of the outputs together (use stack for scalars).
1854      shape = x[0].shape
1855      is_scalar = shape is not None and (shape.ndims == 0)
1856      results.append((array_ops.stack(list(x)) if is_scalar
1857                      else array_ops.concat(list(x), axis=axis)))
1858    else:
1859      # TODO(phawkins): use a smarter policy, e.g., round-robin across shards.
1860      results.append(x[0])
1861
1862  return compile_op, results
1863
1864
1865@tf_export(v1=["tpu.shard"])
1866def shard(
1867    computation: Callable[..., Any],
1868    inputs: Optional[List[core_types.Tensor]] = None,
1869    num_shards: int = 1,
1870    input_shard_axes: Optional[List[int]] = None,
1871    outputs_from_all_shards: Union[bool, List[bool]] = True,
1872    output_shard_axes: Optional[List[int]] = None,
1873    infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
1874    device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
1875    name: Optional[Text] = None,
1876    xla_options: Optional[XLAOptions] = None) -> List[core_types.Tensor]:
1877  """Shards `computation` for parallel execution.
1878
1879  `inputs` must be a list of Tensors or None (equivalent to an empty list), each
1880  of which has a corresponding split axis (from `input_shard_axes`). Each input
1881  is split into `num_shards` pieces along the corresponding axis, and
1882  computation is applied to each shard in parallel.
1883
1884  Tensors are broadcast to all shards if they are lexically captured by
1885  `computation`. e.g.,
1886
1887  x = tf.constant(7)
1888  def computation():
1889    return x + 3
1890  ... = shard(computation, ...)
1891
1892  TODO(phawkins): consider adding support for broadcasting Tensors passed
1893  as inputs.
1894
1895  If `outputs_from_all_shards` is true, the outputs from all shards of
1896  `computation` are concatenated back together along their `output_shard_axes`.
1897  Otherwise, each output is taken from an arbitrary shard.
1898
1899  Inputs and outputs of the computation must be at least rank-1 Tensors.
1900
1901  Args:
1902    computation: A Python function that builds a computation to apply to each
1903      shard of the input.
1904    inputs: A list of input tensors or None (equivalent to an empty list). Each
1905      input tensor has a corresponding shard axes, given by `input_shard_axes`,
1906      which must have size divisible by `num_shards`.
1907    num_shards: The number of shards.
1908    input_shard_axes: A list of dimensions along which to shard `inputs`, or
1909      `None`. `None` means "shard all inputs along dimension 0". If not `None`,
1910      there must be one dimension per input.
1911    outputs_from_all_shards: Boolean or list of boolean. For each output, if
1912      `True`, outputs from all shards are concatenated along the corresponding
1913      `output_shard_axes` entry. Otherwise, each output is taken
1914      from an arbitrary shard. If the argument is a boolean, the argument's
1915      value is used for each output.
1916    output_shard_axes: A list of dimensions along which to concatenate the
1917      outputs of `computation`, or `None`. `None` means "concatenate all outputs
1918      along dimension 0". If not `None`, there must be one dimension per output.
1919      Ignored if `outputs_from_all_shards` is False.
1920    infeed_queue: If not `None`, the `InfeedQueue` to use to augment the inputs
1921      of `computation`.
1922    device_assignment: If not `None`, a `DeviceAssignment` describing the
1923      mapping between logical cores in the computation with physical cores in
1924      the TPU topology. Uses a default device assignment if `None`. The
1925      `DeviceAssignment` may be omitted if each shard of the computation uses
1926      only one core, and there is either only one shard, or the number of shards
1927      is equal to the number of cores in the TPU system.
1928    name: (Deprecated) Does nothing.
1929    xla_options: An instance of `tpu.XLAOptions` which indicates the options
1930      passed to XLA compiler. Use `None` for default options.
1931  Returns:
1932    A list of output tensors.
1933  Raises:
1934    ValueError: If num_shards <= 0
1935    ValueError: If len(input_shard_axes) != len(inputs)
1936    ValueError: If len(output_shard_axes) != len(outputs from `computation`)
1937  """
1938  return split_compile_and_shard(
1939      computation,
1940      inputs=inputs,
1941      num_shards=num_shards,
1942      input_shard_axes=input_shard_axes,
1943      outputs_from_all_shards=outputs_from_all_shards,
1944      output_shard_axes=output_shard_axes,
1945      infeed_queue=infeed_queue,
1946      device_assignment=device_assignment,
1947      name=name,
1948      xla_options=xla_options)[1]
1949
1950
1951@tf_export(v1=["tpu.batch_parallel"])
1952def batch_parallel(
1953    computation: Callable[..., Any],
1954    inputs: List[List[Optional[core_types.Tensor]]] = None,
1955    num_shards: int = 1,
1956    infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
1957    device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
1958    name: Optional[Text] = None,
1959    xla_options: Optional[XLAOptions] = None):
1960  """Shards `computation` along the batch dimension for parallel execution.
1961
1962  Convenience wrapper around shard().
1963
1964  `inputs` must be a list of Tensors or None (equivalent to an empty list).
1965  Each input is split into `num_shards` pieces along the 0-th dimension, and
1966  computation is applied to each shard in parallel.
1967
1968  Tensors are broadcast to all shards if they are lexically captured by
1969  `computation`. e.g.,
1970
1971  x = tf.constant(7)
1972  def computation():
1973    return x + 3
1974  ... = shard(computation, ...)
1975
1976  The outputs from all shards are concatenated back together along their 0-th
1977  dimension.
1978
1979  Inputs and outputs of the computation must be at least rank-1 Tensors.
1980
1981  Args:
1982    computation: A Python function that builds a computation to apply to each
1983      shard of the input.
1984    inputs: A list of input tensors or None (equivalent to an empty list). The
1985      0-th dimension of each Tensor must have size divisible by `num_shards`.
1986    num_shards: The number of shards.
1987    infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
1988      of arguments as inputs to `computation`.
1989    device_assignment: If not `None`, a `DeviceAssignment` describing the
1990      mapping between logical cores in the computation with physical cores in
1991      the TPU topology. Uses a default device assignment if `None`. The
1992      `DeviceAssignment` may be omitted if each shard of the computation uses
1993      only one core, and there is either only one shard, or the number of shards
1994      is equal to the number of cores in the TPU system.
1995    name: (Deprecated) Does nothing.
1996    xla_options: An instance of `tpu.XLAOptions` which indicates the options
1997      passed to XLA compiler. Use `None` for default options.
1998  Returns:
1999    A list of output tensors.
2000  Raises:
2001    ValueError: If `num_shards <= 0`
2002  """
2003  return shard(
2004      computation,
2005      inputs,
2006      num_shards=num_shards,
2007      infeed_queue=infeed_queue,
2008      device_assignment=device_assignment,
2009      name=name,
2010      xla_options=xla_options)
2011
2012
2013@tf_export(v1=["tpu.rewrite"])
2014def rewrite(
2015    computation: Callable[..., Any],
2016    inputs: List[List[Optional[core_types.Tensor]]] = None,
2017    infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
2018    device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
2019    name: Optional[Text] = None,
2020    xla_options: Optional[XLAOptions] = None) -> Any:
2021  """Rewrites `computation` for execution on a TPU system.
2022
2023  Args:
2024    computation: A Python function that builds a computation to apply to the
2025      input. If the function takes n inputs, 'inputs' should be a list of n
2026      tensors.
2027
2028      `computation` may return a list of operations and tensors. Tensors must
2029      come before operations in the returned list.  The return value of
2030      `rewrite` is a list of tensors corresponding to the tensors from the
2031      output of `computation`.
2032
2033      All `Operation`s constructed during `computation` will be executed when
2034      evaluating any of the returned output tensors, not just the ones returned.
2035    inputs: A list of input tensors or `None` (equivalent to an empty list).
2036      Each input can be a nested structure containing values that are
2037      convertible to tensors. Note that passing an N-dimension list of
2038      compatible values will result in a N-dimension list of scalar tensors
2039      rather than a single Rank-N tensors. If you need different behavior,
2040      convert part of inputs to tensors with `tf.convert_to_tensor`.
2041    infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
2042      of arguments as inputs to `computation`.
2043    device_assignment: if not `None`, a `DeviceAssignment` describing the
2044      mapping between logical cores in the computation with physical cores in
2045      the TPU topology. May be omitted for a single-core computation, in which
2046      case the core attached to task 0, TPU device 0 is used.
2047    name: (Deprecated) Does nothing.
2048    xla_options: An instance of `tpu.XLAOptions` which indicates the options
2049      passed to XLA compiler. Use `None` for default options.
2050  Returns:
2051    Same data structure as if computation(*inputs) is called directly with some
2052    exceptions for correctness. Exceptions include:
2053      1) None output: a NoOp would be returned which control-depends on
2054         computation.
2055      2) Single value output: A tuple containing the value would be returned.
2056      3) Operation-only outputs: a NoOp would be returned which
2057         control-depends on computation.
2058      TODO(b/121383831): Investigate into removing these special cases.
2059  """
2060  # TODO(b/36647078) remove disable when pylint bug is fixed.
2061  # pylint: disable=indexing-exception
2062  return replicate(
2063      computation,
2064      None if inputs is None else [inputs],
2065      infeed_queue=infeed_queue,
2066      device_assignment=device_assignment,
2067      name=name,
2068      xla_options=xla_options)[0]
2069  # pylint: enable=indexing-exception
2070
2071  # Operations that indicate some error in the user's inference graph.
2072
2073
2074_DENYLISTED_INFERENCE_OPS = set([
2075    "ReadVariableOp",
2076    "AssignVariableOp",
2077    "AssignAddVariableOp",
2078    "AssignSubVariableOp",
2079    "VarHandleOp",
2080    "Variable",
2081    "VariableV2",
2082])
2083
2084
2085def under_tpu_inference_context() -> bool:
2086  """Check if it is currently under `_TPUInferenceContext`."""
2087  graph = ops.get_default_graph()
2088  while graph:
2089    context = graph._get_control_flow_context()  # pylint: disable=protected-access
2090    while context:
2091      if isinstance(context, _TPUInferenceContext):
2092        return True
2093      context = context.outer_context
2094    if isinstance(graph, function._FuncGraph):  # pylint: disable=protected-access
2095      graph = graph._outer_graph  # pylint: disable=protected-access
2096    elif isinstance(graph, func_graph.FuncGraph):
2097      graph = graph.outer_graph
2098    else:
2099      return False
2100
2101
2102class _TPUInferenceContext(control_flow_ops.XLAControlFlowContext):
2103  """A `ControlFlowContext` for nodes inside a TPU inference computation.
2104
2105  The primary role of `_TPUInferenceContext` is to indicate the mode of
2106  operation and possibly sanity check operators inside a
2107  tpu.rewrite_for_inference() computation.
2108  """
2109
2110  def __init__(self, name: Text, check_ops: bool = True):
2111    super(_TPUInferenceContext, self).__init__()
2112    self._name = name
2113    self._check_ops = check_ops
2114
2115  def AddOp(self, op):
2116    self._AddOpInternal(op)
2117
2118  def _AddOpInternal(self, op):
2119    # pylint: disable=protected-access
2120    if self._check_ops and op.type in _DENYLISTED_INFERENCE_OPS:
2121      raise NotImplementedError(
2122          "Operation of type %s (%s) is not supported on the TPU for inference."
2123          " Execution will fail if this op is used in the graph. Make sure your"
2124          " variables are using variable_scope." % (op.type, op.name))
2125    if self._outer_context:
2126      self._outer_context.AddInnerOp(op)
2127
2128  def AddValue(self, val):
2129    result = val
2130    if self._outer_context:
2131      result = self._outer_context.AddValue(val)
2132    return result
2133
2134  def AddInnerOp(self, op):
2135    self._AddOpInternal(op)
2136
2137  @property
2138  def grad_state(self):
2139    return None
2140
2141
2142def validate_inference_rewrite_for_variables(graph: ops.Graph):
2143  """Validates whether rewrite_for_inference() 'worked' for variables.
2144
2145     The rewrite_for_inference() method is supposed to append GuaranteeConstOps
2146     after ReadVariableOps, but this mechanism works only if you are using
2147     tf.compat.v1.get_variable() to create and access variables in your tpu
2148     computation. This validation method can be called immediately after calling
2149     tpu.rewrite_for_inference() to check whether GuaranteeConstOps where added
2150     to the graph.
2151
2152     Typical usages:
2153       tpu.validate_inference_rewrite_for_variables(
2154           tf.compat.v1.get_default_graph())
2155
2156       tpu.validate_inference_rewrite_for_variables(sess.graph)
2157
2158  Args:
2159    graph: The graph which needs to be validated.
2160  Raises:
2161    RuntimeError: if validation failed.
2162  """
2163  if not any(x.type == "GuaranteeConst" for x in graph.get_operations()):
2164    raise RuntimeError(
2165        "No GuaranteeConst ops found in the graph after running "
2166        "tpu.rewrite_for_inference(...). Please check that you are using "
2167        "tf.get_variable() to create and access variables in your tpu "
2168        "computation.")
2169
2170
2171def rewrite_for_inference(
2172    computation: Callable[..., Any],
2173    inputs: Optional[List[core_types.Tensor]] = None,
2174    infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
2175    device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
2176    name: Optional[Text] = None) -> List[core_types.Tensor]:
2177  """Rewrites `computation` for inference on a TPU system.
2178
2179     Other than 'rewriting' the computation to run on a TPU, if using variables
2180     in your computation, it moves the ReadVariableOps outside the TPU
2181     computation, and adds GuaranteeConst ops just after the ReadVariableOps.
2182     This mechanism works only if you are using tf.compat.v1.get_variable() to
2183     create and access variables in your tpu computation. You can validate
2184     whether this worked, by calling validate_inference_rewrite_for_variables()
2185     method immediately after this method to check whether GuaranteeConstOps
2186     where added to the graph.
2187
2188  Args:
2189    computation: A Python function that builds a computation to apply to the
2190      input. If the function takes n inputs, 'inputs' should be a list of n
2191      tensors. If the function returns m outputs, rewrite will return a list of
2192      m tensors.
2193    inputs: A list of input tensors or `None` (equivalent to an empty list).
2194    infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
2195      of arguments as inputs to `computation`.
2196    device_assignment: if not `None`, a `DeviceAssignment` describing the
2197      mapping between logical cores in the computation with physical cores in
2198      the TPU topology. May be omitted for a single-core computation, in which
2199      case the core attached to task 0, TPU device 0 is used.
2200    name: The name of the operator.
2201  Returns:
2202    A list of output tensors.
2203  """
2204
2205  def guarantee_const_getter(getter, name, *args, **kwargs):
2206    with ops.control_dependencies(None):
2207      return array_ops.guarantee_const(
2208          getter(name, *args, **kwargs), name=name + "/GuaranteeConst")
2209
2210  def wrapped_computation(*args, **kwargs):
2211    """Execute computation under `_TPUInferenceContext`."""
2212    context = _TPUInferenceContext(
2213        name=ops.get_default_graph().unique_name("rewrite_for_inference"))
2214    try:
2215      context.Enter()
2216
2217      vscope = variable_scope.get_variable_scope()
2218      prev_custom_getter = vscope.custom_getter
2219      prev_caching_device = vscope.caching_device
2220      vscope.set_custom_getter(guarantee_const_getter)
2221      vscope.set_caching_device(lambda op: op.device)
2222
2223      result = computation(*args, **kwargs)
2224
2225      vscope.set_custom_getter(prev_custom_getter)
2226      vscope.set_caching_device(prev_caching_device)
2227    finally:
2228      context.Exit()
2229    return result
2230
2231  # pylint: disable=undefined-variable
2232  return rewrite(
2233      wrapped_computation,
2234      inputs=inputs,
2235      infeed_queue=infeed_queue,
2236      device_assignment=device_assignment,
2237      name=name)
2238  # pylint: enable=undefined-variable
2239
2240
2241def prune_unconnected_ops_from_xla(prune_graph: ops.Graph):
2242  """Prunes unconnected ops as listed in _UNCONNECTED_OPS_TO_PRUNE.
2243
2244  Args:
2245    prune_graph: A tensorflow graph from which we wish to prune unconnected ops
2246      as listed in _UNCONNECTED_OPS_TO_PRUNE.  In general, these ops should have
2247      no inputs and no consumers. These can often be left behind due to graph
2248      construction rewiring (for instance TF-Hub). While they never execute,
2249      they will cause XLA compile to fail so we strip them from XLA compile by
2250      removing the tpu_replicate attribute.
2251  """
2252  # Scan over the top level graph and all function graphs.
2253  for graph in [prune_graph] + [
2254      f for f in prune_graph._functions.values()  # pylint: disable=protected-access
2255  ]:
2256    if not isinstance(graph, ops.Graph):
2257      continue
2258    for op in graph.get_operations():
2259      if op.type not in _UNCONNECTED_OPS_TO_PRUNE:
2260        continue
2261      outputs_consumed = False
2262      for output in op.outputs:
2263        if output.consumers():
2264          outputs_consumed = True
2265          break
2266      if not outputs_consumed:
2267        logging.info(
2268            "Pruning OP %s of type %s from XLA Compile due to "
2269            "it being disconnected.", op.name, op.type)
2270        op._clear_attr(_TPU_REPLICATE_ATTR)  # pylint: disable=protected-access
2271