• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ======================================
15
16"""Library of TPU helper functions."""
17
18import collections
19import enum
20import typing
21from typing import Any, Callable, Iterable, List, Optional, Text, Tuple, Union
22
23from absl import logging
24import numpy as np
25
26from tensorflow.compiler.tf2xla.python import xla as tf2xla
27from tensorflow.core.framework import attr_value_pb2
28from tensorflow.core.protobuf.tpu import dynamic_padding_pb2 as dynamic_padding
29from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 as embedding_pb2
30from tensorflow.python import tf2
31from tensorflow.python.compiler.xla import xla
32from tensorflow.python.distribute import device_util
33from tensorflow.python.distribute import distribution_strategy_context
34from tensorflow.python.framework import auto_control_deps
35from tensorflow.python.framework import c_api_util
36from tensorflow.python.framework import composite_tensor
37from tensorflow.python.framework import config
38from tensorflow.python.framework import constant_op
39from tensorflow.python.framework import device as pydev
40from tensorflow.python.framework import dtypes
41from tensorflow.python.framework import errors
42from tensorflow.python.framework import func_graph
43from tensorflow.python.framework import function
44from tensorflow.python.framework import ops
45from tensorflow.python.framework import tensor_shape
46from tensorflow.python.ops import array_ops
47from tensorflow.python.ops import control_flow_ops
48from tensorflow.python.ops import math_ops
49from tensorflow.python.ops import variable_scope
50from tensorflow.python.ops import variables
51from tensorflow.python.tpu import device_assignment as device_assignment_lib
52from tensorflow.python.tpu import tpu_feed
53from tensorflow.python.tpu import tpu_function
54from tensorflow.python.tpu import tpu_name_util
55from tensorflow.python.tpu.ops import tpu_ops
56from tensorflow.python.types import core as core_types
57from tensorflow.python.util import compat
58from tensorflow.python.util import nest
59from tensorflow.python.util import object_identity
60from tensorflow.python.util import traceback_utils
61from tensorflow.python.util import variable_utils
62from tensorflow.python.util.tf_export import tf_export
63
64
65ops.NotDifferentiable("TPUReplicatedInput")
66
67# Operations that indicate some error in the users graph, e.g. a placeholder
68# that's introduced outside of the infeed.
69_DENYLISTED_OPS = set([
70    "Placeholder",
71])
72
73# XLA doesn't currently support reading of intermediate tensors, thus some ops
74# are not supported.
75_UNSUPPORTED_OPS = set([
76    "AudioSummary",
77    "AudioSummaryV2",
78    "HistogramSummary",
79    "ImageSummary",
80    "MergeSummary",
81    "Print",
82    "ScalarSummary",
83    "TensorSummary",
84    "TensorSummaryV2",
85    ])
86
87# Ops which can be safely pruned from XLA compile if they have no consumers.
88#  These ops should also have no inputs.
89_UNCONNECTED_OPS_TO_PRUNE = set(["Placeholder", "VarHandleOp"])
90
91_MAX_WARNING_LINES = 5
92
93_TPU_REPLICATE_ATTR = "_tpu_replicate"
94_POST_DEVICE_REWRITE_ATTR = "_post_device_rewrite"
95_TPU_COMPILATION_STATUS_ATTR = "_tpu_compilation_status"
96_OUTSIDE_COMPILATION_ATTR = "_xla_outside_compilation"
97_PIVOT_FOR_CLUSTER = "_pivot_for_cluster"
98
99
100core = tpu_name_util.core
101
102
103def _tpu_system_device_name(job: Optional[Text]) -> Text:
104  """Returns the device name for the TPU_SYSTEM device of `job`."""
105  if job is None:
106    return "/device:TPU_SYSTEM:0"
107  else:
108    return "/job:%s/device:TPU_SYSTEM:0" % job
109
110
111@tf_export(v1=["tpu.initialize_system"])
112def initialize_system(
113    embedding_config: Optional[embedding_pb2.TPUEmbeddingConfiguration] = None,
114    job: Optional[Text] = None,
115    compilation_failure_closes_chips: bool = True,
116    tpu_cancellation_closes_chips: Optional[bool] = None,
117) -> core_types.Tensor:
118  """Initializes a distributed TPU system for use with TensorFlow.
119
120  Args:
121    embedding_config: If not None, a `TPUEmbeddingConfiguration` proto
122      describing the desired configuration of the hardware embedding lookup
123      tables. If embedding_config is None, no hardware embeddings can be used.
124    job: The job (the XXX in TensorFlow device specification /job:XXX) that
125      contains the TPU devices that will be initialized. If job=None it is
126      assumed there is only one job in the TensorFlow flock, and an error will
127      be returned if this assumption does not hold.
128    compilation_failure_closes_chips: Set the configuration whether
129      we want to close TPU chips when there is a compilation failure.
130    tpu_cancellation_closes_chips: Set the configuration whether
131      we want to close TPU chips when a TPU execution is cancelled. If the value
132      is None, the behavior will be determined by the command line flag
133      `tpu_cancellation_closes_chips` for the TPU worker. WARNING: this argument
134      only applies to TFRT TPU runtime.
135  Returns:
136    A serialized `TopologyProto` that describes the TPU system. Note:
137      the topology must be evaluated using `Session.run` before it can be used.
138  """
139  config_string = ("" if embedding_config is None else
140                   embedding_config.SerializeToString())
141
142  # The enum is defined in core/tpu/kernels/tpu_execute_op_options.h.
143  tpu_cancellation_closes_chips_enum = 0
144  if tpu_cancellation_closes_chips is not None:
145    if tpu_cancellation_closes_chips:
146      tpu_cancellation_closes_chips_enum = 1
147    else:
148      tpu_cancellation_closes_chips_enum = 2
149
150  with ops.device(_tpu_system_device_name(job)):
151    topology = tpu_ops.configure_distributed_tpu(
152        compilation_failure_closes_chips=compilation_failure_closes_chips,
153        tpu_cancellation_closes_chips=tpu_cancellation_closes_chips_enum,
154    )
155
156    if embedding_config is None:
157      return topology
158
159    # This set of control dependencies is needed as this function is expected to
160    # return an op which will return the topology when executed, but we need to
161    # call the embedding initialization op between initializing the TPU and
162    # returning the topology.
163    with ops.control_dependencies([topology]):
164      embedding_init = tpu_ops.configure_tpu_embedding(config=config_string)
165    with ops.control_dependencies([embedding_init]):
166      return array_ops.identity(topology, name="tpu_init_identity")
167
168
169def initialize_system_for_tpu_embedding(
170    embedding_config: embedding_pb2.TPUEmbeddingConfiguration,
171    job: Optional[Text] = None,
172) -> ops.Operation:
173  """Initializes a distributed TPU Embedding system for use with TensorFlow.
174
175  The following two are equivalent:
176  1. initialize_system() with embedding_config.
177  2. initialize_system() without embedding_config, then
178     initialize_system_for_tpu_embedding().
179  initialize_system() should not be called with embedding_config if
180  initialize_system_for_tpu_embedding() is meant to be called later.
181
182  Args:
183    embedding_config: a `TPUEmbeddingConfiguration` proto describing the desired
184      configuration of the hardware embedding lookup tables.
185    job: The job (the XXX in TensorFlow device specification /job:XXX) that
186      contains the TPU devices that will be initialized. If job=None it is
187      assumed there is only one job in the TensorFlow flock, and an error will
188      be returned if this assumption does not hold.
189
190  Returns:
191    A no-op.
192  """
193  config_string = embedding_config.SerializeToString()
194  with ops.device(_tpu_system_device_name(job)):
195    return tpu_ops.configure_tpu_embedding(config=config_string)
196
197
198@tf_export(v1=["tpu.shutdown_system"])
199def shutdown_system(job: Optional[Text] = None) -> ops.Operation:
200  """Shuts down a running a distributed TPU system.
201
202  Args:
203    job: The job (the XXX in TensorFlow device specification /job:XXX) that
204      contains the TPU devices that will be shutdown. If job=None it is
205      assumed there is only one job in the TensorFlow flock, and an error will
206      be returned if this assumption does not hold.
207  """
208  with ops.device(_tpu_system_device_name(job)):
209    shutdown_distributed_tpu = tpu_ops.shutdown_distributed_tpu()
210  return shutdown_distributed_tpu
211
212
213def _enclosing_tpu_context_and_graph() -> Tuple[Any, Any]:
214  """Returns the TPUReplicateContext and its associated graph."""
215  graph = ops.get_default_graph()
216  while graph is not None:
217    # pylint: disable=protected-access
218    context_ = graph._get_control_flow_context()
219    # pylint: enable=protected-access
220    while context_ is not None:
221      if isinstance(context_, TPUReplicateContext):
222        return context_, graph
223      context_ = context_.outer_context
224    graph = getattr(graph, "outer_graph", None)
225  raise ValueError("get_replicated_var_handle() called without "
226                   "TPUReplicateContext. This shouldn't happen. Please file "
227                   "a bug.")
228
229
230def is_tpu_strategy(strategy: Any) -> bool:
231  is_tpu_strat = lambda k: k.__name__.startswith("TPUStrategy")
232  clz = strategy.__class__
233  return is_tpu_strat(clz) or any(map(is_tpu_strat, clz.__bases__))
234
235
236def _enclosing_tpu_device_assignment(
237) -> Optional[device_assignment_lib.DeviceAssignment]:
238  if not distribution_strategy_context.has_strategy():
239    return None
240  strategy = distribution_strategy_context.get_strategy()
241  if not is_tpu_strategy(strategy):
242    return None
243  return strategy.extended._device_assignment  # pylint: disable=protected-access
244
245
246@auto_control_deps.register_acd_resource_resolver
247def tpu_replicated_input_resolver(
248    op: ops.Operation,
249    resource_reads: object_identity.ObjectIdentitySet,
250    resource_writes: object_identity.ObjectIdentitySet) -> bool:
251  """Replaces TPUReplicatedInput outputs with its inputs in resource_inputs."""
252  # Ignore TPUReplicatedInput for ACD purposes since we will be directly adding
253  # control deps on the replicated inputs.
254  if op.type == "TPUReplicatedInput":
255    if resource_reads or resource_writes:
256      resource_reads.clear()
257      resource_writes.clear()
258      return True
259    else:
260      return False
261  # Replace tensors in `resource_inputs` which are outputs of TPUReplicatedInput
262  # with the actual replicated inputs. This allows ACD to correct add control
263  # deps when there are multiple calls to `run` in a
264  # `tf.function`.
265  def replace_with_unreplicated_resources(resource_inputs):
266    """Replaces handles in `resource_inputs` with their unreplicated inputs."""
267    to_remove = []
268    to_add = []
269    for resource in resource_inputs:
270      if resource.op.type == "TPUReplicatedInput":
271        to_remove.append(resource)
272        to_add.extend(resource.op.inputs)
273    for t in to_remove:
274      resource_inputs.discard(t)
275    resource_inputs.update(to_add)
276    return to_add or to_remove
277
278  return bool(replace_with_unreplicated_resources(resource_reads) or
279              replace_with_unreplicated_resources(resource_writes))
280
281
282class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
283  """A `ControlFlowContext` for nodes inside a TPU computation.
284
285  The primary role of `TPUReplicateContext` is to mark operators inside a
286  tpu.replicate() computation with the attribute "_tpu_replicate=XYZ", where XYZ
287  is a unique name.
288
289  We use a `ControlFlowContext` to perform the annotation since it integrates
290  with Tensorflow constructs like ResourceVariables. For example, if a
291  `ResourceVariable` is constructed inside a tpu.replicate() block, the
292  `ResourceVariable` implementation can use
293  `with ops.control_dependencies(None)` to build the variable's definition
294  outside the replicated computation.
295  """
296
297  def __init__(self, name: Text, num_replicas: int, pivot: ops.Operation):
298    """Builds a new TPUReplicateContext.
299
300    Args:
301      name: a unique name for the context, used to populate the `_tpu_replicate`
302        attribute.
303      num_replicas: an integer that gives the number of replicas for the
304        computation.
305      pivot: a pivot node. Nodes in the TPUReplicateContext that do not have any
306        inputs will have a control dependency on the pivot node. This ensures
307        that nodes are correctly included in any enclosing control flow
308        contexts.
309    """
310    super(TPUReplicateContext, self).__init__()
311    self._num_replicas = num_replicas
312    self._outer_device_function_stack = None
313    self._oc_dev_fn_stack = None
314    self._outside_compilation_cluster = None
315    self._outside_compilation_v2_context = None
316    self._outside_compilation_counter = 0
317    self._in_gradient_colocation = None
318    self._gradient_colocation_stack = []
319    self._host_compute_core = []
320    self._name = name
321    self._name_as_bytes = compat.as_bytes(name)
322    self._tpu_relicate_attr_buf = c_api_util.ScopedTFBuffer(
323        attr_value_pb2.AttrValue(s=self._name_as_bytes).SerializeToString())
324    self._unsupported_ops = []
325    self._pivot = pivot
326    self._replicated_vars = {}
327
328  def get_replicated_var_handle(self,
329                                name: Text,
330                                handle_id: Text,
331                                vars_: Union[List[core_types.Tensor],
332                                             List[variables.Variable]],
333                                is_mirrored: bool = False,
334                                is_packed: bool = False) -> core_types.Tensor:
335    """Returns a variable handle for replicated TPU variable 'var'.
336
337    This is a method used by an experimental replicated variable implementation
338    and is not intended as a public API.
339
340    Args:
341      name: The common name of the variable.
342      handle_id: Unique ID of the variable handle, used as the cache key.
343      vars_: The replicated TPU variables or handles.
344      is_mirrored: Whether the variables are mirrored, which guarantees the
345        values in each replica are always the same.
346      is_packed: Whether the replicated variables are packed into one variable.
347
348    Returns:
349      The handle of the TPU replicated input node.
350    """
351    device_assignment = _enclosing_tpu_device_assignment()
352    # We don't need to put device assignment as part of the replicated_vars key
353    # because each TPUReplicateContext will only have one device assignment.
354    handle = self._replicated_vars.get(handle_id)
355    if handle is not None:
356      return handle
357
358    if device_assignment is not None and not is_packed:
359      # Find a variable copy for each replica in the device assignment.
360      # Note that the order of devices for replicas for the variable and the
361      # device assignment might not match.
362      job_name = pydev.DeviceSpec.from_string(vars_[0].device).job
363      devices_to_vars = {device_util.canonicalize(v.device): v for v in vars_}
364      replicated_vars = []
365      for replica_id in range(device_assignment.num_replicas):
366        for logical_core in range(device_assignment.num_cores_per_replica):
367          device = device_util.canonicalize(
368              device_assignment.tpu_device(
369                  replica=replica_id, logical_core=logical_core, job=job_name))
370          if device in devices_to_vars:
371            replicated_vars.append(devices_to_vars[device])
372            break
373        else:
374          raise ValueError(
375              "Failed to find a variable on any device in replica {} for "
376              "current device assignment".format(replica_id))
377    else:
378      replicated_vars = vars_
379
380    # Builds a TPUReplicatedInput node for the variable, if one does not already
381    # exist. The TPUReplicatedInput node must belong to the enclosing
382    # control-flow scope of the TPUReplicateContext.
383    # TODO(phawkins): consider changing the contract of the TPU encapsulation
384    # so the TPUReplicatedInput nodes go inside the TPUReplicateContext scope
385    # instead.
386
387    _, graph = _enclosing_tpu_context_and_graph()
388    with graph.as_default():
389      # If replicated_vars are variables, get the handles. Note that this can be
390      # done inside TPUReplicateContext because replicated_vars.handle may
391      # create new ops.
392      if isinstance(replicated_vars[0], variables.Variable):
393        replicated_vars = [v.handle for v in replicated_vars]
394      # pylint: disable=protected-access
395      saved_context = graph._get_control_flow_context()
396      graph._set_control_flow_context(self.outer_context)
397      handle = tpu_ops.tpu_replicated_input(replicated_vars,
398                                            name=name + "/handle",
399                                            is_mirrored_variable=is_mirrored,
400                                            is_packed=is_packed)
401      graph._set_control_flow_context(saved_context)
402      # pylint: enable=protected-access
403    self._replicated_vars[handle_id] = handle
404    return handle
405
406  def report_unsupported_operations(self) -> None:
407    if self._unsupported_ops:
408      op_str = "\n".join("  %s (%s)" % (op.type, op.name)
409                         for op in self._unsupported_ops[:_MAX_WARNING_LINES])
410      logging.warning("%d unsupported operations found: \n%s",
411                      len(self._unsupported_ops), op_str)
412      if len(self._unsupported_ops) > _MAX_WARNING_LINES:
413        logging.warning("... and %d more" %
414                        (len(self._unsupported_ops) - _MAX_WARNING_LINES))
415
416  def EnterGradientColocation(self, op: ops.Operation, gradient_uid: Text):
417    if op is not None:
418      if ops.get_default_graph()._control_flow_context is None:  # pylint: disable=protected-access
419        # If we are in TF 2 functions (control flow V2 functions, or
420        # tf.function()), we need to attach _xla_outside_compilation attribute
421        # directly because we are not in TPUReplicateContext.
422        try:
423          outside_attr = op.get_attr(_OUTSIDE_COMPILATION_ATTR).decode("ascii")
424        except ValueError:
425          # The attr was not present: do nothing.
426          return
427        parts = outside_attr.split(".")
428        cluster = parts[0] + "." + gradient_uid
429        self._outside_compilation_v2_context = OutsideCompilationV2Context(
430            cluster)
431        self._outside_compilation_v2_context.Enter()
432        return
433      self._gradient_colocation_stack.append(op)
434      if not self._outside_compilation_cluster:
435        try:
436          outside_attr = op.get_attr(_OUTSIDE_COMPILATION_ATTR).decode("ascii")
437          if self._in_gradient_colocation:
438            raise NotImplementedError(
439                "Cannot nest gradient colocation operations outside compilation"
440            )
441          if gradient_uid == "__unsupported__":
442            raise NotImplementedError(
443                "No gradient_uid calling gradient within outside_compilation")
444          # When we take the gradient of an op X in an outside_compilation
445          # cluster C in a forward computation we would like to put the ops
446          # corresponding to the gradient of X into a new outside_compilation
447          # cluster C'. However, if we take the gradient of X twice, the second
448          # one should get yet another new outside_compilation cluster C''.
449          #
450          # The mechanism we adopt is to use a 'root_cluster' which is the
451          # cluster that X was in before we took gradients, and a 'gradient_uid'
452          # which is different for every invocation of gradients, and put the
453          # gradient of X in cluster 'root_cluster.gradient_uid'.
454          #
455          # When taking a gradient of a gradient, some ops will be colocated
456          # with Op in the forward pass (e.g., cluster root_cluster) and some in
457          # the backward pass (e.g., cluster root_cluster.initial_gradient_uid).
458          # We need all of the grad-of-grad ops to be in the same cluster to
459          # avoid cyclic dependencies between clusters. We adopt a heuristic
460          # that puts any op clustered with root_cluster.<xxx> in
461          # root_cluster.gradient_uid, even if xxx was initial_gradient_uid.
462          self._in_gradient_colocation = op
463          parts = outside_attr.split(".")
464          cluster = parts[0] + "." + gradient_uid
465          self._EnterOutsideCompilationScope(cluster=cluster)
466        except ValueError:
467          # The attr was not present: do nothing.
468          pass
469
470  def ExitGradientColocation(self, op: ops.Operation, gradient_uid: Text):
471    if op is not None:
472      if ops.get_default_graph()._control_flow_context is None:  # pylint: disable=protected-access
473        # Inside a TF2 tf.function or control flow graph and `op` was not
474        # marked to be outside compiled.
475        assert self._outside_compilation_v2_context is None
476        return
477      if self._outside_compilation_v2_context is not None:
478        # Inside a TF2 tf.function or control flow graph and `op` was
479        # marked to be outside compiled.
480        self._outside_compilation_v2_context.Exit()
481        self._outside_compilation_v2_context = None
482        return
483      if not self._gradient_colocation_stack:
484        raise errors.InternalError(
485            op.node_def, op,
486            f"Badly nested gradient colocation: empty stack when popping Op {op.name}"
487        )
488      last_op = self._gradient_colocation_stack.pop()
489      if op is last_op:
490        if op is self._in_gradient_colocation:
491          self._in_gradient_colocation = None
492          self._ExitOutsideCompilationScope()
493      else:
494        raise errors.InternalError(
495            op.node_def, op,
496            f"Badly nested gradient colocation, expected {last_op}, got {op.name}"
497        )
498
499  def _EnterOutsideCompilationScope(self, cluster: Optional[Text] = None):
500
501    class FakeOp(object):
502      """A helper class to determine the current device.
503
504      Supports only the type and device set/get methods needed to run the
505      graph's _apply_device_function method.
506      """
507
508      def __init__(self):
509        self._device = ""
510
511      @property
512      def type(self):
513        return "FakeOp"
514
515      @property
516      def device(self):
517        return self._device
518
519      def _set_device(self, device):
520        if isinstance(device, pydev.DeviceSpec):
521          self._device = device.to_string()
522        else:
523          self._device = device
524
525      def _set_device_from_string(self, device_str):
526        self._device = device_str
527
528    if self._outside_compilation_cluster:
529      raise NotImplementedError("Cannot nest outside_compilation clusters")
530    if cluster:
531      self._outside_compilation_cluster = cluster
532    else:
533      self._outside_compilation_cluster = str(self._outside_compilation_counter)
534      self._outside_compilation_counter += 1
535    graph = ops.get_default_graph()
536    fake_op = FakeOp()
537    graph._apply_device_functions(fake_op)  # pylint: disable=protected-access
538    device = pydev.DeviceSpec.from_string(fake_op.device)
539    if (device.device_type == "TPU_REPLICATED_CORE" and
540        device.device_index is not None):
541      self._host_compute_core.append(self._outside_compilation_cluster + ":" +
542                                     str(device.device_index))
543    self._oc_dev_fn_stack = graph._device_function_stack  # pylint: disable=protected-access
544    graph._device_function_stack = self._outer_device_function_stack  # pylint: disable=protected-access
545
546  def _ExitOutsideCompilationScope(self):
547    if not self._outside_compilation_cluster:
548      raise ValueError(
549          "Attempted to exit outside_compilation scope when not in scope")
550    self._outside_compilation_cluster = None
551    graph = ops.get_default_graph()
552    graph._device_function_stack = self._oc_dev_fn_stack  # pylint: disable=protected-access
553
554  def Enter(self) -> None:
555    if not self._outer_device_function_stack:
556      # Capture the device function stack at the time of first entry
557      # since that is the stack that will be used outside_compilation.
558      graph = ops.get_default_graph()
559      # pylint: disable=protected-access
560      self._outer_device_function_stack = graph._device_function_stack.copy()
561      # pylint: enable=protected-access
562    super(TPUReplicateContext, self).Enter()
563
564  def HostComputeCore(self) -> List[Text]:
565    return self._host_compute_core
566
567  def _RemoveExternalControlEdges(
568      self, op: ops.Operation
569      ) -> Tuple[List[ops.Operation], List[ops.Operation]]:
570    """Remove any external control dependency on this op."""
571    internal_control_inputs = []
572    external_control_inputs = []
573    for x in op.control_inputs:
574      # pylint: disable=protected-access
575      is_internal_op = False
576      ctxt = x._get_control_flow_context()
577      while ctxt is not None:
578        if ctxt == self:
579          is_internal_op = True
580          break
581        ctxt = ctxt._outer_context
582      if is_internal_op:
583        internal_control_inputs.append(x)
584      else:
585        external_control_inputs.append(x)
586      # pylint: enable=protected-access
587    # pylint: disable=protected-access
588    op._remove_all_control_inputs()
589    op._add_control_inputs(internal_control_inputs)
590    # pylint: enable=protected-access
591    return internal_control_inputs, external_control_inputs
592
593  def AddOp(self, op: ops.Operation) -> None:
594    # pylint: disable=protected-access
595    if op.type in _DENYLISTED_OPS:
596      logging.error("Operation of type %s (%s) is not supported on the TPU. "
597                    "Execution will fail if this op is used in the graph. ",
598                    op.type, op.name)
599
600    if op.type in _UNSUPPORTED_OPS:
601      self._unsupported_ops.append(op)
602
603    if any(x.dtype._is_ref_dtype for x in op.inputs):
604      raise NotImplementedError(
605          f"Non-resource Variables are not supported inside TPU computations "
606          f"(operator name: {op.name})")
607
608    # TensorFlowOpLayer may clone nodes that are in tpu.rewrite()s. It'll add
609    # the "_cloned" attribute and we should continue in that case.
610    if (_TPU_REPLICATE_ATTR in op.node_def.attr and
611        "_cloned" not in op.node_def.attr):
612      raise ValueError(f"TPU computations cannot be nested on op ({op})")
613    op._set_attr_with_buf(_TPU_REPLICATE_ATTR,
614                          self._tpu_relicate_attr_buf.buffer)
615    if self._outside_compilation_cluster:
616      op._set_attr(
617          _OUTSIDE_COMPILATION_ATTR,
618          attr_value_pb2.AttrValue(
619              s=compat.as_bytes(self._outside_compilation_cluster)))
620    if self._num_replicas > 1 or not self._outside_compilation_cluster:
621      # Prevent feeding or fetching anything that is being compiled,
622      # and any replicated outside_compilation Op.
623      op.graph.prevent_feeding(op)
624      op.graph.prevent_fetching(op)
625
626    # Remove any control edges from outer control flow contexts. These may cause
627    # mismatched frame errors.
628    (internal_control_inputs,
629     external_control_inputs) = self._RemoveExternalControlEdges(op)
630
631    if not op.inputs:
632      # Add a control edge from the control pivot to this op.
633      if not internal_control_inputs:
634        # pylint: disable=protected-access
635        op._add_control_input(self.GetControlPivot())
636        # pylint: enable=protected-access
637    else:
638      for index in range(len(op.inputs)):
639        x = op.inputs[index]
640        real_x = self.AddValue(x)
641        if real_x is not x:
642          op._update_input(index, real_x)  # pylint: disable=protected-access
643
644    if external_control_inputs:
645      # Use an identity to pull control inputs as data inputs. Note that we
646      # ignore ops which don't have outputs. TODO(phawkins): fix that.
647      with ops.control_dependencies(None):
648        self.Enter()
649        external_control_inputs = [
650            array_ops.identity(x.outputs[0]).op
651            for x in external_control_inputs
652            if x.outputs
653        ]
654        self.Exit()
655      # pylint: disable=protected-access
656      op._add_control_inputs(external_control_inputs)
657      # pylint: enable=protected-access
658
659    # Mark op's outputs as seen by this context and any outer contexts.
660    output_names = [x.name for x in op.outputs]
661    context = self
662    while context is not None:
663      # pylint: disable=protected-access
664      context._values.update(output_names)
665      context = context._outer_context
666      # pylint: enable=protected-access
667
668    if self._outer_context:
669      self._outer_context.AddInnerOp(op)
670
671  def AddValue(self, val: core_types.Tensor) -> core_types.Tensor:
672    """Add `val` to the current context and its outer context recursively."""
673    if not self._outer_context:
674      return val
675
676    if val.name in self._values:
677      # Use the real value if it comes from outer context.
678      result = self._external_values.get(val.name)
679      return val if result is None else result
680
681    result = val
682    self._values.add(val.name)
683    if self._outer_context:
684      result = self._outer_context.AddValue(val)
685      self._values.add(result.name)
686
687    self._external_values[val.name] = result
688
689    return result
690
691  def AddInnerOp(self, op: ops.Operation):
692    self.AddOp(op)
693    if self._outer_context:
694      self._outer_context.AddInnerOp(op)
695
696  @property
697  def grad_state(self):
698    # Define the gradient loop state associated with the TPUReplicateContext to
699    # be None as the TPUReplicateContext does not get nested nor does the
700    # grad_state outside the TPUReplicateContext affect the graph inside so the
701    # grad_state should be as if this is the top-level gradient state.
702    return None
703
704  @property
705  def back_prop(self):
706    """Forwards to the enclosing while context, if any."""
707    if self.GetWhileContext():
708      return self.GetWhileContext().back_prop
709    return False
710
711  def GetControlPivot(self) -> ops.Operation:
712    return self._pivot
713
714  def RequiresUniqueFunctionRetracing(self):
715    # More context: b/158152827. TPU stack uses the TPUReplicateContext to
716    # create replicated variable handles and cluster TPU computations, thus we
717    # always retrace a tf.function when the wrapped TPUReplicateContext changes.
718    return True
719
720
721class OutsideCompilationV2Context(control_flow_ops.ControlFlowContext):
722  """The context for outside compilation in Tensorflow 2.0.
723
724  Every op added in this context will be assigned an _xla_outside_compilation
725  attribute.
726  """
727
728  def __init__(self, name: Text):
729    control_flow_ops.ControlFlowContext.__init__(self)
730    self._name = name
731
732  def AddOp(self, op: ops.Operation) -> None:
733    if self._outer_context:
734      self._outer_context.AddOp(op)
735    # pylint: disable=protected-access
736    op._set_attr("_xla_outside_compilation",
737                 attr_value_pb2.AttrValue(s=compat.as_bytes(self._name)))
738    # pylint: enable=protected-access
739
740  def AddInnerOp(self, op: ops.Operation) -> None:
741    if self._outer_context:
742      self._outer_context.AddInnerOp(op)
743    # pylint: disable=protected-access
744    op._set_attr("_xla_outside_compilation",
745                 attr_value_pb2.AttrValue(s=compat.as_bytes(self._name)))
746    # pylint: enable=protected-access
747
748  def to_control_flow_context_def(self, context_def, export_scope=None):
749    raise NotImplementedError
750
751
752@tf_export(v1=["tpu.outside_compilation"])
753def outside_compilation(
754    computation: Callable[..., Any], *args, **kwargs
755    ) -> Any:
756  """Builds part of a computation outside any current TPU replicate scope.
757
758  `tf.tpu.outside_compilation()` is used to run ops in `computation` on CPU
759  instead of running on TPU. For example, users can run ops that are not
760  supported on TPU's (e.g. tf.summary.write()) by explicitly placing those
761  ops on CPU's. Below usage of outside compilation will place ops in
762  `computation_with_string_ops` on CPU.
763
764  Example usage:
765
766  ```python
767  def computation_with_string_ops(x):
768    # strings types are not supported on TPU's and below ops must
769    # run on CPU instead.
770    output = tf.strings.format('1{}', x)
771    return tf.strings.to_number(output)
772
773  def tpu_computation():
774    # Expected output is 11.
775    output = tf.tpu.outside_compilation(computation_with_string_ops, 1)
776  ```
777
778  Outside compilation should be called inside TPUReplicateContext. That is,
779  `tf.tpu.outside_compilation()` should be called inside a function that is
780  passed to `tpu.split_compile_and_replicate()` -- this is implied when
781  outside compilation is invoked inside a function passed to TPUStrategy
782  `run()`. If invoked outside of TPUReplicateContext,
783  then this simply returns the result of `computation`, and therefore,
784  would be a no-op. Note that outside compilation is different from
785  `tf.distribute.experimental.TPUStrategy.merge_call()` as logic in
786  outside compilation is replicated and executed separately for each
787  replica. On the other hand, `merge_call()` requires a `merge_fn`
788  to aggregate the inputs from different replicas and is executed only
789  once.
790
791  For variables placed in TPU device, which includes variables created inside
792  TPUStrategy scope, outside compilation logic must not include variable
793  read/write. For variables placed on host, which is the case when variables
794  created via TPUEstimator, variable read/write is only allowed if the variable
795  is not accessed by any other ops in the TPU computation. Variable read/write
796  from outside compilation cluster is not visible from TPU computation and
797  vice versa. Therefore, if outside compilation logic contains such host
798  variables read/write ops and if the variables are accessed by TPU
799  computation as well, then this may lead to deadlock.
800
801  Internally, `tf.tpu.outside_compilation()` adds outside compilation
802  attributes to all ops in `computation`. During later graph pass, these
803  ops with outside compilation attribute is extracted out and replicated
804  into a host-side graph. Inputs to this extract host-side graph is sent
805  from TPU computation graph to host graph via a pair of XlaSendToHost and
806  XlaRecvFromHost ops. Note that using `tf.tpu.outside_compilation()`
807  may result in tensor transfer between TPU and CPU, leading to non-trivial
808  performance impact.
809
810  Args:
811    computation: A Python function that builds the computation to
812      place on the host.
813    *args: the positional arguments for the computation.
814    **kwargs: the keyword arguments for the computation.
815
816  Returns:
817    The Tensors returned by computation.
818  """
819  args = [] if args is None else args
820  graph = ops.get_default_graph()
821
822  # If we are in TF 2 functions (control flow V2 functions, or tf.function()),
823  # we need to attach _xla_outside_compilation attribute directly because we are
824  # not in TPUReplicateContext.
825  if isinstance(graph, func_graph.FuncGraph):
826    try:
827      tpu_context, _ = _enclosing_tpu_context_and_graph()
828    except ValueError:
829      logging.warning(
830          "Outside compilation attempted outside TPUReplicateContext "
831          "scope. As no enclosing TPUReplicateContext can be found, "
832          "returning the result of `computation` as is.")
833      return computation(*args, **kwargs)
834
835    # pylint: disable=protected-access
836    outside_compilation_name = str(tpu_context._outside_compilation_counter)
837    tpu_context._outside_compilation_counter = (
838        tpu_context._outside_compilation_counter + 1)
839    # pylint: enable=protected-access
840
841    outside_compilation_context = OutsideCompilationV2Context(
842        outside_compilation_name)
843    outside_compilation_context.Enter()
844    args = [] if args is None else args
845    retval = computation(*args, **kwargs)
846    outside_compilation_context.Exit()
847    return retval
848
849  # If we are in a TPUReplicateContext, signal that we are now
850  # outside_compilation
851  initial_context = graph._get_control_flow_context()  # pylint: disable=protected-access
852  context = initial_context
853  while context:
854    if isinstance(context, TPUReplicateContext):
855      context._EnterOutsideCompilationScope()  # pylint: disable=protected-access
856    context = context.outer_context
857
858  retval = computation(*args, **kwargs)
859
860  # If we are in a TPUReplicateContext, signal that we are no longer
861  # outside_compilation
862  final_context = graph._get_control_flow_context()  # pylint: disable=protected-access
863  if initial_context is not final_context:
864    raise NotImplementedError(
865        "Control-flow context cannot be different at start and end of an "
866        "outside_compilation scope")
867  context = initial_context
868  while context:
869    if isinstance(context, TPUReplicateContext):
870      context._ExitOutsideCompilationScope()  # pylint: disable=protected-access
871    context = context.outer_context
872
873  return retval
874
875
876@tf_export(v1=["tpu.PaddingSpec"])
877class PaddingSpec(enum.IntEnum):
878  """Represents the type of padding policies for tpu.replicate."""
879  # By default the policy is set to AUTO, the dynamic input shape dimension will
880  # be pad to maximum of all the replicas.
881  AUTO = 0
882  # Bucketize the dynamic input shape dimension into a power of 2.
883  POWER_OF_TWO = 1
884
885
886@tf_export("tpu.XLAOptions")
887class XLAOptions(
888    collections.namedtuple("XLAOptions", [
889        "use_spmd_for_xla_partitioning",
890        "enable_xla_dynamic_padder",
891    ])):
892  """XLA compilation options.
893
894  Attributes:
895    use_spmd_for_xla_partitioning: Boolean. Whether to use XLA's SPMD
896      partitioner instead of MPMD partitioner when compiler partitioning is
897      requested.
898    enable_xla_dynamic_padder: Boolean. Whether to enable XLA dynamic padder
899      infrastructure to handle dynamic shapes inputs inside XLA. True by
900      default. Disabling this may cause correctness issues with dynamic shapes
901      inputs, as XLA will just assume the inputs are with padded shapes. However
902      users can optionally set it to False to improve device time if masking is
903      already handled in the user side.
904  """
905
906  def __new__(cls,
907              use_spmd_for_xla_partitioning=True,
908              enable_xla_dynamic_padder=True):
909    return super(XLAOptions, cls).__new__(cls, use_spmd_for_xla_partitioning,
910                                          enable_xla_dynamic_padder)
911
912
913@tf_export(v1=["tpu.replicate"])
914@traceback_utils.filter_traceback
915def replicate(
916    computation: Callable[..., Any],
917    inputs: Optional[List[List[core_types.Tensor]]] = None,
918    infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
919    device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
920    name: Optional[Text] = None,
921    maximum_shapes: Optional[Any] = None,
922    padding_spec: Optional[PaddingSpec] = None,
923    xla_options: Optional[XLAOptions] = None) -> List[Any]:
924  """Builds a graph operator that runs a replicated TPU computation.
925
926  Example for the basic usage that `inputs` has static shape:
927
928  ```python
929
930  def computation(x):
931    x = x + 1
932    return tf.math.reduce_mean(x)
933
934  x = tf.convert_to_tensor([1., 2., 3.])
935  y = tf.convert_to_tensor([4., 5., 6.])
936  tf.compat.v1.tpu.replicate(computation, inputs=[[x], [y]])
937  ```
938
939  If the `inputs` has dynamic shapes and you would like to automatically
940  bucketize the inputs to avoid XLA recompilation. See the advanced example
941  below:
942
943  ```python
944
945  def computation(x):
946    x = x + 1
947    return tf.math.reduce_mean(x)
948
949  # Assume input tensors in two replicas `x` and `y` both have dynamic shape
950  # ([None, 2]).
951  tf.compat.v1.tpu.replicate(
952    computation,
953    inputs=[x, y],
954    maximum_shapes=[tf.TensorShape([None, None])],
955    padding_spec=tf.compat.v1.tpu.PaddingSpec.POWER_OF_TWO)
956  ```
957
958  Args:
959    computation: A Python function that builds the computation to replicate.
960    inputs: A list of lists of input tensors or `None` (equivalent to
961      `[[]]`), indexed by `[replica_num][input_num]`. All replicas must
962      have the same number of inputs. Each input can be a nested structure
963      containing values that are convertible to tensors. Note that passing an
964      N-dimension list of compatible values will result in a N-dimension list of
965      scalar tensors rather than a single Rank-N tensors. If you need different
966      behavior, convert part of inputs to tensors with `tf.convert_to_tensor`.
967    infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
968      of arguments as inputs to computation.
969    device_assignment: If not `None`, a `DeviceAssignment` describing the
970      mapping between logical cores in the computation with physical cores in
971      the TPU topology. Uses a default device assignment if `None`. The
972      `DeviceAssignment` may be omitted if each replica of the computation uses
973      only one core, and there is either only one replica, or the number of
974      replicas is equal to the number of cores in the TPU system.
975    name: (Deprecated) Does nothing.
976    maximum_shapes: A nested structure of tf.TensorShape representing the shape
977      to which the respective component of each input element in each replica
978      should be padded. Any unknown dimensions (e.g.
979      tf.compat.v1.Dimension(None) in a tf.TensorShape or -1 in a tensor-like
980      object) will be padded to the maximum size of that dimension over all
981      replicas. The structure of `maximum_shapes` needs to be the same as
982      `inputs[0]`.
983    padding_spec: An enum specified by `tpu.PaddingSpec`. This describes the
984      padding policy when the `inputs` to `tpu.replicate` is dynamic.
985      One usage is to enable automatic bucketizing on the inputs by setting the
986      value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the
987      recompilation in the XLA side.
988    xla_options: An instance of `tpu.XLAOptions` which indicates the options
989      passed to XLA compiler. Use `None` for default options.
990  Returns:
991    A list of outputs, indexed by `[replica_num]` each output can be a nested
992    structure same as what computation() returns with a few exceptions.
993
994    Exceptions include:
995      1) None output: a NoOp would be returned which control-depends on
996         computation.
997      2) Single value output: A tuple containing the value would be returned.
998      3) Operation-only outputs: a NoOp would be returned which
999         control-depends on computation.
1000      TODO(b/121383831): Investigate into removing these special cases.
1001
1002  Raises:
1003    ValueError: If all replicas do not have equal numbers of input tensors.
1004    ValueError: If the number of inputs per replica does not match
1005      the number of formal parameters to `computation`.
1006    ValueError: If the static `inputs` dimensions don't match with the values
1007      given in `maximum_shapes`.
1008    ValueError: If the structure of inputs per replica does not match
1009      the structure of `maximum_shapes`.
1010  """
1011  return split_compile_and_replicate(
1012      computation,
1013      inputs,
1014      infeed_queue,
1015      device_assignment,
1016      name,
1017      maximum_shapes=maximum_shapes,
1018      padding_spec=padding_spec,
1019      xla_options=xla_options)[1]
1020
1021
1022def _ceil_to_pow_of_n(x, n):
1023  """Ceil input `x` to power of `n`."""
1024  x = math_ops.cast(x, dtypes.float32)
1025  lognx = math_ops.log(x) / math_ops.log(n * 1.0)
1026  lognx = math_ops.ceil(lognx)
1027  result = math_ops.pow(n * 1.0, lognx)
1028  result = math_ops.cast(result, dtypes.int32)
1029  return result
1030
1031
1032def _pad_all_input(
1033    inputs: Iterable[core_types.Tensor],
1034    padded_shapes: List[Optional[tensor_shape.TensorShape]],
1035    padding_spec: PaddingSpec
1036) -> Tuple[List[List[Any]], List[dynamic_padding.PaddingMap]]:
1037  """Pad all input tensors given padded_shapes.
1038
1039  The real shape tensors will be concatenated with the padded original inputs.
1040
1041  Args:
1042    inputs: The original inputs.
1043    padded_shapes: A list of padded shapes for each input. If an entry is None,
1044      no padding is performed.
1045    padding_spec: An enum specified by `tpu.PaddingSpec`. This describes the
1046      padding policy when the `inputs` to `tf.tpu.replicate` is dynamic.
1047      One usage is to enable automatic bucketizing on the inputs by setting the
1048      value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the
1049      recompilation in the XLA side.
1050
1051  Returns:
1052    The padded inputs and a PaddingMap list which maps the padded input
1053    dimension to the real shape argument index.
1054  """
1055  # maximum_static_shapes[idx][i] indicates the maximum static size of ith
1056  # dimension of the idx input among all the replicas.
1057  maximum_static_shapes = []
1058  # need_padding[idx][i] indicates whether the ith dimension of the idx input
1059  # needs padding.
1060  need_padding = []
1061  input_shape_tensors = []
1062  for core_idx, inputs_per_core in enumerate(inputs):
1063    for idx, input_tensor in enumerate(inputs_per_core):
1064      input_shape = input_tensor.get_shape().as_list()
1065      if core_idx == 0:
1066        input_shape_tensors.append([])
1067        maximum_static_shapes.append(input_shape)
1068        need_padding.append(np.full_like(input_shape, False, dtype=bool))
1069      else:
1070        for i, s in enumerate(input_shape):
1071          if s is None or s != maximum_static_shapes[idx][i]:
1072            need_padding[idx][i] = True
1073        maximum_static_shapes[idx] = max(input_shape,
1074                                         maximum_static_shapes[idx])
1075
1076      # Append _POST_DEVICE_REWRITE_ATTR attributes to the real shape ops.
1077      real_input_shape = array_ops.shape(input_tensor)
1078      real_input_shape.op._set_attr(  # pylint: disable=protected-access
1079          _POST_DEVICE_REWRITE_ATTR,
1080          attr_value_pb2.AttrValue(b=True))
1081      input_shape_tensors[idx].append(real_input_shape)
1082
1083  maximum_shapes = []
1084  for shapes_per_input in input_shape_tensors:
1085    maximum_shapes.append(
1086        math_ops.reduce_max(array_ops.stack(shapes_per_input), axis=0))
1087
1088  padded_inputs = []
1089  real_shapes = []
1090  padding_maps = []
1091  for core_idx, inputs_per_core in enumerate(inputs):
1092    padded_inputs.append([])
1093    real_shapes.append([])
1094    real_shape_idx = len(inputs_per_core) - 1
1095    for idx, input_tensor in enumerate(inputs_per_core):
1096      input_shape_tensor = input_shape_tensors[idx][core_idx]
1097      input_shape = input_tensor.get_shape().as_list()
1098      padded_shape = padded_shapes[idx]
1099
1100      # If we have no padded_shape, then skip padding.
1101      if any(need_padding[idx]) and padded_shape is not None:
1102        for i, s in enumerate(input_shape):
1103          if need_padding[idx][i]:
1104            if core_idx == 0:
1105              real_shape_idx += 1
1106              padding_map = dynamic_padding.PaddingMap()
1107              padding_map.arg_index = idx
1108              padding_map.shape_index = i
1109              padding_map.padding_arg_index = real_shape_idx
1110              padding_maps.append(padding_map)
1111            real_shapes[core_idx].append(
1112                math_ops.cast(input_shape_tensor[i], dtypes.int32))
1113
1114        paddings = []
1115        for i, s in enumerate(padded_shape.dims):
1116          if need_padding[idx][i]:
1117            # The minimum padded dimension size is 2 as XLA doesn't support size
1118            # 1 dynamic size.
1119            minimum_dynamic_dim_size = 2
1120            if s.value is not None:
1121              # Pad to the given maximum value.
1122              max_dim_size = max(s.value, minimum_dynamic_dim_size)
1123            else:
1124              # If maximum value is not given, then pad to the maximum dimension
1125              # among all the cores.
1126              max_dim_size = math_ops.maximum(maximum_shapes[idx][i],
1127                                              minimum_dynamic_dim_size)
1128              if padding_spec == PaddingSpec.POWER_OF_TWO:
1129                max_dim_size = _ceil_to_pow_of_n(max_dim_size, 2)
1130            # Pad to the given maximum value.
1131            padding = [0, max_dim_size - input_shape_tensor[i]]
1132          else:
1133            padding = [0, 0]
1134          paddings.append(padding)
1135
1136        if input_tensor.get_shape().is_fully_defined():
1137          # TODO(rxsang): This is a hack to make sure padded_input has dynamic
1138          # shapes, so any tf.size/tf.shape op performed on it won't be constant
1139          # folded. Do we have better ways to do it?
1140          padded_input = control_flow_ops.cond(
1141              array_ops.constant(True),
1142              lambda: array_ops.pad(input_tensor, paddings),  # pylint: disable=cell-var-from-loop
1143              lambda: input_tensor)
1144        else:
1145          padded_input = array_ops.pad(input_tensor, paddings)
1146
1147        # Append _POST_DEVICE_REWRITE_ATTR attributes to all padded inputs.
1148        padded_input.op._set_attr(  # pylint: disable=protected-access
1149            _POST_DEVICE_REWRITE_ATTR,
1150            attr_value_pb2.AttrValue(b=True))
1151
1152        padded_inputs[core_idx].append(padded_input)
1153      else:
1154        padded_inputs[core_idx].append(input_tensor)
1155
1156  num_replicas = len(padded_inputs)
1157  for i in range(num_replicas):
1158    padded_inputs[i].extend(real_shapes[i])
1159
1160  return padded_inputs, padding_maps
1161
1162
1163def _flatten_and_filter_composite(maybe_composite, non_composite_output,
1164                                  composite_output=None):
1165  """For an input, replaced the input by a tuple if the input is composite.
1166
1167  If `maybe_composite` is not composite, return the parameter
1168  `non_composite_output` otherwise return a tuple which consists of the value of
1169  the parameter `composite_output` the same number of times as there are
1170  components of the composite tensor.
1171
1172  This is useful for computing a mask when flattening nested data with
1173  `expand_composites=True`. For example
1174
1175  ```python
1176  nest.flatten(data, expand_composites=True)
1177  ```
1178
1179  and
1180
1181  ```python
1182  nest.flatten(nest.map(
1183      data, lambda x: _flatten_and_filter_composite(x, False, True)))
1184  ```
1185
1186  will have the same length and second will be True if the tensor in the first
1187  is derived from a expanding a composite tensor.
1188
1189  Args:
1190    maybe_composite: A value to test for being a composite tensor.
1191    non_composite_output: The value to return when `maybe_composite` is not a
1192      composite.
1193    composite_output: the value to fill the output tuple with if
1194      `maybe_composite` is a composite.
1195
1196  Returns:
1197    `non_composite_output` or a tuple with multiple copies of
1198    `composite_output`.
1199  """
1200
1201  if isinstance(maybe_composite, composite_tensor.CompositeTensor):
1202    num_components = len(nest.flatten(maybe_composite, expand_composites=True))
1203    return (composite_output,) * num_components
1204  return non_composite_output
1205
1206
1207def split_compile_and_replicate(
1208    computation: Callable[..., Any],
1209    inputs: Optional[List[List[core_types.Tensor]]] = None,
1210    infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
1211    device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
1212    name: Optional[Text] = None,
1213    use_tpu: bool = True,
1214    maximum_shapes: Optional[Any] = None,
1215    padding_spec: Optional[PaddingSpec] = None,
1216    xla_options: Optional[XLAOptions] = None,
1217) -> List[List[core_types.Tensor]]:
1218  """Builds graph operators that runs compilation and replicated computation.
1219
1220  This is a lower level interface than replicate that returns a separate compile
1221  and execute output tensor. In the generated graph the compile op feeds into
1222  the execute op and no additional compilation is incurred when running the
1223  compile op before the execute op. The compile op returns additional
1224  information about the compilation but does not return the compiled program.
1225
1226  Args:
1227    computation: A Python function that builds the computation to replicate.
1228    inputs: A list of lists of input tensors or `None` (equivalent to
1229      `[[]]`), indexed by `[replica_num][input_num]`. All replicas must
1230      have the same number of inputs. Each input can be a nested structure
1231      containing values that are convertible to tensors. Note that passing an
1232      N-dimension list of compatible values will result in a N-dimension list of
1233      scalar tensors rather than a single Rank-N tensors. If you need different
1234      behavior, convert part of inputs to tensors with `tf.convert_to_tensor`.
1235    infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
1236      of arguments as inputs to computation.
1237    device_assignment: If not `None`, a `DeviceAssignment` describing the
1238      mapping between logical cores in the computation with physical cores in
1239      the TPU topology. Uses a default device assignment if `None`. The
1240      `DeviceAssignment` may be omitted if each replica of the computation uses
1241      only one core, and there is either only one replica, or the number of
1242      replicas is equal to the number of cores in the TPU system.
1243    name: (Deprecated) Does nothing.
1244    use_tpu: When false, the input `computation` is executed on the XLA CPU/GPU
1245      backends. Currently, only supports a default placement (computation is
1246      placed on GPU if one is available, and on CPU if not).
1247    maximum_shapes: A nested structure of tf.TensorShape representing the shape
1248      to which the respective component of each input element in each replica
1249      should be padded. Any unknown dimensions (e.g.
1250      tf.compat.v1.Dimension(None) in a tf.TensorShape or -1 in a tensor-like
1251      object) will be padded to the maximum size of that dimension over all
1252      replicas. The structure of `maximum_shapes` needs to be the same as
1253      `inputs[0]`.
1254    padding_spec: An enum specified by `tf.tpu.PaddingSpec`. This describes the
1255      padding policy when the `inputs` to `tf.tpu.replicate` is dynamic.
1256      One usage is to enable automatic bucketizing on the inputs by setting the
1257      value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the
1258      recompilation in the XLA side.
1259    xla_options: An instance of `tpu.XLAOptions` which indicates the options
1260      passed to XLA compiler. Use `None` for default options.
1261
1262  Returns:
1263    A list of lists with the first list corresponding to the compile op and the
1264    second a list of output tensors, indexed by `[replica_num][output_num]`.
1265  Raises:
1266    ValueError: If all replicas do not have equal numbers of input tensors.
1267    ValueError: If the number of inputs per replica does not match
1268      the number of formal parameters to `computation`.
1269    ValueError: If the static `inputs` dimensions don't match with the values
1270      given in `maximum_shapes`.
1271    ValueError: If the structure of inputs per replica does not match
1272      the structure of `maximum_shapes`.
1273  """
1274  del name
1275  inputs = [[]] if inputs is None else inputs
1276  xla_options = xla_options or XLAOptions()
1277
1278  metadata_kwargs = {}
1279  if device_assignment is not None:
1280    # Turn the Numpy array into a flattened list so we can pass it as an
1281    # operator attribute.
1282    metadata_kwargs = {
1283        "topology":
1284            device_assignment.topology.serialized(),
1285        "device_assignment":
1286            device_assignment.core_assignment.flatten().tolist()
1287    }
1288    metadata_kwargs["num_cores_per_replica"] = (
1289        device_assignment.num_cores_per_replica)
1290
1291  # This entry is used for enabling automatic outside compilation.
1292  metadata_kwargs["allow_soft_placement"] = config.get_soft_device_placement()
1293  if config.get_soft_device_placement():
1294    logging.info("Automatic outside compilation is enabled. "
1295                 "Ops without XLA kernels will be automatically "
1296                 "placed on CPU.")
1297
1298  if not isinstance(inputs, list):
1299    raise TypeError("tpu.replicate() inputs must be a list of lists/tuples, "
1300                    f"received {type(inputs)}")
1301  if any(not isinstance(inp, (list, tuple)) for inp in inputs):
1302    raise TypeError(
1303        "tpu.replicate() inputs must be a list of lists/tuples, "
1304        f"received types: {[type(inp) for inp in inputs]}")
1305
1306  num_replicas = len(inputs)
1307
1308  # No replicas? Nothing to do.
1309  if num_replicas == 0:
1310    return []
1311
1312  # Checks all replicas have the same structure.
1313  for i in range(1, num_replicas):
1314    nest.assert_same_structure(inputs[0], inputs[i])
1315
1316  # Explicitly read variables.
1317  inputs = variable_utils.convert_variables_to_tensors(inputs)
1318  # Flatten inputs. This structure may contain None values, which will be
1319  # handled later.
1320  flat_inputs_with_nones = [
1321      nest.flatten(per_replica_input, expand_composites=True)
1322      for per_replica_input in inputs
1323  ]
1324  # Mask parallel to one replica's inputs with True for tensors coming from
1325  # composites.
1326  is_composite = nest.flatten(nest.map_structure(
1327      lambda x: _flatten_and_filter_composite(x, False, True), inputs[0]))
1328
1329  # Converts inputs to Tensors, replacing Nones with a placeholder 0 since
1330  # tpu_ops.tpu_replicated_input() can't handle non-Tensor values.
1331  flat_inputs = []
1332  for inp in flat_inputs_with_nones:
1333    flat_inputs.append([
1334        constant_op.constant(0) if x is None else ops.convert_to_tensor(x)
1335        for x in inp
1336    ])
1337
1338  # Verifies that all replicas have matching numbers and types of inputs
1339  flat_input_types = [x.dtype for x in flat_inputs[0]]
1340  input_arity = len(inputs[0])
1341  flat_input_arity = len(flat_input_types)
1342  for i in range(num_replicas):
1343    if len(inputs[i]) != input_arity:
1344      raise ValueError("Replicas must have the same number of inputs. "
1345                       "Replica 0 had {} inputs, replica {} had {} "
1346                       "inputs.".format(input_arity, i, len(inputs[i])))
1347
1348    types = [x.dtype for x in flat_inputs[i]]
1349    if types != flat_input_types:
1350      raise ValueError("Replicas must have matching input types. Replica 0 had "
1351                       "input types {}, replica {} had input types {}".format(
1352                           flat_input_types, i, types))
1353
1354  arg_error = xla.check_function_argument_count(
1355      computation, input_arity, infeed_queue)
1356  if arg_error is not None:
1357    if infeed_queue is None:
1358      raise TypeError(
1359          "Supplied computation cannot be called with the specified inputs. "
1360          f"You specified {input_arity} inputs: {[i.name for i in inputs[0]]}, "
1361          f"but the computation needs {arg_error}")
1362    else:
1363      raise TypeError(
1364          "Supplied computation cannot be called with the specified inputs. "
1365          f"You specified {input_arity} inputs: {[i.name for i in inputs[0]]} ",
1366          f"and {infeed_queue.number_of_tuple_elements} additional inputs "
1367          f"from infeed, but the computation needs {arg_error}")
1368
1369  dynamic_shape_inputs = False
1370  if maximum_shapes:
1371    if infeed_queue:
1372      raise ValueError(
1373          "Dynamic input shapes are not supported with infeed queues")
1374
1375    # Make sure maximum_shapes has the same structure as inputs.
1376    nest.assert_same_structure(inputs[0], maximum_shapes, check_types=False)
1377
1378    # Flatten padded shapes:
1379    # For composite tensor components, we don't want to pad them. For each
1380    # entry of maximum_shapes that corresponds to a composite tensor, replace it
1381    # by a tuple of Nones of the same length as the number of components of the
1382    # composite tensor. When we flatten a second time, this makes
1383    # flat_maximum_shapes have the same length as flat_inputs[i]. We can then
1384    # avoid padding these tensors. The assumption is that they will be used by
1385    # outside compilation or that the components are statically shaped and will
1386    # be used by tpu compatible ops.
1387    flat_maximum_shapes = nest.flatten(
1388        [_flatten_and_filter_composite(x, y)
1389         for x, y in zip(nest.flatten(inputs[0]),
1390                         nest.flatten(maximum_shapes))])
1391    flat_maximum_shapes = [
1392        tensor_shape.TensorShape(s) if s is not None else None
1393        for s in flat_maximum_shapes
1394    ]
1395    nest.assert_same_structure(flat_inputs[0], flat_maximum_shapes,
1396                               check_types=False)
1397
1398    unpadded_inputs = flat_inputs
1399    flat_inputs, padding_maps = _pad_all_input(unpadded_inputs,
1400                                               flat_maximum_shapes,
1401                                               padding_spec)
1402    if padding_maps:
1403      dynamic_shape_inputs = True
1404      logging.info("TPU has inputs with dynamic shapes: %s", unpadded_inputs[0])
1405
1406  metadata_kwargs["step_marker_location"] = getattr(
1407      computation, "step_marker_location", "STEP_MARK_AT_ENTRY")
1408  metadata_kwargs["use_spmd_for_xla_partitioning"] = \
1409      xla_options.use_spmd_for_xla_partitioning
1410
1411  graph = ops.get_default_graph()
1412
1413  # Fan-in: Builds a TPUReplicatedInput node for each input.
1414  flat_replicated_inputs = []
1415  for i in range(0, len(flat_inputs[0])):
1416    replicas = [flat_inputs[replica][i] for replica in range(num_replicas)]
1417    flat_replicated_inputs.append(
1418        tpu_ops.tpu_replicated_input(
1419            replicas, name="input{}".format(i)))
1420  if isinstance(graph, func_graph.FuncGraph):
1421    # When we are in Tensorflow 2.0 function, 'graph' will be a FuncGraph
1422    # object. If both outside graph and this function have a TPU cluster,
1423    # they will have the same cluster name and it will cause problems (because
1424    # we lower functional ops in Tensorflow 2.0). Append function name to
1425    # 'cluster_name' to avoid cluster name collision.
1426    cluster_name = graph.unique_name("cluster_" + graph.name)
1427  else:
1428    cluster_name = graph.unique_name("cluster")
1429  pivot = control_flow_ops.no_op(name=cluster_name + "/pivot")
1430  pivot._set_attr(_PIVOT_FOR_CLUSTER,  # pylint: disable=protected-access
1431                  attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name)))
1432  context = TPUReplicateContext(
1433      name=cluster_name, num_replicas=num_replicas, pivot=pivot)
1434  try:
1435    context.Enter()
1436
1437    metadata = tpu_ops.tpu_replicate_metadata(
1438        num_replicas=num_replicas, use_tpu=use_tpu, **metadata_kwargs)
1439
1440    with tpu_function.tpu_shard_context(
1441        num_replicas), ops.control_dependencies([metadata]):
1442
1443      if dynamic_shape_inputs and xla_options.enable_xla_dynamic_padder:
1444        for padding_map in padding_maps:
1445          input_shape = flat_replicated_inputs[padding_map.arg_index].shape
1446          flat_replicated_inputs[
1447              padding_map.arg_index] = tf2xla.set_dynamic_dimension_size(
1448                  flat_replicated_inputs[padding_map.arg_index],
1449                  padding_map.shape_index,
1450                  flat_replicated_inputs[padding_map.padding_arg_index])
1451          flat_replicated_inputs[padding_map.arg_index].set_shape(input_shape)
1452
1453      # Add identity ops so even unused inputs are "consumed" by the
1454      # computation. This is to avoid orphaned TPUReplicatedInput nodes.
1455      # TODO(phawkins): consider instead pruning unused TPUReplicatedInput
1456      # and eliding trivial TPUReplicatedInput/TPUReplicatedOutput pairs.
1457      flat_replicated_inputs = [
1458          array_ops.identity(x, name="replicated_input_{}".format(i))
1459          for i, x in enumerate(flat_replicated_inputs)
1460      ]
1461      for i, composite in zip(flat_replicated_inputs, is_composite):
1462        # pylint: disable=protected-access
1463        # Add an attribute to the identity node so that they could be removed in
1464        # encapsulate TPU computation pass if unused. However we don't remove
1465        # inputs when dynamic padding is enabled.
1466        # TODO(rxsang): Use other ways except argument index in padding_map so
1467        # outside compilation can work with dynamic padding correctly.
1468        if not dynamic_shape_inputs or composite:
1469          i.op._set_attr("_tpu_input_identity",
1470                         attr_value_pb2.AttrValue(b=True))
1471        # pylint: enable=protected-access
1472
1473      # Clobber replicated placeholders with Nones.
1474      computation_inputs = [
1475          None if inp is None else replicated for replicated, inp in zip(
1476              flat_replicated_inputs, flat_inputs_with_nones[0])
1477      ]
1478
1479      # Unflatten the computation inputs to match original input structure.
1480      computation_inputs = nest.pack_sequence_as(
1481          structure=inputs[0],
1482          flat_sequence=computation_inputs[:flat_input_arity],
1483          expand_composites=True)
1484
1485      # If there is an infeed queue, adds the dequeued values to the
1486      # computation's inputs.
1487      if infeed_queue is not None:
1488        infeed_queue.set_number_of_shards(num_replicas)
1489        for t in infeed_queue.generate_dequeue_op():
1490          computation_inputs.append(t)
1491
1492      # Only resource variables work inside a TPU computation, so turn on
1493      # resource variables for the computation.
1494      # TODO(phawkins): consider removing this code. It will
1495      # be less confusing to clients if they knowingly choose to use resource
1496      # variables.
1497      # Partitioned variables is not supported (b/112311320).
1498      vscope = variable_scope.get_variable_scope()
1499      saved_use_resource = vscope.use_resource
1500      saved_custom_getter = vscope.custom_getter
1501
1502      def custom_getter(getter, name, *args, **kwargs):
1503        """Variables on TPU have a few restrictions."""
1504        partitioner = kwargs.get("partitioner", None)
1505        if partitioner is not None:
1506          kwargs["partitioner"] = None
1507          logging.warning(
1508              "Partitioned variables are not supported on TPU. Got "
1509              "`partitioner` that is %s for variable %s. "
1510              "Setting `partitioner` to `None`.", partitioner, name)
1511        if saved_custom_getter is None:
1512          return getter(name, *args, **kwargs)
1513        else:
1514          return saved_custom_getter(getter, name, *args, **kwargs)
1515
1516      vscope.set_use_resource(True)
1517      vscope.set_custom_getter(custom_getter)
1518
1519      outputs = computation(*computation_inputs)
1520
1521      vscope.set_use_resource(saved_use_resource)
1522      vscope.set_custom_getter(saved_custom_getter)
1523
1524      outputs = variable_utils.convert_variables_to_tensors(outputs)
1525
1526    need_spmd_partitioning = (
1527        xla_options.use_spmd_for_xla_partitioning and
1528        device_assignment is not None and
1529        device_assignment.num_cores_per_replica > 1)
1530    outputs_is_flat = xla.is_flat(outputs)
1531    if outputs_is_flat:
1532      output_tensors, control_deps, pack_template = _postprocess_flat_outputs(
1533          outputs, need_spmd_partitioning)
1534    else:
1535      output_tensors, control_deps, pack_template = (
1536          _postprocess_non_flat_outputs(outputs, need_spmd_partitioning))
1537
1538    # tensor_tracer imports tpu.py. Local import to tensor_tracer to avoid
1539    # import-cycle
1540    if typing.TYPE_CHECKING:
1541      tensor_tracer = Any
1542    else:
1543      # pylint: disable=g-import-not-at-top
1544      from tensorflow.python.tpu import tensor_tracer
1545      # pylint: enable=g-import-not-at-top
1546    if tensor_tracer.TensorTracer.is_enabled():
1547      if tf2.enabled():
1548        logging.warn("TF API ver >= 2.0 detected. "
1549                     "Tensor Tracer v1 is not enabled.")
1550      else:
1551        tt = tensor_tracer.TensorTracer()
1552        output_tensors = tt.trace_tpu(ops.get_default_graph(),
1553                                      output_tensors, control_deps,
1554                                      num_replicas)
1555
1556    context.ExitResult(output_tensors)
1557  finally:
1558    context.report_unsupported_operations()
1559    context.Exit()
1560    host_compute_core = context.HostComputeCore()
1561
1562  if host_compute_core:
1563    attr_value = attr_value_pb2.AttrValue()
1564    attr_value.list.s.extend(compat.as_bytes(x) for x in host_compute_core)
1565    metadata._set_attr("host_compute_core", attr_value)  # pylint: disable=protected-access
1566
1567  with ops.control_dependencies([metadata]):
1568    if use_tpu:
1569      compile_status = tpu_ops.tpu_compilation_result()
1570      op = compile_status.op
1571      attr_value = attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name))
1572      op._set_attr(_TPU_COMPILATION_STATUS_ATTR, attr_value)  # pylint: disable=protected-access
1573    else:
1574      compile_status = control_flow_ops.no_op(name="compilation_status")
1575
1576  if not output_tensors:
1577    # Returns a list of NoOps dependent on the replication Op, indexed by
1578    # [replica_num].
1579    return [
1580        compile_status,
1581        [
1582            control_flow_ops.group(control_deps, name="shard_%d" % i)
1583            for i in range(num_replicas)
1584        ]
1585    ]
1586
1587  # Fan-out: Builds a TPUReplicatedOutput node for each output.
1588  replicated_outputs = [[] for i in range(num_replicas)]
1589  for i, t in enumerate(output_tensors):
1590
1591    # None values returned by the computation can't be sent to
1592    # tpu_ops.tpu_replicated_output(), we handle them specially here. We can
1593    # avoid the placeholder 0 routine required on the inputs since outputs are
1594    # replicated per-tensor, not per-replica, so we can skip replication.
1595    if t is None:
1596      for replica in range(num_replicas):
1597        replicated_outputs[replica].append(None)
1598      continue
1599
1600    # Fan-out: Builds a TPUReplicatedOutput node for each output.
1601    ys = tpu_ops.tpu_replicated_output(
1602        t, num_replicas, name="output{}".format(i))
1603
1604    # Wraps the outputs in identity operators so the names of any possible
1605    # `fetch` nodes are preserved by the replication rewrite.
1606    with ops.control_dependencies(control_deps):
1607      for replica in range(num_replicas):
1608        replicated_outputs[replica].append(
1609            array_ops.identity(
1610                ys[replica], name="output_%d_shard_%d" % (i, replica)))
1611
1612  replicated_outputs = [
1613      nest.pack_sequence_as(pack_template, replica_outs, expand_composites=True)
1614      for replica_outs in replicated_outputs
1615  ]
1616
1617  return [compile_status, replicated_outputs]
1618
1619
1620def _postprocess_flat_outputs(
1621    outputs: Any,
1622    need_spmd_partitioning: bool
1623) -> Tuple[List[Optional[core_types.Tensor]], List[ops.Operation], List[Any]]:
1624  """Validates non-flat outputs, add backs device assignments and other attrs.
1625
1626  Args:
1627    outputs: Output from `computation` inside `tpu.rewrite`.
1628    need_spmd_partitioning: Whether XLA SPMD partitioning is needed.
1629
1630  Returns:
1631    - Tensors extracted from outputs.
1632    - Operations extracted from outputs.
1633    - A pack template for use with nest.pack_sequence_as to pack the tensors.
1634  """
1635  # Following code segment is to preserve legacy behavior. Previously we only
1636  # supported flat outputs and thus for consistency it was nice to convert even
1637  # single element into a tuple. But now that we support arbitrary output
1638  # structure, this is no longer necessary.
1639  # TODO(b/121383831): Migrate all legacy use cases and delete this special
1640  # case.
1641  # If the computation returns `None`, make it an empty tuple.
1642  if outputs is None:
1643    outputs = tuple()
1644
1645  # For legacy / backwards compatibility reasons we return a list for "flat"
1646  # output values (even if the user's flat return value was a different type or
1647  # even just a scalar value) so use nest.flatten to compute a flat list pack
1648  # template.
1649  pack_template = nest.flatten(outputs, expand_composites=False)
1650
1651  # Even though outputs is already "flat", we flatten any composites so their
1652  # component tensors can be tagged and replicated. The pack_template will be
1653  # used by the caller to repack the composite tensors.
1654  outputs = nest.flatten(outputs, expand_composites=True)
1655
1656  # Append `no_op` here so that fetching any return value of this function
1657  # will trigger TPUExecute node.
1658  outputs += (control_flow_ops.no_op(),)
1659
1660  maybe_convert = lambda x: None if x is None else ops.convert_to_tensor(x)
1661  try:
1662    if need_spmd_partitioning:
1663      outputs = [
1664          o if isinstance(o, ops.Operation) else maybe_convert(o)
1665          for o in outputs
1666      ]
1667    else:
1668      with ops.device(core(0)):
1669        outputs = [
1670            o if isinstance(o, ops.Operation) else maybe_convert(o)
1671            for o in outputs
1672        ]
1673  except Exception as e:
1674    raise ValueError(
1675        "TPU function return values must all either be Operations or "
1676        f"convertible to Tensors. Got error: {e}")
1677
1678  # Separates the returned Operations and Tensors.
1679  output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
1680  output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)]
1681
1682  if outputs != output_tensors + output_operations:
1683    raise ValueError(
1684        "TPU functions must return zero-or more Tensor values followed by "
1685        "zero or more Operations.")
1686
1687  # Trim operations off the end of the pack template. output_operations has 1
1688  # extra element due to the no-op that is added.
1689  if len(output_operations) > 1:
1690    pack_template = pack_template[:1 - len(output_operations)]
1691
1692  # Wraps outputs in Identity ops. Otherwise a replicated input copied
1693  # straight to an output would bypass the replicate(). This would be bad
1694  # because the TPUReplicatedInput/TPUReplicatedOutput operator would not
1695  # be rewritten away, leading to a runtime error.
1696  # TODO(phawkins): extend the rewrite to elide these nodes instead.
1697  new_output_tensors = []
1698  for t in output_tensors:
1699    if t is None:
1700      new_output_tensors.append(None)
1701    elif need_spmd_partitioning:
1702      o = array_ops.identity(t)
1703      # pylint: disable=protected-access
1704      o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True))
1705      # pylint: enable=protected-access
1706      new_output_tensors.append(o)
1707    else:
1708      with ops.device(t.device if t.device else core(0)):
1709        o = array_ops.identity(t)
1710        # pylint: disable=protected-access
1711        o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True))
1712        # pylint: enable=protected-access
1713        new_output_tensors.append(o)
1714  return new_output_tensors, output_operations, pack_template
1715
1716
1717def _postprocess_non_flat_outputs(
1718    outputs: Any,
1719    need_spmd_partitioning: bool
1720) -> Tuple[List[Optional[core_types.Tensor]], List[ops.Operation], List[Any]]:
1721  """Validates non-flat outputs, add backs device assignments and other attrs.
1722
1723  Args:
1724    outputs: Output from `computation` inside `tpu.rewrite`.
1725    need_spmd_partitioning: Whether XLA SPMD partitioning is needed.
1726
1727  Returns:
1728    - Tensors extracted from outputs.
1729    - An empty Operations list because Operations are not allowed in non-flat
1730      outputs.
1731    - A pack template for use with nest.pack_sequence_as to pack the tensors.
1732  """
1733
1734  # Flatten output items.
1735  flat_outputs = nest.flatten(outputs, expand_composites=True)
1736
1737  # Convert all non-None non-Operation outputs to Tensors.
1738  for i, o in enumerate(flat_outputs):
1739    if o is None:
1740      flat_outputs[i] = None
1741      continue
1742
1743    if isinstance(o, ops.Operation):
1744      raise ValueError(
1745          "tpu.rewrite does not support Operation as return value in non-flat "
1746          "output structure. You can set returned Operations as control "
1747          "dependencies of returned Tensors so Operations are triggered when "
1748          f'Tensors are evaluated. Operation found: "{o.name}"')
1749
1750    try:
1751      o = ops.convert_to_tensor(o)
1752    except Exception as e:
1753      raise ValueError(
1754          "TPU function return values must all either be Operations or "
1755          f'convertible to Tensors. Got error: "{e}"')
1756
1757    # Wraps outputs in Identity ops. Otherwise a replicated input copied
1758    # straight to an output would bypass the replicate(). This would be bad
1759    # because the TPUReplicatedInput/TPUReplicatedOutput operator would not
1760    # be rewritten away, leading to a runtime error.
1761    # TODO(phawkins): extend the rewrite to elide these nodes instead.
1762    if need_spmd_partitioning:
1763      o = array_ops.identity(o)
1764      # pylint: disable=protected-access
1765      o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True))
1766      # pylint: enable=protected-access
1767      flat_outputs[i] = array_ops.identity(o)
1768    else:
1769      with ops.device(o.device if o.device else core(0)):
1770        o = array_ops.identity(o)
1771        # pylint: disable=protected-access
1772        o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True))
1773        # pylint: enable=protected-access
1774        flat_outputs[i] = array_ops.identity(o)
1775
1776  # All flat_outputs are Tensors, and no Operations.
1777  return flat_outputs, [], outputs
1778
1779
1780def split_compile_and_shard(
1781    computation: Callable[..., Any],
1782    inputs: Optional[List[List[Optional[core_types.Tensor]]]] = None,
1783    num_shards: int = 1,
1784    input_shard_axes: Optional[List[int]] = None,
1785    outputs_from_all_shards: Union[bool, List[bool]] = True,
1786    output_shard_axes: Optional[List[int]] = None,
1787    infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
1788    device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
1789    name: Optional[Text] = None,
1790    xla_options: Optional[XLAOptions] = None,
1791    ) -> Tuple[ops.Operation, List[core_types.Tensor]]:
1792  """Shards `computation` for parallel execution.
1793
1794  `inputs` must be a list of Tensors or None (equivalent to an empty list), each
1795  of which has a corresponding split axis (from `input_shard_axes`). Each input
1796  is split into `num_shards` pieces along the corresponding axis, and
1797  computation is applied to each shard in parallel.
1798
1799  Tensors are broadcast to all shards if they are lexically captured by
1800  `computation`. e.g.,
1801
1802  x = tf.constant(7)
1803  def computation():
1804    return x + 3
1805  ... = shard(computation, ...)
1806
1807  If `outputs_from_all_shards` is true, the outputs from all shards of
1808  `computation` are concatenated back together along their `output_shard_axes`.
1809  Otherwise, each output is taken from an arbitrary shard.
1810
1811  Inputs and outputs of the computation must be at least rank-1 Tensors.
1812
1813  Args:
1814    computation: A Python function that builds a computation to apply to each
1815      shard of the input.
1816    inputs: A list of input tensors or None (equivalent to an empty list). Each
1817      input tensor has a corresponding shard axes, given by `input_shard_axes`,
1818      which must have size divisible by `num_shards`.
1819    num_shards: The number of shards.
1820    input_shard_axes: A list of dimensions along which to shard `inputs`, or
1821      `None`. `None` means "shard all inputs along dimension 0". If not `None`,
1822      there must be one dimension per input.
1823    outputs_from_all_shards: Boolean or list of boolean. For each output, if
1824      `True`, outputs from all shards are concatenated along the corresponding
1825      `output_shard_axes` entry. Otherwise, each output is taken
1826      from an arbitrary shard. If the argument is a boolean, the argument's
1827      value is used for each output.
1828    output_shard_axes: A list of dimensions along which to concatenate the
1829      outputs of `computation`, or `None`. `None` means "concatenate all outputs
1830      along dimension 0". If not `None`, there must be one dimension per output.
1831      Ignored if `outputs_from_all_shards` is False.
1832    infeed_queue: If not `None`, the `InfeedQueue` to use to augment the inputs
1833      of `computation`.
1834    device_assignment: If not `None`, a `DeviceAssignment` describing the
1835      mapping between logical cores in the computation with physical cores in
1836      the TPU topology. Uses a default device assignment if `None`. The
1837      `DeviceAssignment` may be omitted if each shard of the computation uses
1838      only one core, and there is either only one shard, or the number of shards
1839      is equal to the number of cores in the TPU system.
1840    name: (Deprecated) Does nothing.
1841    xla_options: An instance of `tpu.XLAOptions` which indicates the options
1842      passed to XLA compiler. Use `None` for default options.
1843  Returns:
1844    A tuple of (compile op, [output tensors]).
1845  Raises:
1846    ValueError: If num_shards <= 0
1847    ValueError: If len(input_shard_axes) != len(inputs)
1848    ValueError: If len(output_shard_axes) != len(outputs from `computation`)
1849  """
1850  # TODO(phawkins): consider adding support for broadcasting Tensors passed as
1851  # inputs.
1852
1853  if num_shards <= 0:
1854    raise ValueError(
1855        f"num_shards must be a positive integer. Received {num_shards}")
1856
1857  inputs = [] if inputs is None else inputs
1858  if not isinstance(inputs, list):
1859    raise TypeError("tpu.shard()'s inputs must be a list of Tensors or None. "
1860                    f"Received {type(inputs)}")
1861
1862  # Converts inputs to Tensors.
1863  inputs = [ops.convert_to_tensor(x) for x in inputs]
1864
1865  if input_shard_axes is None:
1866    input_shard_axes = [0] * len(inputs)
1867  if len(inputs) != len(input_shard_axes):
1868    raise ValueError("Length of input_shard_axes must be equal to the number "
1869                     f"of inputs. Received {len(inputs)} inputs and "
1870                     f"{len(input_shard_axes)} input_shard_axes.")
1871
1872  if inputs:
1873    # Splits the `inputs` along the corresponding `input_shard_axes`, giving
1874    # lists with layout [input][shard]
1875    split_inputs = [
1876        array_ops.split(x, num_shards, axis=axis)
1877        for (axis, x) in zip(input_shard_axes, inputs)]
1878
1879    # Transposes the input lists to have layout [shard][input]
1880    transposed_inputs = [list(i) for i in zip(*split_inputs)]
1881  else:
1882    transposed_inputs = [[]] * num_shards
1883
1884  compile_op, outputs = split_compile_and_replicate(
1885      computation,
1886      transposed_inputs,
1887      infeed_queue=infeed_queue,
1888      device_assignment=device_assignment,
1889      name=name,
1890      xla_options=xla_options)
1891
1892  # There must be at least one shard since num_shards > 0.
1893  # TODO(b/36647078) remove disable when pylint bug is fixed.
1894  # pylint: disable=indexing-exception
1895  if isinstance(outputs[0], ops.Operation):
1896    # pylint: enable=indexing-exception
1897    # There were no outputs from the computation and replicate returned a list
1898    # of NoOps with control dependencies on the computation. Return the first
1899    # one so it can be used as a control dependency or fetch node.
1900    # TODO(b/36647078) remove disable when pylint bug is fixed.
1901    # pylint: disable=indexing-exception
1902    return compile_op, [outputs[0]]
1903    # pylint: enable=indexing-exception
1904
1905  # TODO(b/36647078) remove disable when pylint bug is fixed.
1906  # pylint: disable=indexing-exception
1907  num_outputs = len(outputs[0])
1908  # pylint: enable=indexing-exception
1909
1910  if output_shard_axes is None:
1911    output_shard_axes = [0] * num_outputs
1912  if num_outputs != len(output_shard_axes):
1913    raise ValueError("Length of output_shard_axes must be equal to the number "
1914                     f"of outputs. Received {num_outputs} outputs "
1915                     f"and {len(output_shard_axes)} output_shard_axes.")
1916
1917  if isinstance(outputs_from_all_shards, bool):
1918    outputs_from_all_shards = [outputs_from_all_shards] * num_outputs
1919
1920  if num_outputs != len(outputs_from_all_shards):
1921    raise ValueError(
1922        "Length of outputs_from_all_shards must be equal to the number of "
1923        f"outputs. Received {num_outputs} outputs  and "
1924        f"{len(outputs_from_all_shards)} outputs_from_all_shards.")
1925
1926  results = []
1927  for (axis, all_shards, x) in zip(output_shard_axes, outputs_from_all_shards,
1928                                   zip(*outputs)):
1929    if all_shards:
1930      # Concatenate all of the outputs together (use stack for scalars).
1931      shape = x[0].shape
1932      is_scalar = shape is not None and (shape.ndims == 0)
1933      results.append((array_ops.stack(list(x)) if is_scalar
1934                      else array_ops.concat(list(x), axis=axis)))
1935    else:
1936      # TODO(phawkins): use a smarter policy, e.g., round-robin across shards.
1937      results.append(x[0])
1938
1939  return compile_op, results
1940
1941
1942@tf_export(v1=["tpu.shard"])
1943@traceback_utils.filter_traceback
1944def shard(
1945    computation: Callable[..., Any],
1946    inputs: Optional[List[core_types.Tensor]] = None,
1947    num_shards: int = 1,
1948    input_shard_axes: Optional[List[int]] = None,
1949    outputs_from_all_shards: Union[bool, List[bool]] = True,
1950    output_shard_axes: Optional[List[int]] = None,
1951    infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
1952    device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
1953    name: Optional[Text] = None,
1954    xla_options: Optional[XLAOptions] = None) -> List[core_types.Tensor]:
1955  """Shards `computation` for parallel execution.
1956
1957  `inputs` must be a list of Tensors or None (equivalent to an empty list), each
1958  of which has a corresponding split axis (from `input_shard_axes`). Each input
1959  is split into `num_shards` pieces along the corresponding axis, and
1960  computation is applied to each shard in parallel.
1961
1962  Tensors are broadcast to all shards if they are lexically captured by
1963  `computation`. e.g.,
1964
1965  x = tf.constant(7)
1966  def computation():
1967    return x + 3
1968  ... = shard(computation, ...)
1969
1970  TODO(phawkins): consider adding support for broadcasting Tensors passed
1971  as inputs.
1972
1973  If `outputs_from_all_shards` is true, the outputs from all shards of
1974  `computation` are concatenated back together along their `output_shard_axes`.
1975  Otherwise, each output is taken from an arbitrary shard.
1976
1977  Inputs and outputs of the computation must be at least rank-1 Tensors.
1978
1979  Args:
1980    computation: A Python function that builds a computation to apply to each
1981      shard of the input.
1982    inputs: A list of input tensors or None (equivalent to an empty list). Each
1983      input tensor has a corresponding shard axes, given by `input_shard_axes`,
1984      which must have size divisible by `num_shards`.
1985    num_shards: The number of shards.
1986    input_shard_axes: A list of dimensions along which to shard `inputs`, or
1987      `None`. `None` means "shard all inputs along dimension 0". If not `None`,
1988      there must be one dimension per input.
1989    outputs_from_all_shards: Boolean or list of boolean. For each output, if
1990      `True`, outputs from all shards are concatenated along the corresponding
1991      `output_shard_axes` entry. Otherwise, each output is taken
1992      from an arbitrary shard. If the argument is a boolean, the argument's
1993      value is used for each output.
1994    output_shard_axes: A list of dimensions along which to concatenate the
1995      outputs of `computation`, or `None`. `None` means "concatenate all outputs
1996      along dimension 0". If not `None`, there must be one dimension per output.
1997      Ignored if `outputs_from_all_shards` is False.
1998    infeed_queue: If not `None`, the `InfeedQueue` to use to augment the inputs
1999      of `computation`.
2000    device_assignment: If not `None`, a `DeviceAssignment` describing the
2001      mapping between logical cores in the computation with physical cores in
2002      the TPU topology. Uses a default device assignment if `None`. The
2003      `DeviceAssignment` may be omitted if each shard of the computation uses
2004      only one core, and there is either only one shard, or the number of shards
2005      is equal to the number of cores in the TPU system.
2006    name: (Deprecated) Does nothing.
2007    xla_options: An instance of `tpu.XLAOptions` which indicates the options
2008      passed to XLA compiler. Use `None` for default options.
2009  Returns:
2010    A list of output tensors.
2011  Raises:
2012    ValueError: If num_shards <= 0
2013    ValueError: If len(input_shard_axes) != len(inputs)
2014    ValueError: If len(output_shard_axes) != len(outputs from `computation`)
2015  """
2016  return split_compile_and_shard(
2017      computation,
2018      inputs=inputs,
2019      num_shards=num_shards,
2020      input_shard_axes=input_shard_axes,
2021      outputs_from_all_shards=outputs_from_all_shards,
2022      output_shard_axes=output_shard_axes,
2023      infeed_queue=infeed_queue,
2024      device_assignment=device_assignment,
2025      name=name,
2026      xla_options=xla_options)[1]
2027
2028
2029@tf_export(v1=["tpu.batch_parallel"])
2030@traceback_utils.filter_traceback
2031def batch_parallel(
2032    computation: Callable[..., Any],
2033    inputs: Optional[List[List[Optional[core_types.Tensor]]]] = None,
2034    num_shards: int = 1,
2035    infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
2036    device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
2037    name: Optional[Text] = None,
2038    xla_options: Optional[XLAOptions] = None):
2039  """Shards `computation` along the batch dimension for parallel execution.
2040
2041  Convenience wrapper around shard().
2042
2043  `inputs` must be a list of Tensors or None (equivalent to an empty list).
2044  Each input is split into `num_shards` pieces along the 0-th dimension, and
2045  computation is applied to each shard in parallel.
2046
2047  Tensors are broadcast to all shards if they are lexically captured by
2048  `computation`. e.g.,
2049
2050  x = tf.constant(7)
2051  def computation():
2052    return x + 3
2053  ... = shard(computation, ...)
2054
2055  The outputs from all shards are concatenated back together along their 0-th
2056  dimension.
2057
2058  Inputs and outputs of the computation must be at least rank-1 Tensors.
2059
2060  Args:
2061    computation: A Python function that builds a computation to apply to each
2062      shard of the input.
2063    inputs: A list of input tensors or None (equivalent to an empty list). The
2064      0-th dimension of each Tensor must have size divisible by `num_shards`.
2065    num_shards: The number of shards.
2066    infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
2067      of arguments as inputs to `computation`.
2068    device_assignment: If not `None`, a `DeviceAssignment` describing the
2069      mapping between logical cores in the computation with physical cores in
2070      the TPU topology. Uses a default device assignment if `None`. The
2071      `DeviceAssignment` may be omitted if each shard of the computation uses
2072      only one core, and there is either only one shard, or the number of shards
2073      is equal to the number of cores in the TPU system.
2074    name: (Deprecated) Does nothing.
2075    xla_options: An instance of `tpu.XLAOptions` which indicates the options
2076      passed to XLA compiler. Use `None` for default options.
2077  Returns:
2078    A list of output tensors.
2079  Raises:
2080    ValueError: If `num_shards <= 0`
2081  """
2082  return shard(
2083      computation,
2084      inputs,
2085      num_shards=num_shards,
2086      infeed_queue=infeed_queue,
2087      device_assignment=device_assignment,
2088      name=name,
2089      xla_options=xla_options)
2090
2091
2092@tf_export(v1=["tpu.rewrite"])
2093@traceback_utils.filter_traceback
2094def rewrite(
2095    computation: Callable[..., Any],
2096    inputs: Optional[List[List[Optional[core_types.Tensor]]]] = None,
2097    infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
2098    device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
2099    name: Optional[Text] = None,
2100    xla_options: Optional[XLAOptions] = None) -> Any:
2101  """Rewrites `computation` for execution on a TPU system.
2102
2103  Args:
2104    computation: A Python function that builds a computation to apply to the
2105      input. If the function takes n inputs, 'inputs' should be a list of n
2106      tensors.
2107
2108      `computation` may return a list of operations and tensors. Tensors must
2109      come before operations in the returned list.  The return value of
2110      `rewrite` is a list of tensors corresponding to the tensors from the
2111      output of `computation`.
2112
2113      All `Operation`s constructed during `computation` will be executed when
2114      evaluating any of the returned output tensors, not just the ones returned.
2115    inputs: A list of input tensors or `None` (equivalent to an empty list).
2116      Each input can be a nested structure containing values that are
2117      convertible to tensors. Note that passing an N-dimension list of
2118      compatible values will result in a N-dimension list of scalar tensors
2119      rather than a single Rank-N tensors. If you need different behavior,
2120      convert part of inputs to tensors with `tf.convert_to_tensor`.
2121    infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
2122      of arguments as inputs to `computation`.
2123    device_assignment: if not `None`, a `DeviceAssignment` describing the
2124      mapping between logical cores in the computation with physical cores in
2125      the TPU topology. May be omitted for a single-core computation, in which
2126      case the core attached to task 0, TPU device 0 is used.
2127    name: (Deprecated) Does nothing.
2128    xla_options: An instance of `tpu.XLAOptions` which indicates the options
2129      passed to XLA compiler. Use `None` for default options.
2130  Returns:
2131    Same data structure as if computation(*inputs) is called directly with some
2132    exceptions for correctness. Exceptions include:
2133      1) None output: a NoOp would be returned which control-depends on
2134         computation.
2135      2) Single value output: A tuple containing the value would be returned.
2136      3) Operation-only outputs: a NoOp would be returned which
2137         control-depends on computation.
2138      TODO(b/121383831): Investigate into removing these special cases.
2139  """
2140  # TODO(b/36647078) remove disable when pylint bug is fixed.
2141  # pylint: disable=indexing-exception
2142  return replicate(
2143      computation,
2144      None if inputs is None else [inputs],
2145      infeed_queue=infeed_queue,
2146      device_assignment=device_assignment,
2147      name=name,
2148      xla_options=xla_options)[0]
2149  # pylint: enable=indexing-exception
2150
2151  # Operations that indicate some error in the user's inference graph.
2152
2153
2154_DENYLISTED_INFERENCE_OPS = set([
2155    "ReadVariableOp",
2156    "AssignVariableOp",
2157    "AssignAddVariableOp",
2158    "AssignSubVariableOp",
2159    "VarHandleOp",
2160    "Variable",
2161    "VariableV2",
2162])
2163
2164
2165def under_tpu_inference_context() -> bool:
2166  """Check if it is currently under `_TPUInferenceContext`."""
2167  graph = ops.get_default_graph()
2168  while graph:
2169    context = graph._get_control_flow_context()  # pylint: disable=protected-access
2170    while context:
2171      if isinstance(context, _TPUInferenceContext):
2172        return True
2173      context = context.outer_context
2174    if isinstance(graph, function._FuncGraph):  # pylint: disable=protected-access
2175      graph = graph._outer_graph  # pylint: disable=protected-access
2176    elif isinstance(graph, func_graph.FuncGraph):
2177      graph = graph.outer_graph
2178    else:
2179      return False
2180
2181
2182class _TPUInferenceContext(control_flow_ops.XLAControlFlowContext):
2183  """A `ControlFlowContext` for nodes inside a TPU inference computation.
2184
2185  The primary role of `_TPUInferenceContext` is to indicate the mode of
2186  operation and possibly sanity check operators inside a
2187  tpu.rewrite_for_inference() computation.
2188  """
2189
2190  def __init__(self, name: Text, check_ops: bool = True):
2191    super(_TPUInferenceContext, self).__init__()
2192    self._name = name
2193    self._check_ops = check_ops
2194
2195  def AddOp(self, op):
2196    self._AddOpInternal(op)
2197
2198  def _AddOpInternal(self, op):
2199    # pylint: disable=protected-access
2200    if self._check_ops and op.type in _DENYLISTED_INFERENCE_OPS:
2201      raise NotImplementedError(
2202          f"Operation of type {op.type} ({op.name}) is not supported on the "
2203          "TPU for inference. Execution will fail if this op is used in the "
2204          "graph. Make sure your variables are using variable_scope.")
2205    if self._outer_context:
2206      self._outer_context.AddInnerOp(op)
2207
2208  def AddValue(self, val):
2209    result = val
2210    if self._outer_context:
2211      result = self._outer_context.AddValue(val)
2212    return result
2213
2214  def AddInnerOp(self, op):
2215    self._AddOpInternal(op)
2216
2217  @property
2218  def grad_state(self):
2219    return None
2220
2221
2222def validate_inference_rewrite_for_variables(graph: ops.Graph):
2223  """Validates whether rewrite_for_inference() 'worked' for variables.
2224
2225     The rewrite_for_inference() method is supposed to append GuaranteeConstOps
2226     after ReadVariableOps, but this mechanism works only if you are using
2227     tf.compat.v1.get_variable() to create and access variables in your tpu
2228     computation. This validation method can be called immediately after calling
2229     tpu.rewrite_for_inference() to check whether GuaranteeConstOps where added
2230     to the graph.
2231
2232     Typical usages:
2233       tpu.validate_inference_rewrite_for_variables(
2234           tf.compat.v1.get_default_graph())
2235
2236       tpu.validate_inference_rewrite_for_variables(sess.graph)
2237
2238  Args:
2239    graph: The graph which needs to be validated.
2240  Raises:
2241    RuntimeError: if validation failed.
2242  """
2243  if not any(x.type == "GuaranteeConst" for x in graph.get_operations()):
2244    raise RuntimeError(
2245        "No GuaranteeConst ops found in the graph after running "
2246        "tpu.rewrite_for_inference(...). Please check that you are using "
2247        "tf.get_variable() to create and access variables in your tpu "
2248        "computation.")
2249
2250
2251def rewrite_for_inference(
2252    computation: Callable[..., Any],
2253    inputs: Optional[List[core_types.Tensor]] = None,
2254    infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
2255    device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
2256    name: Optional[Text] = None) -> List[core_types.Tensor]:
2257  """Rewrites `computation` for inference on a TPU system.
2258
2259     Other than 'rewriting' the computation to run on a TPU, if using variables
2260     in your computation, it moves the ReadVariableOps outside the TPU
2261     computation, and adds GuaranteeConst ops just after the ReadVariableOps.
2262     This mechanism works only if you are using tf.compat.v1.get_variable() to
2263     create and access variables in your tpu computation. You can validate
2264     whether this worked, by calling validate_inference_rewrite_for_variables()
2265     method immediately after this method to check whether GuaranteeConstOps
2266     where added to the graph.
2267
2268  Args:
2269    computation: A Python function that builds a computation to apply to the
2270      input. If the function takes n inputs, 'inputs' should be a list of n
2271      tensors. If the function returns m outputs, rewrite will return a list of
2272      m tensors.
2273    inputs: A list of input tensors or `None` (equivalent to an empty list).
2274    infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
2275      of arguments as inputs to `computation`.
2276    device_assignment: if not `None`, a `DeviceAssignment` describing the
2277      mapping between logical cores in the computation with physical cores in
2278      the TPU topology. May be omitted for a single-core computation, in which
2279      case the core attached to task 0, TPU device 0 is used.
2280    name: The name of the operator.
2281  Returns:
2282    A list of output tensors.
2283  """
2284
2285  def guarantee_const_getter(getter, name, *args, **kwargs):
2286    with ops.control_dependencies(None):
2287      return array_ops.guarantee_const(
2288          getter(name, *args, **kwargs), name=name + "/GuaranteeConst")
2289
2290  def wrapped_computation(*args, **kwargs):
2291    """Execute computation under `_TPUInferenceContext`."""
2292    context = _TPUInferenceContext(
2293        name=ops.get_default_graph().unique_name("rewrite_for_inference"))
2294    try:
2295      context.Enter()
2296
2297      vscope = variable_scope.get_variable_scope()
2298      prev_custom_getter = vscope.custom_getter
2299      prev_caching_device = vscope.caching_device
2300      vscope.set_custom_getter(guarantee_const_getter)
2301      vscope.set_caching_device(lambda op: op.device)
2302
2303      result = computation(*args, **kwargs)
2304
2305      vscope.set_custom_getter(prev_custom_getter)
2306      vscope.set_caching_device(prev_caching_device)
2307    finally:
2308      context.Exit()
2309    return result
2310
2311  # pylint: disable=undefined-variable
2312  return rewrite(
2313      wrapped_computation,
2314      inputs=inputs,
2315      infeed_queue=infeed_queue,
2316      device_assignment=device_assignment,
2317      name=name)
2318  # pylint: enable=undefined-variable
2319
2320
2321def prune_unconnected_ops_from_xla(prune_graph: ops.Graph):
2322  """Prunes unconnected ops as listed in _UNCONNECTED_OPS_TO_PRUNE.
2323
2324  Args:
2325    prune_graph: A tensorflow graph from which we wish to prune unconnected ops
2326      as listed in _UNCONNECTED_OPS_TO_PRUNE.  In general, these ops should have
2327      no inputs and no consumers. These can often be left behind due to graph
2328      construction rewiring (for instance TF-Hub). While they never execute,
2329      they will cause XLA compile to fail so we strip them from XLA compile by
2330      removing the tpu_replicate attribute.
2331  """
2332  # Scan over the top level graph and all function graphs.
2333  for graph in [prune_graph] + [
2334      f for f in prune_graph._functions.values()  # pylint: disable=protected-access
2335  ]:
2336    if not isinstance(graph, ops.Graph):
2337      continue
2338    for op in graph.get_operations():
2339      if op.type not in _UNCONNECTED_OPS_TO_PRUNE:
2340        continue
2341      outputs_consumed = False
2342      for output in op.outputs:
2343        if output.consumers():
2344          outputs_consumed = True
2345          break
2346      if not outputs_consumed:
2347        logging.info(
2348            "Pruning OP %s of type %s from XLA Compile due to "
2349            "it being disconnected.", op.name, op.type)
2350        op._clear_attr(_TPU_REPLICATE_ATTR)  # pylint: disable=protected-access
2351