• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ======================================
15
16"""Library of TPU helper functions."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23import enum
24import typing
25from typing import Any, Callable, Iterable, List, Optional, Text, Tuple, Union
26
27from absl import logging
28import numpy as np
29from six.moves import xrange  # pylint: disable=redefined-builtin
30
31from tensorflow.compiler.tf2xla.python import xla as tf2xla
32from tensorflow.core.framework import attr_value_pb2
33from tensorflow.core.protobuf.tpu import dynamic_padding_pb2 as dynamic_padding
34from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 as embedding_pb2
35from tensorflow.python.compiler.xla import xla
36from tensorflow.python.distribute import device_util
37from tensorflow.python.distribute import distribution_strategy_context
38from tensorflow.python.framework import auto_control_deps
39from tensorflow.python.framework import c_api_util
40from tensorflow.python.framework import composite_tensor
41from tensorflow.python.framework import config
42from tensorflow.python.framework import constant_op
43from tensorflow.python.framework import device as pydev
44from tensorflow.python.framework import dtypes
45from tensorflow.python.framework import errors
46from tensorflow.python.framework import func_graph
47from tensorflow.python.framework import function
48from tensorflow.python.framework import ops
49from tensorflow.python.framework import tensor_shape
50from tensorflow.python.ops import array_ops
51from tensorflow.python.ops import control_flow_ops
52from tensorflow.python.ops import math_ops
53from tensorflow.python.ops import variable_scope
54from tensorflow.python.ops import variables
55from tensorflow.python.tpu import device_assignment as device_assignment_lib
56from tensorflow.python.tpu import tpu_feed
57from tensorflow.python.tpu import tpu_function
58from tensorflow.python.tpu import tpu_name_util
59from tensorflow.python.tpu.ops import tpu_ops
60from tensorflow.python.types import core as core_types
61from tensorflow.python.util import compat
62from tensorflow.python.util import nest
63from tensorflow.python.util import object_identity
64from tensorflow.python.util.tf_export import tf_export
65
66ops.NotDifferentiable("TPUReplicatedInput")
67
68# Operations that indicate some error in the users graph, e.g. a placeholder
69# that's introduced outside of the infeed.
70_DENYLISTED_OPS = set([
71    "Placeholder",
72])
73
74# XLA doesn't currently support reading of intermediate tensors, thus some ops
75# are not supported.
76_UNSUPPORTED_OPS = set([
77    "AudioSummary",
78    "AudioSummaryV2",
79    "HistogramSummary",
80    "ImageSummary",
81    "MergeSummary",
82    "Print",
83    "ScalarSummary",
84    "TensorSummary",
85    "TensorSummaryV2",
86    ])
87
88# Ops which can be safely pruned from XLA compile if they have no consumers.
89#  These ops should also have no inputs.
90_UNCONNECTED_OPS_TO_PRUNE = set(["Placeholder", "VarHandleOp"])
91
92_MAX_WARNING_LINES = 5
93
94_TPU_REPLICATE_ATTR = "_tpu_replicate"
95_POST_DEVICE_REWRITE_ATTR = "_post_device_rewrite"
96_TPU_COMPILATION_STATUS_ATTR = "_tpu_compilation_status"
97_OUTSIDE_COMPILATION_ATTR = "_xla_outside_compilation"
98_PIVOT_FOR_CLUSTER = "_pivot_for_cluster"
99
100
101core = tpu_name_util.core
102
103
104def _tpu_system_device_name(job: Optional[Text]) -> Text:
105  """Returns the device name for the TPU_SYSTEM device of `job`."""
106  if job is None:
107    return "/device:TPU_SYSTEM:0"
108  else:
109    return "/job:%s/device:TPU_SYSTEM:0" % job
110
111
112@tf_export(v1=["tpu.initialize_system"])
113def initialize_system(
114    embedding_config: Optional[embedding_pb2.TPUEmbeddingConfiguration] = None,
115    job: Optional[Text] = None,
116    compilation_failure_closes_chips: bool = True
117) -> core_types.Tensor:
118  """Initializes a distributed TPU system for use with TensorFlow.
119
120  Args:
121    embedding_config: If not None, a `TPUEmbeddingConfiguration` proto
122      describing the desired configuration of the hardware embedding lookup
123      tables. If embedding_config is None, no hardware embeddings can be used.
124    job: The job (the XXX in TensorFlow device specification /job:XXX) that
125      contains the TPU devices that will be initialized. If job=None it is
126      assumed there is only one job in the TensorFlow flock, and an error will
127      be returned if this assumption does not hold.
128    compilation_failure_closes_chips: Set the configuration whether
129      we want to close TPU chips when there is a compilation failure.
130  Returns:
131    A serialized `TopologyProto` that describes the TPU system. Note:
132      the topology must be evaluated using `Session.run` before it can be used.
133  """
134  config_string = ("" if embedding_config is None else
135                   embedding_config.SerializeToString())
136  with ops.device(_tpu_system_device_name(job)):
137    topology = tpu_ops.configure_distributed_tpu(
138        compilation_failure_closes_chips=compilation_failure_closes_chips)
139
140    if embedding_config is None:
141      return topology
142
143    # This set of control dependencies is needed as this function is expected to
144    # return an op which will return the topology when executed, but we need to
145    # call the embedding initialization op between initializing the TPU and
146    # returning the topology.
147    with ops.control_dependencies([topology]):
148      embedding_init = tpu_ops.configure_tpu_embedding(config=config_string)
149    with ops.control_dependencies([embedding_init]):
150      return array_ops.identity(topology, name="tpu_init_identity")
151
152
153def initialize_system_for_tpu_embedding(
154    embedding_config: embedding_pb2.TPUEmbeddingConfiguration,
155    job: Optional[Text] = None,
156) -> ops.Operation:
157  """Initializes a distributed TPU Embedding system for use with TensorFlow.
158
159  The following two are equivalent:
160  1. initialize_system() with embedding_config.
161  2. initialize_system() without embedding_config, then
162     initialize_system_for_tpu_embedding().
163  initialize_system() should not be called with embedding_config if
164  initialize_system_for_tpu_embedding() is meant to be called later.
165
166  Args:
167    embedding_config: a `TPUEmbeddingConfiguration` proto describing the desired
168      configuration of the hardware embedding lookup tables.
169    job: The job (the XXX in TensorFlow device specification /job:XXX) that
170      contains the TPU devices that will be initialized. If job=None it is
171      assumed there is only one job in the TensorFlow flock, and an error will
172      be returned if this assumption does not hold.
173
174  Returns:
175    A no-op.
176  """
177  config_string = embedding_config.SerializeToString()
178  with ops.device(_tpu_system_device_name(job)):
179    return tpu_ops.configure_tpu_embedding(config=config_string)
180
181
182@tf_export(v1=["tpu.shutdown_system"])
183def shutdown_system(job: Optional[Text] = None) -> ops.Operation:
184  """Shuts down a running a distributed TPU system.
185
186  Args:
187    job: The job (the XXX in TensorFlow device specification /job:XXX) that
188      contains the TPU devices that will be shutdown. If job=None it is
189      assumed there is only one job in the TensorFlow flock, and an error will
190      be returned if this assumption does not hold.
191  """
192  with ops.device(_tpu_system_device_name(job)):
193    shutdown_distributed_tpu = tpu_ops.shutdown_distributed_tpu()
194  return shutdown_distributed_tpu
195
196
197def _enclosing_tpu_context_and_graph() -> Tuple[Any, Any]:
198  """Returns the TPUReplicateContext and its associated graph."""
199  graph = ops.get_default_graph()
200  while graph is not None:
201    # pylint: disable=protected-access
202    context_ = graph._get_control_flow_context()
203    # pylint: enable=protected-access
204    while context_ is not None:
205      if isinstance(context_, TPUReplicateContext):
206        return context_, graph
207      context_ = context_.outer_context
208    graph = getattr(graph, "outer_graph", None)
209  raise ValueError("get_replicated_var_handle() called without "
210                   "TPUReplicateContext. This shouldn't happen. Please file "
211                   "a bug.")
212
213
214def is_tpu_strategy(strategy: Any) -> bool:
215  is_tpu_strat = lambda k: k.__name__.startswith("TPUStrategy")
216  clz = strategy.__class__
217  return is_tpu_strat(clz) or any(map(is_tpu_strat, clz.__bases__))
218
219
220def _enclosing_tpu_device_assignment(
221) -> Optional[device_assignment_lib.DeviceAssignment]:
222  if not distribution_strategy_context.has_strategy():
223    return None
224  strategy = distribution_strategy_context.get_strategy()
225  if not is_tpu_strategy(strategy):
226    return None
227  return strategy.extended._device_assignment  # pylint: disable=protected-access
228
229
230@auto_control_deps.register_acd_resource_resolver
231def tpu_replicated_input_resolver(
232    op: ops.Operation,
233    resource_reads: object_identity.ObjectIdentitySet,
234    resource_writes: object_identity.ObjectIdentitySet) -> bool:
235  """Replaces TPUReplicatedInput outputs with its inputs in resource_inputs."""
236  # Ignore TPUReplicatedInput for ACD purposes since we will be directly adding
237  # control deps on the replicated inputs.
238  if op.type == "TPUReplicatedInput":
239    if resource_reads or resource_writes:
240      resource_reads.clear()
241      resource_writes.clear()
242      return True
243    else:
244      return False
245  # Replace tensors in `resource_inputs` which are outputs of TPUReplicatedInput
246  # with the actual replicated inputs. This allows ACD to correct add control
247  # deps when there are multiple calls to `run` in a
248  # `tf.function`.
249  def replace_with_unreplicated_resources(resource_inputs):
250    """Replaces handles in `resource_inputs` with their unreplicated inputs."""
251    to_remove = []
252    to_add = []
253    for resource in resource_inputs:
254      if resource.op.type == "TPUReplicatedInput":
255        to_remove.append(resource)
256        to_add.extend(resource.op.inputs)
257    for t in to_remove:
258      resource_inputs.discard(t)
259    resource_inputs.update(to_add)
260    return to_add or to_remove
261
262  return bool(replace_with_unreplicated_resources(resource_reads) or
263              replace_with_unreplicated_resources(resource_writes))
264
265
266class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
267  """A `ControlFlowContext` for nodes inside a TPU computation.
268
269  The primary role of `TPUReplicateContext` is to mark operators inside a
270  tpu.replicate() computation with the attribute "_tpu_replicate=XYZ", where XYZ
271  is a unique name.
272
273  We use a `ControlFlowContext` to perform the annotation since it integrates
274  with Tensorflow constructs like ResourceVariables. For example, if a
275  `ResourceVariable` is constructed inside a tpu.replicate() block, the
276  `ResourceVariable` implementation can use
277  `with ops.control_dependencies(None)` to build the variable's definition
278  outside the replicated computation.
279  """
280
281  def __init__(self, name: Text, num_replicas: int, pivot: ops.Operation):
282    """Builds a new TPUReplicateContext.
283
284    Args:
285      name: a unique name for the context, used to populate the `_tpu_replicate`
286        attribute.
287      num_replicas: an integer that gives the number of replicas for the
288        computation.
289      pivot: a pivot node. Nodes in the TPUReplicateContext that do not have any
290        inputs will have a control dependency on the pivot node. This ensures
291        that nodes are correctly included in any enclosing control flow
292        contexts.
293    """
294    super(TPUReplicateContext, self).__init__()
295    self._num_replicas = num_replicas
296    self._outer_device_function_stack = None
297    self._oc_dev_fn_stack = None
298    self._outside_compilation_cluster = None
299    self._outside_compilation_v2_context = None
300    self._outside_compilation_counter = 0
301    self._in_gradient_colocation = None
302    self._gradient_colocation_stack = []
303    self._host_compute_core = []
304    self._name = name
305    self._name_as_bytes = compat.as_bytes(name)
306    self._tpu_relicate_attr_buf = c_api_util.ScopedTFBuffer(
307        attr_value_pb2.AttrValue(s=self._name_as_bytes).SerializeToString())
308    self._unsupported_ops = []
309    self._pivot = pivot
310    self._replicated_vars = {}
311
312  def get_replicated_var_handle(self,
313                                name: Text,
314                                vars_: Union[List[core_types.Tensor],
315                                             List[variables.Variable]],
316                                is_mirrored: bool = False,
317                                is_packed: bool = False) -> core_types.Tensor:
318    """Returns a variable handle for replicated TPU variable 'var'.
319
320    This is a method used by an experimental replicated variable implementation
321    and is not intended as a public API.
322
323    Args:
324      name: The common name of the variable.
325      vars_: The replicated TPU variables or handles.
326      is_mirrored: Whether the variables are mirrored, which guarantees the
327        values in each replica are always the same.
328      is_packed: Whether the replicated variables are packed into one variable.
329
330    Returns:
331      The handle of the TPU replicated input node.
332    """
333    device_assignment = _enclosing_tpu_device_assignment()
334    # We don't need to put device assignment as part of the replicated_vars key
335    # because each TPUReplicateContext will only have one device assignment.
336    handle = self._replicated_vars.get(name)
337    if handle is not None:
338      return handle
339
340    if device_assignment is not None and not is_packed:
341      # Find a variable copy for each replica in the device assignment.
342      # Note that the order of devices for replicas for the variable and the
343      # device assignment might not match.
344      job_name = pydev.DeviceSpec.from_string(vars_[0].device).job
345      devices_to_vars = {device_util.canonicalize(v.device): v for v in vars_}
346      replicated_vars = []
347      for replica_id in range(device_assignment.num_replicas):
348        for logical_core in range(device_assignment.num_cores_per_replica):
349          device = device_util.canonicalize(
350              device_assignment.tpu_device(
351                  replica=replica_id, logical_core=logical_core, job=job_name))
352          if device in devices_to_vars:
353            replicated_vars.append(devices_to_vars[device])
354            break
355        else:
356          raise ValueError(
357              "Failed to find a variable on any device in replica {} for "
358              "current device assignment".format(replica_id))
359    else:
360      replicated_vars = vars_
361
362    # Builds a TPUReplicatedInput node for the variable, if one does not already
363    # exist. The TPUReplicatedInput node must belong to the enclosing
364    # control-flow scope of the TPUReplicateContext.
365    # TODO(phawkins): consider changing the contract of the TPU encapsulation
366    # so the TPUReplicatedInput nodes go inside the TPUReplicateContext scope
367    # instead.
368
369    _, graph = _enclosing_tpu_context_and_graph()
370    with graph.as_default():
371      # If replicated_vars are variables, get the handles. Note that this can be
372      # done inside TPUReplicateContext because replicated_vars.handle may
373      # create new ops.
374      if isinstance(replicated_vars[0], variables.Variable):
375        replicated_vars = [v.handle for v in replicated_vars]
376      # pylint: disable=protected-access
377      saved_context = graph._get_control_flow_context()
378      graph._set_control_flow_context(self.outer_context)
379      handle = tpu_ops.tpu_replicated_input(replicated_vars,
380                                            name=name + "/handle",
381                                            is_mirrored_variable=is_mirrored,
382                                            is_packed=is_packed)
383      graph._set_control_flow_context(saved_context)
384      # pylint: enable=protected-access
385    self._replicated_vars[name] = handle
386    return handle
387
388  def report_unsupported_operations(self) -> None:
389    if self._unsupported_ops:
390      op_str = "\n".join("  %s (%s)" % (op.type, op.name)
391                         for op in self._unsupported_ops[:_MAX_WARNING_LINES])
392      logging.warning("%d unsupported operations found: \n%s",
393                      len(self._unsupported_ops), op_str)
394      if len(self._unsupported_ops) > _MAX_WARNING_LINES:
395        logging.warning("... and %d more" %
396                        (len(self._unsupported_ops) - _MAX_WARNING_LINES))
397
398  def EnterGradientColocation(self, op: ops.Operation, gradient_uid: Text):
399    if op is not None:
400      if ops.get_default_graph()._control_flow_context is None:  # pylint: disable=protected-access
401        # If we are in TF 2 functions (control flow V2 functions, or
402        # tf.function()), we need to attach _xla_outside_compilation attribute
403        # directly because we are not in TPUReplicateContext.
404        try:
405          outside_attr = op.get_attr(_OUTSIDE_COMPILATION_ATTR).decode("ascii")
406        except ValueError:
407          # The attr was not present: do nothing.
408          return
409        parts = outside_attr.split(".")
410        cluster = parts[0] + "." + gradient_uid
411        self._outside_compilation_v2_context = OutsideCompilationV2Context(
412            cluster)
413        self._outside_compilation_v2_context.Enter()
414        return
415      self._gradient_colocation_stack.append(op)
416      if not self._outside_compilation_cluster:
417        try:
418          outside_attr = op.get_attr(_OUTSIDE_COMPILATION_ATTR).decode("ascii")
419          if self._in_gradient_colocation:
420            raise NotImplementedError(
421                "Cannot nest gradient colocation operations outside compilation"
422            )
423          if gradient_uid == "__unsupported__":
424            raise NotImplementedError(
425                "No gradient_uid calling gradient within outside_compilation")
426          # When we take the gradient of an op X in an outside_compilation
427          # cluster C in a forward computation we would like to put the ops
428          # corresponding to the gradient of X into a new outside_compilation
429          # cluster C'. However, if we take the gradient of X twice, the second
430          # one should get yet another new outside_compilation cluster C''.
431          #
432          # The mechanism we adopt is to use a 'root_cluster' which is the
433          # cluster that X was in before we took gradients, and a 'gradient_uid'
434          # which is different for every invocation of gradients, and put the
435          # gradient of X in cluster 'root_cluster.gradient_uid'.
436          #
437          # When taking a gradient of a gradient, some ops will be colocated
438          # with Op in the forward pass (e.g., cluster root_cluster) and some in
439          # the backward pass (e.g., cluster root_cluster.initial_gradient_uid).
440          # We need all of the grad-of-grad ops to be in the same cluster to
441          # avoid cyclic dependencies between clusters. We adopt a heuristic
442          # that puts any op clustered with root_cluster.<xxx> in
443          # root_cluster.gradient_uid, even if xxx was initial_gradient_uid.
444          self._in_gradient_colocation = op
445          parts = outside_attr.split(".")
446          cluster = parts[0] + "." + gradient_uid
447          self._EnterOutsideCompilationScope(cluster=cluster)
448        except ValueError:
449          # The attr was not present: do nothing.
450          pass
451
452  def ExitGradientColocation(self, op: ops.Operation, gradient_uid: Text):
453    if op is not None:
454      if ops.get_default_graph()._control_flow_context is None:  # pylint: disable=protected-access
455        # Inside a TF2 tf.function or control flow graph and `op` was not
456        # marked to be outside compiled.
457        assert self._outside_compilation_v2_context is None
458        return
459      if self._outside_compilation_v2_context is not None:
460        # Inside a TF2 tf.function or control flow graph and `op` was
461        # marked to be outside compiled.
462        self._outside_compilation_v2_context.Exit()
463        self._outside_compilation_v2_context = None
464        return
465      if not self._gradient_colocation_stack:
466        raise errors.InternalError(
467            op.node_def, op,
468            f"Badly nested gradient colocation: empty stack when popping Op {op.name}"
469        )
470      last_op = self._gradient_colocation_stack.pop()
471      if op is last_op:
472        if op is self._in_gradient_colocation:
473          self._in_gradient_colocation = None
474          self._ExitOutsideCompilationScope()
475      else:
476        raise errors.InternalError(
477            op.node_def, op,
478            f"Badly nested gradient colocation, expected {last_op}, got {op.name}"
479        )
480
481  def _EnterOutsideCompilationScope(self, cluster: Optional[Text] = None):
482
483    class FakeOp(object):
484      """A helper class to determine the current device.
485
486      Supports only the type and device set/get methods needed to run the
487      graph's _apply_device_function method.
488      """
489
490      def __init__(self):
491        self._device = ""
492
493      @property
494      def type(self):
495        return "FakeOp"
496
497      @property
498      def device(self):
499        return self._device
500
501      def _set_device(self, device):
502        if isinstance(device, pydev.DeviceSpec):
503          self._device = device.to_string()
504        else:
505          self._device = device
506
507      def _set_device_from_string(self, device_str):
508        self._device = device_str
509
510    if self._outside_compilation_cluster:
511      raise NotImplementedError("Cannot nest outside_compilation clusters")
512    if cluster:
513      self._outside_compilation_cluster = cluster
514    else:
515      self._outside_compilation_cluster = str(self._outside_compilation_counter)
516      self._outside_compilation_counter += 1
517    graph = ops.get_default_graph()
518    fake_op = FakeOp()
519    graph._apply_device_functions(fake_op)  # pylint: disable=protected-access
520    device = pydev.DeviceSpec.from_string(fake_op.device)
521    if (device.device_type == "TPU_REPLICATED_CORE" and
522        device.device_index is not None):
523      self._host_compute_core.append(self._outside_compilation_cluster + ":" +
524                                     str(device.device_index))
525    self._oc_dev_fn_stack = graph._device_function_stack  # pylint: disable=protected-access
526    graph._device_function_stack = self._outer_device_function_stack  # pylint: disable=protected-access
527
528  def _ExitOutsideCompilationScope(self):
529    if not self._outside_compilation_cluster:
530      raise ValueError(
531          "Attempted to exit outside_compilation scope when not in scope")
532    self._outside_compilation_cluster = None
533    graph = ops.get_default_graph()
534    graph._device_function_stack = self._oc_dev_fn_stack  # pylint: disable=protected-access
535
536  def Enter(self) -> None:
537    if not self._outer_device_function_stack:
538      # Capture the device function stack at the time of first entry
539      # since that is the stack that will be used outside_compilation.
540      graph = ops.get_default_graph()
541      # pylint: disable=protected-access
542      self._outer_device_function_stack = graph._device_function_stack.copy()
543      # pylint: enable=protected-access
544    super(TPUReplicateContext, self).Enter()
545
546  def HostComputeCore(self) -> List[Text]:
547    return self._host_compute_core
548
549  def _RemoveExternalControlEdges(
550      self, op: ops.Operation
551      ) -> Tuple[List[ops.Operation], List[ops.Operation]]:
552    """Remove any external control dependency on this op."""
553    internal_control_inputs = []
554    external_control_inputs = []
555    for x in op.control_inputs:
556      # pylint: disable=protected-access
557      is_internal_op = False
558      ctxt = x._get_control_flow_context()
559      while ctxt is not None:
560        if ctxt == self:
561          is_internal_op = True
562          break
563        ctxt = ctxt._outer_context
564      if is_internal_op:
565        internal_control_inputs.append(x)
566      else:
567        external_control_inputs.append(x)
568      # pylint: enable=protected-access
569    # pylint: disable=protected-access
570    op._remove_all_control_inputs()
571    op._add_control_inputs(internal_control_inputs)
572    # pylint: enable=protected-access
573    return internal_control_inputs, external_control_inputs
574
575  def AddOp(self, op: ops.Operation) -> None:
576    # pylint: disable=protected-access
577    if op.type in _DENYLISTED_OPS:
578      logging.error("Operation of type %s (%s) is not supported on the TPU. "
579                    "Execution will fail if this op is used in the graph. ",
580                    op.type, op.name)
581
582    if op.type in _UNSUPPORTED_OPS:
583      self._unsupported_ops.append(op)
584
585    if any(x.dtype._is_ref_dtype for x in op.inputs):
586      raise NotImplementedError(
587          f"Non-resource Variables are not supported inside TPU computations "
588          f"(operator name: {op.name})")
589
590    # TensorFlowOpLayer may clone nodes that are in tpu.rewrite()s. It'll add
591    # the "_cloned" attribute and we should continue in that case.
592    if (_TPU_REPLICATE_ATTR in op.node_def.attr and
593        "_cloned" not in op.node_def.attr):
594      raise ValueError(f"TPU computations cannot be nested on op ({op})")
595    op._set_attr_with_buf(_TPU_REPLICATE_ATTR,
596                          self._tpu_relicate_attr_buf.buffer)
597    if self._outside_compilation_cluster:
598      op._set_attr(
599          _OUTSIDE_COMPILATION_ATTR,
600          attr_value_pb2.AttrValue(
601              s=compat.as_bytes(self._outside_compilation_cluster)))
602    if self._num_replicas > 1 or not self._outside_compilation_cluster:
603      # Prevent feeding or fetching anything that is being compiled,
604      # and any replicated outside_compilation Op.
605      op.graph.prevent_feeding(op)
606      op.graph.prevent_fetching(op)
607
608    # Remove any control edges from outer control flow contexts. These may cause
609    # mismatched frame errors.
610    (internal_control_inputs,
611     external_control_inputs) = self._RemoveExternalControlEdges(op)
612
613    if not op.inputs:
614      # Add a control edge from the control pivot to this op.
615      if not internal_control_inputs:
616        # pylint: disable=protected-access
617        op._add_control_input(self.GetControlPivot())
618        # pylint: enable=protected-access
619    else:
620      for index in xrange(len(op.inputs)):
621        x = op.inputs[index]
622        real_x = self.AddValue(x)
623        if real_x is not x:
624          op._update_input(index, real_x)  # pylint: disable=protected-access
625
626    if external_control_inputs:
627      # Use an identity to pull control inputs as data inputs. Note that we
628      # ignore ops which don't have outputs. TODO(phawkins): fix that.
629      with ops.control_dependencies(None):
630        self.Enter()
631        external_control_inputs = [
632            array_ops.identity(x.outputs[0]).op
633            for x in external_control_inputs
634            if x.outputs
635        ]
636        self.Exit()
637      # pylint: disable=protected-access
638      op._add_control_inputs(external_control_inputs)
639      # pylint: enable=protected-access
640
641    # Mark op's outputs as seen by this context and any outer contexts.
642    output_names = [x.name for x in op.outputs]
643    context = self
644    while context is not None:
645      # pylint: disable=protected-access
646      context._values.update(output_names)
647      context = context._outer_context
648      # pylint: enable=protected-access
649
650    if self._outer_context:
651      self._outer_context.AddInnerOp(op)
652
653  def AddValue(self, val: core_types.Tensor) -> core_types.Tensor:
654    """Add `val` to the current context and its outer context recursively."""
655    if not self._outer_context:
656      return val
657
658    if val.name in self._values:
659      # Use the real value if it comes from outer context.
660      result = self._external_values.get(val.name)
661      return val if result is None else result
662
663    result = val
664    self._values.add(val.name)
665    if self._outer_context:
666      result = self._outer_context.AddValue(val)
667      self._values.add(result.name)
668
669    self._external_values[val.name] = result
670
671    return result
672
673  def AddInnerOp(self, op: ops.Operation):
674    self.AddOp(op)
675    if self._outer_context:
676      self._outer_context.AddInnerOp(op)
677
678  @property
679  def grad_state(self):
680    # Define the gradient loop state associated with the TPUReplicateContext to
681    # be None as the TPUReplicateContext does not get nested nor does the
682    # grad_state outside the TPUReplicateContext affect the graph inside so the
683    # grad_state should be as if this is the top-level gradient state.
684    return None
685
686  @property
687  def back_prop(self):
688    """Forwards to the enclosing while context, if any."""
689    if self.GetWhileContext():
690      return self.GetWhileContext().back_prop
691    return False
692
693  def GetControlPivot(self) -> ops.Operation:
694    return self._pivot
695
696  def RequiresUniqueFunctionRetracing(self):
697    # More context: b/158152827. TPU stack uses the TPUReplicateContext to
698    # create replicated variable handles and cluster TPU computations, thus we
699    # always retrace a tf.function when the wrapped TPUReplicateContext changes.
700    return True
701
702
703class OutsideCompilationV2Context(control_flow_ops.ControlFlowContext):
704  """The context for outside compilation in Tensorflow 2.0.
705
706  Every op added in this context will be assigned an _xla_outside_compilation
707  attribute.
708  """
709
710  def __init__(self, name: Text):
711    control_flow_ops.ControlFlowContext.__init__(self)
712    self._name = name
713
714  def AddOp(self, op: ops.Operation) -> None:
715    if self._outer_context:
716      self._outer_context.AddOp(op)
717    # pylint: disable=protected-access
718    op._set_attr("_xla_outside_compilation",
719                 attr_value_pb2.AttrValue(s=compat.as_bytes(self._name)))
720    # pylint: enable=protected-access
721
722  def AddInnerOp(self, op: ops.Operation) -> None:
723    if self._outer_context:
724      self._outer_context.AddInnerOp(op)
725    # pylint: disable=protected-access
726    op._set_attr("_xla_outside_compilation",
727                 attr_value_pb2.AttrValue(s=compat.as_bytes(self._name)))
728    # pylint: enable=protected-access
729
730  def to_control_flow_context_def(self, context_def, export_scope=None):
731    raise NotImplementedError
732
733
734@tf_export(v1=["tpu.outside_compilation"])
735def outside_compilation(
736    computation: Callable[..., Any], *args, **kwargs
737    ) -> Any:
738  """Builds part of a computation outside any current TPU replicate scope.
739
740  `tf.tpu.outside_compilation()` is used to run ops in `computation` on CPU
741  instead of running on TPU. For example, users can run ops that are not
742  supported on TPU's (e.g. tf.summary.write()) by explicitly placing those
743  ops on CPU's. Below usage of outside compilation will place ops in
744  `computation_with_string_ops` on CPU.
745
746  Example usage:
747
748  ```python
749  def computation_with_string_ops(x):
750    # strings types are not supported on TPU's and below ops must
751    # run on CPU instead.
752    output = tf.strings.format('1{}', x)
753    return tf.strings.to_number(output)
754
755  def tpu_computation():
756    # Expected output is 11.
757    output = tf.tpu.outside_compilation(computation_with_string_ops, 1)
758  ```
759
760  Outside compilation should be called inside TPUReplicateContext. That is,
761  `tf.tpu.outside_compilation()` should be called inside a function that is
762  passed to `tpu.split_compile_and_replicate()` -- this is implied when
763  outside compilation is invoked inside a function passed to TPUStrategy
764  `run()`. If invoked outside of TPUReplicateContext,
765  then this simply returns the result of `computation`, and therefore,
766  would be a no-op. Note that outside compilation is different from
767  `tf.distribute.experimental.TPUStrategy.merge_call()` as logic in
768  outside compilation is replicated and executed separately for each
769  replica. On the other hand, `merge_call()` requires a `merge_fn`
770  to aggregate the inputs from different replicas and is executed only
771  once.
772
773  For variables placed in TPU device, which includes variables created inside
774  TPUStrategy scope, outside compilation logic must not include variable
775  read/write. For variables placed on host, which is the case when variables
776  created via TPUEstimator, variable read/write is only allowed if the variable
777  is not accessed by any other ops in the TPU computation. Variable read/write
778  from outside compilation cluster is not visible from TPU computation and
779  vice versa. Therefore, if outside compilation logic contains such host
780  variables read/write ops and if the variables are accessed by TPU
781  computation as well, then this may lead to deadlock.
782
783  Internally, `tf.tpu.outside_compilation()` adds outside compilation
784  attributes to all ops in `computation`. During later graph pass, these
785  ops with outside compilation attribute is extracted out and replicated
786  into a host-side graph. Inputs to this extract host-side graph is sent
787  from TPU computation graph to host graph via a pair of XlaSendToHost and
788  XlaRecvFromHost ops. Note that using `tf.tpu.outside_compilation()`
789  may result in tensor transfer between TPU and CPU, leading to non-trivial
790  performance impact.
791
792  Args:
793    computation: A Python function that builds the computation to
794      place on the host.
795    *args: the positional arguments for the computation.
796    **kwargs: the keyword arguments for the computation.
797
798  Returns:
799    The Tensors returned by computation.
800  """
801  args = [] if args is None else args
802  graph = ops.get_default_graph()
803
804  # If we are in TF 2 functions (control flow V2 functions, or tf.function()),
805  # we need to attach _xla_outside_compilation attribute directly because we are
806  # not in TPUReplicateContext.
807  if isinstance(graph, func_graph.FuncGraph):
808    try:
809      tpu_context, _ = _enclosing_tpu_context_and_graph()
810    except ValueError:
811      logging.warning(
812          "Outside compilation attempted outside TPUReplicateContext "
813          "scope. As no enclosing TPUReplicateContext can be found, "
814          "returning the result of `computation` as is.")
815      return computation(*args, **kwargs)
816
817    # pylint: disable=protected-access
818    outside_compilation_name = str(tpu_context._outside_compilation_counter)
819    tpu_context._outside_compilation_counter = (
820        tpu_context._outside_compilation_counter + 1)
821    # pylint: enable=protected-access
822
823    outside_compilation_context = OutsideCompilationV2Context(
824        outside_compilation_name)
825    outside_compilation_context.Enter()
826    args = [] if args is None else args
827    retval = computation(*args, **kwargs)
828    outside_compilation_context.Exit()
829    return retval
830
831  # If we are in a TPUReplicateContext, signal that we are now
832  # outside_compilation
833  initial_context = graph._get_control_flow_context()  # pylint: disable=protected-access
834  context = initial_context
835  while context:
836    if isinstance(context, TPUReplicateContext):
837      context._EnterOutsideCompilationScope()  # pylint: disable=protected-access
838    context = context.outer_context
839
840  retval = computation(*args, **kwargs)
841
842  # If we are in a TPUReplicateContext, signal that we are no longer
843  # outside_compilation
844  final_context = graph._get_control_flow_context()  # pylint: disable=protected-access
845  if initial_context is not final_context:
846    raise NotImplementedError(
847        "Control-flow context cannot be different at start and end of an "
848        "outside_compilation scope")
849  context = initial_context
850  while context:
851    if isinstance(context, TPUReplicateContext):
852      context._ExitOutsideCompilationScope()  # pylint: disable=protected-access
853    context = context.outer_context
854
855  return retval
856
857
858@tf_export(v1=["tpu.PaddingSpec"])
859class PaddingSpec(enum.IntEnum):
860  """Represents the type of padding policies for tpu.replicate."""
861  # By default the policy is set to AUTO, the dynamic input shape dimension will
862  # be pad to maximum of all the replicas.
863  AUTO = 0
864  # Bucketize the dynamic input shape dimension into a power of 2.
865  POWER_OF_TWO = 1
866
867
868@tf_export("tpu.XLAOptions")
869class XLAOptions(
870    collections.namedtuple("XLAOptions", [
871        "use_spmd_for_xla_partitioning",
872        "enable_xla_dynamic_padder",
873    ])):
874  """XLA compilation options.
875
876  Attributes:
877    use_spmd_for_xla_partitioning: Boolean. Whether to use XLA's SPMD
878      partitioner instead of MPMD partitioner when compiler partitioning is
879      requested.
880    enable_xla_dynamic_padder: Boolean. Whether to enable XLA dynamic padder
881      infrastructure to handle dynamic shapes inputs inside XLA. True by
882      default. Disabling this may cause correctness issues with dynamic shapes
883      inputs, as XLA will just assume the inputs are with padded shapes. However
884      users can optionally set it to False to improve device time if masking is
885      already handled in the user side.
886  """
887
888  def __new__(cls,
889              use_spmd_for_xla_partitioning=True,
890              enable_xla_dynamic_padder=True):
891    return super(XLAOptions, cls).__new__(cls, use_spmd_for_xla_partitioning,
892                                          enable_xla_dynamic_padder)
893
894
895@tf_export(v1=["tpu.replicate"])
896def replicate(
897    computation: Callable[..., Any],
898    inputs: Optional[List[List[core_types.Tensor]]] = None,
899    infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
900    device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
901    name: Optional[Text] = None,
902    maximum_shapes: Optional[Any] = None,
903    padding_spec: Optional[PaddingSpec] = None,
904    xla_options: Optional[XLAOptions] = None) -> List[Any]:
905  """Builds a graph operator that runs a replicated TPU computation.
906
907  Example for the basic usage that `inputs` has static shape:
908
909  ```python
910
911  def computation(x):
912    x = x + 1
913    return tf.math.reduce_mean(x)
914
915  x = tf.convert_to_tensor([1., 2., 3.])
916  y = tf.convert_to_tensor([4., 5., 6.])
917  tf.compat.v1.tpu.replicate(computation, inputs=[[x], [y]])
918  ```
919
920  If the `inputs` has dynamic shapes and you would like to automatically
921  bucketize the inputs to avoid XLA recompilation. See the advanced example
922  below:
923
924  ```python
925
926  def computation(x):
927    x = x + 1
928    return tf.math.reduce_mean(x)
929
930  # Assume input tensors in two replicas `x` and `y` both have dynamic shape
931  # ([None, 2]).
932  tf.compat.v1.tpu.replicate(
933    computation,
934    inputs=[x, y],
935    maximum_shapes=[tf.TensorShape([None, None])],
936    padding_spec=tf.compat.v1.tpu.PaddingSpec.POWER_OF_TWO)
937  ```
938
939  Args:
940    computation: A Python function that builds the computation to replicate.
941    inputs: A list of lists of input tensors or `None` (equivalent to
942      `[[]]`), indexed by `[replica_num][input_num]`. All replicas must
943      have the same number of inputs. Each input can be a nested structure
944      containing values that are convertible to tensors. Note that passing an
945      N-dimension list of compatible values will result in a N-dimension list of
946      scalar tensors rather than a single Rank-N tensors. If you need different
947      behavior, convert part of inputs to tensors with `tf.convert_to_tensor`.
948    infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
949      of arguments as inputs to computation.
950    device_assignment: If not `None`, a `DeviceAssignment` describing the
951      mapping between logical cores in the computation with physical cores in
952      the TPU topology. Uses a default device assignment if `None`. The
953      `DeviceAssignment` may be omitted if each replica of the computation uses
954      only one core, and there is either only one replica, or the number of
955      replicas is equal to the number of cores in the TPU system.
956    name: (Deprecated) Does nothing.
957    maximum_shapes: A nested structure of tf.TensorShape representing the shape
958      to which the respective component of each input element in each replica
959      should be padded. Any unknown dimensions (e.g.
960      tf.compat.v1.Dimension(None) in a tf.TensorShape or -1 in a tensor-like
961      object) will be padded to the maximum size of that dimension over all
962      replicas. The structure of `maximum_shapes` needs to be the same as
963      `inputs[0]`.
964    padding_spec: An enum specified by `tpu.PaddingSpec`. This describes the
965      padding policy when the `inputs` to `tpu.replicate` is dynamic.
966      One usage is to enable automatic bucketizing on the inputs by setting the
967      value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the
968      recompilation in the XLA side.
969    xla_options: An instance of `tpu.XLAOptions` which indicates the options
970      passed to XLA compiler. Use `None` for default options.
971  Returns:
972    A list of outputs, indexed by `[replica_num]` each output can be a nested
973    structure same as what computation() returns with a few exceptions.
974
975    Exceptions include:
976      1) None output: a NoOp would be returned which control-depends on
977         computation.
978      2) Single value output: A tuple containing the value would be returned.
979      3) Operation-only outputs: a NoOp would be returned which
980         control-depends on computation.
981      TODO(b/121383831): Investigate into removing these special cases.
982
983  Raises:
984    ValueError: If all replicas do not have equal numbers of input tensors.
985    ValueError: If the number of inputs per replica does not match
986      the number of formal parameters to `computation`.
987    ValueError: If the static `inputs` dimensions don't match with the values
988      given in `maximum_shapes`.
989    ValueError: If the structure of inputs per replica does not match
990      the structure of `maximum_shapes`.
991  """
992  return split_compile_and_replicate(
993      computation,
994      inputs,
995      infeed_queue,
996      device_assignment,
997      name,
998      maximum_shapes=maximum_shapes,
999      padding_spec=padding_spec,
1000      xla_options=xla_options)[1]
1001
1002
1003def _ceil_to_pow_of_n(x, n):
1004  """Ceil input `x` to power of `n`."""
1005  x = math_ops.cast(x, dtypes.float32)
1006  lognx = math_ops.log(x) / math_ops.log(n * 1.0)
1007  lognx = math_ops.ceil(lognx)
1008  result = math_ops.pow(n * 1.0, lognx)
1009  result = math_ops.cast(result, dtypes.int32)
1010  return result
1011
1012
1013def _pad_all_input(
1014    inputs: Iterable[core_types.Tensor],
1015    padded_shapes: List[Optional[tensor_shape.TensorShape]],
1016    padding_spec: PaddingSpec
1017) -> Tuple[List[List[Any]], List[dynamic_padding.PaddingMap]]:
1018  """Pad all input tensors given padded_shapes.
1019
1020  The real shape tensors will be concatenated with the padded original inputs.
1021
1022  Args:
1023    inputs: The original inputs.
1024    padded_shapes: A list of padded shapes for each input. If an entry is None,
1025      no padding is performed.
1026    padding_spec: An enum specified by `tpu.PaddingSpec`. This describes the
1027      padding policy when the `inputs` to `tf.tpu.replicate` is dynamic.
1028      One usage is to enable automatic bucketizing on the inputs by setting the
1029      value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the
1030      recompilation in the XLA side.
1031
1032  Returns:
1033    The padded inputs and a PaddingMap list which maps the padded input
1034    dimension to the real shape argument index.
1035  """
1036  # maximum_static_shapes[idx][i] indicates the maximum static size of ith
1037  # dimension of the idx input among all the replicas.
1038  maximum_static_shapes = []
1039  # need_padding[idx][i] indicates whether the ith dimension of the idx input
1040  # needs padding.
1041  need_padding = []
1042  input_shape_tensors = []
1043  for core_idx, inputs_per_core in enumerate(inputs):
1044    for idx, input_tensor in enumerate(inputs_per_core):
1045      input_shape = input_tensor.get_shape().as_list()
1046      if core_idx == 0:
1047        input_shape_tensors.append([])
1048        maximum_static_shapes.append(input_shape)
1049        need_padding.append(np.full_like(input_shape, False, dtype=bool))
1050      else:
1051        for i, s in enumerate(input_shape):
1052          if s is None or s != maximum_static_shapes[idx][i]:
1053            need_padding[idx][i] = True
1054        maximum_static_shapes[idx] = max(input_shape,
1055                                         maximum_static_shapes[idx])
1056
1057      # Append _POST_DEVICE_REWRITE_ATTR attributes to the real shape ops.
1058      real_input_shape = array_ops.shape(input_tensor)
1059      real_input_shape.op._set_attr(  # pylint: disable=protected-access
1060          _POST_DEVICE_REWRITE_ATTR,
1061          attr_value_pb2.AttrValue(b=True))
1062      input_shape_tensors[idx].append(real_input_shape)
1063
1064  maximum_shapes = []
1065  for shapes_per_input in input_shape_tensors:
1066    maximum_shapes.append(
1067        math_ops.reduce_max(array_ops.stack(shapes_per_input), axis=0))
1068
1069  padded_inputs = []
1070  real_shapes = []
1071  padding_maps = []
1072  for core_idx, inputs_per_core in enumerate(inputs):
1073    padded_inputs.append([])
1074    real_shapes.append([])
1075    real_shape_idx = len(inputs_per_core) - 1
1076    for idx, input_tensor in enumerate(inputs_per_core):
1077      input_shape_tensor = input_shape_tensors[idx][core_idx]
1078      input_shape = input_tensor.get_shape().as_list()
1079      padded_shape = padded_shapes[idx]
1080
1081      # If we have no padded_shape, then skip padding.
1082      if any(need_padding[idx]) and padded_shape is not None:
1083        for i, s in enumerate(input_shape):
1084          if need_padding[idx][i]:
1085            if core_idx == 0:
1086              real_shape_idx += 1
1087              padding_map = dynamic_padding.PaddingMap()
1088              padding_map.arg_index = idx
1089              padding_map.shape_index = i
1090              padding_map.padding_arg_index = real_shape_idx
1091              padding_maps.append(padding_map)
1092            real_shapes[core_idx].append(
1093                math_ops.cast(input_shape_tensor[i], dtypes.int32))
1094
1095        paddings = []
1096        for i, s in enumerate(padded_shape.dims):
1097          if need_padding[idx][i]:
1098            # The minimum padded dimension size is 2 as XLA doesn't support size
1099            # 1 dynamic size.
1100            minimum_dynamic_dim_size = 2
1101            if s.value is not None:
1102              # Pad to the given maximum value.
1103              max_dim_size = max(s.value, minimum_dynamic_dim_size)
1104            else:
1105              # If maximum value is not given, then pad to the maximum dimension
1106              # among all the cores.
1107              max_dim_size = math_ops.maximum(maximum_shapes[idx][i],
1108                                              minimum_dynamic_dim_size)
1109              if padding_spec == PaddingSpec.POWER_OF_TWO:
1110                max_dim_size = _ceil_to_pow_of_n(max_dim_size, 2)
1111            # Pad to the given maximum value.
1112            padding = [0, max_dim_size - input_shape_tensor[i]]
1113          else:
1114            padding = [0, 0]
1115          paddings.append(padding)
1116
1117        if input_tensor.get_shape().is_fully_defined():
1118          # TODO(rxsang): This is a hack to make sure padded_input has dynamic
1119          # shapes, so any tf.size/tf.shape op performed on it won't be constant
1120          # folded. Do we have better ways to do it?
1121          padded_input = control_flow_ops.cond(
1122              array_ops.constant(True),
1123              lambda: array_ops.pad(input_tensor, paddings),  # pylint: disable=cell-var-from-loop
1124              lambda: input_tensor)
1125        else:
1126          padded_input = array_ops.pad(input_tensor, paddings)
1127
1128        # Append _POST_DEVICE_REWRITE_ATTR attributes to all padded inputs.
1129        padded_input.op._set_attr(  # pylint: disable=protected-access
1130            _POST_DEVICE_REWRITE_ATTR,
1131            attr_value_pb2.AttrValue(b=True))
1132
1133        padded_inputs[core_idx].append(padded_input)
1134      else:
1135        padded_inputs[core_idx].append(input_tensor)
1136
1137  num_replicas = len(padded_inputs)
1138  for i in range(num_replicas):
1139    padded_inputs[i].extend(real_shapes[i])
1140
1141  return padded_inputs, padding_maps
1142
1143
1144def _flatten_and_filter_composite(maybe_composite, non_composite_output,
1145                                  composite_output=None):
1146  """For an input, replaced the input by a tuple if the input is composite.
1147
1148  If `maybe_composite` is not composite, return the parameter
1149  `non_composite_output` otherwise return a tuple which consists of the value of
1150  the parameter `composite_output` the same number of times as there are
1151  components of the composite tensor.
1152
1153  This is useful for computing a mask when flattening nested data with
1154  `expand_composites=True`. For example
1155
1156  ```python
1157  nest.flatten(data, expand_composites=True)
1158  ```
1159
1160  and
1161
1162  ```python
1163  nest.flatten(nest.map(
1164      data, lambda x: _flatten_and_filter_composite(x, False, True)))
1165  ```
1166
1167  will have the same length and second will be True if the tensor in the first
1168  is derived from a expanding a composite tensor.
1169
1170  Args:
1171    maybe_composite: A value to test for being a composite tensor.
1172    non_composite_output: The value to return when `maybe_composite` is not a
1173      composite.
1174    composite_output: the value to fill the output tuple with if
1175      `maybe_composite` is a composite.
1176
1177  Returns:
1178    `non_composite_output` or a tuple with multiple copies of
1179    `composite_output`.
1180  """
1181
1182  if isinstance(maybe_composite, composite_tensor.CompositeTensor):
1183    num_components = len(nest.flatten(maybe_composite, expand_composites=True))
1184    return (composite_output,) * num_components
1185  return non_composite_output
1186
1187
1188def split_compile_and_replicate(
1189    computation: Callable[..., Any],
1190    inputs: Optional[List[List[core_types.Tensor]]] = None,
1191    infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
1192    device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
1193    name: Optional[Text] = None,
1194    use_tpu: bool = True,
1195    maximum_shapes: Optional[Any] = None,
1196    padding_spec: Optional[PaddingSpec] = None,
1197    xla_options: Optional[XLAOptions] = None,
1198) -> List[List[core_types.Tensor]]:
1199  """Builds graph operators that runs compilation and replicated computation.
1200
1201  This is a lower level interface than replicate that returns a separate compile
1202  and execute output tensor. In the generated graph the compile op feeds into
1203  the execute op and no additional compilation is incurred when running the
1204  compile op before the execute op. The compile op returns additional
1205  information about the compilation but does not return the compiled program.
1206
1207  Args:
1208    computation: A Python function that builds the computation to replicate.
1209    inputs: A list of lists of input tensors or `None` (equivalent to
1210      `[[]]`), indexed by `[replica_num][input_num]`. All replicas must
1211      have the same number of inputs. Each input can be a nested structure
1212      containing values that are convertible to tensors. Note that passing an
1213      N-dimension list of compatible values will result in a N-dimension list of
1214      scalar tensors rather than a single Rank-N tensors. If you need different
1215      behavior, convert part of inputs to tensors with `tf.convert_to_tensor`.
1216    infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
1217      of arguments as inputs to computation.
1218    device_assignment: If not `None`, a `DeviceAssignment` describing the
1219      mapping between logical cores in the computation with physical cores in
1220      the TPU topology. Uses a default device assignment if `None`. The
1221      `DeviceAssignment` may be omitted if each replica of the computation uses
1222      only one core, and there is either only one replica, or the number of
1223      replicas is equal to the number of cores in the TPU system.
1224    name: (Deprecated) Does nothing.
1225    use_tpu: When false, the input `computation` is executed on the XLA CPU/GPU
1226      backends. Currently, only supports a default placement (computation is
1227      placed on GPU if one is available, and on CPU if not).
1228    maximum_shapes: A nested structure of tf.TensorShape representing the shape
1229      to which the respective component of each input element in each replica
1230      should be padded. Any unknown dimensions (e.g.
1231      tf.compat.v1.Dimension(None) in a tf.TensorShape or -1 in a tensor-like
1232      object) will be padded to the maximum size of that dimension over all
1233      replicas. The structure of `maximum_shapes` needs to be the same as
1234      `inputs[0]`.
1235    padding_spec: An enum specified by `tf.tpu.PaddingSpec`. This describes the
1236      padding policy when the `inputs` to `tf.tpu.replicate` is dynamic.
1237      One usage is to enable automatic bucketizing on the inputs by setting the
1238      value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the
1239      recompilation in the XLA side.
1240    xla_options: An instance of `tpu.XLAOptions` which indicates the options
1241      passed to XLA compiler. Use `None` for default options.
1242
1243  Returns:
1244    A list of lists with the first list corresponding to the compile op and the
1245    second a list of output tensors, indexed by `[replica_num][output_num]`.
1246  Raises:
1247    ValueError: If all replicas do not have equal numbers of input tensors.
1248    ValueError: If the number of inputs per replica does not match
1249      the number of formal parameters to `computation`.
1250    ValueError: If the static `inputs` dimensions don't match with the values
1251      given in `maximum_shapes`.
1252    ValueError: If the structure of inputs per replica does not match
1253      the structure of `maximum_shapes`.
1254  """
1255  del name
1256  inputs = [[]] if inputs is None else inputs
1257  xla_options = xla_options or XLAOptions()
1258
1259  metadata_kwargs = {}
1260  if device_assignment is not None:
1261    # Turn the Numpy array into a flattened list so we can pass it as an
1262    # operator attribute.
1263    metadata_kwargs = {
1264        "topology":
1265            device_assignment.topology.serialized(),
1266        "device_assignment":
1267            device_assignment.core_assignment.flatten().tolist()
1268    }
1269    metadata_kwargs["num_cores_per_replica"] = (
1270        device_assignment.num_cores_per_replica)
1271
1272  # This entry is used for enabling automatic outside compilation.
1273  metadata_kwargs["allow_soft_placement"] = config.get_soft_device_placement()
1274  if config.get_soft_device_placement():
1275    logging.info("Automatic outside compilation is enabled. "
1276                 "Ops without XLA kernels will be automatically "
1277                 "placed on CPU.")
1278
1279  if not isinstance(inputs, list):
1280    raise TypeError("tpu.replicate() inputs must be a list of lists/tuples, "
1281                    f"received {type(inputs)}")
1282  if any(not isinstance(inp, (list, tuple)) for inp in inputs):
1283    raise TypeError(
1284        "tpu.replicate() inputs must be a list of lists/tuples, "
1285        f"received types: {[type(inp) for inp in inputs]}")
1286
1287  num_replicas = len(inputs)
1288
1289  # No replicas? Nothing to do.
1290  if num_replicas == 0:
1291    return []
1292
1293  # Checks all replicas have the same structure.
1294  for i in xrange(1, num_replicas):
1295    nest.assert_same_structure(inputs[0], inputs[i])
1296
1297  # Flatten inputs. This structure may contain None values, which will be
1298  # handled later.
1299  flat_inputs_with_nones = [
1300      nest.flatten(per_replica_input, expand_composites=True)
1301      for per_replica_input in inputs
1302  ]
1303  # Mask parallel to one replica's inputs with True for tensors coming from
1304  # composites.
1305  is_composite = nest.flatten(nest.map_structure(
1306      lambda x: _flatten_and_filter_composite(x, False, True), inputs[0]))
1307
1308  # Converts inputs to Tensors, replacing Nones with a placeholder 0 since
1309  # tpu_ops.tpu_replicated_input() can't handle non-Tensor values.
1310  flat_inputs = []
1311  for inp in flat_inputs_with_nones:
1312    flat_inputs.append([
1313        constant_op.constant(0) if x is None else ops.convert_to_tensor(x)
1314        for x in inp
1315    ])
1316
1317  # Verifies that all replicas have matching numbers and types of inputs
1318  flat_input_types = [x.dtype for x in flat_inputs[0]]
1319  input_arity = len(inputs[0])
1320  flat_input_arity = len(flat_input_types)
1321  for i in range(num_replicas):
1322    if len(inputs[i]) != input_arity:
1323      raise ValueError("Replicas must have the same number of inputs. "
1324                       "Replica 0 had {} inputs, replica {} had {} "
1325                       "inputs.".format(input_arity, i, len(inputs[i])))
1326
1327    types = [x.dtype for x in flat_inputs[i]]
1328    if types != flat_input_types:
1329      raise ValueError("Replicas must have matching input types. Replica 0 had "
1330                       "input types {}, replica {} had input types {}".format(
1331                           flat_input_types, i, types))
1332
1333  arg_error = xla.check_function_argument_count(
1334      computation, input_arity, infeed_queue)
1335  if arg_error is not None:
1336    if infeed_queue is None:
1337      raise TypeError(
1338          "Supplied computation cannot be called with the specified inputs. "
1339          f"You specified {input_arity} inputs: {[i.name for i in inputs[0]]}, "
1340          f"but the computation needs{arg_error}")
1341    else:
1342      raise TypeError(
1343          "Supplied computation cannot be called with the specified inputs. "
1344          f"You specified {input_arity} inputs: {[i.name for i in inputs[0]]} ",
1345          f"and {infeed_queue.number_of_tuple_elements} additional inputs "
1346          f"from infeed, but the computation needs {arg_error}")
1347
1348  dynamic_shape_inputs = False
1349  if maximum_shapes:
1350    if infeed_queue:
1351      raise ValueError(
1352          "Dynamic input shapes are not supported with infeed queues")
1353
1354    # Make sure maximum_shapes has the same structure as inputs.
1355    nest.assert_same_structure(inputs[0], maximum_shapes, check_types=False)
1356
1357    # Flatten padded shapes:
1358    # For composite tensor components, we don't want to pad them. For each
1359    # entry of maximum_shapes that corresponds to a composite tensor, replace it
1360    # by a tuple of Nones of the same length as the number of components of the
1361    # composite tensor. When we flatten a second time, this makes
1362    # flat_maximum_shapes have the same length as flat_inputs[i]. We can then
1363    # avoid padding these tensors. The assumption is that they will be used by
1364    # outside compilation or that the components are statically shaped and will
1365    # be used by tpu compatible ops.
1366    flat_maximum_shapes = nest.flatten(
1367        [_flatten_and_filter_composite(x, y)
1368         for x, y in zip(nest.flatten(inputs[0]),
1369                         nest.flatten(maximum_shapes))])
1370    flat_maximum_shapes = [
1371        tensor_shape.TensorShape(s) if s is not None else None
1372        for s in flat_maximum_shapes
1373    ]
1374    nest.assert_same_structure(flat_inputs[0], flat_maximum_shapes,
1375                               check_types=False)
1376
1377    unpadded_inputs = flat_inputs
1378    flat_inputs, padding_maps = _pad_all_input(unpadded_inputs,
1379                                               flat_maximum_shapes,
1380                                               padding_spec)
1381    if padding_maps:
1382      dynamic_shape_inputs = True
1383      logging.info("TPU has inputs with dynamic shapes: %s", unpadded_inputs[0])
1384
1385  metadata_kwargs["step_marker_location"] = getattr(
1386      computation, "step_marker_location", "STEP_MARK_AT_ENTRY")
1387  metadata_kwargs["use_spmd_for_xla_partitioning"] = \
1388      xla_options.use_spmd_for_xla_partitioning
1389
1390  graph = ops.get_default_graph()
1391
1392  # Fan-in: Builds a TPUReplicatedInput node for each input.
1393  flat_replicated_inputs = []
1394  for i in range(0, len(flat_inputs[0])):
1395    replicas = [flat_inputs[replica][i] for replica in xrange(num_replicas)]
1396    flat_replicated_inputs.append(
1397        tpu_ops.tpu_replicated_input(
1398            replicas, name="input{}".format(i), index=i))
1399  if isinstance(graph, func_graph.FuncGraph):
1400    # When we are in Tensorflow 2.0 function, 'graph' will be a FuncGraph
1401    # object. If both outside graph and this function have a TPU cluster,
1402    # they will have the same cluster name and it will cause problems (because
1403    # we lower functional ops in Tensorflow 2.0). Append function name to
1404    # 'cluster_name' to avoid cluster name collision.
1405    cluster_name = graph.unique_name("cluster_" + graph.name)
1406  else:
1407    cluster_name = graph.unique_name("cluster")
1408  pivot = control_flow_ops.no_op(name=cluster_name + "/pivot")
1409  pivot._set_attr(_PIVOT_FOR_CLUSTER,  # pylint: disable=protected-access
1410                  attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name)))
1411  context = TPUReplicateContext(
1412      name=cluster_name, num_replicas=num_replicas, pivot=pivot)
1413  try:
1414    context.Enter()
1415
1416    metadata = tpu_ops.tpu_replicate_metadata(
1417        num_replicas=num_replicas, use_tpu=use_tpu, **metadata_kwargs)
1418
1419    with tpu_function.tpu_shard_context(
1420        num_replicas), ops.control_dependencies([metadata]):
1421
1422      if dynamic_shape_inputs and xla_options.enable_xla_dynamic_padder:
1423        for padding_map in padding_maps:
1424          input_shape = flat_replicated_inputs[padding_map.arg_index].shape
1425          flat_replicated_inputs[
1426              padding_map.arg_index] = tf2xla.set_dynamic_dimension_size(
1427                  flat_replicated_inputs[padding_map.arg_index],
1428                  padding_map.shape_index,
1429                  flat_replicated_inputs[padding_map.padding_arg_index])
1430          flat_replicated_inputs[padding_map.arg_index].set_shape(input_shape)
1431
1432      # Add identity ops so even unused inputs are "consumed" by the
1433      # computation. This is to avoid orphaned TPUReplicatedInput nodes.
1434      # TODO(phawkins): consider instead pruning unused TPUReplicatedInput
1435      # and eliding trivial TPUReplicatedInput/TPUReplicatedOutput pairs.
1436      flat_replicated_inputs = [
1437          array_ops.identity(x, name="replicated_input_{}".format(i))
1438          for i, x in enumerate(flat_replicated_inputs)
1439      ]
1440      for i, composite in zip(flat_replicated_inputs, is_composite):
1441        # pylint: disable=protected-access
1442        # Add an attribute to the identity node so that they could be removed in
1443        # encapsulate TPU computation pass if unused. However we don't remove
1444        # inputs when dynamic padding is enabled.
1445        # TODO(rxsang): Use other ways except argument index in padding_map so
1446        # outside compilation can work with dynamic padding correctly.
1447        if not dynamic_shape_inputs or composite:
1448          i.op._set_attr("_tpu_input_identity",
1449                         attr_value_pb2.AttrValue(b=True))
1450        # pylint: enable=protected-access
1451
1452      # Clobber replicated placeholders with Nones.
1453      computation_inputs = [
1454          None if inp is None else replicated for replicated, inp in zip(
1455              flat_replicated_inputs, flat_inputs_with_nones[0])
1456      ]
1457
1458      # Unflatten the computation inputs to match original input structure.
1459      computation_inputs = nest.pack_sequence_as(
1460          structure=inputs[0],
1461          flat_sequence=computation_inputs[:flat_input_arity],
1462          expand_composites=True)
1463
1464      # If there is an infeed queue, adds the dequeued values to the
1465      # computation's inputs.
1466      if infeed_queue is not None:
1467        infeed_queue.set_number_of_shards(num_replicas)
1468        for t in infeed_queue.generate_dequeue_op():
1469          computation_inputs.append(t)
1470
1471      # Only resource variables work inside a TPU computation, so turn on
1472      # resource variables for the computation.
1473      # TODO(phawkins): consider removing this code. It will
1474      # be less confusing to clients if they knowingly choose to use resource
1475      # variables.
1476      # Partitioned variables is not supported (b/112311320).
1477      vscope = variable_scope.get_variable_scope()
1478      saved_use_resource = vscope.use_resource
1479      saved_custom_getter = vscope.custom_getter
1480
1481      def custom_getter(getter, name, *args, **kwargs):
1482        """Variables on TPU have a few restrictions."""
1483        partitioner = kwargs.get("partitioner", None)
1484        if partitioner is not None:
1485          kwargs["partitioner"] = None
1486          logging.warning(
1487              "Partitioned variables are not supported on TPU. Got "
1488              "`partitioner` that is %s for variable %s. "
1489              "Setting `partitioner` to `None`.", partitioner, name)
1490        if saved_custom_getter is None:
1491          return getter(name, *args, **kwargs)
1492        else:
1493          return saved_custom_getter(getter, name, *args, **kwargs)
1494
1495      vscope.set_use_resource(True)
1496      vscope.set_custom_getter(custom_getter)
1497
1498      outputs = computation(*computation_inputs)
1499
1500      vscope.set_use_resource(saved_use_resource)
1501      vscope.set_custom_getter(saved_custom_getter)
1502
1503    outputs_is_flat = xla.is_flat(outputs)
1504    if outputs_is_flat:
1505      output_tensors, control_deps, pack_template = _postprocess_flat_outputs(
1506          outputs)
1507    else:
1508      output_tensors, control_deps, pack_template = (
1509          _postprocess_non_flat_outputs(outputs))
1510
1511    # tensor_tracer imports tpu.py. Local import to tensor_tracer to avoid
1512    # import-cycle
1513    if typing.TYPE_CHECKING:
1514      tensor_tracer = Any
1515    else:
1516      # pylint: disable=g-import-not-at-top
1517      from tensorflow.python.tpu import tensor_tracer
1518      # pylint: enable=g-import-not-at-top
1519    if tensor_tracer.TensorTracer.is_enabled():
1520      tt = tensor_tracer.TensorTracer()
1521      output_tensors = tt.trace_tpu(ops.get_default_graph(),
1522                                    output_tensors, control_deps,
1523                                    num_replicas)
1524
1525    context.ExitResult(output_tensors)
1526  finally:
1527    context.report_unsupported_operations()
1528    context.Exit()
1529    host_compute_core = context.HostComputeCore()
1530
1531  if host_compute_core:
1532    attr_value = attr_value_pb2.AttrValue()
1533    attr_value.list.s.extend(compat.as_bytes(x) for x in host_compute_core)
1534    metadata._set_attr("host_compute_core", attr_value)  # pylint: disable=protected-access
1535
1536  with ops.control_dependencies([metadata]):
1537    if use_tpu:
1538      compile_status = tpu_ops.tpu_compilation_result()
1539      op = compile_status.op
1540      attr_value = attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name))
1541      op._set_attr(_TPU_COMPILATION_STATUS_ATTR, attr_value)  # pylint: disable=protected-access
1542    else:
1543      compile_status = control_flow_ops.no_op(name="compilation_status")
1544
1545  if not output_tensors:
1546    # Returns a list of NoOps dependent on the replication Op, indexed by
1547    # [replica_num].
1548    return [
1549        compile_status,
1550        [
1551            control_flow_ops.group(control_deps, name="shard_%d" % i)
1552            for i in range(num_replicas)
1553        ]
1554    ]
1555
1556  # Fan-out: Builds a TPUReplicatedOutput node for each output.
1557  replicated_outputs = [[] for i in range(num_replicas)]
1558  for i, t in enumerate(output_tensors):
1559
1560    # None values returned by the computation can't be sent to
1561    # tpu_ops.tpu_replicated_output(), we handle them specially here. We can
1562    # avoid the placeholder 0 routine required on the inputs since outputs are
1563    # replicated per-tensor, not per-replica, so we can skip replication.
1564    if t is None:
1565      for replica in range(num_replicas):
1566        replicated_outputs[replica].append(None)
1567      continue
1568
1569    # Fan-out: Builds a TPUReplicatedOutput node for each output.
1570    ys = tpu_ops.tpu_replicated_output(
1571        t, num_replicas, name="output{}".format(i))
1572
1573    # Wraps the outputs in identity operators so the names of any possible
1574    # `fetch` nodes are preserved by the replication rewrite.
1575    with ops.control_dependencies(control_deps):
1576      for replica in range(num_replicas):
1577        replicated_outputs[replica].append(
1578            array_ops.identity(
1579                ys[replica], name="output_%d_shard_%d" % (i, replica)))
1580
1581  replicated_outputs = [
1582      nest.pack_sequence_as(pack_template, replica_outs, expand_composites=True)
1583      for replica_outs in replicated_outputs
1584  ]
1585
1586  return [compile_status, replicated_outputs]
1587
1588
1589def _postprocess_flat_outputs(
1590    outputs: Any
1591) -> Tuple[List[Optional[core_types.Tensor]], List[ops.Operation], List[Any]]:
1592  """Validates non-flat outputs, add backs device assignments and other attrs.
1593
1594  Args:
1595    outputs: Output from `computation` inside `tpu.rewrite`.
1596
1597  Returns:
1598    - Tensors extracted from outputs.
1599    - Operations extracted from outputs.
1600    - A pack template for use with nest.pack_sequence_as to pack the tensors.
1601  """
1602  # Following code segment is to preserve legacy behavior. Previously we only
1603  # supported flat outputs and thus for consistency it was nice to convert even
1604  # single element into a tuple. But now that we support arbitrary output
1605  # structure, this is no longer necessary.
1606  # TODO(b/121383831): Migrate all legacy use cases and delete this special
1607  # case.
1608  # If the computation returns `None`, make it an empty tuple.
1609  if outputs is None:
1610    outputs = tuple()
1611
1612  # For legacy / backwards compatibility reasons we return a list for "flat"
1613  # output values (even if the user's flat return value was a different type or
1614  # even just a scalar value) so use nest.flatten to compute a flat list pack
1615  # template.
1616  pack_template = nest.flatten(outputs, expand_composites=False)
1617
1618  # Even though outputs is already "flat", we flatten any composites so their
1619  # component tensors can be tagged and replicated. The pack_template will be
1620  # used by the caller to repack the composite tensors.
1621  outputs = nest.flatten(outputs, expand_composites=True)
1622
1623  # Append `no_op` here so that fetching any return value of this function
1624  # will trigger TPUExecute node.
1625  outputs += (control_flow_ops.no_op(),)
1626
1627  maybe_convert = lambda x: None if x is None else ops.convert_to_tensor(x)
1628  try:
1629    with ops.device(core(0)):
1630      outputs = [
1631          o if isinstance(o, ops.Operation) else maybe_convert(o)
1632          for o in outputs
1633      ]
1634  except Exception as e:
1635    raise ValueError(
1636        "TPU function return values must all either be Operations or "
1637        f"convertible to Tensors. Got error: {e}")
1638
1639  # Separates the returned Operations and Tensors.
1640  output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
1641  output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)]
1642
1643  if outputs != output_tensors + output_operations:
1644    raise ValueError(
1645        "TPU functions must return zero-or more Tensor values followed by "
1646        "zero or more Operations.")
1647
1648  # Trim operations off the end of the pack template. output_operations has 1
1649  # extra element due to the no-op that is added.
1650  if len(output_operations) > 1:
1651    pack_template = pack_template[:1 - len(output_operations)]
1652
1653  # Wraps outputs in Identity ops. Otherwise a replicated input copied
1654  # straight to an output would bypass the replicate(). This would be bad
1655  # because the TPUReplicatedInput/TPUReplicatedOutput operator would not
1656  # be rewritten away, leading to a runtime error.
1657  # TODO(phawkins): extend the rewrite to elide these nodes instead.
1658  new_output_tensors = []
1659  for t in output_tensors:
1660    if t is None:
1661      new_output_tensors.append(None)
1662    with ops.device(t.device if t.device else core(0)):
1663      o = array_ops.identity(t)
1664      # pylint: disable=protected-access
1665      o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True))
1666      # pylint: enable=protected-access
1667      new_output_tensors.append(o)
1668  return new_output_tensors, output_operations, pack_template
1669
1670
1671def _postprocess_non_flat_outputs(
1672    outputs: Any
1673) -> Tuple[List[Optional[core_types.Tensor]], List[ops.Operation], List[Any]]:
1674  """Validates non-flat outputs, add backs device assignments and other attrs.
1675
1676  Args:
1677    outputs: Output from `computation` inside `tpu.rewrite`.
1678
1679  Returns:
1680    - Tensors extracted from outputs.
1681    - An empty Operations list because Operations are not allowed in non-flat
1682      outputs.
1683    - A pack template for use with nest.pack_sequence_as to pack the tensors.
1684  """
1685
1686  # Flatten output items.
1687  flat_outputs = nest.flatten(outputs, expand_composites=True)
1688
1689  # Convert all non-None non-Operation outputs to Tensors.
1690  for i, o in enumerate(flat_outputs):
1691    if o is None:
1692      flat_outputs[i] = None
1693      continue
1694
1695    if isinstance(o, ops.Operation):
1696      raise ValueError(
1697          "tpu.rewrite does not support Operation as return value in non-flat "
1698          "output structure. You can set returned Operations as control "
1699          "dependencies of returned Tensors so Operations are triggered when "
1700          f'Tensors are evaluated. Operation found: "{o.name}"')
1701
1702    try:
1703      o = ops.convert_to_tensor(o)
1704    except Exception as e:
1705      raise ValueError(
1706          "TPU function return values must all either be Operations or "
1707          f'convertible to Tensors. Got error: "{e}"')
1708
1709    # Wraps outputs in Identity ops. Otherwise a replicated input copied
1710    # straight to an output would bypass the replicate(). This would be bad
1711    # because the TPUReplicatedInput/TPUReplicatedOutput operator would not
1712    # be rewritten away, leading to a runtime error.
1713    # TODO(phawkins): extend the rewrite to elide these nodes instead.
1714    with ops.device(o.device if o.device else core(0)):
1715      o = array_ops.identity(o)
1716      # pylint: disable=protected-access
1717      o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True))
1718      # pylint: enable=protected-access
1719      flat_outputs[i] = array_ops.identity(o)
1720
1721  # All flat_outputs are Tensors, and no Operations.
1722  return flat_outputs, [], outputs
1723
1724
1725def split_compile_and_shard(
1726    computation: Callable[..., Any],
1727    inputs: Optional[List[List[Optional[core_types.Tensor]]]] = None,
1728    num_shards: int = 1,
1729    input_shard_axes: Optional[List[int]] = None,
1730    outputs_from_all_shards: Union[bool, List[bool]] = True,
1731    output_shard_axes: Optional[List[int]] = None,
1732    infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
1733    device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
1734    name: Optional[Text] = None,
1735    xla_options: Optional[XLAOptions] = None,
1736    ) -> Tuple[ops.Operation, List[core_types.Tensor]]:
1737  """Shards `computation` for parallel execution.
1738
1739  `inputs` must be a list of Tensors or None (equivalent to an empty list), each
1740  of which has a corresponding split axis (from `input_shard_axes`). Each input
1741  is split into `num_shards` pieces along the corresponding axis, and
1742  computation is applied to each shard in parallel.
1743
1744  Tensors are broadcast to all shards if they are lexically captured by
1745  `computation`. e.g.,
1746
1747  x = tf.constant(7)
1748  def computation():
1749    return x + 3
1750  ... = shard(computation, ...)
1751
1752  If `outputs_from_all_shards` is true, the outputs from all shards of
1753  `computation` are concatenated back together along their `output_shard_axes`.
1754  Otherwise, each output is taken from an arbitrary shard.
1755
1756  Inputs and outputs of the computation must be at least rank-1 Tensors.
1757
1758  Args:
1759    computation: A Python function that builds a computation to apply to each
1760      shard of the input.
1761    inputs: A list of input tensors or None (equivalent to an empty list). Each
1762      input tensor has a corresponding shard axes, given by `input_shard_axes`,
1763      which must have size divisible by `num_shards`.
1764    num_shards: The number of shards.
1765    input_shard_axes: A list of dimensions along which to shard `inputs`, or
1766      `None`. `None` means "shard all inputs along dimension 0". If not `None`,
1767      there must be one dimension per input.
1768    outputs_from_all_shards: Boolean or list of boolean. For each output, if
1769      `True`, outputs from all shards are concatenated along the corresponding
1770      `output_shard_axes` entry. Otherwise, each output is taken
1771      from an arbitrary shard. If the argument is a boolean, the argument's
1772      value is used for each output.
1773    output_shard_axes: A list of dimensions along which to concatenate the
1774      outputs of `computation`, or `None`. `None` means "concatenate all outputs
1775      along dimension 0". If not `None`, there must be one dimension per output.
1776      Ignored if `outputs_from_all_shards` is False.
1777    infeed_queue: If not `None`, the `InfeedQueue` to use to augment the inputs
1778      of `computation`.
1779    device_assignment: If not `None`, a `DeviceAssignment` describing the
1780      mapping between logical cores in the computation with physical cores in
1781      the TPU topology. Uses a default device assignment if `None`. The
1782      `DeviceAssignment` may be omitted if each shard of the computation uses
1783      only one core, and there is either only one shard, or the number of shards
1784      is equal to the number of cores in the TPU system.
1785    name: (Deprecated) Does nothing.
1786    xla_options: An instance of `tpu.XLAOptions` which indicates the options
1787      passed to XLA compiler. Use `None` for default options.
1788  Returns:
1789    A tuple of (compile op, [output tensors]).
1790  Raises:
1791    ValueError: If num_shards <= 0
1792    ValueError: If len(input_shard_axes) != len(inputs)
1793    ValueError: If len(output_shard_axes) != len(outputs from `computation`)
1794  """
1795  # TODO(phawkins): consider adding support for broadcasting Tensors passed as
1796  # inputs.
1797
1798  if num_shards <= 0:
1799    raise ValueError(
1800        f"num_shards must be a positive integer. Received {num_shards}")
1801
1802  inputs = [] if inputs is None else inputs
1803  if not isinstance(inputs, list):
1804    raise TypeError("tpu.shard()'s inputs must be a list of Tensors or None. "
1805                    f"Received {type(inputs)}")
1806
1807  # Converts inputs to Tensors.
1808  inputs = [ops.convert_to_tensor(x) for x in inputs]
1809
1810  if input_shard_axes is None:
1811    input_shard_axes = [0] * len(inputs)
1812  if len(inputs) != len(input_shard_axes):
1813    raise ValueError("Length of input_shard_axes must be equal to the number "
1814                     f"of inputs. Received {len(inputs)} inputs and "
1815                     f"{len(input_shard_axes)} input_shard_axes.")
1816
1817  if inputs:
1818    # Splits the `inputs` along the corresponding `input_shard_axes`, giving
1819    # lists with layout [input][shard]
1820    split_inputs = [
1821        array_ops.split(x, num_shards, axis=axis)
1822        for (axis, x) in zip(input_shard_axes, inputs)]
1823
1824    # Transposes the input lists to have layout [shard][input]
1825    transposed_inputs = [list(i) for i in zip(*split_inputs)]
1826  else:
1827    transposed_inputs = [[]] * num_shards
1828
1829  compile_op, outputs = split_compile_and_replicate(
1830      computation,
1831      transposed_inputs,
1832      infeed_queue=infeed_queue,
1833      device_assignment=device_assignment,
1834      name=name,
1835      xla_options=xla_options)
1836
1837  # There must be at least one shard since num_shards > 0.
1838  # TODO(b/36647078) remove disable when pylint bug is fixed.
1839  # pylint: disable=indexing-exception
1840  if isinstance(outputs[0], ops.Operation):
1841    # pylint: enable=indexing-exception
1842    # There were no outputs from the computation and replicate returned a list
1843    # of NoOps with control dependencies on the computation. Return the first
1844    # one so it can be used as a control dependency or fetch node.
1845    # TODO(b/36647078) remove disable when pylint bug is fixed.
1846    # pylint: disable=indexing-exception
1847    return compile_op, [outputs[0]]
1848    # pylint: enable=indexing-exception
1849
1850  # TODO(b/36647078) remove disable when pylint bug is fixed.
1851  # pylint: disable=indexing-exception
1852  num_outputs = len(outputs[0])
1853  # pylint: enable=indexing-exception
1854
1855  if output_shard_axes is None:
1856    output_shard_axes = [0] * num_outputs
1857  if num_outputs != len(output_shard_axes):
1858    raise ValueError("Length of output_shard_axes must be equal to the number "
1859                     f"of outputs. Received {num_outputs} outputs "
1860                     f"and {len(output_shard_axes)} output_shard_axes.")
1861
1862  if isinstance(outputs_from_all_shards, bool):
1863    outputs_from_all_shards = [outputs_from_all_shards] * num_outputs
1864
1865  if num_outputs != len(outputs_from_all_shards):
1866    raise ValueError(
1867        "Length of outputs_from_all_shards must be equal to the number of "
1868        f"outputs. Received {num_outputs} outputs  and "
1869        f"{len(outputs_from_all_shards)} outputs_from_all_shards.")
1870
1871  results = []
1872  for (axis, all_shards, x) in zip(output_shard_axes, outputs_from_all_shards,
1873                                   zip(*outputs)):
1874    if all_shards:
1875      # Concatenate all of the outputs together (use stack for scalars).
1876      shape = x[0].shape
1877      is_scalar = shape is not None and (shape.ndims == 0)
1878      results.append((array_ops.stack(list(x)) if is_scalar
1879                      else array_ops.concat(list(x), axis=axis)))
1880    else:
1881      # TODO(phawkins): use a smarter policy, e.g., round-robin across shards.
1882      results.append(x[0])
1883
1884  return compile_op, results
1885
1886
1887@tf_export(v1=["tpu.shard"])
1888def shard(
1889    computation: Callable[..., Any],
1890    inputs: Optional[List[core_types.Tensor]] = None,
1891    num_shards: int = 1,
1892    input_shard_axes: Optional[List[int]] = None,
1893    outputs_from_all_shards: Union[bool, List[bool]] = True,
1894    output_shard_axes: Optional[List[int]] = None,
1895    infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
1896    device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
1897    name: Optional[Text] = None,
1898    xla_options: Optional[XLAOptions] = None) -> List[core_types.Tensor]:
1899  """Shards `computation` for parallel execution.
1900
1901  `inputs` must be a list of Tensors or None (equivalent to an empty list), each
1902  of which has a corresponding split axis (from `input_shard_axes`). Each input
1903  is split into `num_shards` pieces along the corresponding axis, and
1904  computation is applied to each shard in parallel.
1905
1906  Tensors are broadcast to all shards if they are lexically captured by
1907  `computation`. e.g.,
1908
1909  x = tf.constant(7)
1910  def computation():
1911    return x + 3
1912  ... = shard(computation, ...)
1913
1914  TODO(phawkins): consider adding support for broadcasting Tensors passed
1915  as inputs.
1916
1917  If `outputs_from_all_shards` is true, the outputs from all shards of
1918  `computation` are concatenated back together along their `output_shard_axes`.
1919  Otherwise, each output is taken from an arbitrary shard.
1920
1921  Inputs and outputs of the computation must be at least rank-1 Tensors.
1922
1923  Args:
1924    computation: A Python function that builds a computation to apply to each
1925      shard of the input.
1926    inputs: A list of input tensors or None (equivalent to an empty list). Each
1927      input tensor has a corresponding shard axes, given by `input_shard_axes`,
1928      which must have size divisible by `num_shards`.
1929    num_shards: The number of shards.
1930    input_shard_axes: A list of dimensions along which to shard `inputs`, or
1931      `None`. `None` means "shard all inputs along dimension 0". If not `None`,
1932      there must be one dimension per input.
1933    outputs_from_all_shards: Boolean or list of boolean. For each output, if
1934      `True`, outputs from all shards are concatenated along the corresponding
1935      `output_shard_axes` entry. Otherwise, each output is taken
1936      from an arbitrary shard. If the argument is a boolean, the argument's
1937      value is used for each output.
1938    output_shard_axes: A list of dimensions along which to concatenate the
1939      outputs of `computation`, or `None`. `None` means "concatenate all outputs
1940      along dimension 0". If not `None`, there must be one dimension per output.
1941      Ignored if `outputs_from_all_shards` is False.
1942    infeed_queue: If not `None`, the `InfeedQueue` to use to augment the inputs
1943      of `computation`.
1944    device_assignment: If not `None`, a `DeviceAssignment` describing the
1945      mapping between logical cores in the computation with physical cores in
1946      the TPU topology. Uses a default device assignment if `None`. The
1947      `DeviceAssignment` may be omitted if each shard of the computation uses
1948      only one core, and there is either only one shard, or the number of shards
1949      is equal to the number of cores in the TPU system.
1950    name: (Deprecated) Does nothing.
1951    xla_options: An instance of `tpu.XLAOptions` which indicates the options
1952      passed to XLA compiler. Use `None` for default options.
1953  Returns:
1954    A list of output tensors.
1955  Raises:
1956    ValueError: If num_shards <= 0
1957    ValueError: If len(input_shard_axes) != len(inputs)
1958    ValueError: If len(output_shard_axes) != len(outputs from `computation`)
1959  """
1960  return split_compile_and_shard(
1961      computation,
1962      inputs=inputs,
1963      num_shards=num_shards,
1964      input_shard_axes=input_shard_axes,
1965      outputs_from_all_shards=outputs_from_all_shards,
1966      output_shard_axes=output_shard_axes,
1967      infeed_queue=infeed_queue,
1968      device_assignment=device_assignment,
1969      name=name,
1970      xla_options=xla_options)[1]
1971
1972
1973@tf_export(v1=["tpu.batch_parallel"])
1974def batch_parallel(
1975    computation: Callable[..., Any],
1976    inputs: Optional[List[List[Optional[core_types.Tensor]]]] = None,
1977    num_shards: int = 1,
1978    infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
1979    device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
1980    name: Optional[Text] = None,
1981    xla_options: Optional[XLAOptions] = None):
1982  """Shards `computation` along the batch dimension for parallel execution.
1983
1984  Convenience wrapper around shard().
1985
1986  `inputs` must be a list of Tensors or None (equivalent to an empty list).
1987  Each input is split into `num_shards` pieces along the 0-th dimension, and
1988  computation is applied to each shard in parallel.
1989
1990  Tensors are broadcast to all shards if they are lexically captured by
1991  `computation`. e.g.,
1992
1993  x = tf.constant(7)
1994  def computation():
1995    return x + 3
1996  ... = shard(computation, ...)
1997
1998  The outputs from all shards are concatenated back together along their 0-th
1999  dimension.
2000
2001  Inputs and outputs of the computation must be at least rank-1 Tensors.
2002
2003  Args:
2004    computation: A Python function that builds a computation to apply to each
2005      shard of the input.
2006    inputs: A list of input tensors or None (equivalent to an empty list). The
2007      0-th dimension of each Tensor must have size divisible by `num_shards`.
2008    num_shards: The number of shards.
2009    infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
2010      of arguments as inputs to `computation`.
2011    device_assignment: If not `None`, a `DeviceAssignment` describing the
2012      mapping between logical cores in the computation with physical cores in
2013      the TPU topology. Uses a default device assignment if `None`. The
2014      `DeviceAssignment` may be omitted if each shard of the computation uses
2015      only one core, and there is either only one shard, or the number of shards
2016      is equal to the number of cores in the TPU system.
2017    name: (Deprecated) Does nothing.
2018    xla_options: An instance of `tpu.XLAOptions` which indicates the options
2019      passed to XLA compiler. Use `None` for default options.
2020  Returns:
2021    A list of output tensors.
2022  Raises:
2023    ValueError: If `num_shards <= 0`
2024  """
2025  return shard(
2026      computation,
2027      inputs,
2028      num_shards=num_shards,
2029      infeed_queue=infeed_queue,
2030      device_assignment=device_assignment,
2031      name=name,
2032      xla_options=xla_options)
2033
2034
2035@tf_export(v1=["tpu.rewrite"])
2036def rewrite(
2037    computation: Callable[..., Any],
2038    inputs: Optional[List[List[Optional[core_types.Tensor]]]] = None,
2039    infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
2040    device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
2041    name: Optional[Text] = None,
2042    xla_options: Optional[XLAOptions] = None) -> Any:
2043  """Rewrites `computation` for execution on a TPU system.
2044
2045  Args:
2046    computation: A Python function that builds a computation to apply to the
2047      input. If the function takes n inputs, 'inputs' should be a list of n
2048      tensors.
2049
2050      `computation` may return a list of operations and tensors. Tensors must
2051      come before operations in the returned list.  The return value of
2052      `rewrite` is a list of tensors corresponding to the tensors from the
2053      output of `computation`.
2054
2055      All `Operation`s constructed during `computation` will be executed when
2056      evaluating any of the returned output tensors, not just the ones returned.
2057    inputs: A list of input tensors or `None` (equivalent to an empty list).
2058      Each input can be a nested structure containing values that are
2059      convertible to tensors. Note that passing an N-dimension list of
2060      compatible values will result in a N-dimension list of scalar tensors
2061      rather than a single Rank-N tensors. If you need different behavior,
2062      convert part of inputs to tensors with `tf.convert_to_tensor`.
2063    infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
2064      of arguments as inputs to `computation`.
2065    device_assignment: if not `None`, a `DeviceAssignment` describing the
2066      mapping between logical cores in the computation with physical cores in
2067      the TPU topology. May be omitted for a single-core computation, in which
2068      case the core attached to task 0, TPU device 0 is used.
2069    name: (Deprecated) Does nothing.
2070    xla_options: An instance of `tpu.XLAOptions` which indicates the options
2071      passed to XLA compiler. Use `None` for default options.
2072  Returns:
2073    Same data structure as if computation(*inputs) is called directly with some
2074    exceptions for correctness. Exceptions include:
2075      1) None output: a NoOp would be returned which control-depends on
2076         computation.
2077      2) Single value output: A tuple containing the value would be returned.
2078      3) Operation-only outputs: a NoOp would be returned which
2079         control-depends on computation.
2080      TODO(b/121383831): Investigate into removing these special cases.
2081  """
2082  # TODO(b/36647078) remove disable when pylint bug is fixed.
2083  # pylint: disable=indexing-exception
2084  return replicate(
2085      computation,
2086      None if inputs is None else [inputs],
2087      infeed_queue=infeed_queue,
2088      device_assignment=device_assignment,
2089      name=name,
2090      xla_options=xla_options)[0]
2091  # pylint: enable=indexing-exception
2092
2093  # Operations that indicate some error in the user's inference graph.
2094
2095
2096_DENYLISTED_INFERENCE_OPS = set([
2097    "ReadVariableOp",
2098    "AssignVariableOp",
2099    "AssignAddVariableOp",
2100    "AssignSubVariableOp",
2101    "VarHandleOp",
2102    "Variable",
2103    "VariableV2",
2104])
2105
2106
2107def under_tpu_inference_context() -> bool:
2108  """Check if it is currently under `_TPUInferenceContext`."""
2109  graph = ops.get_default_graph()
2110  while graph:
2111    context = graph._get_control_flow_context()  # pylint: disable=protected-access
2112    while context:
2113      if isinstance(context, _TPUInferenceContext):
2114        return True
2115      context = context.outer_context
2116    if isinstance(graph, function._FuncGraph):  # pylint: disable=protected-access
2117      graph = graph._outer_graph  # pylint: disable=protected-access
2118    elif isinstance(graph, func_graph.FuncGraph):
2119      graph = graph.outer_graph
2120    else:
2121      return False
2122
2123
2124class _TPUInferenceContext(control_flow_ops.XLAControlFlowContext):
2125  """A `ControlFlowContext` for nodes inside a TPU inference computation.
2126
2127  The primary role of `_TPUInferenceContext` is to indicate the mode of
2128  operation and possibly sanity check operators inside a
2129  tpu.rewrite_for_inference() computation.
2130  """
2131
2132  def __init__(self, name: Text, check_ops: bool = True):
2133    super(_TPUInferenceContext, self).__init__()
2134    self._name = name
2135    self._check_ops = check_ops
2136
2137  def AddOp(self, op):
2138    self._AddOpInternal(op)
2139
2140  def _AddOpInternal(self, op):
2141    # pylint: disable=protected-access
2142    if self._check_ops and op.type in _DENYLISTED_INFERENCE_OPS:
2143      raise NotImplementedError(
2144          f"Operation of type {op.type} ({op.name}) is not supported on the "
2145          "TPU for inference. Execution will fail if this op is used in the "
2146          "graph. Make sure your variables are using variable_scope.")
2147    if self._outer_context:
2148      self._outer_context.AddInnerOp(op)
2149
2150  def AddValue(self, val):
2151    result = val
2152    if self._outer_context:
2153      result = self._outer_context.AddValue(val)
2154    return result
2155
2156  def AddInnerOp(self, op):
2157    self._AddOpInternal(op)
2158
2159  @property
2160  def grad_state(self):
2161    return None
2162
2163
2164def validate_inference_rewrite_for_variables(graph: ops.Graph):
2165  """Validates whether rewrite_for_inference() 'worked' for variables.
2166
2167     The rewrite_for_inference() method is supposed to append GuaranteeConstOps
2168     after ReadVariableOps, but this mechanism works only if you are using
2169     tf.compat.v1.get_variable() to create and access variables in your tpu
2170     computation. This validation method can be called immediately after calling
2171     tpu.rewrite_for_inference() to check whether GuaranteeConstOps where added
2172     to the graph.
2173
2174     Typical usages:
2175       tpu.validate_inference_rewrite_for_variables(
2176           tf.compat.v1.get_default_graph())
2177
2178       tpu.validate_inference_rewrite_for_variables(sess.graph)
2179
2180  Args:
2181    graph: The graph which needs to be validated.
2182  Raises:
2183    RuntimeError: if validation failed.
2184  """
2185  if not any(x.type == "GuaranteeConst" for x in graph.get_operations()):
2186    raise RuntimeError(
2187        "No GuaranteeConst ops found in the graph after running "
2188        "tpu.rewrite_for_inference(...). Please check that you are using "
2189        "tf.get_variable() to create and access variables in your tpu "
2190        "computation.")
2191
2192
2193def rewrite_for_inference(
2194    computation: Callable[..., Any],
2195    inputs: Optional[List[core_types.Tensor]] = None,
2196    infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
2197    device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
2198    name: Optional[Text] = None) -> List[core_types.Tensor]:
2199  """Rewrites `computation` for inference on a TPU system.
2200
2201     Other than 'rewriting' the computation to run on a TPU, if using variables
2202     in your computation, it moves the ReadVariableOps outside the TPU
2203     computation, and adds GuaranteeConst ops just after the ReadVariableOps.
2204     This mechanism works only if you are using tf.compat.v1.get_variable() to
2205     create and access variables in your tpu computation. You can validate
2206     whether this worked, by calling validate_inference_rewrite_for_variables()
2207     method immediately after this method to check whether GuaranteeConstOps
2208     where added to the graph.
2209
2210  Args:
2211    computation: A Python function that builds a computation to apply to the
2212      input. If the function takes n inputs, 'inputs' should be a list of n
2213      tensors. If the function returns m outputs, rewrite will return a list of
2214      m tensors.
2215    inputs: A list of input tensors or `None` (equivalent to an empty list).
2216    infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
2217      of arguments as inputs to `computation`.
2218    device_assignment: if not `None`, a `DeviceAssignment` describing the
2219      mapping between logical cores in the computation with physical cores in
2220      the TPU topology. May be omitted for a single-core computation, in which
2221      case the core attached to task 0, TPU device 0 is used.
2222    name: The name of the operator.
2223  Returns:
2224    A list of output tensors.
2225  """
2226
2227  def guarantee_const_getter(getter, name, *args, **kwargs):
2228    with ops.control_dependencies(None):
2229      return array_ops.guarantee_const(
2230          getter(name, *args, **kwargs), name=name + "/GuaranteeConst")
2231
2232  def wrapped_computation(*args, **kwargs):
2233    """Execute computation under `_TPUInferenceContext`."""
2234    context = _TPUInferenceContext(
2235        name=ops.get_default_graph().unique_name("rewrite_for_inference"))
2236    try:
2237      context.Enter()
2238
2239      vscope = variable_scope.get_variable_scope()
2240      prev_custom_getter = vscope.custom_getter
2241      prev_caching_device = vscope.caching_device
2242      vscope.set_custom_getter(guarantee_const_getter)
2243      vscope.set_caching_device(lambda op: op.device)
2244
2245      result = computation(*args, **kwargs)
2246
2247      vscope.set_custom_getter(prev_custom_getter)
2248      vscope.set_caching_device(prev_caching_device)
2249    finally:
2250      context.Exit()
2251    return result
2252
2253  # pylint: disable=undefined-variable
2254  return rewrite(
2255      wrapped_computation,
2256      inputs=inputs,
2257      infeed_queue=infeed_queue,
2258      device_assignment=device_assignment,
2259      name=name)
2260  # pylint: enable=undefined-variable
2261
2262
2263def prune_unconnected_ops_from_xla(prune_graph: ops.Graph):
2264  """Prunes unconnected ops as listed in _UNCONNECTED_OPS_TO_PRUNE.
2265
2266  Args:
2267    prune_graph: A tensorflow graph from which we wish to prune unconnected ops
2268      as listed in _UNCONNECTED_OPS_TO_PRUNE.  In general, these ops should have
2269      no inputs and no consumers. These can often be left behind due to graph
2270      construction rewiring (for instance TF-Hub). While they never execute,
2271      they will cause XLA compile to fail so we strip them from XLA compile by
2272      removing the tpu_replicate attribute.
2273  """
2274  # Scan over the top level graph and all function graphs.
2275  for graph in [prune_graph] + [
2276      f for f in prune_graph._functions.values()  # pylint: disable=protected-access
2277  ]:
2278    if not isinstance(graph, ops.Graph):
2279      continue
2280    for op in graph.get_operations():
2281      if op.type not in _UNCONNECTED_OPS_TO_PRUNE:
2282        continue
2283      outputs_consumed = False
2284      for output in op.outputs:
2285        if output.consumers():
2286          outputs_consumed = True
2287          break
2288      if not outputs_consumed:
2289        logging.info(
2290            "Pruning OP %s of type %s from XLA Compile due to "
2291            "it being disconnected.", op.name, op.type)
2292        op._clear_attr(_TPU_REPLICATE_ATTR)  # pylint: disable=protected-access
2293